use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: Parameters,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Parameters {
#[serde(rename = "type")]
pub param_type: String,
pub properties: HashMap<String, Property>,
pub required: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Property {
#[serde(rename = "type")]
pub prop_type: String,
pub description: Option<String>,
pub items: Option<Box<Property>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ToolResultMessage {
pub role: String,
#[serde(rename = "tool_call_id")]
pub tool_call_id: String,
pub name: String,
pub content: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct LLMResponse {
pub choices: Vec<Choice>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Choice {
pub message: Message,
pub finish_reason: Option<String>,
}
pub trait ToolRegistryTrait: Send + Sync {
fn get_tools(&self) -> Vec<Tool>;
fn execute_tool(&self, name: &str, arguments: &str) -> Result<Value, String>;
}
pub type ToolFunction = dyn Fn(Value) -> Result<Value, String> + Send + Sync;
pub struct ToolRegistry {
tools: HashMap<String, Tool>,
functions: HashMap<String, Box<ToolFunction>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
functions: HashMap::new(),
}
}
pub fn register_tool<F>(&mut self, tool: Tool, function: F)
where
F: Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
{
let name = tool.function.name.clone();
self.tools.insert(name.clone(), tool);
self.functions.insert(name, Box::new(function));
}
}
impl ToolRegistryTrait for ToolRegistry {
fn get_tools(&self) -> Vec<Tool> {
self.tools.values().cloned().collect()
}
fn execute_tool(&self, name: &str, arguments: &str) -> Result<Value, String> {
let args: Value = serde_json::from_str(arguments)
.map_err(|e| format!("Failed to parse arguments: {}", e))?;
if let Some(function) = self.functions.get(name) {
function(args)
} else {
Err(format!("Tool '{}' not found", name))
}
}
}
pub struct CombinedToolRegistry<'a> {
primary: &'a dyn ToolRegistryTrait,
secondary: &'a dyn ToolRegistryTrait,
}
impl<'a> CombinedToolRegistry<'a> {
pub fn new(primary: &'a dyn ToolRegistryTrait, secondary: &'a dyn ToolRegistryTrait) -> Self {
Self { primary, secondary }
}
}
impl<'a> ToolRegistryTrait for CombinedToolRegistry<'a> {
fn get_tools(&self) -> Vec<Tool> {
let mut tools = self.secondary.get_tools();
let primary_tools = self.primary.get_tools();
tools.retain(|tool| {
!primary_tools.iter().any(|t| t.function.name == tool.function.name)
});
tools.extend(primary_tools);
tools
}
fn execute_tool(&self, name: &str, arguments: &str) -> Result<Value, String> {
if let Ok(result) = self.primary.execute_tool(name, arguments) {
return Ok(result);
}
self.secondary.execute_tool(name, arguments)
}
}