use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde_json::{Value, json};
use crate::error::{Error, Result};
use crate::tool::{PermissionLevel, Tool, ToolContext, ToolFn};
type ToolFuture = Pin<Box<dyn Future<Output = Result<Value>> + Send>>;
type ToolHandler = Box<dyn Fn(Value, Arc<ToolContext>) -> ToolFuture + Send + Sync>;
pub fn create_simple_tool(
name: impl Into<String>, description: impl Into<String>, return_value: Value,
) -> ToolFn<ToolHandler> {
let desc = description.into();
let return_value = return_value.clone();
ToolFn::new(
name,
desc,
json!({
"type": "object",
"properties": {},
"additionalProperties": false
}),
Box::new(move |_, _| {
let ret = return_value.clone();
Box::pin(async move { Ok(ret) })
}),
)
}
pub fn create_echo_tool(
name: impl Into<String>, description: impl Into<String>,
) -> ToolFn<ToolHandler> {
ToolFn::new(
name,
description,
json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "要返回的消息"
}
},
"required": ["message"],
"additionalProperties": false
}),
Box::new(|params, _| Box::pin(async move { Ok(params) })),
)
}
pub fn create_delay_tool(
name: impl Into<String>, description: impl Into<String>,
) -> ToolFn<ToolHandler> {
ToolFn::new(
name,
description,
json!({
"type": "object",
"properties": {
"delay_ms": {
"type": "integer",
"description": "延迟毫秒数",
"minimum": 0,
"maximum": 10000
},
"result": {
"description": "延迟后返回的结果"
}
},
"required": ["delay_ms", "result"],
"additionalProperties": false
}),
Box::new(|params, _| {
let delay_ms = params["delay_ms"].as_u64().unwrap_or(0);
let result = params["result"].clone();
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
Ok(result)
})
}),
)
}
type CustomToolFn = Arc<dyn Fn(Value, Arc<ToolContext>) -> ToolFuture + Send + Sync + 'static>;
pub fn create_authenticated_tool(
name: impl Into<String>, description: impl Into<String>, handler: CustomToolFn,
) -> ToolFn<ToolHandler> {
let handler_fn: ToolHandler = Box::new(move |params, context| {
let handler = handler.clone();
Box::pin(async move { handler(params, context).await })
});
ToolFn::new(
name,
description,
json!({
"type": "object",
"properties": {},
"additionalProperties": true
}),
handler_fn,
)
.with_permission_level(PermissionLevel::Authenticated)
}
pub fn create_admin_tool(
name: impl Into<String>, description: impl Into<String>, handler: CustomToolFn,
) -> ToolFn<ToolHandler> {
let handler_fn: ToolHandler = Box::new(move |params, context| {
let handler = handler.clone();
Box::pin(async move { handler(params, context).await })
});
ToolFn::new(
name,
description,
json!({
"type": "object",
"properties": {},
"additionalProperties": true
}),
handler_fn,
)
.with_permission_level(PermissionLevel::Admin)
}
pub fn create_pipeline_tool(
name: impl Into<String>, description: impl Into<String>, tools: Vec<Arc<dyn Tool>>,
) -> ToolFn<ToolHandler> {
let tools = tools.clone();
ToolFn::new(
name,
description,
json!({
"type": "object",
"properties": {
"input": {
"description": "输入参数"
}
},
"required": ["input"],
"additionalProperties": false
}),
Box::new(move |params, context| {
let tools = tools.clone();
Box::pin(async move {
let mut current_params = params["input"].clone();
for tool in &tools {
current_params = tool.execute(current_params, context.clone()).await?;
}
Ok(current_params)
})
}),
)
}
pub fn create_parallel_tool(
name: impl Into<String>, description: impl Into<String>, tools: Vec<Arc<dyn Tool>>,
) -> ToolFn<ToolHandler> {
let tools = tools.clone();
ToolFn::new(
name,
description,
json!({
"type": "object",
"properties": {
"inputs": {
"type": "array",
"description": "每个工具的输入参数"
}
},
"required": ["inputs"],
"additionalProperties": false
}),
Box::new(move |params, context| {
let tools = tools.clone();
Box::pin(async move {
let inputs = params["inputs"]
.as_array()
.ok_or_else(|| Error::InvalidInput("参数 inputs 必须是数组".to_string()))?;
if inputs.len() != tools.len() {
return Err(Error::InvalidInput(format!(
"输入参数数量 ({}) 必须与工具数量 ({}) 一致",
inputs.len(),
tools.len()
)));
}
let mut futures = Vec::with_capacity(tools.len());
for (i, tool) in tools.iter().enumerate() {
let input = inputs[i].clone();
let context = context.clone();
futures.push(tool.execute(input, context));
}
let mut results = Vec::with_capacity(tools.len());
for future in futures {
results.push(future.await?);
}
Ok(json!(results))
})
}),
)
}