use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
pub name: String,
#[serde(rename = "type")]
pub param_type: String,
pub description: String,
#[serde(default = "default_true")]
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<serde_json::Value>,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub data: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
impl ToolResult {
pub fn success(data: serde_json::Value) -> Self {
Self {
success: true,
data,
error: None,
metadata: None,
}
}
pub fn error(error: impl Into<String>) -> Self {
Self {
success: false,
data: serde_json::Value::Null,
error: Some(error.into()),
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
#[serde(rename = "type")]
pub schema_type: String,
pub function: FunctionSchema,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionSchema {
pub name: String,
pub description: String,
pub parameters: ParametersSchema,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParametersSchema {
#[serde(rename = "type")]
pub schema_type: String,
pub properties: HashMap<String, serde_json::Value>,
pub required: Vec<String>,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> &[ToolParameter];
fn category(&self) -> &str;
async fn execute(&self, params: HashMap<String, serde_json::Value>) -> ToolResult;
fn to_schema(&self) -> ToolSchema {
let mut properties = HashMap::new();
let mut required = Vec::new();
for param in self.parameters() {
let mut param_schema = serde_json::json!({
"type": param.param_type,
"description": param.description
});
if let Some(enum_values) = ¶m.enum_values {
param_schema["enum"] = serde_json::json!(enum_values);
}
properties.insert(param.name.clone(), param_schema);
if param.required {
required.push(param.name.clone());
}
}
ToolSchema {
schema_type: "function".to_string(),
function: FunctionSchema {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: ParametersSchema {
schema_type: "object".to_string(),
properties,
required,
},
},
}
}
fn validate_params(&self, params: &HashMap<String, serde_json::Value>) -> Result<(), String> {
for param in self.parameters() {
if param.required && !params.contains_key(¶m.name) {
return Err(format!("Missing required parameter: {}", param.name));
}
if let Some(enum_values) = ¶m.enum_values {
if let Some(value) = params.get(¶m.name) {
let value_str = value.as_str().unwrap_or("");
if !enum_values.contains(&value_str.to_string()) {
return Err(format!(
"Invalid value for {}. Must be one of: {:?}",
param.name, enum_values
));
}
}
}
}
Ok(())
}
}
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
categories: HashMap<String, Vec<String>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
categories: HashMap::new(),
}
}
pub fn register(&mut self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
let category = tool.category().to_string();
self.tools.insert(name.clone(), tool);
self.categories
.entry(category)
.or_insert_with(Vec::new)
.push(name);
}
pub fn unregister(&mut self, tool_name: &str) {
if let Some(tool) = self.tools.remove(tool_name) {
let category = tool.category();
if let Some(tools) = self.categories.get_mut(category) {
tools.retain(|name| name != tool_name);
}
}
}
pub fn get(&self, tool_name: &str) -> Option<&Arc<dyn Tool>> {
self.tools.get(tool_name)
}
pub fn list_tools(&self, category: Option<&str>) -> Vec<&Arc<dyn Tool>> {
if let Some(cat) = category {
if let Some(tool_names) = self.categories.get(cat) {
return tool_names
.iter()
.filter_map(|name| self.tools.get(name))
.collect();
}
return Vec::new();
}
self.tools.values().collect()
}
pub fn get_schemas(&self) -> Vec<ToolSchema> {
self.tools.values().map(|tool| tool.to_schema()).collect()
}
pub async fn execute(
&self,
tool_name: &str,
params: HashMap<String, serde_json::Value>,
) -> ToolResult {
let tool = match self.get(tool_name) {
Some(t) => t,
None => {
return ToolResult::error(format!("Tool not found: {}", tool_name));
}
};
if let Err(e) = tool.validate_params(¶ms) {
return ToolResult::error(e);
}
match tool.execute(params).await {
result => result,
}
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct CalculatorTool;
#[async_trait]
impl Tool for CalculatorTool {
fn name(&self) -> &str {
"calculator"
}
fn description(&self) -> &str {
"Perform mathematical calculations"
}
fn parameters(&self) -> &[ToolParameter] {
&[ToolParameter {
name: "expression".to_string(),
param_type: "string".to_string(),
description: "Mathematical expression to evaluate".to_string(),
required: true,
enum_values: None,
default: None,
}]
}
fn category(&self) -> &str {
"utility"
}
async fn execute(&self, params: HashMap<String, serde_json::Value>) -> ToolResult {
let expression = match params.get("expression") {
Some(serde_json::Value::String(s)) => s,
_ => return ToolResult::error("Invalid expression parameter"),
};
match meval::eval_str(expression) {
Ok(result) => ToolResult::success(serde_json::json!(result)),
Err(e) => ToolResult::error(format!("Calculation error: {}", e)),
}
}
}
pub struct WebSearchTool;
#[async_trait]
impl Tool for WebSearchTool {
fn name(&self) -> &str {
"web_search"
}
fn description(&self) -> &str {
"Search the web for information"
}
fn parameters(&self) -> &[ToolParameter] {
&[
ToolParameter {
name: "query".to_string(),
param_type: "string".to_string(),
description: "Search query".to_string(),
required: true,
enum_values: None,
default: None,
},
ToolParameter {
name: "max_results".to_string(),
param_type: "number".to_string(),
description: "Maximum number of results".to_string(),
required: false,
enum_values: None,
default: Some(serde_json::json!(5)),
},
]
}
fn category(&self) -> &str {
"web"
}
async fn execute(&self, _params: HashMap<String, serde_json::Value>) -> ToolResult {
ToolResult::error("Web search requires API key configuration")
}
}
pub fn load_builtin_tools() -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(CalculatorTool));
registry.register(Arc::new(WebSearchTool));
registry
}