Skip to main content

matrixcode_core/matrixrpc/router/
tool_router.rs

1//! Tool Router Implementation
2//!
3//! Routes tool execution requests to the appropriate extension service.
4//! Maintains a mapping from tool names to service IDs for efficient routing.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use serde_json::Value as JsonValue;
11
12use crate::matrixrpc::{
13    ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
14    RegistryService, ServiceId, ServiceStatus,
15};
16
17/// Error type for tool routing operations
18#[derive(Debug, thiserror::Error)]
19pub enum ToolRouterError {
20    /// Tool not found in any registered service
21    #[error("Tool '{0}' not found in any registered service")]
22    ToolNotFound(String),
23
24    /// Service that provides the tool is not running
25    #[error("Service '{service_id}' for tool '{tool_name}' is not running (status: {status:?})")]
26    ServiceNotRunning {
27        tool_name: String,
28        service_id: ServiceId,
29        status: ServiceStatus,
30    },
31
32    /// No services registered
33    #[error("No services registered in the registry")]
34    NoServicesRegistered,
35
36    /// Routing failed
37    #[error("Routing failed: {0}")]
38    RoutingFailed(String),
39
40    /// Invalid tool parameters
41    #[error("Invalid parameters for tool '{tool}': {reason}")]
42    InvalidParams { tool: String, reason: String },
43
44    /// Internal error
45    #[error("Internal error: {0}")]
46    Internal(String),
47}
48
49/// Result of a tool routing operation
50#[derive(Debug, Clone)]
51pub struct ToolRouteResult {
52    /// The service ID that should handle the tool call
53    pub service_id: ServiceId,
54    /// The tool name (may be different if aliased)
55    pub tool_name: String,
56    /// The parameters to pass to the tool
57    pub params: JsonValue,
58    /// The original request ID for correlation
59    pub request_id: JsonRpcId,
60}
61
62/// Tool definition from an extension service
63#[derive(Debug, Clone)]
64pub struct ToolDefinition {
65    /// Tool name
66    pub name: String,
67    /// Service ID that provides this tool
68    pub service_id: ServiceId,
69    /// Tool description
70    pub description: Option<String>,
71    /// Risk level (safe, moderate, dangerous)
72    pub risk_level: Option<String>,
73    /// Timeout in milliseconds
74    pub timeout_ms: Option<u64>,
75}
76
77/// Tool Router
78///
79/// Routes tool execution requests to the appropriate extension service.
80/// Uses the registry service to discover which services provide each tool.
81#[derive(Debug)]
82pub struct ToolRouter {
83    /// Reference to the registry service
84    registry: Arc<RegistryService>,
85    /// Tool name to service ID mapping (cached)
86    tool_index: Arc<RwLock<HashMap<String, ToolDefinition>>>,
87    /// Default timeout for tool execution (milliseconds)
88    default_timeout_ms: u64,
89}
90
91impl ToolRouter {
92    /// Create a new tool router with a registry service
93    pub fn new(registry: Arc<RegistryService>) -> Self {
94        Self {
95            registry,
96            tool_index: Arc::new(RwLock::new(HashMap::new())),
97            default_timeout_ms: 30_000, // 30 seconds default
98        }
99    }
100
101    /// Set default timeout for tool execution
102    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
103        self.default_timeout_ms = timeout_ms;
104        self
105    }
106
107    /// Register a tool from a service
108    ///
109    /// This updates the internal tool index for fast routing.
110    pub async fn register_tool(&self, _service_id: ServiceId, tool_def: ToolDefinition) {
111        let mut index = self.tool_index.write().await;
112        index.insert(tool_def.name.clone(), tool_def);
113    }
114
115    /// Unregister all tools from a service
116    ///
117    /// Called when a service is unregistered or disconnected.
118    pub async fn unregister_service_tools(&self, service_id: &ServiceId) {
119        let mut index = self.tool_index.write().await;
120        index.retain(|_, def| def.service_id != *service_id);
121    }
122
123    /// Rebuild the tool index from the registry
124    ///
125    /// Useful after bulk registration changes.
126    pub async fn rebuild_index(&self) -> Result<(), ToolRouterError> {
127        let services = self.registry.list_all().await;
128        let mut index = self.tool_index.write().await;
129        index.clear();
130
131        for service in services {
132            if service.status != ServiceStatus::Running {
133                continue;
134            }
135
136            // Extract tools from service capabilities
137            for cap in &service.capabilities {
138                if cap.name == "tools" {
139                    // Parse tools from capability config
140                    if let Some(tools_json) = cap.config.get("tools") {
141                        if let Ok(tools) = serde_json::from_value::<Vec<JsonValue>>(tools_json.clone()) {
142                            for tool in tools {
143                                if let Some(name) = tool.get("name").and_then(|n| n.as_str()) {
144                                    let def = ToolDefinition {
145                                        name: name.to_string(),
146                                        service_id: service.id.clone(),
147                                        description: tool.get("description").and_then(|d| d.as_str()).map(|s| s.to_string()),
148                                        risk_level: tool.get("risk_level").and_then(|r| r.as_str()).map(|s| s.to_string()),
149                                        timeout_ms: tool.get("timeout_ms").and_then(|t| t.as_u64()),
150                                    };
151                                    index.insert(name.to_string(), def);
152                                }
153                            }
154                        }
155                    }
156                }
157            }
158        }
159
160        Ok(())
161    }
162
163    /// Route a tool execution request
164    ///
165    /// Given a tool name and parameters, find the appropriate service
166    /// and return routing information.
167    pub async fn route(
168        &self,
169        tool_name: &str,
170        params: JsonValue,
171        request_id: JsonRpcId,
172    ) -> Result<ToolRouteResult, ToolRouterError> {
173        // Look up the tool in the index
174        let index = self.tool_index.read().await;
175        let tool_def = index
176            .get(tool_name)
177            .cloned()
178            .ok_or_else(|| ToolRouterError::ToolNotFound(tool_name.to_string()))?;
179
180        // Check if the service is running
181        let service = self.registry.get(&tool_def.service_id).await;
182        match service {
183            Some(s) if s.status == ServiceStatus::Running => {
184                // Service is healthy, proceed with routing
185                Ok(ToolRouteResult {
186                    service_id: tool_def.service_id,
187                    tool_name: tool_def.name,
188                    params,
189                    request_id,
190                })
191            }
192            Some(s) => {
193                // Service exists but not running
194                Err(ToolRouterError::ServiceNotRunning {
195                    tool_name: tool_name.to_string(),
196                    service_id: tool_def.service_id,
197                    status: s.status,
198                })
199            }
200            None => {
201                // Service not found in registry (shouldn't happen if index is valid)
202                Err(ToolRouterError::ToolNotFound(tool_name.to_string()))
203            }
204        }
205    }
206
207    /// Check if a tool is available
208    pub async fn has_tool(&self, tool_name: &str) -> bool {
209        let index = self.tool_index.read().await;
210        index.contains_key(tool_name)
211    }
212
213    /// List all available tools
214    pub async fn list_tools(&self) -> Vec<ToolDefinition> {
215        let index = self.tool_index.read().await;
216        index.values().cloned().collect()
217    }
218
219    /// Get tool definition
220    pub async fn get_tool(&self, tool_name: &str) -> Option<ToolDefinition> {
221        let index = self.tool_index.read().await;
222        index.get(tool_name).cloned()
223    }
224
225    /// Create a JSON-RPC request for tool execution
226    ///
227    /// Creates the proper request format for the extension service.
228    pub fn create_tool_request(&self, route_result: ToolRouteResult) -> JsonRpcRequest {
229        JsonRpcRequest::with_id("tool.execute", route_result.request_id)
230            .params(serde_json::json!({
231                "tool_name": route_result.tool_name,
232                "params": route_result.params
233            }))
234    }
235
236    /// Create an error response for routing failures
237    pub async fn create_error_response(
238        &self,
239        error: ToolRouterError,
240        request_id: JsonRpcId,
241    ) -> JsonRpcResponse {
242        let (code, message, data) = match error {
243            ToolRouterError::ToolNotFound(tool) => {
244                let index = self.tool_index.read().await;
245                let available: Vec<String> = index.keys().cloned().collect();
246                (
247                    ErrorCode::RESOURCE_NOT_FOUND,
248                    format!("Tool '{}' not found", tool),
249                    Some(serde_json::json!({ "available_tools": available })),
250                )
251            }
252            ToolRouterError::ServiceNotRunning { tool_name, service_id, status } => (
253                ErrorCode::INVALID_STATE,
254                format!("Service '{}' is not running", service_id),
255                Some(serde_json::json!({
256                    "tool_name": tool_name,
257                    "service_id": service_id.to_string(),
258                    "status": serde_json::to_string(&status).unwrap_or_default()
259                })),
260            ),
261            ToolRouterError::NoServicesRegistered => (
262                ErrorCode::RESOURCE_NOT_FOUND,
263                "No services registered".to_string(),
264                None,
265            ),
266            ToolRouterError::InvalidParams { tool, reason } => (
267                ErrorCode::INVALID_PARAMS,
268                format!("Invalid parameters for tool '{}'", tool),
269                Some(serde_json::json!({ "reason": reason })),
270            ),
271            ToolRouterError::RoutingFailed(msg) | ToolRouterError::Internal(msg) => (
272                ErrorCode::INTERNAL_ERROR,
273                msg,
274                None,
275            ),
276        };
277
278        JsonRpcResponse::error(request_id, JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)))
279    }
280
281    /// Get the default timeout for a tool
282    pub async fn get_timeout(&self, tool_name: &str) -> u64 {
283        let index = self.tool_index.read().await;
284        index
285            .get(tool_name)
286            .and_then(|def| def.timeout_ms)
287            .unwrap_or(self.default_timeout_ms)
288    }
289
290    /// Get the count of registered tools
291    pub async fn tool_count(&self) -> usize {
292        let index = self.tool_index.read().await;
293        index.len()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::matrixrpc::{Capability, ExtensionService};
301
302    #[tokio::test]
303    async fn test_tool_router_creation() {
304        let registry = Arc::new(RegistryService::new());
305        let router = ToolRouter::new(registry);
306        assert_eq!(router.default_timeout_ms, 30_000);
307    }
308
309    #[tokio::test]
310    async fn test_register_tool() {
311        let registry = Arc::new(RegistryService::new());
312        let router = ToolRouter::new(registry);
313
314        let service_id = ServiceId::new("test-service");
315        let tool_def = ToolDefinition {
316            name: "test_tool".to_string(),
317            service_id: service_id.clone(),
318            description: Some("A test tool".to_string()),
319            risk_level: Some("safe".to_string()),
320            timeout_ms: Some(5000),
321        };
322
323        router.register_tool(service_id, tool_def).await;
324        assert!(router.has_tool("test_tool").await);
325    }
326
327    #[tokio::test]
328    async fn test_list_tools() {
329        let registry = Arc::new(RegistryService::new());
330        let router = ToolRouter::new(registry);
331
332        let service_id = ServiceId::new("test-service");
333        router.register_tool(service_id.clone(), ToolDefinition {
334            name: "tool1".to_string(),
335            service_id: service_id.clone(),
336            description: None,
337            risk_level: None,
338            timeout_ms: None,
339        }).await;
340
341        router.register_tool(service_id.clone(), ToolDefinition {
342            name: "tool2".to_string(),
343            service_id: service_id.clone(),
344            description: None,
345            risk_level: None,
346            timeout_ms: None,
347        }).await;
348
349        let tools = router.list_tools().await;
350        assert_eq!(tools.len(), 2);
351    }
352
353    #[tokio::test]
354    async fn test_route_tool_not_found() {
355        let registry = Arc::new(RegistryService::new());
356        let router = ToolRouter::new(registry);
357
358        let result = router.route(
359            "unknown_tool",
360            serde_json::json!({}),
361            JsonRpcId::Number(1),
362        ).await;
363
364        assert!(matches!(result, Err(ToolRouterError::ToolNotFound(_))));
365    }
366
367    #[tokio::test]
368    async fn test_create_tool_request() {
369        let registry = Arc::new(RegistryService::new());
370        let router = ToolRouter::new(registry);
371
372        let route_result = ToolRouteResult {
373            service_id: ServiceId::new("test-service"),
374            tool_name: "test_tool".to_string(),
375            params: serde_json::json!({"arg": "value"}),
376            request_id: JsonRpcId::Number(1),
377        };
378
379        let request = router.create_tool_request(route_result);
380        assert_eq!(request.method, "tool.execute");
381        assert!(request.params.is_some());
382    }
383
384    #[tokio::test]
385    async fn test_unregister_service_tools() {
386        let registry = Arc::new(RegistryService::new());
387        let router = ToolRouter::new(registry);
388
389        let service_id = ServiceId::new("test-service");
390        router.register_tool(service_id.clone(), ToolDefinition {
391            name: "tool1".to_string(),
392            service_id: service_id.clone(),
393            description: None,
394            risk_level: None,
395            timeout_ms: None,
396        }).await;
397
398        assert!(router.has_tool("tool1").await);
399        router.unregister_service_tools(&service_id).await;
400        assert!(!router.has_tool("tool1").await);
401    }
402}