agentic_vision_mcp/protocol/
handler.rs1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use serde_json::Value;
8
9use crate::prompts::PromptRegistry;
10use crate::resources::ResourceRegistry;
11use crate::session::VisionSessionManager;
12use crate::tools::ToolRegistry;
13use crate::types::*;
14
15use super::negotiation::NegotiatedCapabilities;
16use super::validator::validate_request;
17
18pub struct ProtocolHandler {
20 session: Arc<Mutex<VisionSessionManager>>,
21 capabilities: Arc<Mutex<NegotiatedCapabilities>>,
22 shutdown_requested: Arc<AtomicBool>,
23 auto_session_started: AtomicBool,
25}
26
27impl ProtocolHandler {
28 pub fn new(session: Arc<Mutex<VisionSessionManager>>) -> Self {
29 Self {
30 session,
31 capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
32 shutdown_requested: Arc::new(AtomicBool::new(false)),
33 auto_session_started: AtomicBool::new(false),
34 }
35 }
36
37 pub fn shutdown_requested(&self) -> bool {
39 self.shutdown_requested.load(Ordering::Relaxed)
40 }
41
42 pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
43 match msg {
44 JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
45 JsonRpcMessage::Notification(notif) => {
46 self.handle_notification(notif).await;
47 None
48 }
49 _ => {
50 tracing::warn!("Received unexpected message type from client");
51 None
52 }
53 }
54 }
55
56 pub async fn cleanup(&self) {
58 if !self.auto_session_started.load(Ordering::Relaxed) {
59 return;
60 }
61
62 let mut session = self.session.lock().await;
63 match session.end_session() {
64 Ok(sid) => {
65 tracing::info!("Auto-ended vision session {sid} on EOF");
66 }
67 Err(e) => {
68 tracing::warn!("Failed to auto-end vision session on EOF: {e}");
69 if let Err(save_err) = session.save() {
70 tracing::error!("Failed to save vision on EOF cleanup: {save_err}");
71 }
72 }
73 }
74 self.auto_session_started.store(false, Ordering::Relaxed);
75 }
76
77 async fn handle_request(&self, request: JsonRpcRequest) -> Value {
78 if let Err(e) = validate_request(&request) {
79 return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
80 }
81
82 let id = request.id.clone();
83 let result = self.dispatch_request(&request).await;
84
85 match result {
86 Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
87 Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
88 }
89 }
90
91 async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
92 match request.method.as_str() {
93 "initialize" => self.handle_initialize(request.params.clone()).await,
94 "shutdown" => self.handle_shutdown().await,
95
96 "tools/list" => self.handle_tools_list().await,
97 "tools/call" => self.handle_tools_call(request.params.clone()).await,
98
99 "resources/list" => self.handle_resources_list().await,
100 "resources/templates/list" => self.handle_resource_templates_list().await,
101 "resources/read" => self.handle_resources_read(request.params.clone()).await,
102 "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
103 "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
104
105 "prompts/list" => self.handle_prompts_list().await,
106 "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
107
108 "ping" => Ok(Value::Object(serde_json::Map::new())),
109
110 _ => Err(McpError::MethodNotFound(request.method.clone())),
111 }
112 }
113
114 async fn handle_notification(&self, notification: JsonRpcNotification) {
115 match notification.method.as_str() {
116 "initialized" => {
117 let mut caps = self.capabilities.lock().await;
118 if let Err(e) = caps.mark_initialized() {
119 tracing::error!("Failed to mark initialized: {e}");
120 }
121
122 let mut session = self.session.lock().await;
124 match session.start_session(None) {
125 Ok(sid) => {
126 self.auto_session_started.store(true, Ordering::Relaxed);
127 tracing::info!("Auto-started vision session {sid}");
128 }
129 Err(e) => {
130 tracing::error!("Failed to auto-start vision session: {e}");
131 }
132 }
133 }
134 "notifications/cancelled" | "$/cancelRequest" => {
135 tracing::info!("Received cancellation notification");
136 }
137 _ => {
138 tracing::debug!("Unknown notification: {}", notification.method);
139 }
140 }
141 }
142
143 async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
144 let init_params: InitializeParams = params
145 .map(serde_json::from_value)
146 .transpose()
147 .map_err(|e| McpError::InvalidParams(e.to_string()))?
148 .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
149
150 let mut caps = self.capabilities.lock().await;
151 let result = caps.negotiate(init_params)?;
152
153 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
154 }
155
156 async fn handle_shutdown(&self) -> McpResult<Value> {
157 tracing::info!("Shutdown requested");
158
159 let mut session = self.session.lock().await;
160
161 if self.auto_session_started.swap(false, Ordering::Relaxed) {
163 let sid = session.current_session_id();
164 match session.end_session() {
165 Ok(_) => {
166 tracing::info!("Auto-ended vision session {sid}");
167 }
168 Err(e) => {
169 tracing::warn!("Failed to auto-end vision session on shutdown: {e}");
170 session.save()?;
171 }
172 }
173 } else {
174 session.save()?;
175 }
176
177 self.shutdown_requested.store(true, Ordering::Relaxed);
178 Ok(Value::Object(serde_json::Map::new()))
179 }
180
181 async fn handle_tools_list(&self) -> McpResult<Value> {
182 let result = ToolListResult {
183 tools: ToolRegistry::list_tools(),
184 next_cursor: None,
185 };
186 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
187 }
188
189 async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
190 let call_params: ToolCallParams = params
191 .map(serde_json::from_value)
192 .transpose()
193 .map_err(|e| McpError::InvalidParams(e.to_string()))?
194 .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
195
196 let tool_name = call_params.name.clone();
197 let args_summary = call_params
198 .arguments
199 .as_ref()
200 .map(|a| truncate_json_summary(a, 200))
201 .unwrap_or_default();
202
203 let result =
206 match ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await
207 {
208 Ok(r) => r,
209 Err(e) if e.is_protocol_error() => return Err(e),
210 Err(e) => ToolCallResult::error(e.to_string()),
211 };
212
213 if tool_name != "observation_log" {
216 let now = std::time::SystemTime::now()
217 .duration_since(std::time::UNIX_EPOCH)
218 .unwrap_or_default()
219 .as_secs();
220 let capture_id = extract_capture_id(&result);
221 let record = crate::session::ToolCallRecord {
222 tool_name,
223 summary: args_summary,
224 timestamp: now,
225 capture_id,
226 };
227 let mut session = self.session.lock().await;
228 session.log_tool_call(record);
229 }
230
231 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
232 }
233
234 async fn handle_resources_list(&self) -> McpResult<Value> {
235 let result = ResourceListResult {
236 resources: ResourceRegistry::list_resources(),
237 next_cursor: None,
238 };
239 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
240 }
241
242 async fn handle_resource_templates_list(&self) -> McpResult<Value> {
243 let result = ResourceTemplateListResult {
244 resource_templates: ResourceRegistry::list_templates(),
245 next_cursor: None,
246 };
247 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
248 }
249
250 async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
251 let read_params: ResourceReadParams = params
252 .map(serde_json::from_value)
253 .transpose()
254 .map_err(|e| McpError::InvalidParams(e.to_string()))?
255 .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
256
257 let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
258
259 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
260 }
261
262 async fn handle_prompts_list(&self) -> McpResult<Value> {
263 let result = PromptListResult {
264 prompts: PromptRegistry::list_prompts(),
265 next_cursor: None,
266 };
267 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
268 }
269
270 async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
271 let get_params: PromptGetParams = params
272 .map(serde_json::from_value)
273 .transpose()
274 .map_err(|e| McpError::InvalidParams(e.to_string()))?
275 .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
276
277 let result = PromptRegistry::get(&get_params.name, get_params.arguments).await?;
278
279 serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
280 }
281}
282
283fn truncate_json_summary(value: &Value, max_len: usize) -> String {
285 let s = value.to_string();
286 if s.len() <= max_len {
287 s
288 } else {
289 format!("{}…", &s[..max_len])
290 }
291}
292
293fn extract_capture_id(result: &crate::types::ToolCallResult) -> Option<u64> {
295 for content in &result.content {
296 if let crate::types::ToolContent::Text { text } = content {
297 if let Ok(v) = serde_json::from_str::<Value>(text) {
298 if let Some(id) = v.get("capture_id").and_then(|v| v.as_u64()) {
299 return Some(id);
300 }
301 }
302 }
303 }
304 None
305}