Skip to main content

agentic_vision_mcp/protocol/
handler.rs

1//! Main request dispatcher — receives JSON-RPC messages, routes to handlers.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use serde_json::Value;
8
9use crate::prompts::PromptRegistry;
10use crate::resources::ResourceRegistry;
11use crate::session::VisionSessionManager;
12use crate::tools::ToolRegistry;
13use crate::types::*;
14
15use super::negotiation::NegotiatedCapabilities;
16use super::validator::validate_request;
17
18/// The main protocol handler that dispatches incoming JSON-RPC messages.
19pub struct ProtocolHandler {
20    session: Arc<Mutex<VisionSessionManager>>,
21    capabilities: Arc<Mutex<NegotiatedCapabilities>>,
22    shutdown_requested: Arc<AtomicBool>,
23    /// Tracks whether an auto-session was started so we can auto-end it.
24    auto_session_started: AtomicBool,
25}
26
27impl ProtocolHandler {
28    pub fn new(session: Arc<Mutex<VisionSessionManager>>) -> Self {
29        Self {
30            session,
31            capabilities: Arc::new(Mutex::new(NegotiatedCapabilities::default())),
32            shutdown_requested: Arc::new(AtomicBool::new(false)),
33            auto_session_started: AtomicBool::new(false),
34        }
35    }
36
37    /// Returns true once a shutdown request has been handled.
38    pub fn shutdown_requested(&self) -> bool {
39        self.shutdown_requested.load(Ordering::Relaxed)
40    }
41
42    pub async fn handle_message(&self, msg: JsonRpcMessage) -> Option<Value> {
43        match msg {
44            JsonRpcMessage::Request(req) => Some(self.handle_request(req).await),
45            JsonRpcMessage::Notification(notif) => {
46                self.handle_notification(notif).await;
47                None
48            }
49            _ => {
50                tracing::warn!("Received unexpected message type from client");
51                None
52            }
53        }
54    }
55
56    /// Cleanup on transport close (EOF). Auto-ends session if one was started.
57    pub async fn cleanup(&self) {
58        if !self.auto_session_started.load(Ordering::Relaxed) {
59            return;
60        }
61
62        let mut session = self.session.lock().await;
63        match session.end_session() {
64            Ok(sid) => {
65                tracing::info!("Auto-ended vision session {sid} on EOF");
66            }
67            Err(e) => {
68                tracing::warn!("Failed to auto-end vision session on EOF: {e}");
69                if let Err(save_err) = session.save() {
70                    tracing::error!("Failed to save vision on EOF cleanup: {save_err}");
71                }
72            }
73        }
74        self.auto_session_started.store(false, Ordering::Relaxed);
75    }
76
77    async fn handle_request(&self, request: JsonRpcRequest) -> Value {
78        if let Err(e) = validate_request(&request) {
79            return serde_json::to_value(e.to_json_rpc_error(request.id)).unwrap_or_default();
80        }
81
82        let id = request.id.clone();
83        let result = self.dispatch_request(&request).await;
84
85        match result {
86            Ok(value) => serde_json::to_value(JsonRpcResponse::new(id, value)).unwrap_or_default(),
87            Err(e) => serde_json::to_value(e.to_json_rpc_error(id)).unwrap_or_default(),
88        }
89    }
90
91    async fn dispatch_request(&self, request: &JsonRpcRequest) -> McpResult<Value> {
92        match request.method.as_str() {
93            "initialize" => self.handle_initialize(request.params.clone()).await,
94            "shutdown" => self.handle_shutdown().await,
95
96            "tools/list" => self.handle_tools_list().await,
97            "tools/call" => self.handle_tools_call(request.params.clone()).await,
98
99            "resources/list" => self.handle_resources_list().await,
100            "resources/templates/list" => self.handle_resource_templates_list().await,
101            "resources/read" => self.handle_resources_read(request.params.clone()).await,
102            "resources/subscribe" => Ok(Value::Object(serde_json::Map::new())),
103            "resources/unsubscribe" => Ok(Value::Object(serde_json::Map::new())),
104
105            "prompts/list" => self.handle_prompts_list().await,
106            "prompts/get" => self.handle_prompts_get(request.params.clone()).await,
107
108            "ping" => Ok(Value::Object(serde_json::Map::new())),
109
110            _ => Err(McpError::MethodNotFound(request.method.clone())),
111        }
112    }
113
114    async fn handle_notification(&self, notification: JsonRpcNotification) {
115        match notification.method.as_str() {
116            "initialized" => {
117                let mut caps = self.capabilities.lock().await;
118                if let Err(e) = caps.mark_initialized() {
119                    tracing::error!("Failed to mark initialized: {e}");
120                }
121
122                // Auto-start vision session when client confirms connection.
123                let mut session = self.session.lock().await;
124                match session.start_session(None) {
125                    Ok(sid) => {
126                        self.auto_session_started.store(true, Ordering::Relaxed);
127                        tracing::info!("Auto-started vision session {sid}");
128                    }
129                    Err(e) => {
130                        tracing::error!("Failed to auto-start vision session: {e}");
131                    }
132                }
133            }
134            "notifications/cancelled" | "$/cancelRequest" => {
135                tracing::info!("Received cancellation notification");
136            }
137            _ => {
138                tracing::debug!("Unknown notification: {}", notification.method);
139            }
140        }
141    }
142
143    async fn handle_initialize(&self, params: Option<Value>) -> McpResult<Value> {
144        let init_params: InitializeParams = params
145            .map(serde_json::from_value)
146            .transpose()
147            .map_err(|e| McpError::InvalidParams(e.to_string()))?
148            .ok_or_else(|| McpError::InvalidParams("Initialize params required".to_string()))?;
149
150        let mut caps = self.capabilities.lock().await;
151        let result = caps.negotiate(init_params)?;
152
153        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
154    }
155
156    async fn handle_shutdown(&self) -> McpResult<Value> {
157        tracing::info!("Shutdown requested");
158
159        let mut session = self.session.lock().await;
160
161        // Auto-end vision session if one was auto-started.
162        if self.auto_session_started.swap(false, Ordering::Relaxed) {
163            let sid = session.current_session_id();
164            match session.end_session() {
165                Ok(_) => {
166                    tracing::info!("Auto-ended vision session {sid}");
167                }
168                Err(e) => {
169                    tracing::warn!("Failed to auto-end vision session on shutdown: {e}");
170                    session.save()?;
171                }
172            }
173        } else {
174            session.save()?;
175        }
176
177        self.shutdown_requested.store(true, Ordering::Relaxed);
178        Ok(Value::Object(serde_json::Map::new()))
179    }
180
181    async fn handle_tools_list(&self) -> McpResult<Value> {
182        let result = ToolListResult {
183            tools: ToolRegistry::list_tools(),
184            next_cursor: None,
185        };
186        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
187    }
188
189    async fn handle_tools_call(&self, params: Option<Value>) -> McpResult<Value> {
190        let call_params: ToolCallParams = params
191            .map(serde_json::from_value)
192            .transpose()
193            .map_err(|e| McpError::InvalidParams(e.to_string()))?
194            .ok_or_else(|| McpError::InvalidParams("Tool call params required".to_string()))?;
195
196        let tool_name = call_params.name.clone();
197        let args_summary = call_params
198            .arguments
199            .as_ref()
200            .map(|a| truncate_json_summary(a, 200))
201            .unwrap_or_default();
202
203        // Classify errors: protocol errors (ToolNotFound etc.) become JSON-RPC errors;
204        // tool execution errors (CaptureNotFound, VisionError, etc.) become isError: true.
205        let result =
206            match ToolRegistry::call(&call_params.name, call_params.arguments, &self.session).await
207            {
208                Ok(r) => r,
209                Err(e) if e.is_protocol_error() => return Err(e),
210                Err(e) => ToolCallResult::error(e.to_string()),
211            };
212
213        // Auto-capture tool context into the session log.
214        // Skip logging observation_log calls to avoid recursion.
215        if tool_name != "observation_log" {
216            let now = std::time::SystemTime::now()
217                .duration_since(std::time::UNIX_EPOCH)
218                .unwrap_or_default()
219                .as_secs();
220            let capture_id = extract_capture_id(&result);
221            let record = crate::session::ToolCallRecord {
222                tool_name,
223                summary: args_summary,
224                timestamp: now,
225                capture_id,
226            };
227            let mut session = self.session.lock().await;
228            session.log_tool_call(record);
229        }
230
231        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
232    }
233
234    async fn handle_resources_list(&self) -> McpResult<Value> {
235        let result = ResourceListResult {
236            resources: ResourceRegistry::list_resources(),
237            next_cursor: None,
238        };
239        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
240    }
241
242    async fn handle_resource_templates_list(&self) -> McpResult<Value> {
243        let result = ResourceTemplateListResult {
244            resource_templates: ResourceRegistry::list_templates(),
245            next_cursor: None,
246        };
247        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
248    }
249
250    async fn handle_resources_read(&self, params: Option<Value>) -> McpResult<Value> {
251        let read_params: ResourceReadParams = params
252            .map(serde_json::from_value)
253            .transpose()
254            .map_err(|e| McpError::InvalidParams(e.to_string()))?
255            .ok_or_else(|| McpError::InvalidParams("Resource read params required".to_string()))?;
256
257        let result = ResourceRegistry::read(&read_params.uri, &self.session).await?;
258
259        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
260    }
261
262    async fn handle_prompts_list(&self) -> McpResult<Value> {
263        let result = PromptListResult {
264            prompts: PromptRegistry::list_prompts(),
265            next_cursor: None,
266        };
267        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
268    }
269
270    async fn handle_prompts_get(&self, params: Option<Value>) -> McpResult<Value> {
271        let get_params: PromptGetParams = params
272            .map(serde_json::from_value)
273            .transpose()
274            .map_err(|e| McpError::InvalidParams(e.to_string()))?
275            .ok_or_else(|| McpError::InvalidParams("Prompt get params required".to_string()))?;
276
277        let result = PromptRegistry::get(&get_params.name, get_params.arguments).await?;
278
279        serde_json::to_value(result).map_err(|e| McpError::InternalError(e.to_string()))
280    }
281}
282
283/// Truncate a JSON value to a short summary string.
284fn truncate_json_summary(value: &Value, max_len: usize) -> String {
285    let s = value.to_string();
286    if s.len() <= max_len {
287        s
288    } else {
289        format!("{}…", &s[..max_len])
290    }
291}
292
293/// Try to extract a capture_id from a tool call result.
294fn extract_capture_id(result: &crate::types::ToolCallResult) -> Option<u64> {
295    for content in &result.content {
296        if let crate::types::ToolContent::Text { text } = content {
297            if let Ok(v) = serde_json::from_str::<Value>(text) {
298                if let Some(id) = v.get("capture_id").and_then(|v| v.as_u64()) {
299                    return Some(id);
300                }
301            }
302        }
303    }
304    None
305}