matrixcode_core/matrixrpc/router/
tool_router.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use serde_json::Value as JsonValue;
11
12use crate::matrixrpc::{
13 ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
14 RegistryService, ServiceId, ServiceStatus,
15};
16
17#[derive(Debug, thiserror::Error)]
19pub enum ToolRouterError {
20 #[error("Tool '{0}' not found in any registered service")]
22 ToolNotFound(String),
23
24 #[error("Service '{service_id}' for tool '{tool_name}' is not running (status: {status:?})")]
26 ServiceNotRunning {
27 tool_name: String,
28 service_id: ServiceId,
29 status: ServiceStatus,
30 },
31
32 #[error("No services registered in the registry")]
34 NoServicesRegistered,
35
36 #[error("Routing failed: {0}")]
38 RoutingFailed(String),
39
40 #[error("Invalid parameters for tool '{tool}': {reason}")]
42 InvalidParams { tool: String, reason: String },
43
44 #[error("Internal error: {0}")]
46 Internal(String),
47}
48
49#[derive(Debug, Clone)]
51pub struct ToolRouteResult {
52 pub service_id: ServiceId,
54 pub tool_name: String,
56 pub params: JsonValue,
58 pub request_id: JsonRpcId,
60}
61
62#[derive(Debug, Clone)]
64pub struct ToolDefinition {
65 pub name: String,
67 pub service_id: ServiceId,
69 pub description: Option<String>,
71 pub risk_level: Option<String>,
73 pub timeout_ms: Option<u64>,
75}
76
77#[derive(Debug)]
82pub struct ToolRouter {
83 registry: Arc<RegistryService>,
85 tool_index: Arc<RwLock<HashMap<String, ToolDefinition>>>,
87 default_timeout_ms: u64,
89}
90
91impl ToolRouter {
92 pub fn new(registry: Arc<RegistryService>) -> Self {
94 Self {
95 registry,
96 tool_index: Arc::new(RwLock::new(HashMap::new())),
97 default_timeout_ms: 30_000, }
99 }
100
101 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
103 self.default_timeout_ms = timeout_ms;
104 self
105 }
106
107 pub async fn register_tool(&self, _service_id: ServiceId, tool_def: ToolDefinition) {
111 let mut index = self.tool_index.write().await;
112 index.insert(tool_def.name.clone(), tool_def);
113 }
114
115 pub async fn unregister_service_tools(&self, service_id: &ServiceId) {
119 let mut index = self.tool_index.write().await;
120 index.retain(|_, def| def.service_id != *service_id);
121 }
122
123 pub async fn rebuild_index(&self) -> Result<(), ToolRouterError> {
127 let services = self.registry.list_all().await;
128 let mut index = self.tool_index.write().await;
129 index.clear();
130
131 for service in services {
132 if service.status != ServiceStatus::Running {
133 continue;
134 }
135
136 for cap in &service.capabilities {
138 if cap.name == "tools" {
139 if let Some(tools_json) = cap.config.get("tools") {
141 if let Ok(tools) = serde_json::from_value::<Vec<JsonValue>>(tools_json.clone()) {
142 for tool in tools {
143 if let Some(name) = tool.get("name").and_then(|n| n.as_str()) {
144 let def = ToolDefinition {
145 name: name.to_string(),
146 service_id: service.id.clone(),
147 description: tool.get("description").and_then(|d| d.as_str()).map(|s| s.to_string()),
148 risk_level: tool.get("risk_level").and_then(|r| r.as_str()).map(|s| s.to_string()),
149 timeout_ms: tool.get("timeout_ms").and_then(|t| t.as_u64()),
150 };
151 index.insert(name.to_string(), def);
152 }
153 }
154 }
155 }
156 }
157 }
158 }
159
160 Ok(())
161 }
162
163 pub async fn route(
168 &self,
169 tool_name: &str,
170 params: JsonValue,
171 request_id: JsonRpcId,
172 ) -> Result<ToolRouteResult, ToolRouterError> {
173 let index = self.tool_index.read().await;
175 let tool_def = index
176 .get(tool_name)
177 .cloned()
178 .ok_or_else(|| ToolRouterError::ToolNotFound(tool_name.to_string()))?;
179
180 let service = self.registry.get(&tool_def.service_id).await;
182 match service {
183 Some(s) if s.status == ServiceStatus::Running => {
184 Ok(ToolRouteResult {
186 service_id: tool_def.service_id,
187 tool_name: tool_def.name,
188 params,
189 request_id,
190 })
191 }
192 Some(s) => {
193 Err(ToolRouterError::ServiceNotRunning {
195 tool_name: tool_name.to_string(),
196 service_id: tool_def.service_id,
197 status: s.status,
198 })
199 }
200 None => {
201 Err(ToolRouterError::ToolNotFound(tool_name.to_string()))
203 }
204 }
205 }
206
207 pub async fn has_tool(&self, tool_name: &str) -> bool {
209 let index = self.tool_index.read().await;
210 index.contains_key(tool_name)
211 }
212
213 pub async fn list_tools(&self) -> Vec<ToolDefinition> {
215 let index = self.tool_index.read().await;
216 index.values().cloned().collect()
217 }
218
219 pub async fn get_tool(&self, tool_name: &str) -> Option<ToolDefinition> {
221 let index = self.tool_index.read().await;
222 index.get(tool_name).cloned()
223 }
224
225 pub fn create_tool_request(&self, route_result: ToolRouteResult) -> JsonRpcRequest {
229 JsonRpcRequest::with_id("tool.execute", route_result.request_id)
230 .params(serde_json::json!({
231 "tool_name": route_result.tool_name,
232 "params": route_result.params
233 }))
234 }
235
236 pub async fn create_error_response(
238 &self,
239 error: ToolRouterError,
240 request_id: JsonRpcId,
241 ) -> JsonRpcResponse {
242 let (code, message, data) = match error {
243 ToolRouterError::ToolNotFound(tool) => {
244 let index = self.tool_index.read().await;
245 let available: Vec<String> = index.keys().cloned().collect();
246 (
247 ErrorCode::RESOURCE_NOT_FOUND,
248 format!("Tool '{}' not found", tool),
249 Some(serde_json::json!({ "available_tools": available })),
250 )
251 }
252 ToolRouterError::ServiceNotRunning { tool_name, service_id, status } => (
253 ErrorCode::INVALID_STATE,
254 format!("Service '{}' is not running", service_id),
255 Some(serde_json::json!({
256 "tool_name": tool_name,
257 "service_id": service_id.to_string(),
258 "status": serde_json::to_string(&status).unwrap_or_default()
259 })),
260 ),
261 ToolRouterError::NoServicesRegistered => (
262 ErrorCode::RESOURCE_NOT_FOUND,
263 "No services registered".to_string(),
264 None,
265 ),
266 ToolRouterError::InvalidParams { tool, reason } => (
267 ErrorCode::INVALID_PARAMS,
268 format!("Invalid parameters for tool '{}'", tool),
269 Some(serde_json::json!({ "reason": reason })),
270 ),
271 ToolRouterError::RoutingFailed(msg) | ToolRouterError::Internal(msg) => (
272 ErrorCode::INTERNAL_ERROR,
273 msg,
274 None,
275 ),
276 };
277
278 JsonRpcResponse::error(request_id, JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)))
279 }
280
281 pub async fn get_timeout(&self, tool_name: &str) -> u64 {
283 let index = self.tool_index.read().await;
284 index
285 .get(tool_name)
286 .and_then(|def| def.timeout_ms)
287 .unwrap_or(self.default_timeout_ms)
288 }
289
290 pub async fn tool_count(&self) -> usize {
292 let index = self.tool_index.read().await;
293 index.len()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::matrixrpc::{Capability, ExtensionService};
301
302 #[tokio::test]
303 async fn test_tool_router_creation() {
304 let registry = Arc::new(RegistryService::new());
305 let router = ToolRouter::new(registry);
306 assert_eq!(router.default_timeout_ms, 30_000);
307 }
308
309 #[tokio::test]
310 async fn test_register_tool() {
311 let registry = Arc::new(RegistryService::new());
312 let router = ToolRouter::new(registry);
313
314 let service_id = ServiceId::new("test-service");
315 let tool_def = ToolDefinition {
316 name: "test_tool".to_string(),
317 service_id: service_id.clone(),
318 description: Some("A test tool".to_string()),
319 risk_level: Some("safe".to_string()),
320 timeout_ms: Some(5000),
321 };
322
323 router.register_tool(service_id, tool_def).await;
324 assert!(router.has_tool("test_tool").await);
325 }
326
327 #[tokio::test]
328 async fn test_list_tools() {
329 let registry = Arc::new(RegistryService::new());
330 let router = ToolRouter::new(registry);
331
332 let service_id = ServiceId::new("test-service");
333 router.register_tool(service_id.clone(), ToolDefinition {
334 name: "tool1".to_string(),
335 service_id: service_id.clone(),
336 description: None,
337 risk_level: None,
338 timeout_ms: None,
339 }).await;
340
341 router.register_tool(service_id.clone(), ToolDefinition {
342 name: "tool2".to_string(),
343 service_id: service_id.clone(),
344 description: None,
345 risk_level: None,
346 timeout_ms: None,
347 }).await;
348
349 let tools = router.list_tools().await;
350 assert_eq!(tools.len(), 2);
351 }
352
353 #[tokio::test]
354 async fn test_route_tool_not_found() {
355 let registry = Arc::new(RegistryService::new());
356 let router = ToolRouter::new(registry);
357
358 let result = router.route(
359 "unknown_tool",
360 serde_json::json!({}),
361 JsonRpcId::Number(1),
362 ).await;
363
364 assert!(matches!(result, Err(ToolRouterError::ToolNotFound(_))));
365 }
366
367 #[tokio::test]
368 async fn test_create_tool_request() {
369 let registry = Arc::new(RegistryService::new());
370 let router = ToolRouter::new(registry);
371
372 let route_result = ToolRouteResult {
373 service_id: ServiceId::new("test-service"),
374 tool_name: "test_tool".to_string(),
375 params: serde_json::json!({"arg": "value"}),
376 request_id: JsonRpcId::Number(1),
377 };
378
379 let request = router.create_tool_request(route_result);
380 assert_eq!(request.method, "tool.execute");
381 assert!(request.params.is_some());
382 }
383
384 #[tokio::test]
385 async fn test_unregister_service_tools() {
386 let registry = Arc::new(RegistryService::new());
387 let router = ToolRouter::new(registry);
388
389 let service_id = ServiceId::new("test-service");
390 router.register_tool(service_id.clone(), ToolDefinition {
391 name: "tool1".to_string(),
392 service_id: service_id.clone(),
393 description: None,
394 risk_level: None,
395 timeout_ms: None,
396 }).await;
397
398 assert!(router.has_tool("tool1").await);
399 router.unregister_service_tools(&service_id).await;
400 assert!(!router.has_tool("tool1").await);
401 }
402}