Skip to main content

agentic_memory_mcp/protocol/
handler.rs

1//! Main request dispatcher — receives JSON-RPC messages, routes to handlers.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use serde_json::Value;
8
9use crate::prompts::PromptRegistry;
10use crate::resources::ResourceRegistry;
11use crate::session::SessionManager;
12use crate::tools::ToolRegistry;
13use crate::types::*;
14
15use super::negotiation::NegotiatedCapabilities;
16use super::validator::validate_request;
17
18/// The main protocol handler that dispatches incoming JSON-RPC messages.
19pub struct ProtocolHandler {
20    session: Arc<Mutex<SessionManager>>,
21    capabilities: Arc<Mutex<NegotiatedCapabilities>>,
22    shutdown_requested: Arc<AtomicBool>,
23    memory_mode: MemoryMode,
24    /// Tracks whether an auto-session was started so we can auto-end it.
25    auto_session_started: AtomicBool,
26}
27
28impl ProtocolHandler {
29    /// Create a new protocol handler with the given session manager.
30    pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
31        Self {
32            session,
33            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
34            shutdown_requested: Arc::new(AtomicBool::new(false)),
35            memory_mode: MemoryMode::Smart,
36            auto_session_started: AtomicBool::new(false),
37        }
38    }
39
40    /// Create a new protocol handler with a specific memory mode.
41    pub fn with_mode(session: Arc<Mutex<SessionManager>>, mode: MemoryMode) -> Self {
42        Self {
43            session,
44            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::with_mode(mode))),
45            shutdown_requested: Arc::new(AtomicBool::new(false)),
46            memory_mode: mode,
47            auto_session_started: AtomicBool::new(false),
48        }
49    }
50
51    /// Returns true once a shutdown request has been handled.
52    pub fn shutdown_requested(&self) -> bool {
53        self.shutdown_requested.load(Ordering::Relaxed)
54    }
55
56    /// Handle an incoming JSON-RPC message and optionally return a response.
57    pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
58        match msg {
59            JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
60            JsonRpcMessage::Notification(notif) => {
61                self.handle_notification(notif).await;
62                None
63            }
64            _ => {
65                // Responses and errors from the client are unexpected
66                tracing::warn!("Received unexpected message type from client");
67                None
68            }
69        }
70    }
71
72    /// Cleanup on transport close (EOF). Auto-ends session if one was started.
73    pub async fn cleanup(&self) {
74        if !self.auto_session_started.load(Ordering::Relaxed) {
75            return;
76        }
77
78        let mut session = self.session.lock().await;
79        let sid = session.current_session_id();
80        match session.end_session_with_episode(sid, "Session ended: MCP connection closed") {
81            Ok(episode_id) => {
82                tracing::info!("Auto-ended session {sid} on EOF, episode node {episode_id}");
83            }
84            Err(e) => {
85                tracing::warn!("Failed to auto-end session on EOF: {e}");
86                if let Err(save_err) = session.save() {
87                    tracing::error!("Failed to save on EOF cleanup: {save_err}");
88                }
89            }
90        }
91        self.auto_session_started.store(false, Ordering::Relaxed);
92    }
93
94    async fn handle_request(&self, request: JsonRpcRequest) -> Value {
95        // Validate JSON-RPC structure
96        if let Err(e) = validate_request(&request) {
97            return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
98        }
99
100        let id = request.id.clone();
101        let result = self.dispatch_request(&request).await;
102
103        match result {
104            Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
105            Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
106        }
107    }
108
109    async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
110        match request.method.as_str() {
111            // Lifecycle
112            "initialize" => self.handle_initialize(request.params.clone()).await,
113            "shutdown" => self.handle_shutdown().await,
114
115            // Tools
116            "tools/list" => self.handle_tools_list().await,
117            "tools/call" => self.handle_tools_call(request.params.clone()).await,
118
119            // Resources
120            "resources/list" => self.handle_resources_list().await,
121            "resources/templates/list" => self.handle_resource_templates_list().await,
122            "resources/read" => self.handle_resources_read(request.params.clone()).await,
123            "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
124            "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
125
126            // Prompts
127            "prompts/list" => self.handle_prompts_list().await,
128            "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
129
130            // Ping
131            "ping" => Ok(Value::Object(serde_json::Map::new())),
132
133            _ => Err(McpError::MethodNotFound(request.method.clone())),
134        }
135    }
136
137    async fn handle_notification(&self, notification: JsonRpcNotification) {
138        match notification.method.as_str() {
139            "initialized" => {
140                let mut caps = self.capabilities.lock().await;
141                if let Err(e) = caps.mark_initialized() {
142                    tracing::error!("Failed to mark initialized: {e}");
143                }
144
145                // Auto-start session when client confirms connection (smart/full mode).
146                if self.memory_mode != MemoryMode::Minimal {
147                    let mut session = self.session.lock().await;
148                    match session.start_session(None) {
149                        Ok(sid) => {
150                            self.auto_session_started.store(true, Ordering::Relaxed);
151                            tracing::info!(
152                                "Auto-started session {sid} (mode={:?})",
153                                self.memory_mode
154                            );
155                        }
156                        Err(e) => {
157                            tracing::error!("Failed to auto-start session: {e}");
158                        }
159                    }
160                }
161            }
162            "notifications/cancelled" | "$/cancelRequest" => {
163                tracing::info!("Received cancellation notification");
164            }
165            _ => {
166                tracing::debug!("Unknown notification: {}", notification.method);
167            }
168        }
169    }
170
171    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
172        let init_params: InitializeParams = params
173            .map(serde_json::from_value)
174            .transpose()
175            .map_err(|e| McpError::InvalidParams(e.to_string()))?
176            .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
177
178        let mut caps = self.capabilities.lock().await;
179        let result = caps.negotiate(init_params)?;
180
181        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
182    }
183
184    async fn handle_shutdown(&self) -> McpResult<Value> {
185        tracing::info!("Shutdown requested");
186
187        let mut session = self.session.lock().await;
188
189        // Auto-end session with episode summary if one was auto-started.
190        if self.auto_session_started.swap(false, Ordering::Relaxed) {
191            let sid = session.current_session_id();
192            match session.end_session_with_episode(sid, "Session ended: MCP client shutdown") {
193                Ok(episode_id) => {
194                    tracing::info!("Auto-ended session {sid}, episode node {episode_id}");
195                }
196                Err(e) => {
197                    tracing::warn!("Failed to auto-end session on shutdown: {e}");
198                    session.save()?;
199                }
200            }
201        } else {
202            session.save()?;
203        }
204
205        self.shutdown_requested.store(true, Ordering::Relaxed);
206        Ok(Value::Object(serde_json::Map::new()))
207    }
208
209    async fn handle_tools_list(&self) -> McpResult<Value> {
210        let result = ToolListResult {
211            tools: ToolRegistry::list_tools(),
212            next_cursor: None,
213        };
214        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
215    }
216
217    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
218        let call_params: ToolCallParams = params
219            .map(serde_json::from_value)
220            .transpose()
221            .map_err(|e| McpError::InvalidParams(e.to_string()))?
222            .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
223
224        {
225            let mut session = self.session.lock().await;
226            if let Err(e) =
227                session.capture_tool_call(&call_params.name, call_params.arguments.as_ref())
228            {
229                tracing::warn!(
230                    "Auto-capture skipped for tool {} due to error: {}",
231                    call_params.name,
232                    e
233                );
234            }
235        }
236
237        let result =
238            ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
239
240        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
241    }
242
243    async fn handle_resources_list(&self) -> McpResult<Value> {
244        let result = ResourceListResult {
245            resources: ResourceRegistry::list_resources(),
246            next_cursor: None,
247        };
248        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
249    }
250
251    async fn handle_resource_templates_list(&self) -> McpResult<Value> {
252        let result = ResourceTemplateListResult {
253            resource_templates: ResourceRegistry::list_templates(),
254            next_cursor: None,
255        };
256        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
257    }
258
259    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
260        let read_params: ResourceReadParams = params
261            .map(serde_json::from_value)
262            .transpose()
263            .map_err(|e| McpError::InvalidParams(e.to_string()))?
264            .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
265
266        let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
267
268        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
269    }
270
271    async fn handle_prompts_list(&self) -> McpResult<Value> {
272        let result = PromptListResult {
273            prompts: PromptRegistry::list_prompts(),
274            next_cursor: None,
275        };
276        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
277    }
278
279    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
280        let get_params: PromptGetParams = params
281            .map(serde_json::from_value)
282            .transpose()
283            .map_err(|e| McpError::InvalidParams(e.to_string()))?
284            .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
285
286        {
287            let mut session = self.session.lock().await;
288            if let Err(e) =
289                session.capture_prompt_request(&get_params.name, get_params.arguments.as_ref())
290            {
291                tracing::warn!(
292                    "Auto-capture skipped for prompt {} due to error: {}",
293                    get_params.name,
294                    e
295                );
296            }
297        }
298
299        let result =
300            PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
301
302        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
303    }
304}