agentic_memory_mcp/protocol/
handler.rs1use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde_json::Value;
7
8use crate::prompts::PromptRegistry;
9use crate::resources::ResourceRegistry;
10use crate::session::SessionManager;
11use crate::tools::ToolRegistry;
12use crate::types::*;
13
14use super::negotiation::NegotiatedCapabilities;
15use super::validator::validate_request;
16
17pub struct ProtocolHandler {
19 session: Arc<Mutex<SessionManager>>,
20 capabilities: Arc<Mutex<NegotiatedCapabilities>>,
21}
22
23impl ProtocolHandler {
24 pub fn new(session: Arc<Mutex<SessionManager>>) -> Self {
26 Self {
27 session,
28 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
29 }
30 }
31
32 pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
34 match msg {
35 JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
36 JsonRpcMessage::Notification(notif) => {
37 self.handle_notification(notif).await;
38 None
39 }
40 _ => {
41 tracing::warn!("Received unexpected message type from client");
43 None
44 }
45 }
46 }
47
48 async fn handle_request(&self, request: JsonRpcRequest) -> Value {
49 if let Err(e) = validate_request(&request) {
51 return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
52 }
53
54 let id = request.id.clone();
55 let result = self.dispatch_request(&request).await;
56
57 match result {
58 Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
59 Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
60 }
61 }
62
63 async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
64 match request.method.as_str() {
65 "initialize" => self.handle_initialize(request.params.clone()).await,
67 "shutdown" => self.handle_shutdown().await,
68
69 "tools/list" => self.handle_tools_list().await,
71 "tools/call" => self.handle_tools_call(request.params.clone()).await,
72
73 "resources/list" => self.handle_resources_list().await,
75 "resources/templates/list" => self.handle_resource_templates_list().await,
76 "resources/read" => self.handle_resources_read(request.params.clone()).await,
77 "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
78 "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
79
80 "prompts/list" => self.handle_prompts_list().await,
82 "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
83
84 "ping" => Ok(Value::Object(serde_json::Map::new())),
86
87 _ => Err(McpError::MethodNotFound(request.method.clone())),
88 }
89 }
90
91 async fn handle_notification(&self, notification: JsonRpcNotification) {
92 match notification.method.as_str() {
93 "initialized" => {
94 let mut caps = self.capabilities.lock().await;
95 if let Err(e) = caps.mark_initialized() {
96 tracing::error!("Failed to mark initialized: {e}");
97 }
98 }
99 "notifications/cancelled" | "$/cancelRequest" => {
100 tracing::info!("Received cancellation notification");
101 }
102 _ => {
103 tracing::debug!("Unknown notification: {}", notification.method);
104 }
105 }
106 }
107
108 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
109 let init_params: InitializeParams = params
110 .map(serde_json::from_value)
111 .transpose()
112 .map_err(|e| McpError::InvalidParams(e.to_string()))?
113 .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
114
115 let mut caps = self.capabilities.lock().await;
116 let result = caps.negotiate(init_params)?;
117
118 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
119 }
120
121 async fn handle_shutdown(&self) -> McpResult<Value> {
122 tracing::info!("Shutdown requested");
123 let mut session = self.session.lock().await;
124 session.save()?;
125 Ok(Value::Object(serde_json::Map::new()))
126 }
127
128 async fn handle_tools_list(&self) -> McpResult<Value> {
129 let result = ToolListResult {
130 tools: ToolRegistry::list_tools(),
131 next_cursor: None,
132 };
133 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
134 }
135
136 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
137 let call_params: ToolCallParams = params
138 .map(serde_json::from_value)
139 .transpose()
140 .map_err(|e| McpError::InvalidParams(e.to_string()))?
141 .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
142
143 let result =
144 ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await?;
145
146 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
147 }
148
149 async fn handle_resources_list(&self) -> McpResult<Value> {
150 let result = ResourceListResult {
151 resources: ResourceRegistry::list_resources(),
152 next_cursor: None,
153 };
154 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
155 }
156
157 async fn handle_resource_templates_list(&self) -> McpResult<Value> {
158 let result = ResourceTemplateListResult {
159 resource_templates: ResourceRegistry::list_templates(),
160 next_cursor: None,
161 };
162 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
163 }
164
165 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
166 let read_params: ResourceReadParams = params
167 .map(serde_json::from_value)
168 .transpose()
169 .map_err(|e| McpError::InvalidParams(e.to_string()))?
170 .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
171
172 let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
173
174 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
175 }
176
177 async fn handle_prompts_list(&self) -> McpResult<Value> {
178 let result = PromptListResult {
179 prompts: PromptRegistry::list_prompts(),
180 next_cursor: None,
181 };
182 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
183 }
184
185 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
186 let get_params: PromptGetParams = params
187 .map(serde_json::from_value)
188 .transpose()
189 .map_err(|e| McpError::InvalidParams(e.to_string()))?
190 .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
191
192 let result =
193 PromptRegistry::get(&get_params.name, get_params.arguments, &self.session).await?;
194
195 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
196 }
197}