use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use crate::types::{Layer2Error, Layer2Result, ToolResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolMeta {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
pub required: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolRequest {
pub tool_call_id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub r#type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> serde_json::Value;
async fn execute(&self, args: &str) -> Layer2Result<ToolResult>;
fn validate_args(&self, _args: &serde_json::Value) -> Layer2Result<bool> {
Ok(true)
}
}
#[async_trait]
pub trait ToolRegistryTrait: Send + Sync {
fn register(&self, tool: Box<dyn Tool>) -> Layer2Result<()>;
fn unregister(&self, name: &str) -> Layer2Result<bool>;
fn get(&self, name: &str) -> Option<Arc<dyn Tool>>;
fn exists(&self, name: &str) -> bool;
fn list(&self) -> Vec<String>;
fn definitions(&self) -> Vec<ToolDefinition>;
async fn execute(&self, name: &str, args: &str) -> Layer2Result<ToolResult>;
fn count(&self) -> usize;
}
pub struct ToolRegistry {
tools: parking_lot::RwLock<HashMap<String, Arc<dyn Tool>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: parking_lot::RwLock::new(HashMap::new()),
}
}
pub fn with_builtin_tools() -> Self {
Self::new()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolRegistryTrait for ToolRegistry {
fn register(&self, tool: Box<dyn Tool>) -> Layer2Result<()> {
let mut tools = self.tools.write();
let name = tool.name().to_string();
tools.insert(name, Arc::from(tool));
Ok(())
}
fn unregister(&self, name: &str) -> Layer2Result<bool> {
let mut tools = self.tools.write();
Ok(tools.remove(name).is_some())
}
fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
let tools = self.tools.read();
tools.get(name).cloned()
}
fn exists(&self, name: &str) -> bool {
let tools = self.tools.read();
tools.contains_key(name)
}
fn list(&self) -> Vec<String> {
let tools = self.tools.read();
tools.keys().cloned().collect()
}
fn definitions(&self) -> Vec<ToolDefinition> {
let tools = self.tools.read();
tools
.values()
.map(|tool| ToolDefinition {
r#type: "function".to_string(),
function: FunctionDefinition {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters(),
},
})
.collect()
}
async fn execute(&self, name: &str, args: &str) -> Layer2Result<ToolResult> {
let tool = self
.get(name)
.ok_or_else(|| Layer2Error::ToolNotFound(name.to_string()))?;
tool.execute(args).await
}
fn count(&self) -> usize {
let tools = self.tools.read();
tools.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_registry_creation() {
let registry = ToolRegistry::new();
assert_eq!(registry.count(), 0);
}
#[test]
fn test_tool_registry_list() {
let registry = ToolRegistry::new();
let list = registry.list();
assert!(list.is_empty());
}
}