use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::Arc;
use crate::error::{Error, Result};
use crate::message::ToolResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn schema(&self) -> ToolSchema;
async fn call(&self, call_id: &str, args: serde_json::Value) -> ToolResult;
}
#[derive(Clone, Default)]
pub struct ToolRegistry {
tools: BTreeMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<T: Tool + 'static>(&mut self, tool: T) {
let name = tool.schema().name;
self.tools.insert(name, Arc::new(tool));
}
pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
let name = tool.schema().name;
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>> {
self.tools
.get(name)
.cloned()
.ok_or_else(|| Error::ToolNotFound { name: name.into() })
}
pub fn schemas(&self) -> Vec<ToolSchema> {
self.tools.values().map(|t| t.schema()).collect()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn len(&self) -> usize {
self.tools.len()
}
}