use std::collections::HashMap;
use std::sync::Arc;
use uira_core::ToolOutput;
use crate::tools::provider::ToolProvider;
use crate::tools::{BoxedTool, Tool, ToolContext, ToolError};
pub struct ToolRouter {
tools: HashMap<String, BoxedTool>,
providers: Vec<Arc<dyn ToolProvider>>,
}
impl ToolRouter {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
providers: Vec::new(),
}
}
pub fn register(&mut self, tool: impl Tool + 'static) {
let name = tool.name().to_string();
self.tools.insert(name, Arc::new(tool));
}
pub fn register_boxed(&mut self, tool: BoxedTool) {
let name = tool.name().to_string();
self.tools.insert(name, tool);
}
pub fn register_provider(&mut self, provider: Arc<dyn ToolProvider>) {
self.providers.push(provider);
}
pub fn get(&self, name: &str) -> Option<&BoxedTool> {
self.tools.get(name)
}
pub fn has(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn tool_supports_parallel(&self, name: &str) -> bool {
self.tools
.get(name)
.map(|t| t.supports_parallel())
.unwrap_or(false)
}
pub async fn dispatch(
&self,
name: &str,
input: serde_json::Value,
ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
if let Some(tool) = self.tools.get(name) {
return tool.execute(input, ctx).await;
}
for provider in &self.providers {
if provider.handles(name) {
return provider.execute(name, input, ctx).await;
}
}
Err(ToolError::NotFound {
name: name.to_string(),
})
}
pub fn names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.tools.keys().map(String::as_str).collect();
names.sort_unstable();
names
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn specs(&self) -> Vec<uira_core::ToolSpec> {
let mut specs: Vec<uira_core::ToolSpec> = self
.tools
.values()
.map(|t| uira_core::ToolSpec::new(t.name(), t.description(), t.schema()))
.collect();
for provider in &self.providers {
specs.extend(provider.specs());
}
specs
}
}
impl Default for ToolRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::FunctionTool;
use serde_json::json;
use uira_core::JsonSchema;
#[tokio::test]
async fn test_router_dispatch() {
let mut router = ToolRouter::new();
router.register(FunctionTool::new(
"echo",
"Echo input",
JsonSchema::object(),
|input: serde_json::Value| async move { Ok(ToolOutput::text(input.to_string())) },
));
assert!(router.has("echo"));
assert!(!router.has("nonexistent"));
let ctx = ToolContext::default();
let result = router
.dispatch("echo", json!({"msg": "hello"}), &ctx)
.await
.unwrap();
assert!(result.as_text().unwrap().contains("hello"));
}
#[tokio::test]
async fn test_router_missing_tool() {
let router = ToolRouter::new();
let ctx = ToolContext::default();
let err = router
.dispatch("missing", json!({}), &ctx)
.await
.unwrap_err();
assert!(matches!(err, ToolError::NotFound { .. }));
}
}