Skip to main content

agentic_memory_mcp/protocol/
handler.rs

1//! Main request dispatcher — receives JSON-RPC messages, routes to handlers.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use serde_json::Value;
8
9use crate::prompts::PromptRegistry;
10use crate::resources::ResourceRegistry;
11use crate::session::SessionManager;
12use crate::tools::ToolRegistry;
13use crate::types::*;
14
15use super::negotiation::NegotiatedCapabilities;
16use super::validator::validate_request;
17
18/// The main protocol handler that dispatches incoming JSON-RPC messages.
19pub struct ProtocolHandler {
20    session: Arc<Mutex<SessionManager>>,
21    capabilities: Arc<Mutex<NegotiatedCapabilities>>,
22    shutdown_requested: Arc<AtomicBool>,
23}
24
25impl ProtocolHandler {
26    /// Create a new protocol handler with the given session manager.
27    pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
28        Self {
29            session,
30            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
31            shutdown_requested: Arc::new(AtomicBool::new(false)),
32        }
33    }
34
35    /// Create a new protocol handler with a specific memory mode.
36    pub fn with_mode(session: Arc<Mutex<SessionManager>>, mode: MemoryMode) -> Self {
37        Self {
38            session,
39            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::with_mode(mode))),
40            shutdown_requested: Arc::new(AtomicBool::new(false)),
41        }
42    }
43
44    /// Returns true once a shutdown request has been handled.
45    pub fn shutdown_requested(&self) -> bool {
46        self.shutdown_requested.load(Ordering::Relaxed)
47    }
48
49    /// Handle an incoming JSON-RPC message and optionally return a response.
50    pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
51        match msg {
52            JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
53            JsonRpcMessage::Notification(notif) => {
54                self.handle_notification(notif).await;
55                None
56            }
57            _ => {
58                // Responses and errors from the client are unexpected
59                tracing::warn!("Received unexpected message type from client");
60                None
61            }
62        }
63    }
64
65    async fn handle_request(&self, request: JsonRpcRequest) -> Value {
66        // Validate JSON-RPC structure
67        if let Err(e) = validate_request(&request) {
68            return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
69        }
70
71        let id = request.id.clone();
72        let result = self.dispatch_request(&request).await;
73
74        match result {
75            Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
76            Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
77        }
78    }
79
80    async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
81        match request.method.as_str() {
82            // Lifecycle
83            "initialize" => self.handle_initialize(request.params.clone()).await,
84            "shutdown" => self.handle_shutdown().await,
85
86            // Tools
87            "tools/list" => self.handle_tools_list().await,
88            "tools/call" => self.handle_tools_call(request.params.clone()).await,
89
90            // Resources
91            "resources/list" => self.handle_resources_list().await,
92            "resources/templates/list" => self.handle_resource_templates_list().await,
93            "resources/read" => self.handle_resources_read(request.params.clone()).await,
94            "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
95            "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
96
97            // Prompts
98            "prompts/list" => self.handle_prompts_list().await,
99            "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
100
101            // Ping
102            "ping" => Ok(Value::Object(serde_json::Map::new())),
103
104            _ => Err(McpError::MethodNotFound(request.method.clone())),
105        }
106    }
107
108    async fn handle_notification(&self, notification: JsonRpcNotification) {
109        match notification.method.as_str() {
110            "initialized" => {
111                let mut caps = self.capabilities.lock().await;
112                if let Err(e) = caps.mark_initialized() {
113                    tracing::error!("Failed to mark initialized: {e}");
114                }
115            }
116            "notifications/cancelled" | "$/cancelRequest" => {
117                tracing::info!("Received cancellation notification");
118            }
119            _ => {
120                tracing::debug!("Unknown notification: {}", notification.method);
121            }
122        }
123    }
124
125    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
126        let init_params: InitializeParams = params
127            .map(serde_json::from_value)
128            .transpose()
129            .map_err(|e| McpError::InvalidParams(e.to_string()))?
130            .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
131
132        let mut caps = self.capabilities.lock().await;
133        let result = caps.negotiate(init_params)?;
134
135        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
136    }
137
138    async fn handle_shutdown(&self) -> McpResult<Value> {
139        tracing::info!("Shutdown requested");
140        let mut session = self.session.lock().await;
141        session.save()?;
142        self.shutdown_requested.store(true, Ordering::Relaxed);
143        Ok(Value::Object(serde_json::Map::new()))
144    }
145
146    async fn handle_tools_list(&self) -> McpResult<Value> {
147        let result = ToolListResult {
148            tools: ToolRegistry::list_tools(),
149            next_cursor: None,
150        };
151        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
152    }
153
154    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
155        let call_params: ToolCallParams = params
156            .map(serde_json::from_value)
157            .transpose()
158            .map_err(|e| McpError::InvalidParams(e.to_string()))?
159            .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
160
161        {
162            let mut session = self.session.lock().await;
163            if let Err(e) =
164                session.capture_tool_call(&call_params.name, call_params.arguments.as_ref())
165            {
166                tracing::warn!(
167                    "Auto-capture skipped for tool {} due to error: {}",
168                    call_params.name,
169                    e
170                );
171            }
172        }
173
174        let result =
175            ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
176
177        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
178    }
179
180    async fn handle_resources_list(&self) -> McpResult<Value> {
181        let result = ResourceListResult {
182            resources: ResourceRegistry::list_resources(),
183            next_cursor: None,
184        };
185        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
186    }
187
188    async fn handle_resource_templates_list(&self) -> McpResult<Value> {
189        let result = ResourceTemplateListResult {
190            resource_templates: ResourceRegistry::list_templates(),
191            next_cursor: None,
192        };
193        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
194    }
195
196    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
197        let read_params: ResourceReadParams = params
198            .map(serde_json::from_value)
199            .transpose()
200            .map_err(|e| McpError::InvalidParams(e.to_string()))?
201            .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
202
203        let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
204
205        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
206    }
207
208    async fn handle_prompts_list(&self) -> McpResult<Value> {
209        let result = PromptListResult {
210            prompts: PromptRegistry::list_prompts(),
211            next_cursor: None,
212        };
213        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
214    }
215
216    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
217        let get_params: PromptGetParams = params
218            .map(serde_json::from_value)
219            .transpose()
220            .map_err(|e| McpError::InvalidParams(e.to_string()))?
221            .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
222
223        {
224            let mut session = self.session.lock().await;
225            if let Err(e) =
226                session.capture_prompt_request(&get_params.name, get_params.arguments.as_ref())
227            {
228                tracing::warn!(
229                    "Auto-capture skipped for prompt {} due to error: {}",
230                    get_params.name,
231                    e
232                );
233            }
234        }
235
236        let result =
237            PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
238
239        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
240    }
241}