hehe-tools 0.0.1

Tool system and built-in tools for hehe AI Agent framework
Documentation
use crate::error::{Result, ToolError};
use crate::traits::Tool;
use hehe_core::ToolDefinition;
use std::collections::HashMap;
use std::sync::Arc;

pub struct ToolRegistry {
    tools: HashMap<String, Arc<dyn Tool>>,
}

impl ToolRegistry {
    pub fn new() -> Self {
        Self {
            tools: HashMap::new(),
        }
    }

    pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<()> {
        let name = tool.name().to_string();
        if self.tools.contains_key(&name) {
            return Err(ToolError::AlreadyRegistered(name));
        }
        self.tools.insert(name, tool);
        Ok(())
    }

    pub fn register_boxed(&mut self, tool: Box<dyn Tool>) -> Result<()> {
        self.register(Arc::from(tool))
    }

    pub fn unregister(&mut self, name: &str) -> Option<Arc<dyn Tool>> {
        self.tools.remove(name)
    }

    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
        self.tools.get(name)
    }

    pub fn contains(&self, name: &str) -> bool {
        self.tools.contains_key(name)
    }

    pub fn list(&self) -> Vec<&str> {
        self.tools.keys().map(|s| s.as_str()).collect()
    }

    pub fn definitions(&self) -> Vec<ToolDefinition> {
        self.tools
            .values()
            .map(|t| t.definition().clone())
            .collect()
    }

    pub fn len(&self) -> usize {
        self.tools.len()
    }

    pub fn is_empty(&self) -> bool {
        self.tools.is_empty()
    }

    pub fn dangerous_tools(&self) -> Vec<&str> {
        self.tools
            .iter()
            .filter(|(_, t)| t.is_dangerous())
            .map(|(name, _)| name.as_str())
            .collect()
    }

    pub fn safe_tools(&self) -> Vec<&str> {
        self.tools
            .iter()
            .filter(|(_, t)| !t.is_dangerous())
            .map(|(name, _)| name.as_str())
            .collect()
    }
}

impl Default for ToolRegistry {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::traits::ToolOutput;
    use async_trait::async_trait;
    use hehe_core::{Context, ToolDefinition, ToolParameter};
    use serde_json::Value;

    struct MockTool {
        def: ToolDefinition,
    }

    impl MockTool {
        fn new(name: &str, dangerous: bool) -> Self {
            let mut def = ToolDefinition::new(name, format!("{} tool", name));
            if dangerous {
                def = def.dangerous();
            }
            Self { def }
        }
    }

    #[async_trait]
    impl Tool for MockTool {
        fn definition(&self) -> &ToolDefinition {
            &self.def
        }

        async fn execute(&self, _ctx: &Context, _input: Value) -> Result<ToolOutput> {
            Ok(ToolOutput::text("mock output"))
        }
    }

    #[test]
    fn test_registry_register() {
        let mut registry = ToolRegistry::new();
        
        let tool = Arc::new(MockTool::new("test_tool", false));
        registry.register(tool).unwrap();

        assert!(registry.contains("test_tool"));
        assert_eq!(registry.len(), 1);
    }

    #[test]
    fn test_registry_duplicate() {
        let mut registry = ToolRegistry::new();
        
        registry.register(Arc::new(MockTool::new("dup", false))).unwrap();
        let result = registry.register(Arc::new(MockTool::new("dup", false)));
        
        assert!(result.is_err());
    }

    #[test]
    fn test_registry_list() {
        let mut registry = ToolRegistry::new();
        
        registry.register(Arc::new(MockTool::new("tool_a", false))).unwrap();
        registry.register(Arc::new(MockTool::new("tool_b", true))).unwrap();

        let list = registry.list();
        assert_eq!(list.len(), 2);
        assert!(list.contains(&"tool_a"));
        assert!(list.contains(&"tool_b"));
    }

    #[test]
    fn test_registry_dangerous_tools() {
        let mut registry = ToolRegistry::new();
        
        registry.register(Arc::new(MockTool::new("safe", false))).unwrap();
        registry.register(Arc::new(MockTool::new("danger", true))).unwrap();

        let dangerous = registry.dangerous_tools();
        assert_eq!(dangerous.len(), 1);
        assert!(dangerous.contains(&"danger"));

        let safe = registry.safe_tools();
        assert_eq!(safe.len(), 1);
        assert!(safe.contains(&"safe"));
    }

    #[test]
    fn test_registry_unregister() {
        let mut registry = ToolRegistry::new();
        
        registry.register(Arc::new(MockTool::new("removable", false))).unwrap();
        assert!(registry.contains("removable"));

        let removed = registry.unregister("removable");
        assert!(removed.is_some());
        assert!(!registry.contains("removable"));
    }

    #[test]
    fn test_registry_definitions() {
        let mut registry = ToolRegistry::new();
        
        registry.register(Arc::new(MockTool::new("tool1", false))).unwrap();
        registry.register(Arc::new(MockTool::new("tool2", false))).unwrap();

        let defs = registry.definitions();
        assert_eq!(defs.len(), 2);
    }
}