use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub name: String,
pub value: Value,
pub success: bool,
pub error: Option<String>,
}
impl ToolResult {
pub fn success(name: impl Into<String>, value: Value) -> Self {
Self {
name: name.into(),
value,
success: true,
error: None,
}
}
pub fn failure(name: impl Into<String>, error: impl Into<String>) -> Self {
Self {
name: name.into(),
value: Value::Null,
success: false,
error: Some(error.into()),
}
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> Value;
async fn execute(&self, args: Value) -> Result<Value>;
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.parameters_schema(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Default)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, tool: impl Tool + 'static) {
let name = tool.name().to_string();
self.tools.insert(name, Arc::new(tool));
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
pub fn has(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn list(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools.values().map(|t| t.definition()).collect()
}
pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
match self.get(name) {
Some(tool) => match tool.execute(args).await {
Ok(value) => Ok(ToolResult::success(name, value)),
Err(e) => Ok(ToolResult::failure(name, e.to_string())),
},
None => Err(Error::tool(format!("Tool not found: {}", name))),
}
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tools", &self.list())
.finish()
}
}
pub struct FunctionTool<F, Fut>
where
F: Fn(Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<Value>> + Send,
{
name: String,
description: String,
parameters: Value,
func: F,
}
impl<F, Fut> FunctionTool<F, Fut>
where
F: Fn(Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<Value>> + Send,
{
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: Value,
func: F,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
func,
}
}
}
#[async_trait]
impl<F, Fut> Tool for FunctionTool<F, Fut>
where
F: Fn(Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<Value>> + Send,
{
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Value {
self.parameters.clone()
}
async fn execute(&self, args: Value) -> Result<Value> {
(self.func)(args).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_result_success() {
let result = ToolResult::success("test", serde_json::json!("hello"));
assert!(result.success);
assert_eq!(result.name, "test");
assert!(result.error.is_none());
}
#[test]
fn test_tool_result_failure() {
let result = ToolResult::failure("test", "something went wrong");
assert!(!result.success);
assert_eq!(result.error, Some("something went wrong".to_string()));
}
#[test]
fn test_tool_registry() {
let registry = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
}