use super::{BaseTool, Tool, ToolError};
use async_trait::async_trait;
use serde_json::Value;
pub struct StructuredTool<T: Tool> {
inner: T,
name: String,
description: String,
schema: Option<Value>,
}
impl<T: Tool> StructuredTool<T> {
pub fn new(tool: T, name: Option<&str>, description: Option<&str>) -> Self {
let schema = tool.args_schema();
Self {
inner: tool,
name: name.map(|s| s.to_string()).unwrap_or_else(|| "tool".to_string()),
description: description.map(|s| s.to_string()).unwrap_or_else(|| "A tool".to_string()),
schema,
}
}
fn parse_input(&self, input: String) -> Result<T::Input, ToolError> {
let json: Value = serde_json::from_str(&input)
.map_err(|e| ToolError::InvalidInput(format!("JSON 解析失败: {}", e)))?;
serde_json::from_value(json)
.map_err(|e| ToolError::InvalidInput(format!("输入格式不匹配: {}", e)))
}
fn serialize_output(output: T::Output) -> Result<String, ToolError> {
serde_json::to_string(&output)
.map_err(|e| ToolError::ExecutionFailed(format!("输出序列化失败: {}", e)))
}
}
#[async_trait]
impl<T: Tool> BaseTool for StructuredTool<T> {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
async fn run(&self, input: String) -> Result<String, ToolError> {
let parsed_input = self.parse_input(input)?;
let output = self.inner.invoke(parsed_input).await?;
Self::serialize_output(output)
}
fn args_schema(&self) -> Option<Value> {
self.schema.clone()
}
fn return_direct(&self) -> bool {
false
}
async fn handle_error(&self, error: ToolError) -> String {
format!("工具 '{}' 执行失败: {}", self.name, error)
}
}
impl<T: Tool> std::fmt::Debug for StructuredTool<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StructuredTool")
.field("name", &self.name)
.field("description", &self.description)
.finish()
}
}