use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::tool_error::ToolError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: &'static str,
pub description: &'static str,
pub parameters: &'static str,
}
impl ToolDefinition {
pub const fn new(
name: &'static str,
description: &'static str,
parameters: &'static str,
) -> Self {
Self {
name,
description,
parameters,
}
}
pub fn to_openai_format(&self) -> serde_json::Value {
serde_json::json!({
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": serde_json::from_str::<serde_json::Value>(self.parameters)
.unwrap_or(serde_json::json!({}))
}
})
}
pub fn to_anthropic_format(&self) -> serde_json::Value {
serde_json::json!({
"name": self.name,
"description": self.description,
"input_schema": serde_json::from_str::<serde_json::Value>(self.parameters)
.unwrap_or(serde_json::json!({}))
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Capability {
PureComputation,
Network,
FileSystem,
Subprocess,
Environment,
Cryptography,
}
#[async_trait]
pub trait Tool: Send + Sync {
fn definition(&self) -> &ToolDefinition;
async fn execute(&self, args: Value) -> Result<Value, ToolError>;
fn validate(&self, _args: &Value) -> Result<(), ToolError> {
Ok(())
}
fn capabilities(&self) -> Vec<Capability> {
vec![Capability::PureComputation]
}
fn timeout(&self) -> Duration {
Duration::from_secs(30)
}
fn is_available(&self) -> bool {
true
}
}
#[derive(Default)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, tool: Arc<dyn Tool>) -> bool {
let name = tool.definition().name.to_string();
if self.tools.contains_key(&name) {
tracing::warn!("Tool '{}' already registered, skipping duplicate", name);
return false;
}
self.tools.insert(name, tool);
true
}
pub fn register_replace(&mut self, tool: Arc<dyn Tool>) {
let name = tool.definition().name.to_string();
if self.tools.contains_key(&name) {
tracing::warn!("Replacing existing tool '{}'", name);
}
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn remove(&mut self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.remove(name)
}
pub fn names(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
pub fn definitions(&self) -> Vec<&ToolDefinition> {
self.tools.values().map(|t| t.definition()).collect()
}
pub fn to_openai_format(&self) -> Vec<serde_json::Value> {
self.tools
.values()
.map(|t| t.definition().to_openai_format())
.collect()
}
pub fn to_anthropic_format(&self) -> Vec<serde_json::Value> {
self.tools
.values()
.map(|t| t.definition().to_anthropic_format())
.collect()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn available(&self) -> Vec<Arc<dyn Tool>> {
self.tools
.values()
.filter(|t| t.is_available())
.cloned()
.collect()
}
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tools", &self.names())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_definition() {
const TEST_TOOL: ToolDefinition =
ToolDefinition::new("test_tool", "A test tool", r#"{"type": "object"}"#);
assert_eq!(TEST_TOOL.name, "test_tool");
assert_eq!(TEST_TOOL.description, "A test tool");
}
#[test]
fn test_openai_format() {
let tool = ToolDefinition::new(
"search",
"Search the web",
r#"{"type": "object", "properties": {"query": {"type": "string"}}}"#,
);
let json = tool.to_openai_format();
assert_eq!(json["type"], "function");
assert_eq!(json["function"]["name"], "search");
}
#[test]
fn test_anthropic_format() {
let tool = ToolDefinition::new(
"search",
"Search the web",
r#"{"type": "object", "properties": {"query": {"type": "string"}}}"#,
);
let json = tool.to_anthropic_format();
assert_eq!(json["name"], "search");
assert!(json.get("input_schema").is_some());
}
struct MockTool {
definition: ToolDefinition,
}
impl MockTool {
fn new(name: &'static str) -> Self {
Self {
definition: ToolDefinition::new(name, "A mock tool", r#"{"type": "object"}"#),
}
}
}
#[async_trait]
impl Tool for MockTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
async fn execute(&self, _args: Value) -> Result<Value, ToolError> {
Ok(serde_json::json!({"mock": true}))
}
}
#[test]
fn test_registry_basic() {
let mut registry = ToolRegistry::new();
assert!(registry.is_empty());
let tool = Arc::new(MockTool::new("mock"));
assert!(registry.register(tool));
assert_eq!(registry.len(), 1);
assert!(registry.contains("mock"));
}
#[test]
fn test_registry_duplicate_rejection() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(MockTool::new("dup")));
let duplicate = Arc::new(MockTool::new("dup"));
assert!(!registry.register(duplicate));
assert_eq!(registry.len(), 1);
}
#[test]
fn test_registry_lookup() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(MockTool::new("finder")));
assert!(registry.get("finder").is_some());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_capability_enum() {
let caps = vec![Capability::PureComputation, Capability::Network];
assert!(caps.contains(&Capability::PureComputation));
assert!(!caps.contains(&Capability::FileSystem));
}
}