Skip to main content

agentic_memory_mcp/protocol/
handler.rs

1//! Main request dispatcher — receives JSON-RPC messages, routes to handlers.
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde_json::Value;
7
8use crate::prompts::PromptRegistry;
9use crate::resources::ResourceRegistry;
10use crate::session::SessionManager;
11use crate::tools::ToolRegistry;
12use crate::types::*;
13
14use super::negotiation::NegotiatedCapabilities;
15use super::validator::validate_request;
16
17/// The main protocol handler that dispatches incoming JSON-RPC messages.
18pub struct ProtocolHandler {
19    session: Arc<Mutex<SessionManager>>,
20    capabilities: Arc<Mutex<NegotiatedCapabilities>>,
21}
22
23impl ProtocolHandler {
24    /// Create a new protocol handler with the given session manager.
25    pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
26        Self {
27            session,
28            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
29        }
30    }
31
32    /// Handle an incoming JSON-RPC message and optionally return a response.
33    pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
34        match msg {
35            JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
36            JsonRpcMessage::Notification(notif) => {
37                self.handle_notification(notif).await;
38                None
39            }
40            _ => {
41                // Responses and errors from the client are unexpected
42                tracing::warn!("Received unexpected message type from client");
43                None
44            }
45        }
46    }
47
48    async fn handle_request(&self, request: JsonRpcRequest) -> Value {
49        // Validate JSON-RPC structure
50        if let Err(e) = validate_request(&request) {
51            return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
52        }
53
54        let id = request.id.clone();
55        let result = self.dispatch_request(&request).await;
56
57        match result {
58            Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
59            Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
60        }
61    }
62
63    async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
64        match request.method.as_str() {
65            // Lifecycle
66            "initialize" => self.handle_initialize(request.params.clone()).await,
67            "shutdown" => self.handle_shutdown().await,
68
69            // Tools
70            "tools/list" => self.handle_tools_list().await,
71            "tools/call" => self.handle_tools_call(request.params.clone()).await,
72
73            // Resources
74            "resources/list" => self.handle_resources_list().await,
75            "resources/templates/list" => self.handle_resource_templates_list().await,
76            "resources/read" => self.handle_resources_read(request.params.clone()).await,
77            "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
78            "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
79
80            // Prompts
81            "prompts/list" => self.handle_prompts_list().await,
82            "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
83
84            // Ping
85            "ping" => Ok(Value::Object(serde_json::Map::new())),
86
87            _ => Err(McpError::MethodNotFound(request.method.clone())),
88        }
89    }
90
91    async fn handle_notification(&self, notification: JsonRpcNotification) {
92        match notification.method.as_str() {
93            "initialized" => {
94                let mut caps = self.capabilities.lock().await;
95                if let Err(e) = caps.mark_initialized() {
96                    tracing::error!("Failed to mark initialized: {e}");
97                }
98            }
99            "notifications/cancelled" | "$/cancelRequest" => {
100                tracing::info!("Received cancellation notification");
101            }
102            _ => {
103                tracing::debug!("Unknown notification: {}", notification.method);
104            }
105        }
106    }
107
108    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
109        let init_params: InitializeParams = params
110            .map(serde_json::from_value)
111            .transpose()
112            .map_err(|e| McpError::InvalidParams(e.to_string()))?
113            .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
114
115        let mut caps = self.capabilities.lock().await;
116        let result = caps.negotiate(init_params)?;
117
118        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
119    }
120
121    async fn handle_shutdown(&self) -> McpResult<Value> {
122        tracing::info!("Shutdown requested");
123        let mut session = self.session.lock().await;
124        session.save()?;
125        Ok(Value::Object(serde_json::Map::new()))
126    }
127
128    async fn handle_tools_list(&self) -> McpResult<Value> {
129        let result = ToolListResult {
130            tools: ToolRegistry::list_tools(),
131            next_cursor: None,
132        };
133        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
134    }
135
136    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
137        let call_params: ToolCallParams = params
138            .map(serde_json::from_value)
139            .transpose()
140            .map_err(|e| McpError::InvalidParams(e.to_string()))?
141            .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
142
143        let result =
144            ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
145
146        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
147    }
148
149    async fn handle_resources_list(&self) -> McpResult<Value> {
150        let result = ResourceListResult {
151            resources: ResourceRegistry::list_resources(),
152            next_cursor: None,
153        };
154        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
155    }
156
157    async fn handle_resource_templates_list(&self) -> McpResult<Value> {
158        let result = ResourceTemplateListResult {
159            resource_templates: ResourceRegistry::list_templates(),
160            next_cursor: None,
161        };
162        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
163    }
164
165    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
166        let read_params: ResourceReadParams = params
167            .map(serde_json::from_value)
168            .transpose()
169            .map_err(|e| McpError::InvalidParams(e.to_string()))?
170            .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
171
172        let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
173
174        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
175    }
176
177    async fn handle_prompts_list(&self) -> McpResult<Value> {
178        let result = PromptListResult {
179            prompts: PromptRegistry::list_prompts(),
180            next_cursor: None,
181        };
182        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
183    }
184
185    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
186        let get_params: PromptGetParams = params
187            .map(serde_json::from_value)
188            .transpose()
189            .map_err(|e| McpError::InvalidParams(e.to_string()))?
190            .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
191
192        let result =
193            PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
194
195        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
196    }
197}