agentic_memory_mcp/protocol/
handler.rs1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use serde_json::Value;
8
9use crate::prompts::PromptRegistry;
10use crate::resources::ResourceRegistry;
11use crate::session::SessionManager;
12use crate::tools::ToolRegistry;
13use crate::types::*;
14
15use super::negotiation::NegotiatedCapabilities;
16use super::validator::validate_request;
17
18pub struct ProtocolHandler {
20 session: Arc<Mutex<SessionManager>>,
21 capabilities: Arc<Mutex<NegotiatedCapabilities>>,
22 shutdown_requested: Arc<AtomicBool>,
23 memory_mode: MemoryMode,
24 auto_session_started: AtomicBool,
26}
27
28impl ProtocolHandler {
29 pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
31 Self {
32 session,
33 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
34 shutdown_requested: Arc::new(AtomicBool::new(false)),
35 memory_mode: MemoryMode::Smart,
36 auto_session_started: AtomicBool::new(false),
37 }
38 }
39
40 pub fn with_mode(session: Arc<Mutex<SessionManager>>, mode: MemoryMode) -> Self {
42 Self {
43 session,
44 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::with_mode(mode))),
45 shutdown_requested: Arc::new(AtomicBool::new(false)),
46 memory_mode: mode,
47 auto_session_started: AtomicBool::new(false),
48 }
49 }
50
51 pub fn shutdown_requested(&self) -> bool {
53 self.shutdown_requested.load(Ordering::Relaxed)
54 }
55
56 pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
58 match msg {
59 JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
60 JsonRpcMessage::Notification(notif) => {
61 self.handle_notification(notif).await;
62 None
63 }
64 _ => {
65 tracing::warn!("Received unexpected message type from client");
67 None
68 }
69 }
70 }
71
72 pub async fn cleanup(&self) {
74 if !self.auto_session_started.load(Ordering::Relaxed) {
75 return;
76 }
77
78 let mut session = self.session.lock().await;
79 let sid = session.current_session_id();
80 match session.end_session_with_episode(sid, "Session ended: MCP connection closed") {
81 Ok(episode_id) => {
82 tracing::info!("Auto-ended session {sid} on EOF, episode node {episode_id}");
83 }
84 Err(e) => {
85 tracing::warn!("Failed to auto-end session on EOF: {e}");
86 if let Err(save_err) = session.save() {
87 tracing::error!("Failed to save on EOF cleanup: {save_err}");
88 }
89 }
90 }
91 self.auto_session_started.store(false, Ordering::Relaxed);
92 }
93
94 async fn handle_request(&self, request: JsonRpcRequest) -> Value {
95 if let Err(e) = validate_request(&request) {
97 return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
98 }
99
100 let id = request.id.clone();
101 let result = self.dispatch_request(&request).await;
102
103 match result {
104 Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
105 Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
106 }
107 }
108
109 async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
110 match request.method.as_str() {
111 "initialize" => self.handle_initialize(request.params.clone()).await,
113 "shutdown" => self.handle_shutdown().await,
114
115 "tools/list" => self.handle_tools_list().await,
117 "tools/call" => self.handle_tools_call(request.params.clone()).await,
118
119 "resources/list" => self.handle_resources_list().await,
121 "resources/templates/list" => self.handle_resource_templates_list().await,
122 "resources/read" => self.handle_resources_read(request.params.clone()).await,
123 "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
124 "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
125
126 "prompts/list" => self.handle_prompts_list().await,
128 "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
129
130 "ping" => Ok(Value::Object(serde_json::Map::new())),
132
133 _ => Err(McpError::MethodNotFound(request.method.clone())),
134 }
135 }
136
137 async fn handle_notification(&self, notification: JsonRpcNotification) {
138 match notification.method.as_str() {
139 "initialized" => {
140 let mut caps = self.capabilities.lock().await;
141 if let Err(e) = caps.mark_initialized() {
142 tracing::error!("Failed to mark initialized: {e}");
143 }
144
145 if self.memory_mode != MemoryMode::Minimal {
147 let mut session = self.session.lock().await;
148 match session.start_session(None) {
149 Ok(sid) => {
150 self.auto_session_started.store(true, Ordering::Relaxed);
151 tracing::info!(
152 "Auto-started session {sid} (mode={:?})",
153 self.memory_mode
154 );
155 }
156 Err(e) => {
157 tracing::error!("Failed to auto-start session: {e}");
158 }
159 }
160 }
161 }
162 "notifications/cancelled" | "$/cancelRequest" => {
163 tracing::info!("Received cancellation notification");
164 }
165 _ => {
166 tracing::debug!("Unknown notification: {}", notification.method);
167 }
168 }
169 }
170
171 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
172 let init_params: InitializeParams = params
173 .map(serde_json::from_value)
174 .transpose()
175 .map_err(|e| McpError::InvalidParams(e.to_string()))?
176 .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
177
178 let mut caps = self.capabilities.lock().await;
179 let result = caps.negotiate(init_params)?;
180
181 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
182 }
183
184 async fn handle_shutdown(&self) -> McpResult<Value> {
185 tracing::info!("Shutdown requested");
186
187 let mut session = self.session.lock().await;
188
189 if self.auto_session_started.swap(false, Ordering::Relaxed) {
191 let sid = session.current_session_id();
192 match session.end_session_with_episode(sid, "Session ended: MCP client shutdown") {
193 Ok(episode_id) => {
194 tracing::info!("Auto-ended session {sid}, episode node {episode_id}");
195 }
196 Err(e) => {
197 tracing::warn!("Failed to auto-end session on shutdown: {e}");
198 session.save()?;
199 }
200 }
201 } else {
202 session.save()?;
203 }
204
205 self.shutdown_requested.store(true, Ordering::Relaxed);
206 Ok(Value::Object(serde_json::Map::new()))
207 }
208
209 async fn handle_tools_list(&self) -> McpResult<Value> {
210 let result = ToolListResult {
211 tools: ToolRegistry::list_tools(),
212 next_cursor: None,
213 };
214 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
215 }
216
217 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
218 let call_params: ToolCallParams = params
219 .map(serde_json::from_value)
220 .transpose()
221 .map_err(|e| McpError::InvalidParams(e.to_string()))?
222 .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
223
224 {
225 let mut session = self.session.lock().await;
226 if let Err(e) =
227 session.capture_tool_call(&call_params.name, call_params.arguments.as_ref())
228 {
229 tracing::warn!(
230 "Auto-capture skipped for tool {} due to error: {}",
231 call_params.name,
232 e
233 );
234 }
235 }
236
237 let result =
238 ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
239
240 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
241 }
242
243 async fn handle_resources_list(&self) -> McpResult<Value> {
244 let result = ResourceListResult {
245 resources: ResourceRegistry::list_resources(),
246 next_cursor: None,
247 };
248 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
249 }
250
251 async fn handle_resource_templates_list(&self) -> McpResult<Value> {
252 let result = ResourceTemplateListResult {
253 resource_templates: ResourceRegistry::list_templates(),
254 next_cursor: None,
255 };
256 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
257 }
258
259 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
260 let read_params: ResourceReadParams = params
261 .map(serde_json::from_value)
262 .transpose()
263 .map_err(|e| McpError::InvalidParams(e.to_string()))?
264 .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
265
266 let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
267
268 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
269 }
270
271 async fn handle_prompts_list(&self) -> McpResult<Value> {
272 let result = PromptListResult {
273 prompts: PromptRegistry::list_prompts(),
274 next_cursor: None,
275 };
276 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
277 }
278
279 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
280 let get_params: PromptGetParams = params
281 .map(serde_json::from_value)
282 .transpose()
283 .map_err(|e| McpError::InvalidParams(e.to_string()))?
284 .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
285
286 {
287 let mut session = self.session.lock().await;
288 if let Err(e) =
289 session.capture_prompt_request(&get_params.name, get_params.arguments.as_ref())
290 {
291 tracing::warn!(
292 "Auto-capture skipped for prompt {} due to error: {}",
293 get_params.name,
294 e
295 );
296 }
297 }
298
299 let result =
300 PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
301
302 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
303 }
304}