claude_code_acp/mcp/
registry.rs

1//! Tool registry for managing MCP tools
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use sacp::JrConnectionCx;
7use sacp::link::AgentToClient;
8use sacp::schema::{
9    SessionId, SessionNotification, SessionUpdate, Terminal, ToolCallContent, ToolCallId,
10    ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields,
11};
12use serde::{Deserialize, Serialize};
13
14use super::tools::Tool;
15use crate::session::BackgroundProcessManager;
16use crate::settings::PermissionChecker;
17use crate::terminal::TerminalClient;
18
19/// Tool execution result
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ToolResult {
22    /// Result status
23    pub status: ToolStatus,
24    /// Output content
25    pub content: String,
26    /// Whether this is an error
27    pub is_error: bool,
28    /// Additional metadata
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub metadata: Option<serde_json::Value>,
31}
32
33impl ToolResult {
34    /// Create a successful result
35    pub fn success(content: impl Into<String>) -> Self {
36        Self {
37            status: ToolStatus::Success,
38            content: content.into(),
39            is_error: false,
40            metadata: None,
41        }
42    }
43
44    /// Create an error result
45    pub fn error(message: impl Into<String>) -> Self {
46        Self {
47            status: ToolStatus::Error,
48            content: message.into(),
49            is_error: true,
50            metadata: None,
51        }
52    }
53
54    /// Create a result with metadata
55    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
56        self.metadata = Some(metadata);
57        self
58    }
59}
60
61/// Tool execution status
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "lowercase")]
64pub enum ToolStatus {
65    /// Tool executed successfully
66    Success,
67    /// Tool execution failed
68    Error,
69    /// Tool execution was cancelled
70    Cancelled,
71    /// Tool is still running (for async operations)
72    Running,
73}
74
75/// Tool execution context
76#[derive(Debug, Clone)]
77pub struct ToolContext {
78    /// Session ID
79    pub session_id: String,
80    /// Working directory
81    pub cwd: std::path::PathBuf,
82    /// Whether to allow dangerous operations
83    pub allow_dangerous: bool,
84    /// Background process manager
85    background_processes: Option<Arc<BackgroundProcessManager>>,
86    /// Terminal client for executing commands via Client PTY
87    terminal_client: Option<Arc<TerminalClient>>,
88    /// Current tool use ID (for sending mid-execution updates)
89    tool_use_id: Option<String>,
90    /// Connection context for sending notifications
91    connection_cx: Option<JrConnectionCx<AgentToClient>>,
92    /// Permission checker for tool-level permission checks
93    pub permission_checker: Option<Arc<tokio::sync::RwLock<PermissionChecker>>>,
94}
95
96impl ToolContext {
97    /// Create a new tool context
98    pub fn new(session_id: impl Into<String>, cwd: impl Into<std::path::PathBuf>) -> Self {
99        Self {
100            session_id: session_id.into(),
101            cwd: cwd.into(),
102            allow_dangerous: false,
103            background_processes: None,
104            terminal_client: None,
105            tool_use_id: None,
106            connection_cx: None,
107            permission_checker: None,
108        }
109    }
110
111    /// Set whether dangerous operations are allowed
112    pub fn with_dangerous(mut self, allow: bool) -> Self {
113        self.allow_dangerous = allow;
114        self
115    }
116
117    /// Set the background process manager
118    pub fn with_background_processes(mut self, manager: Arc<BackgroundProcessManager>) -> Self {
119        self.background_processes = Some(manager);
120        self
121    }
122
123    /// Set the terminal client
124    pub fn with_terminal_client(mut self, client: Arc<TerminalClient>) -> Self {
125        self.terminal_client = Some(client);
126        self
127    }
128
129    /// Set the current tool use ID
130    pub fn with_tool_use_id(mut self, id: impl Into<String>) -> Self {
131        self.tool_use_id = Some(id.into());
132        self
133    }
134
135    /// Set the connection context for sending notifications
136    pub fn with_connection_cx(mut self, cx: JrConnectionCx<AgentToClient>) -> Self {
137        self.connection_cx = Some(cx);
138        self
139    }
140
141    /// Set the permission checker for tool-level permission checks
142    pub fn with_permission_checker(
143        mut self,
144        checker: Arc<tokio::sync::RwLock<PermissionChecker>>,
145    ) -> Self {
146        self.permission_checker = Some(checker);
147        self
148    }
149
150    /// Get the background process manager
151    pub fn background_processes(&self) -> Option<&Arc<BackgroundProcessManager>> {
152        self.background_processes.as_ref()
153    }
154
155    /// Get the terminal client
156    ///
157    /// When available, tools can use this to execute commands via the Client's PTY
158    /// instead of directly spawning processes.
159    pub fn terminal_client(&self) -> Option<&Arc<TerminalClient>> {
160        self.terminal_client.as_ref()
161    }
162
163    /// Get the current tool use ID
164    pub fn tool_use_id(&self) -> Option<&str> {
165        self.tool_use_id.as_deref()
166    }
167
168    /// Send a ToolCallUpdate notification with Terminal content
169    ///
170    /// This is used by tools like Bash to send terminal ID immediately after
171    /// creating a terminal, so the client can start showing terminal output.
172    ///
173    /// # Arguments
174    ///
175    /// * `terminal_id` - The terminal ID from CreateTerminalResponse
176    /// * `status` - The tool call status (usually InProgress)
177    /// * `title` - Optional title/description for the tool call
178    ///
179    /// # Returns
180    ///
181    /// `Ok(())` if notification was sent, `Err` if context doesn't have connection
182    pub fn send_terminal_update(
183        &self,
184        terminal_id: impl Into<String>,
185        status: ToolCallStatus,
186        title: Option<&str>,
187    ) -> Result<(), String> {
188        let Some(connection_cx) = &self.connection_cx else {
189            return Err("No connection context available".to_string());
190        };
191
192        let Some(tool_use_id) = &self.tool_use_id else {
193            return Err("No tool use ID available".to_string());
194        };
195
196        // Build terminal content
197        let terminal = Terminal::new(terminal_id.into());
198        let content = vec![ToolCallContent::Terminal(terminal)];
199
200        // Build update fields
201        let mut update_fields = ToolCallUpdateFields::new().status(status).content(content);
202
203        if let Some(title) = title {
204            update_fields = update_fields.title(title);
205        }
206
207        // Build and send notification
208        let tool_call_id = ToolCallId::new(tool_use_id.clone());
209        let update = ToolCallUpdate::new(tool_call_id, update_fields);
210        let notification = SessionNotification::new(
211            SessionId::new(self.session_id.as_str()),
212            SessionUpdate::ToolCallUpdate(update),
213        );
214
215        connection_cx
216            .send_notification(notification)
217            .map_err(|e| format!("Failed to send notification: {}", e))
218    }
219}
220
221/// ACP tool prefix for compatibility with TypeScript implementation
222pub const ACP_TOOL_PREFIX: &str = "mcp__acp__";
223
224/// Tool registry for managing available tools
225#[derive(Debug, Default)]
226pub struct ToolRegistry {
227    /// Registered tools by name
228    tools: HashMap<String, Arc<dyn Tool>>,
229}
230
231impl ToolRegistry {
232    /// Create a new empty registry
233    pub fn new() -> Self {
234        Self {
235            tools: HashMap::new(),
236        }
237    }
238
239    /// Register a tool
240    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
241        let name = tool.name().to_string();
242        self.tools.insert(name, Arc::new(tool));
243    }
244
245    /// Register a tool as Arc
246    pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
247        let name = tool.name().to_string();
248        self.tools.insert(name, tool);
249    }
250
251    /// Get a tool by name, supporting ACP prefix
252    ///
253    /// If the tool name starts with `mcp__acp__`, it will try to find
254    /// the tool with the prefix stripped.
255    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
256        // Try direct lookup first
257        if let Some(tool) = self.tools.get(name) {
258            return Some(tool.clone());
259        }
260
261        // Try stripping ACP prefix
262        if let Some(stripped) = name.strip_prefix(ACP_TOOL_PREFIX) {
263            if let Some(tool) = self.tools.get(stripped) {
264                return Some(tool.clone());
265            }
266        }
267
268        None
269    }
270
271    /// Check if a tool exists, supporting ACP prefix
272    pub fn contains(&self, name: &str) -> bool {
273        if self.tools.contains_key(name) {
274            return true;
275        }
276
277        // Try stripping ACP prefix
278        if let Some(stripped) = name.strip_prefix(ACP_TOOL_PREFIX) {
279            return self.tools.contains_key(stripped);
280        }
281
282        false
283    }
284
285    /// Normalize a tool name by stripping ACP prefix if present
286    pub fn normalize_name(name: &str) -> &str {
287        name.strip_prefix(ACP_TOOL_PREFIX).unwrap_or(name)
288    }
289
290    /// Get all tool names
291    pub fn names(&self) -> Vec<&str> {
292        self.tools.keys().map(String::as_str).collect()
293    }
294
295    /// Get the number of registered tools
296    pub fn len(&self) -> usize {
297        self.tools.len()
298    }
299
300    /// Check if the registry is empty
301    pub fn is_empty(&self) -> bool {
302        self.tools.is_empty()
303    }
304
305    /// Execute a tool by name
306    pub async fn execute(
307        &self,
308        name: &str,
309        input: serde_json::Value,
310        context: &ToolContext,
311    ) -> ToolResult {
312        match self.get(name) {
313            Some(tool) => tool.execute(input, context).await,
314            None => ToolResult::error(format!("Tool not found: {}", name)),
315        }
316    }
317
318    /// Get tool schemas for all registered tools
319    pub fn schemas(&self) -> Vec<ToolSchema> {
320        self.tools
321            .values()
322            .map(|tool| ToolSchema {
323                name: tool.name().to_string(),
324                description: tool.description().to_string(),
325                input_schema: tool.input_schema(),
326            })
327            .collect()
328    }
329}
330
331/// Tool schema for registration
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct ToolSchema {
334    /// Tool name
335    pub name: String,
336    /// Tool description
337    pub description: String,
338    /// JSON Schema for input
339    pub input_schema: serde_json::Value,
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use serde_json::json;
346
347    #[test]
348    fn test_tool_result_success() {
349        let result = ToolResult::success("Hello, World!");
350        assert_eq!(result.status, ToolStatus::Success);
351        assert_eq!(result.content, "Hello, World!");
352        assert!(!result.is_error);
353    }
354
355    #[test]
356    fn test_tool_result_error() {
357        let result = ToolResult::error("Something went wrong");
358        assert_eq!(result.status, ToolStatus::Error);
359        assert_eq!(result.content, "Something went wrong");
360        assert!(result.is_error);
361    }
362
363    #[test]
364    fn test_tool_result_with_metadata() {
365        let result = ToolResult::success("data").with_metadata(json!({"lines": 10}));
366        assert!(result.metadata.is_some());
367    }
368
369    #[test]
370    fn test_tool_context() {
371        let ctx = ToolContext::new("session-1", "/tmp").with_dangerous(true);
372        assert_eq!(ctx.session_id, "session-1");
373        assert_eq!(ctx.cwd, std::path::PathBuf::from("/tmp"));
374        assert!(ctx.allow_dangerous);
375    }
376
377    #[test]
378    fn test_registry_new() {
379        let registry = ToolRegistry::new();
380        assert!(registry.is_empty());
381        assert_eq!(registry.len(), 0);
382    }
383
384    #[test]
385    fn test_acp_prefix_normalize() {
386        // Without prefix
387        assert_eq!(ToolRegistry::normalize_name("Read"), "Read");
388        assert_eq!(ToolRegistry::normalize_name("Bash"), "Bash");
389
390        // With prefix
391        assert_eq!(ToolRegistry::normalize_name("mcp__acp__Read"), "Read");
392        assert_eq!(ToolRegistry::normalize_name("mcp__acp__Bash"), "Bash");
393        assert_eq!(
394            ToolRegistry::normalize_name("mcp__acp__TodoWrite"),
395            "TodoWrite"
396        );
397    }
398
399    #[test]
400    fn test_acp_prefix_constant() {
401        assert_eq!(ACP_TOOL_PREFIX, "mcp__acp__");
402    }
403}