agentic_memory_mcp/protocol/
handler.rs1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use serde_json::Value;
8
9use crate::prompts::PromptRegistry;
10use crate::resources::ResourceRegistry;
11use crate::session::SessionManager;
12use crate::tools::ToolRegistry;
13use crate::types::*;
14
15use super::negotiation::NegotiatedCapabilities;
16use super::validator::validate_request;
17
18pub struct ProtocolHandler {
20 session: Arc<Mutex<SessionManager>>,
21 capabilities: Arc<Mutex<NegotiatedCapabilities>>,
22 shutdown_requested: Arc<AtomicBool>,
23}
24
25impl ProtocolHandler {
26 pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
28 Self {
29 session,
30 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
31 shutdown_requested: Arc::new(AtomicBool::new(false)),
32 }
33 }
34
35 pub fn with_mode(session: Arc<Mutex<SessionManager>>, mode: MemoryMode) -> Self {
37 Self {
38 session,
39 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::with_mode(mode))),
40 shutdown_requested: Arc::new(AtomicBool::new(false)),
41 }
42 }
43
44 pub fn shutdown_requested(&self) -> bool {
46 self.shutdown_requested.load(Ordering::Relaxed)
47 }
48
49 pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
51 match msg {
52 JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
53 JsonRpcMessage::Notification(notif) => {
54 self.handle_notification(notif).await;
55 None
56 }
57 _ => {
58 tracing::warn!("Received unexpected message type from client");
60 None
61 }
62 }
63 }
64
65 async fn handle_request(&self, request: JsonRpcRequest) -> Value {
66 if let Err(e) = validate_request(&request) {
68 return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
69 }
70
71 let id = request.id.clone();
72 let result = self.dispatch_request(&request).await;
73
74 match result {
75 Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
76 Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
77 }
78 }
79
80 async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
81 match request.method.as_str() {
82 "initialize" => self.handle_initialize(request.params.clone()).await,
84 "shutdown" => self.handle_shutdown().await,
85
86 "tools/list" => self.handle_tools_list().await,
88 "tools/call" => self.handle_tools_call(request.params.clone()).await,
89
90 "resources/list" => self.handle_resources_list().await,
92 "resources/templates/list" => self.handle_resource_templates_list().await,
93 "resources/read" => self.handle_resources_read(request.params.clone()).await,
94 "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
95 "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
96
97 "prompts/list" => self.handle_prompts_list().await,
99 "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
100
101 "ping" => Ok(Value::Object(serde_json::Map::new())),
103
104 _ => Err(McpError::MethodNotFound(request.method.clone())),
105 }
106 }
107
108 async fn handle_notification(&self, notification: JsonRpcNotification) {
109 match notification.method.as_str() {
110 "initialized" => {
111 let mut caps = self.capabilities.lock().await;
112 if let Err(e) = caps.mark_initialized() {
113 tracing::error!("Failed to mark initialized: {e}");
114 }
115 }
116 "notifications/cancelled" | "$/cancelRequest" => {
117 tracing::info!("Received cancellation notification");
118 }
119 _ => {
120 tracing::debug!("Unknown notification: {}", notification.method);
121 }
122 }
123 }
124
125 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
126 let init_params: InitializeParams = params
127 .map(serde_json::from_value)
128 .transpose()
129 .map_err(|e| McpError::InvalidParams(e.to_string()))?
130 .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
131
132 let mut caps = self.capabilities.lock().await;
133 let result = caps.negotiate(init_params)?;
134
135 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
136 }
137
138 async fn handle_shutdown(&self) -> McpResult<Value> {
139 tracing::info!("Shutdown requested");
140 let mut session = self.session.lock().await;
141 session.save()?;
142 self.shutdown_requested.store(true, Ordering::Relaxed);
143 Ok(Value::Object(serde_json::Map::new()))
144 }
145
146 async fn handle_tools_list(&self) -> McpResult<Value> {
147 let result = ToolListResult {
148 tools: ToolRegistry::list_tools(),
149 next_cursor: None,
150 };
151 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
152 }
153
154 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
155 let call_params: ToolCallParams = params
156 .map(serde_json::from_value)
157 .transpose()
158 .map_err(|e| McpError::InvalidParams(e.to_string()))?
159 .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
160
161 {
162 let mut session = self.session.lock().await;
163 if let Err(e) =
164 session.capture_tool_call(&call_params.name, call_params.arguments.as_ref())
165 {
166 tracing::warn!(
167 "Auto-capture skipped for tool {} due to error: {}",
168 call_params.name,
169 e
170 );
171 }
172 }
173
174 let result =
175 ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
176
177 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
178 }
179
180 async fn handle_resources_list(&self) -> McpResult<Value> {
181 let result = ResourceListResult {
182 resources: ResourceRegistry::list_resources(),
183 next_cursor: None,
184 };
185 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
186 }
187
188 async fn handle_resource_templates_list(&self) -> McpResult<Value> {
189 let result = ResourceTemplateListResult {
190 resource_templates: ResourceRegistry::list_templates(),
191 next_cursor: None,
192 };
193 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
194 }
195
196 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
197 let read_params: ResourceReadParams = params
198 .map(serde_json::from_value)
199 .transpose()
200 .map_err(|e| McpError::InvalidParams(e.to_string()))?
201 .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
202
203 let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
204
205 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
206 }
207
208 async fn handle_prompts_list(&self) -> McpResult<Value> {
209 let result = PromptListResult {
210 prompts: PromptRegistry::list_prompts(),
211 next_cursor: None,
212 };
213 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
214 }
215
216 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
217 let get_params: PromptGetParams = params
218 .map(serde_json::from_value)
219 .transpose()
220 .map_err(|e| McpError::InvalidParams(e.to_string()))?
221 .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
222
223 {
224 let mut session = self.session.lock().await;
225 if let Err(e) =
226 session.capture_prompt_request(&get_params.name, get_params.arguments.as_ref())
227 {
228 tracing::warn!(
229 "Auto-capture skipped for prompt {} due to error: {}",
230 get_params.name,
231 e
232 );
233 }
234 }
235
236 let result =
237 PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
238
239 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
240 }
241}