agentic_memory_mcp/protocol/
handler.rs1use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde_json::Value;
7
8use crate::prompts::PromptRegistry;
9use crate::resources::ResourceRegistry;
10use crate::session::SessionManager;
11use crate::tools::ToolRegistry;
12use crate::types::*;
13
14use super::negotiation::NegotiatedCapabilities;
15use super::validator::validate_request;
16
17pub struct ProtocolHandler {
19 session: Arc<Mutex<SessionManager>>,
20 capabilities: Arc<Mutex<NegotiatedCapabilities>>,
21}
22
23impl ProtocolHandler {
24 pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
26 Self {
27 session,
28 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
29 }
30 }
31
32 pub fn with_mode(session: Arc<Mutex<SessionManager>>, mode: MemoryMode) -> Self {
34 Self {
35 session,
36 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::with_mode(mode))),
37 }
38 }
39
40 pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
42 match msg {
43 JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
44 JsonRpcMessage::Notification(notif) => {
45 self.handle_notification(notif).await;
46 None
47 }
48 _ => {
49 tracing::warn!("Received unexpected message type from client");
51 None
52 }
53 }
54 }
55
56 async fn handle_request(&self, request: JsonRpcRequest) -> Value {
57 if let Err(e) = validate_request(&request) {
59 return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
60 }
61
62 let id = request.id.clone();
63 let result = self.dispatch_request(&request).await;
64
65 match result {
66 Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
67 Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
68 }
69 }
70
71 async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
72 match request.method.as_str() {
73 "initialize" => self.handle_initialize(request.params.clone()).await,
75 "shutdown" => self.handle_shutdown().await,
76
77 "tools/list" => self.handle_tools_list().await,
79 "tools/call" => self.handle_tools_call(request.params.clone()).await,
80
81 "resources/list" => self.handle_resources_list().await,
83 "resources/templates/list" => self.handle_resource_templates_list().await,
84 "resources/read" => self.handle_resources_read(request.params.clone()).await,
85 "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
86 "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
87
88 "prompts/list" => self.handle_prompts_list().await,
90 "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
91
92 "ping" => Ok(Value::Object(serde_json::Map::new())),
94
95 _ => Err(McpError::MethodNotFound(request.method.clone())),
96 }
97 }
98
99 async fn handle_notification(&self, notification: JsonRpcNotification) {
100 match notification.method.as_str() {
101 "initialized" => {
102 let mut caps = self.capabilities.lock().await;
103 if let Err(e) = caps.mark_initialized() {
104 tracing::error!("Failed to mark initialized: {e}");
105 }
106 }
107 "notifications/cancelled" | "$/cancelRequest" => {
108 tracing::info!("Received cancellation notification");
109 }
110 _ => {
111 tracing::debug!("Unknown notification: {}", notification.method);
112 }
113 }
114 }
115
116 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
117 let init_params: InitializeParams = params
118 .map(serde_json::from_value)
119 .transpose()
120 .map_err(|e| McpError::InvalidParams(e.to_string()))?
121 .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
122
123 let mut caps = self.capabilities.lock().await;
124 let result = caps.negotiate(init_params)?;
125
126 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
127 }
128
129 async fn handle_shutdown(&self) -> McpResult<Value> {
130 tracing::info!("Shutdown requested");
131 let mut session = self.session.lock().await;
132 session.save()?;
133 Ok(Value::Object(serde_json::Map::new()))
134 }
135
136 async fn handle_tools_list(&self) -> McpResult<Value> {
137 let result = ToolListResult {
138 tools: ToolRegistry::list_tools(),
139 next_cursor: None,
140 };
141 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
142 }
143
144 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
145 let call_params: ToolCallParams = params
146 .map(serde_json::from_value)
147 .transpose()
148 .map_err(|e| McpError::InvalidParams(e.to_string()))?
149 .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
150
151 let result =
152 ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
153
154 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
155 }
156
157 async fn handle_resources_list(&self) -> McpResult<Value> {
158 let result = ResourceListResult {
159 resources: ResourceRegistry::list_resources(),
160 next_cursor: None,
161 };
162 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
163 }
164
165 async fn handle_resource_templates_list(&self) -> McpResult<Value> {
166 let result = ResourceTemplateListResult {
167 resource_templates: ResourceRegistry::list_templates(),
168 next_cursor: None,
169 };
170 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
171 }
172
173 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
174 let read_params: ResourceReadParams = params
175 .map(serde_json::from_value)
176 .transpose()
177 .map_err(|e| McpError::InvalidParams(e.to_string()))?
178 .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
179
180 let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
181
182 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
183 }
184
185 async fn handle_prompts_list(&self) -> McpResult<Value> {
186 let result = PromptListResult {
187 prompts: PromptRegistry::list_prompts(),
188 next_cursor: None,
189 };
190 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
191 }
192
193 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
194 let get_params: PromptGetParams = params
195 .map(serde_json::from_value)
196 .transpose()
197 .map_err(|e| McpError::InvalidParams(e.to_string()))?
198 .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
199
200 let result =
201 PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
202
203 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
204 }
205}