use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde_json::Value as JsonValue;
use crate::matrixrpc::{
ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
RegistryService, ServiceId, ServiceStatus,
};
#[derive(Debug, thiserror::Error)]
pub enum ToolRouterError {
#[error("Tool '{0}' not found in any registered service")]
ToolNotFound(String),
#[error("Service '{service_id}' for tool '{tool_name}' is not running (status: {status:?})")]
ServiceNotRunning {
tool_name: String,
service_id: ServiceId,
status: ServiceStatus,
},
#[error("No services registered in the registry")]
NoServicesRegistered,
#[error("Routing failed: {0}")]
RoutingFailed(String),
#[error("Invalid parameters for tool '{tool}': {reason}")]
InvalidParams { tool: String, reason: String },
#[error("Internal error: {0}")]
Internal(String),
}
#[derive(Debug, Clone)]
pub struct ToolRouteResult {
pub service_id: ServiceId,
pub tool_name: String,
pub params: JsonValue,
pub request_id: JsonRpcId,
}
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub service_id: ServiceId,
pub description: Option<String>,
pub risk_level: Option<String>,
pub timeout_ms: Option<u64>,
}
#[derive(Debug)]
pub struct ToolRouter {
registry: Arc<RegistryService>,
tool_index: Arc<RwLock<HashMap<String, ToolDefinition>>>,
default_timeout_ms: u64,
}
impl ToolRouter {
pub fn new(registry: Arc<RegistryService>) -> Self {
Self {
registry,
tool_index: Arc::new(RwLock::new(HashMap::new())),
default_timeout_ms: 30_000, }
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.default_timeout_ms = timeout_ms;
self
}
pub async fn register_tool(&self, _service_id: ServiceId, tool_def: ToolDefinition) {
let mut index = self.tool_index.write().await;
index.insert(tool_def.name.clone(), tool_def);
}
pub async fn unregister_service_tools(&self, service_id: &ServiceId) {
let mut index = self.tool_index.write().await;
index.retain(|_, def| def.service_id != *service_id);
}
pub async fn rebuild_index(&self) -> Result<(), ToolRouterError> {
let services = self.registry.list_all().await;
let mut index = self.tool_index.write().await;
index.clear();
for service in services {
if service.status != ServiceStatus::Running {
continue;
}
for cap in &service.capabilities {
if cap.name == "tools" {
if let Some(tools_json) = cap.config.get("tools") {
if let Ok(tools) = serde_json::from_value::<Vec<JsonValue>>(tools_json.clone()) {
for tool in tools {
if let Some(name) = tool.get("name").and_then(|n| n.as_str()) {
let def = ToolDefinition {
name: name.to_string(),
service_id: service.id.clone(),
description: tool.get("description").and_then(|d| d.as_str()).map(|s| s.to_string()),
risk_level: tool.get("risk_level").and_then(|r| r.as_str()).map(|s| s.to_string()),
timeout_ms: tool.get("timeout_ms").and_then(|t| t.as_u64()),
};
index.insert(name.to_string(), def);
}
}
}
}
}
}
}
Ok(())
}
pub async fn route(
&self,
tool_name: &str,
params: JsonValue,
request_id: JsonRpcId,
) -> Result<ToolRouteResult, ToolRouterError> {
let index = self.tool_index.read().await;
let tool_def = index
.get(tool_name)
.cloned()
.ok_or_else(|| ToolRouterError::ToolNotFound(tool_name.to_string()))?;
let service = self.registry.get(&tool_def.service_id).await;
match service {
Some(s) if s.status == ServiceStatus::Running => {
Ok(ToolRouteResult {
service_id: tool_def.service_id,
tool_name: tool_def.name,
params,
request_id,
})
}
Some(s) => {
Err(ToolRouterError::ServiceNotRunning {
tool_name: tool_name.to_string(),
service_id: tool_def.service_id,
status: s.status,
})
}
None => {
Err(ToolRouterError::ToolNotFound(tool_name.to_string()))
}
}
}
pub async fn has_tool(&self, tool_name: &str) -> bool {
let index = self.tool_index.read().await;
index.contains_key(tool_name)
}
pub async fn list_tools(&self) -> Vec<ToolDefinition> {
let index = self.tool_index.read().await;
index.values().cloned().collect()
}
pub async fn get_tool(&self, tool_name: &str) -> Option<ToolDefinition> {
let index = self.tool_index.read().await;
index.get(tool_name).cloned()
}
pub fn create_tool_request(&self, route_result: ToolRouteResult) -> JsonRpcRequest {
JsonRpcRequest::with_id("tool.execute", route_result.request_id)
.params(serde_json::json!({
"tool_name": route_result.tool_name,
"params": route_result.params
}))
}
pub async fn create_error_response(
&self,
error: ToolRouterError,
request_id: JsonRpcId,
) -> JsonRpcResponse {
let (code, message, data) = match error {
ToolRouterError::ToolNotFound(tool) => {
let index = self.tool_index.read().await;
let available: Vec<String> = index.keys().cloned().collect();
(
ErrorCode::RESOURCE_NOT_FOUND,
format!("Tool '{}' not found", tool),
Some(serde_json::json!({ "available_tools": available })),
)
}
ToolRouterError::ServiceNotRunning { tool_name, service_id, status } => (
ErrorCode::INVALID_STATE,
format!("Service '{}' is not running", service_id),
Some(serde_json::json!({
"tool_name": tool_name,
"service_id": service_id.to_string(),
"status": serde_json::to_string(&status).unwrap_or_default()
})),
),
ToolRouterError::NoServicesRegistered => (
ErrorCode::RESOURCE_NOT_FOUND,
"No services registered".to_string(),
None,
),
ToolRouterError::InvalidParams { tool, reason } => (
ErrorCode::INVALID_PARAMS,
format!("Invalid parameters for tool '{}'", tool),
Some(serde_json::json!({ "reason": reason })),
),
ToolRouterError::RoutingFailed(msg) | ToolRouterError::Internal(msg) => (
ErrorCode::INTERNAL_ERROR,
msg,
None,
),
};
JsonRpcResponse::error(request_id, JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)))
}
pub async fn get_timeout(&self, tool_name: &str) -> u64 {
let index = self.tool_index.read().await;
index
.get(tool_name)
.and_then(|def| def.timeout_ms)
.unwrap_or(self.default_timeout_ms)
}
pub async fn tool_count(&self) -> usize {
let index = self.tool_index.read().await;
index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrixrpc::{Capability, ExtensionService};
#[tokio::test]
async fn test_tool_router_creation() {
let registry = Arc::new(RegistryService::new());
let router = ToolRouter::new(registry);
assert_eq!(router.default_timeout_ms, 30_000);
}
#[tokio::test]
async fn test_register_tool() {
let registry = Arc::new(RegistryService::new());
let router = ToolRouter::new(registry);
let service_id = ServiceId::new("test-service");
let tool_def = ToolDefinition {
name: "test_tool".to_string(),
service_id: service_id.clone(),
description: Some("A test tool".to_string()),
risk_level: Some("safe".to_string()),
timeout_ms: Some(5000),
};
router.register_tool(service_id, tool_def).await;
assert!(router.has_tool("test_tool").await);
}
#[tokio::test]
async fn test_list_tools() {
let registry = Arc::new(RegistryService::new());
let router = ToolRouter::new(registry);
let service_id = ServiceId::new("test-service");
router.register_tool(service_id.clone(), ToolDefinition {
name: "tool1".to_string(),
service_id: service_id.clone(),
description: None,
risk_level: None,
timeout_ms: None,
}).await;
router.register_tool(service_id.clone(), ToolDefinition {
name: "tool2".to_string(),
service_id: service_id.clone(),
description: None,
risk_level: None,
timeout_ms: None,
}).await;
let tools = router.list_tools().await;
assert_eq!(tools.len(), 2);
}
#[tokio::test]
async fn test_route_tool_not_found() {
let registry = Arc::new(RegistryService::new());
let router = ToolRouter::new(registry);
let result = router.route(
"unknown_tool",
serde_json::json!({}),
JsonRpcId::Number(1),
).await;
assert!(matches!(result, Err(ToolRouterError::ToolNotFound(_))));
}
#[tokio::test]
async fn test_create_tool_request() {
let registry = Arc::new(RegistryService::new());
let router = ToolRouter::new(registry);
let route_result = ToolRouteResult {
service_id: ServiceId::new("test-service"),
tool_name: "test_tool".to_string(),
params: serde_json::json!({"arg": "value"}),
request_id: JsonRpcId::Number(1),
};
let request = router.create_tool_request(route_result);
assert_eq!(request.method, "tool.execute");
assert!(request.params.is_some());
}
#[tokio::test]
async fn test_unregister_service_tools() {
let registry = Arc::new(RegistryService::new());
let router = ToolRouter::new(registry);
let service_id = ServiceId::new("test-service");
router.register_tool(service_id.clone(), ToolDefinition {
name: "tool1".to_string(),
service_id: service_id.clone(),
description: None,
risk_level: None,
timeout_ms: None,
}).await;
assert!(router.has_tool("tool1").await);
router.unregister_service_tools(&service_id).await;
assert!(!router.has_tool("tool1").await);
}
}