use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
#[async_trait]
pub trait BaseTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
async fn run(&self, input: String) -> Result<String, ToolError>;
fn args_schema(&self) -> Option<Value> {
None
}
fn return_direct(&self) -> bool {
false
}
async fn handle_error(&self, error: ToolError) -> String {
format!("工具 '{}' 执行失败: {}", self.name(), error)
}
}
#[async_trait]
pub trait Tool: Send + Sync {
type Input: DeserializeOwned + JsonSchema + Send + Sync + 'static;
type Output: Serialize + Send + Sync;
async fn invoke(&self, input: Self::Input) -> Result<Self::Output, ToolError>;
fn args_schema(&self) -> Option<Value> {
use schemars::schema_for;
serde_json::to_value(schema_for!(Self::Input)).ok()
}
}
#[derive(Debug)]
pub enum ToolError {
InvalidInput(String),
ExecutionFailed(String),
Timeout(u64),
ToolNotFound(String),
}
impl std::fmt::Display for ToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ToolError::InvalidInput(msg) => write!(f, "输入无效: {}", msg),
ToolError::ExecutionFailed(msg) => write!(f, "执行失败: {}", msg),
ToolError::Timeout(seconds) => write!(f, "执行超时: {}秒", seconds),
ToolError::ToolNotFound(name) => write!(f, "工具未找到: {}", name),
}
}
}
impl std::error::Error for ToolError {}
use super::ToolDefinition;
pub fn to_tool_definition(tool: &dyn BaseTool) -> ToolDefinition {
ToolDefinition::new(tool.name(), tool.description())
.with_parameters(
tool.args_schema()
.unwrap_or(serde_json::json!({"type": "object"}))
)
}