hefa_core/tools/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use thiserror::Error;
7
8/// Result of running a tool function.
9#[derive(Debug, Clone, PartialEq)]
10pub struct ToolResult {
11    pub content: Value,
12}
13
14/// Errors tool implementations can return.
15#[derive(Debug, Error, PartialEq)]
16pub enum ToolError {
17    #[error("tool `{name}` not found")]
18    NotFound { name: String },
19    #[error("invalid input: {0}")]
20    InvalidInput(String),
21    #[error("execution error: {0}")]
22    Execution(String),
23}
24
25#[async_trait]
26pub trait Tool: Send + Sync {
27    fn name(&self) -> &'static str;
28    fn json_schema(&self) -> Value;
29    async fn call(&self, args: Value) -> Result<ToolResult, ToolError>;
30}
31
32/// Registry mapping tool names to implementations.
33#[derive(Default)]
34pub struct ToolRegistry {
35    tools: HashMap<String, Arc<dyn Tool>>,
36}
37
38impl ToolRegistry {
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    pub fn register<T>(&mut self, tool: T)
44    where
45        T: Tool + 'static,
46    {
47        self.tools.insert(tool.name().to_string(), Arc::new(tool));
48    }
49
50    pub fn register_boxed(&mut self, tool: Box<dyn Tool>) {
51        let name = tool.name().to_string();
52        let arc: Arc<dyn Tool> = tool.into();
53        self.tools.insert(name, arc);
54    }
55
56    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
57        self.tools.get(name).cloned()
58    }
59
60    pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolResult, ToolError> {
61        match self.get(name) {
62            Some(tool) => tool.call(args).await,
63            None => Err(ToolError::NotFound {
64                name: name.to_string(),
65            }),
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    struct EchoTool;
75
76    #[async_trait]
77    impl Tool for EchoTool {
78        fn name(&self) -> &'static str {
79            "echo"
80        }
81
82        fn json_schema(&self) -> Value {
83            serde_json::json!({
84                "type": "object",
85                "properties": {
86                    "message": {"type": "string"}
87                },
88                "required": ["message"]
89            })
90        }
91
92        async fn call(&self, args: Value) -> Result<ToolResult, ToolError> {
93            let msg = args
94                .get("message")
95                .and_then(Value::as_str)
96                .ok_or_else(|| ToolError::InvalidInput("missing message".into()))?;
97            Ok(ToolResult {
98                content: serde_json::json!({ "echo": msg }),
99            })
100        }
101    }
102
103    #[tokio::test]
104    async fn registry_registers_and_invokes_tool() {
105        let mut registry = ToolRegistry::new();
106        registry.register(EchoTool);
107        let args = serde_json::json!({ "message": "hello" });
108        let result = registry.invoke("echo", args).await.expect("tool result");
109        assert_eq!(result.content, serde_json::json!({ "echo": "hello" }));
110    }
111
112    #[tokio::test]
113    async fn registry_returns_not_found() {
114        let registry = ToolRegistry::new();
115        let err = registry
116            .invoke("missing", serde_json::json!({}))
117            .await
118            .unwrap_err();
119        assert!(matches!(err, ToolError::NotFound { .. }));
120    }
121}