use std::{collections::HashMap, path::PathBuf};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolSpec {
pub name: String,
pub description: String,
pub input_schema: Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolResult {
pub id: String,
pub ok: bool,
pub content: String,
}
#[derive(Clone, Debug)]
pub struct ToolContext {
pub cwd: PathBuf,
pub max_output_bytes: usize,
}
#[derive(Debug, Error)]
pub enum ToolError {
#[error("invalid arguments: {0}")]
InvalidArguments(#[from] serde_json::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("utf-8 error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("{0}")]
Message(String),
}
#[async_trait::async_trait]
pub trait Tool: Send + Sync {
fn spec(&self) -> ToolSpec;
async fn call(
&self,
args: Value,
ctx: ToolContext,
id: String,
) -> Result<ToolResult, ToolError>;
}
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register<T: Tool + 'static>(&mut self, tool: T) {
self.tools.insert(tool.spec().name, Box::new(tool));
}
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|t| t.as_ref())
}
pub fn specs(&self) -> Vec<ToolSpec> {
self.tools.values().map(|t| t.spec()).collect()
}
}
pub fn resolve_path(cwd: &std::path::Path, path: &str) -> PathBuf {
let p = PathBuf::from(path);
if p.is_absolute() {
p
} else {
cwd.join(p)
}
}
pub fn truncate(mut s: String, max: usize) -> String {
if s.len() <= max {
return s;
}
s.truncate(max);
s.push_str("\n[truncated]");
s
}