#![allow(dead_code)]
use crate::agency::error::AgencyResult;
use crate::agency::models::ToolResult;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
pub name: String,
#[serde(rename = "type")]
pub param_type: String,
pub description: String,
#[serde(default)]
pub required: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
#[serde(default)]
pub parameters: Vec<ToolParameter>,
#[serde(default)]
pub category: ToolCategory,
#[serde(default)]
pub requires_confirmation: bool,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
}
impl Tool {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: Vec::new(),
category: ToolCategory::Custom,
requires_confirmation: false,
metadata: HashMap::new(),
}
}
pub fn to_function_definition(&self) -> Value {
let mut properties = serde_json::Map::new();
let mut required = Vec::new();
for param in &self.parameters {
let mut prop = serde_json::Map::new();
prop.insert("type".to_string(), Value::String(param.param_type.clone()));
prop.insert(
"description".to_string(),
Value::String(param.description.clone()),
);
if let Some(enum_vals) = ¶m.enum_values {
prop.insert(
"enum".to_string(),
Value::Array(enum_vals.iter().map(|v| Value::String(v.clone())).collect()),
);
}
properties.insert(param.name.clone(), Value::Object(prop));
if param.required {
required.push(Value::String(param.name.clone()));
}
}
serde_json::json!({
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCategory {
#[default]
Custom,
Search,
Code,
File,
Data,
Communication,
System,
Builtin,
}
pub struct ToolBuilder {
tool: Tool,
}
impl ToolBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
tool: Tool {
name: name.into(),
description: String::new(),
parameters: Vec::new(),
category: ToolCategory::Custom,
requires_confirmation: false,
metadata: HashMap::new(),
},
}
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.tool.description = desc.into();
self
}
pub fn parameter(
mut self,
name: impl Into<String>,
param_type: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
self.tool.parameters.push(ToolParameter {
name: name.into(),
param_type: param_type.into(),
description: description.into(),
required,
enum_values: None,
default: None,
});
self
}
pub fn string_param(
self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
self.parameter(name, "string", description, required)
}
pub fn number_param(
self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
self.parameter(name, "number", description, required)
}
pub fn bool_param(
self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
self.parameter(name, "boolean", description, required)
}
pub fn category(mut self, category: ToolCategory) -> Self {
self.tool.category = category;
self
}
pub fn requires_confirmation(mut self, requires: bool) -> Self {
self.tool.requires_confirmation = requires;
self
}
pub fn build(self) -> Tool {
self.tool
}
}
#[async_trait]
pub trait ToolExecutor: Send + Sync {
fn definition(&self) -> &Tool;
async fn execute(&self, args: Value) -> AgencyResult<ToolResult>;
}
pub type ToolFn = Box<
dyn Fn(Value) -> Pin<Box<dyn Future<Output = AgencyResult<ToolResult>> + Send>> + Send + Sync,
>;
#[derive(Default)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<Tool>>,
executors: HashMap<String, Arc<dyn ToolExecutor>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_builtins() -> Self {
let mut registry = Self::new();
registry.register_builtins();
registry
}
pub fn register(&mut self, tool: Tool) {
self.tools.insert(tool.name.clone(), Arc::new(tool));
}
pub fn register_with_executor(&mut self, executor: impl ToolExecutor + 'static) {
let tool = executor.definition().clone();
let name = tool.name.clone();
self.tools.insert(name.clone(), Arc::new(tool));
self.executors.insert(name, Arc::new(executor));
}
pub fn get(&self, name: &str) -> Option<&Arc<Tool>> {
self.tools.get(name)
}
pub fn get_executor(&self, name: &str) -> Option<&Arc<dyn ToolExecutor>> {
self.executors.get(name)
}
pub fn list(&self) -> Vec<&Tool> {
self.tools.values().map(|t| t.as_ref()).collect()
}
pub fn to_definitions(&self) -> Vec<Value> {
self.tools
.values()
.map(|t| t.to_function_definition())
.collect()
}
fn register_builtins(&mut self) {
for tool in BuiltinTools::all() {
self.register(tool);
}
}
}
pub struct BuiltinTools;
impl BuiltinTools {
pub fn all() -> Vec<Tool> {
vec![
Self::web_search(),
Self::code_execution(),
Self::read_file(),
Self::write_file(),
Self::list_directory(),
Self::http_request(),
Self::calculator(),
]
}
pub fn web_search() -> Tool {
ToolBuilder::new("web_search")
.description("Search the web for information. Returns relevant snippets and URLs.")
.string_param("query", "The search query", true)
.number_param(
"max_results",
"Maximum number of results (default: 5)",
false,
)
.category(ToolCategory::Search)
.build()
}
pub fn code_execution() -> Tool {
ToolBuilder::new("code_execution")
.description("Execute code in a sandboxed environment. Supports Python, JavaScript, and shell scripts.")
.string_param("code", "The code to execute", true)
.string_param("language", "Programming language (python, javascript, shell)", true)
.number_param("timeout", "Execution timeout in seconds (default: 30)", false)
.category(ToolCategory::Code)
.requires_confirmation(true)
.build()
}
pub fn read_file() -> Tool {
ToolBuilder::new("read_file")
.description("Read the contents of a file from the filesystem.")
.string_param("path", "Path to the file to read", true)
.string_param("encoding", "File encoding (default: utf-8)", false)
.category(ToolCategory::File)
.build()
}
pub fn write_file() -> Tool {
ToolBuilder::new("write_file")
.description("Write content to a file. Creates the file if it doesn't exist.")
.string_param("path", "Path to the file to write", true)
.string_param("content", "Content to write to the file", true)
.bool_param("append", "Append to file instead of overwriting", false)
.category(ToolCategory::File)
.requires_confirmation(true)
.build()
}
pub fn list_directory() -> Tool {
ToolBuilder::new("list_directory")
.description("List the contents of a directory.")
.string_param("path", "Path to the directory", true)
.bool_param("recursive", "Include subdirectories", false)
.bool_param("include_hidden", "Include hidden files", false)
.category(ToolCategory::File)
.build()
}
pub fn http_request() -> Tool {
ToolBuilder::new("http_request")
.description("Make an HTTP request to a URL.")
.string_param("url", "The URL to request", true)
.string_param("method", "HTTP method (GET, POST, PUT, DELETE)", false)
.string_param("body", "Request body (for POST/PUT)", false)
.string_param("headers", "JSON object of headers", false)
.category(ToolCategory::Communication)
.build()
}
pub fn calculator() -> Tool {
ToolBuilder::new("calculator")
.description("Evaluate mathematical expressions. Supports basic arithmetic, functions, and constants.")
.string_param("expression", "The mathematical expression to evaluate", true)
.category(ToolCategory::Data)
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_builder() {
let tool = ToolBuilder::new("test_tool")
.description("A test tool")
.string_param("input", "Input parameter", true)
.number_param("count", "Count parameter", false)
.category(ToolCategory::Custom)
.build();
assert_eq!(tool.name, "test_tool");
assert_eq!(tool.description, "A test tool");
assert_eq!(tool.parameters.len(), 2);
assert!(tool.parameters[0].required);
assert!(!tool.parameters[1].required);
}
#[test]
fn test_function_definition() {
let tool = BuiltinTools::web_search();
let def = tool.to_function_definition();
assert_eq!(def["type"], "function");
assert_eq!(def["function"]["name"], "web_search");
}
#[test]
fn test_registry() {
let registry = ToolRegistry::with_builtins();
assert!(registry.get("web_search").is_some());
assert!(registry.get("code_execution").is_some());
assert!(registry.get("nonexistent").is_none());
}
}