use crate::raw::request::tool::Tool as RawTool;
use async_trait::async_trait;
use serde_json::Value;
use serde_json::json;
use std::collections::HashMap;
#[async_trait]
pub trait Tool: Send + Sync {
fn raw_tools(&self) -> Vec<RawTool>;
async fn call(&self, name: &str, args: Value) -> Value;
}
pub struct ToolBundle {
tools: Vec<Box<dyn Tool>>,
index: std::collections::HashMap<String, usize>,
}
impl Default for ToolBundle {
fn default() -> Self {
Self::new()
}
}
impl ToolBundle {
pub fn new() -> Self {
Self {
tools: vec![],
index: HashMap::new(),
}
}
pub fn add<T: Tool + 'static>(mut self, tool: T) -> Self {
let idx = self.tools.len();
for raw in tool.raw_tools() {
self.index.insert(raw.function.name.clone(), idx);
}
self.tools.push(Box::new(tool));
self
}
}
#[async_trait]
impl Tool for ToolBundle {
fn raw_tools(&self) -> Vec<RawTool> {
self.tools.iter().flat_map(|t| t.raw_tools()).collect()
}
async fn call(&self, name: &str, args: Value) -> Value {
match self.index.get(name) {
Some(&idx) => self.tools[idx].call(name, args).await,
None => json!({ "error": format!("未知工具: {name}") }),
}
}
}