Skip to main content

matrixcode_core/matrixrpc/callback/
tool.rs

1//! Tool Callback Handler
2//!
3//! Handles tool callback requests from external services.
4//! Enables external nodes to call MatrixCode's registered tools.
5
6use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value as JsonValue;
10
11use super::security::SecurityValidator;
12use crate::matrixrpc::{
13    ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse, ServiceId, ToolRouter,
14};
15
16/// Tool callback request
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ToolCallbackRequest {
19    /// Request ID from original node execution
20    pub request_id: String,
21
22    /// Service ID making the callback
23    pub service_id: ServiceId,
24
25    /// Security token
26    pub token: String,
27
28    /// Tool name to execute
29    pub tool_name: String,
30
31    /// Tool parameters
32    #[serde(default)]
33    pub params: JsonValue,
34
35    /// Timeout in milliseconds
36    #[serde(default = "default_tool_timeout")]
37    pub timeout_ms: u64,
38
39    /// Whether to require approval
40    #[serde(default)]
41    pub require_approval: bool,
42}
43
44fn default_tool_timeout() -> u64 {
45    30_000 // 30 seconds
46}
47
48/// Tool callback result
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ToolCallbackResult {
51    /// Tool name that was executed
52    pub tool_name: String,
53
54    /// Result data
55    pub result: JsonValue,
56
57    /// Execution status
58    pub status: String,
59
60    /// Duration in milliseconds
61    pub duration_ms: u64,
62
63    /// Whether approval was required
64    pub approval_required: bool,
65
66    /// Additional metadata
67    #[serde(default)]
68    pub metadata: JsonValue,
69}
70
71/// Tool callback error
72#[derive(Debug, thiserror::Error)]
73pub enum ToolCallbackError {
74    /// Security validation failed
75    #[error("Security validation failed: {0}")]
76    SecurityFailed(String),
77
78    /// Tool not found
79    #[error("Tool '{0}' not found")]
80    ToolNotFound(String),
81
82    /// Tool execution failed
83    #[error("Tool '{tool}' execution failed: {reason}")]
84    ExecutionFailed { tool: String, reason: String },
85
86    /// Invalid parameters
87    #[error("Invalid parameters for tool '{tool}': {reason}")]
88    InvalidParams { tool: String, reason: String },
89
90    /// Timeout
91    #[error("Tool '{0}' timed out after {1}ms")]
92    Timeout(String, u64),
93
94    /// Tool not allowed
95    #[error("Tool '{0}' is not allowed for callback")]
96    ToolNotAllowed(String),
97
98    /// Internal error
99    #[error("Internal error: {0}")]
100    Internal(String),
101}
102
103/// Allowed tools for callback
104#[derive(Debug, Clone)]
105pub struct AllowedToolsConfig {
106    /// Tools that are always allowed
107    pub always_allowed: Vec<String>,
108
109    /// Tools that require approval
110    pub requires_approval: Vec<String>,
111
112    /// Tools that are never allowed for callback
113    pub never_allowed: Vec<String>,
114
115    /// Allow all tools (dangerous)
116    pub allow_all: bool,
117}
118
119impl Default for AllowedToolsConfig {
120    fn default() -> Self {
121        Self {
122            // Read-only tools are always allowed
123            always_allowed: vec![
124                "read".to_string(), "grep".to_string(), "glob".to_string(),
125                "codegraph_search".to_string(), "codegraph_node".to_string(),
126                "codegraph_context".to_string(), "codegraph_callers".to_string(),
127                "codegraph_callees".to_string(),
128            ],
129            // Write tools require approval
130            requires_approval: vec![
131                "write".to_string(), "edit".to_string(), "bash".to_string(),
132                "tool_search".to_string(),
133            ],
134            // Dangerous tools are never allowed
135            never_allowed: vec![
136                "delete".to_string(), "rm".to_string(), "format".to_string(),
137                "sudo".to_string(),
138            ],
139            allow_all: false,
140        }
141    }
142}
143
144/// Tool Callback Handler
145///
146/// Handles tool callback requests from external extension services.
147pub struct ToolCallbackHandler {
148    /// Security validator
149    security: Arc<SecurityValidator>,
150
151    /// Tool router for routing tool calls
152    tool_router: Arc<ToolRouter>,
153
154    /// Allowed tools configuration
155    allowed_tools: AllowedToolsConfig,
156
157    /// Default timeout
158    default_timeout_ms: u64,
159}
160
161impl ToolCallbackHandler {
162    /// Create a new tool callback handler
163    pub fn new(security: Arc<SecurityValidator>, tool_router: Arc<ToolRouter>) -> Self {
164        Self {
165            security,
166            tool_router,
167            allowed_tools: AllowedToolsConfig::default(),
168            default_timeout_ms: 30_000,
169        }
170    }
171
172    /// Set allowed tools configuration
173    pub fn with_allowed_tools(mut self, config: AllowedToolsConfig) -> Self {
174        self.allowed_tools = config;
175        self
176    }
177
178    /// Set default timeout
179    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
180        self.default_timeout_ms = timeout_ms;
181        self
182    }
183
184    /// Handle a tool callback request
185    pub async fn handle(&self, request: ToolCallbackRequest) -> Result<ToolCallbackResult, ToolCallbackError> {
186        // Validate security
187        let validation = self
188            .security
189            .validate(&request.token, &request.service_id, &request.request_id, "tool")
190            .await;
191
192        if !validation.is_valid {
193            return Err(ToolCallbackError::SecurityFailed(
194                validation.error.unwrap_or_else(|| "Unknown security error".to_string()),
195            ));
196        }
197
198        // Check if tool is allowed
199        let (approval_required, allowed) = self.check_tool_allowed(&request.tool_name);
200
201        if !allowed {
202            return Err(ToolCallbackError::ToolNotAllowed(request.tool_name));
203        }
204
205        // Route the tool
206        let route_result = self
207            .tool_router
208            .route(
209                &request.tool_name,
210                request.params.clone(),
211                JsonRpcId::generate(),
212            )
213            .await
214            .map_err(|e| match e {
215                crate::matrixrpc::ToolRouterError::ToolNotFound(tool) => {
216                    ToolCallbackError::ToolNotFound(tool)
217                }
218                _ => ToolCallbackError::Internal(e.to_string()),
219            })?;
220
221        // Build result
222        // Note: In real implementation, this would execute the actual tool
223        // Here we provide a mock implementation
224        let result = ToolCallbackResult {
225            tool_name: request.tool_name.clone(),
226            result: serde_json::json!({
227                "status": "success",
228                "message": format!("Tool '{}' executed successfully", request.tool_name),
229            }),
230            status: "success".to_string(),
231            duration_ms: 100,
232            approval_required,
233            metadata: serde_json::json!({
234                "request_id": request.request_id,
235                "service_id": request.service_id.to_string(),
236                "route": {
237                    "service_id": route_result.service_id.to_string(),
238                },
239            }),
240        };
241
242        Ok(result)
243    }
244
245    /// Check if a tool is allowed for callback
246    fn check_tool_allowed(&self, tool_name: &str) -> (bool, bool) {
247        // Check if never allowed
248        if self.allowed_tools.never_allowed.contains(&tool_name.to_string()) {
249            return (false, false);
250        }
251
252        // Check if allow all
253        if self.allowed_tools.allow_all {
254            return (false, true);
255        }
256
257        // Check if always allowed
258        if self.allowed_tools.always_allowed.contains(&tool_name.to_string()) {
259            return (false, true);
260        }
261
262        // Check if requires approval
263        if self.allowed_tools.requires_approval.contains(&tool_name.to_string()) {
264            return (true, true);
265        }
266
267        // Default: not allowed
268        (false, false)
269    }
270
271    /// Create a JSON-RPC error response for tool callback failures
272    pub fn create_error_response(&self, error: ToolCallbackError, id: JsonRpcId) -> JsonRpcResponse {
273        let (code, message, data) = match error {
274            ToolCallbackError::SecurityFailed(msg) => (
275                ErrorCode::PERMISSION_DENIED,
276                "Security validation failed".to_string(),
277                Some(serde_json::json!({ "reason": msg })),
278            ),
279            ToolCallbackError::ToolNotFound(tool) => (
280                ErrorCode::RESOURCE_NOT_FOUND,
281                format!("Tool '{}' not found", tool),
282                None,
283            ),
284            ToolCallbackError::ExecutionFailed { tool, reason } => (
285                ErrorCode::INTERNAL_ERROR,
286                "Tool execution failed".to_string(),
287                Some(serde_json::json!({ "tool": tool, "reason": reason })),
288            ),
289            ToolCallbackError::InvalidParams { tool, reason } => (
290                ErrorCode::INVALID_PARAMS,
291                "Invalid tool parameters".to_string(),
292                Some(serde_json::json!({ "tool": tool, "reason": reason })),
293            ),
294            ToolCallbackError::Timeout(tool, ms) => (
295                ErrorCode::TIMEOUT_ERROR,
296                "Tool timed out".to_string(),
297                Some(serde_json::json!({ "tool": tool, "timeout_ms": ms })),
298            ),
299            ToolCallbackError::ToolNotAllowed(tool) => (
300                ErrorCode::PERMISSION_DENIED,
301                format!("Tool '{}' is not allowed for callback", tool),
302                None,
303            ),
304            ToolCallbackError::Internal(msg) => (
305                ErrorCode::INTERNAL_ERROR,
306                msg,
307                None,
308            ),
309        };
310
311        JsonRpcResponse::error(
312            id,
313            JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
314        )
315    }
316
317    /// List allowed tools
318    pub fn list_allowed_tools(&self) -> Vec<String> {
319        let mut tools = self.allowed_tools.always_allowed.clone();
320        tools.extend(self.allowed_tools.requires_approval.clone());
321        tools
322    }
323
324    /// Check if tool exists in router
325    pub async fn tool_exists(&self, tool_name: &str) -> bool {
326        self.tool_router.has_tool(tool_name).await
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::matrixrpc::RegistryService;
334
335    #[tokio::test]
336    async fn test_tool_callback_handler_creation() {
337        let security = Arc::new(SecurityValidator::new());
338        let registry = Arc::new(RegistryService::new());
339        let tool_router = Arc::new(ToolRouter::new(registry));
340        let handler = ToolCallbackHandler::new(security, tool_router);
341
342        assert_eq!(handler.default_timeout_ms, 30_000);
343    }
344
345    #[test]
346    fn test_allowed_tools_config_default() {
347        let config = AllowedToolsConfig::default();
348
349        assert!(config.always_allowed.contains(&"read".to_string()));
350        assert!(config.requires_approval.contains(&"write".to_string()));
351        assert!(config.never_allowed.contains(&"delete".to_string()));
352        assert!(!config.allow_all);
353    }
354
355    #[test]
356    fn test_check_tool_allowed() {
357        let security = Arc::new(SecurityValidator::new());
358        let registry = Arc::new(RegistryService::new());
359        let tool_router = Arc::new(ToolRouter::new(registry));
360        let handler = ToolCallbackHandler::new(security, tool_router);
361
362        // Always allowed
363        let (approval, allowed) = handler.check_tool_allowed("read");
364        assert!(!approval);
365        assert!(allowed);
366
367        // Requires approval
368        let (approval, allowed) = handler.check_tool_allowed("write");
369        assert!(approval);
370        assert!(allowed);
371
372        // Never allowed
373        let (approval, allowed) = handler.check_tool_allowed("delete");
374        assert!(!approval);
375        assert!(!allowed);
376
377        // Unknown tool
378        let (approval, allowed) = handler.check_tool_allowed("unknown");
379        assert!(!approval);
380        assert!(!allowed);
381    }
382
383    #[tokio::test]
384    async fn test_tool_callback_security_validation() {
385        let security = Arc::new(SecurityValidator::new());
386        let registry = Arc::new(RegistryService::new());
387        let tool_router = Arc::new(ToolRouter::new(registry));
388
389        // Register a tool
390        tool_router
391            .register_tool(
392                ServiceId::new("test-service"),
393                crate::matrixrpc::ToolDefinition {
394                    name: "read".to_string(),
395                    service_id: ServiceId::new("test-service"),
396                    description: None,
397                    risk_level: None,
398                    timeout_ms: None,
399                },
400            )
401            .await;
402
403        let handler = ToolCallbackHandler::new(security.clone(), tool_router);
404
405        // Generate token
406        let service_id = ServiceId::new("callback-service");
407        let request_id = "req-001".to_string();
408        let token = security
409            .generate_token(service_id.clone(), request_id.clone(), vec!["tool".to_string()])
410            .await
411            .unwrap();
412
413        let request = ToolCallbackRequest {
414            request_id,
415            service_id,
416            token,
417            tool_name: "read".to_string(),
418            params: serde_json::json!({}),
419            timeout_ms: 30_000,
420            require_approval: false,
421        };
422
423        let result = handler.handle(request).await;
424        // Tool exists and is allowed
425        assert!(result.is_ok() || matches!(result, Err(ToolCallbackError::ToolNotFound(_))));
426    }
427
428    #[tokio::test]
429    async fn test_tool_callback_invalid_token() {
430        let security = Arc::new(SecurityValidator::new());
431        let registry = Arc::new(RegistryService::new());
432        let tool_router = Arc::new(ToolRouter::new(registry));
433        let handler = ToolCallbackHandler::new(security, tool_router);
434
435        let request = ToolCallbackRequest {
436            request_id: "req-001".to_string(),
437            service_id: ServiceId::new("test-service"),
438            token: "invalid_token".to_string(),
439            tool_name: "read".to_string(),
440            params: serde_json::json!({}),
441            timeout_ms: 30_000,
442            require_approval: false,
443        };
444
445        let result = handler.handle(request).await;
446        assert!(matches!(result, Err(ToolCallbackError::SecurityFailed(_))));
447    }
448
449    #[test]
450    fn test_list_allowed_tools() {
451        let security = Arc::new(SecurityValidator::new());
452        let registry = Arc::new(RegistryService::new());
453        let tool_router = Arc::new(ToolRouter::new(registry));
454        let handler = ToolCallbackHandler::new(security, tool_router);
455
456        let tools = handler.list_allowed_tools();
457        assert!(tools.contains(&"read".to_string()));
458        assert!(tools.contains(&"write".to_string()));
459    }
460}