Skip to main content

agentic_memory_mcp/protocol/
handler.rs

1//! Main request dispatcher — receives JSON-RPC messages, routes to handlers.
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde_json::Value;
7
8use crate::prompts::PromptRegistry;
9use crate::resources::ResourceRegistry;
10use crate::session::SessionManager;
11use crate::tools::ToolRegistry;
12use crate::types::*;
13
14use super::negotiation::NegotiatedCapabilities;
15use super::validator::validate_request;
16
17/// The main protocol handler that dispatches incoming JSON-RPC messages.
18pub struct ProtocolHandler {
19    session: Arc<Mutex<SessionManager>>,
20    capabilities: Arc<Mutex<NegotiatedCapabilities>>,
21}
22
23impl ProtocolHandler {
24    /// Create a new protocol handler with the given session manager.
25    pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
26        Self {
27            session,
28            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
29        }
30    }
31
32    /// Create a new protocol handler with a specific memory mode.
33    pub fn with_mode(session: Arc<Mutex<SessionManager>>, mode: MemoryMode) -> Self {
34        Self {
35            session,
36            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::with_mode(mode))),
37        }
38    }
39
40    /// Handle an incoming JSON-RPC message and optionally return a response.
41    pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
42        match msg {
43            JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
44            JsonRpcMessage::Notification(notif) => {
45                self.handle_notification(notif).await;
46                None
47            }
48            _ => {
49                // Responses and errors from the client are unexpected
50                tracing::warn!("Received unexpected message type from client");
51                None
52            }
53        }
54    }
55
56    async fn handle_request(&self, request: JsonRpcRequest) -> Value {
57        // Validate JSON-RPC structure
58        if let Err(e) = validate_request(&request) {
59            return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
60        }
61
62        let id = request.id.clone();
63        let result = self.dispatch_request(&request).await;
64
65        match result {
66            Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
67            Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
68        }
69    }
70
71    async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
72        match request.method.as_str() {
73            // Lifecycle
74            "initialize" => self.handle_initialize(request.params.clone()).await,
75            "shutdown" => self.handle_shutdown().await,
76
77            // Tools
78            "tools/list" => self.handle_tools_list().await,
79            "tools/call" => self.handle_tools_call(request.params.clone()).await,
80
81            // Resources
82            "resources/list" => self.handle_resources_list().await,
83            "resources/templates/list" => self.handle_resource_templates_list().await,
84            "resources/read" => self.handle_resources_read(request.params.clone()).await,
85            "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
86            "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
87
88            // Prompts
89            "prompts/list" => self.handle_prompts_list().await,
90            "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
91
92            // Ping
93            "ping" => Ok(Value::Object(serde_json::Map::new())),
94
95            _ => Err(McpError::MethodNotFound(request.method.clone())),
96        }
97    }
98
99    async fn handle_notification(&self, notification: JsonRpcNotification) {
100        match notification.method.as_str() {
101            "initialized" => {
102                let mut caps = self.capabilities.lock().await;
103                if let Err(e) = caps.mark_initialized() {
104                    tracing::error!("Failed to mark initialized: {e}");
105                }
106            }
107            "notifications/cancelled" | "$/cancelRequest" => {
108                tracing::info!("Received cancellation notification");
109            }
110            _ => {
111                tracing::debug!("Unknown notification: {}", notification.method);
112            }
113        }
114    }
115
116    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
117        let init_params: InitializeParams = params
118            .map(serde_json::from_value)
119            .transpose()
120            .map_err(|e| McpError::InvalidParams(e.to_string()))?
121            .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
122
123        let mut caps = self.capabilities.lock().await;
124        let result = caps.negotiate(init_params)?;
125
126        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
127    }
128
129    async fn handle_shutdown(&self) -> McpResult<Value> {
130        tracing::info!("Shutdown requested");
131        let mut session = self.session.lock().await;
132        session.save()?;
133        Ok(Value::Object(serde_json::Map::new()))
134    }
135
136    async fn handle_tools_list(&self) -> McpResult<Value> {
137        let result = ToolListResult {
138            tools: ToolRegistry::list_tools(),
139            next_cursor: None,
140        };
141        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
142    }
143
144    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
145        let call_params: ToolCallParams = params
146            .map(serde_json::from_value)
147            .transpose()
148            .map_err(|e| McpError::InvalidParams(e.to_string()))?
149            .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
150
151        let result =
152            ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
153
154        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
155    }
156
157    async fn handle_resources_list(&self) -> McpResult<Value> {
158        let result = ResourceListResult {
159            resources: ResourceRegistry::list_resources(),
160            next_cursor: None,
161        };
162        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
163    }
164
165    async fn handle_resource_templates_list(&self) -> McpResult<Value> {
166        let result = ResourceTemplateListResult {
167            resource_templates: ResourceRegistry::list_templates(),
168            next_cursor: None,
169        };
170        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
171    }
172
173    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
174        let read_params: ResourceReadParams = params
175            .map(serde_json::from_value)
176            .transpose()
177            .map_err(|e| McpError::InvalidParams(e.to_string()))?
178            .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
179
180        let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
181
182        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
183    }
184
185    async fn handle_prompts_list(&self) -> McpResult<Value> {
186        let result = PromptListResult {
187            prompts: PromptRegistry::list_prompts(),
188            next_cursor: None,
189        };
190        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
191    }
192
193    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
194        let get_params: PromptGetParams = params
195            .map(serde_json::from_value)
196            .transpose()
197            .map_err(|e| McpError::InvalidParams(e.to_string()))?
198            .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
199
200        let result =
201            PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
202
203        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
204    }
205}