hehe-tools 0.0.1

Tool system and built-in tools for hehe AI Agent framework
Documentation
use crate::error::{Result, ToolError};
use crate::registry::ToolRegistry;
use crate::traits::ToolOutput;
use hehe_core::{Context, ToolCall, ToolCallStatus};
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use tracing::{info, warn};

pub struct ToolExecutor {
    registry: Arc<ToolRegistry>,
    default_timeout: Duration,
    require_confirmation_for_dangerous: bool,
}

impl ToolExecutor {
    pub fn new(registry: Arc<ToolRegistry>) -> Self {
        Self {
            registry,
            default_timeout: Duration::from_secs(60),
            require_confirmation_for_dangerous: true,
        }
    }

    pub fn with_timeout(mut self, timeout: Duration) -> Self {
        self.default_timeout = timeout;
        self
    }

    pub fn allow_dangerous_without_confirmation(mut self) -> Self {
        self.require_confirmation_for_dangerous = false;
        self
    }

    pub async fn execute(
        &self,
        ctx: &Context,
        name: &str,
        input: Value,
    ) -> Result<ToolOutput> {
        let tool = self
            .registry
            .get(name)
            .ok_or_else(|| ToolError::not_found(name))?;

        if ctx.is_cancelled() {
            return Err(ToolError::Cancelled);
        }

        tool.validate_input(&input)?;

        info!(tool = name, "Executing tool");

        let execute_timeout = ctx
            .remaining()
            .unwrap_or(self.default_timeout)
            .min(self.default_timeout);

        let result = timeout(execute_timeout, tool.execute(ctx, input)).await;

        match result {
            Ok(Ok(output)) => {
                info!(tool = name, is_error = output.is_error, "Tool execution completed");
                Ok(output)
            }
            Ok(Err(e)) => {
                warn!(tool = name, error = %e, "Tool execution failed");
                Err(e)
            }
            Err(_) => {
                warn!(tool = name, timeout_ms = ?execute_timeout.as_millis(), "Tool execution timed out");
                Err(ToolError::Timeout(execute_timeout.as_millis() as u64))
            }
        }
    }

    pub async fn execute_call(&self, ctx: &Context, call: &mut ToolCall) -> Result<ToolOutput> {
        call.start();

        match self.execute(ctx, &call.name, call.input.clone()).await {
            Ok(output) => {
                if output.is_error {
                    call.fail(&output.content);
                } else {
                    call.complete(serde_json::to_value(&output.content).unwrap_or(Value::Null));
                }
                Ok(output)
            }
            Err(e) => {
                call.fail(e.to_string());
                Err(e)
            }
        }
    }

    pub fn registry(&self) -> &ToolRegistry {
        &self.registry
    }

    pub fn is_dangerous(&self, name: &str) -> bool {
        self.registry
            .get(name)
            .map(|t| t.is_dangerous())
            .unwrap_or(false)
    }

    pub fn needs_confirmation(&self, name: &str) -> bool {
        self.require_confirmation_for_dangerous && self.is_dangerous(name)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::traits::Tool;
    use async_trait::async_trait;
    use hehe_core::ToolDefinition;

    struct EchoTool {
        def: ToolDefinition,
    }

    impl EchoTool {
        fn new() -> Self {
            Self {
                def: ToolDefinition::new("echo", "Echoes input"),
            }
        }
    }

    #[async_trait]
    impl Tool for EchoTool {
        fn definition(&self) -> &ToolDefinition {
            &self.def
        }

        async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
            Ok(ToolOutput::text(input.to_string()))
        }
    }

    struct SlowTool {
        def: ToolDefinition,
    }

    impl SlowTool {
        fn new() -> Self {
            Self {
                def: ToolDefinition::new("slow", "A slow tool"),
            }
        }
    }

    #[async_trait]
    impl Tool for SlowTool {
        fn definition(&self) -> &ToolDefinition {
            &self.def
        }

        async fn execute(&self, _ctx: &Context, _input: Value) -> Result<ToolOutput> {
            tokio::time::sleep(Duration::from_secs(10)).await;
            Ok(ToolOutput::text("done"))
        }
    }

    #[tokio::test]
    async fn test_executor_execute() {
        let mut registry = ToolRegistry::new();
        registry.register(Arc::new(EchoTool::new())).unwrap();

        let executor = ToolExecutor::new(Arc::new(registry));
        let ctx = Context::new();

        let output = executor
            .execute(&ctx, "echo", serde_json::json!({"message": "hello"}))
            .await
            .unwrap();

        assert!(output.content.contains("hello"));
    }

    #[tokio::test]
    async fn test_executor_not_found() {
        let registry = ToolRegistry::new();
        let executor = ToolExecutor::new(Arc::new(registry));
        let ctx = Context::new();

        let result = executor.execute(&ctx, "nonexistent", Value::Null).await;
        assert!(matches!(result, Err(ToolError::NotFound(_))));
    }

    #[tokio::test]
    async fn test_executor_timeout() {
        let mut registry = ToolRegistry::new();
        registry.register(Arc::new(SlowTool::new())).unwrap();

        let executor = ToolExecutor::new(Arc::new(registry))
            .with_timeout(Duration::from_millis(100));
        let ctx = Context::new();

        let result = executor.execute(&ctx, "slow", Value::Null).await;
        assert!(matches!(result, Err(ToolError::Timeout(_))));
    }

    #[tokio::test]
    async fn test_executor_execute_call() {
        let mut registry = ToolRegistry::new();
        registry.register(Arc::new(EchoTool::new())).unwrap();

        let executor = ToolExecutor::new(Arc::new(registry));
        let ctx = Context::new();

        let mut call = ToolCall::new("echo", serde_json::json!({"x": 1}));
        assert!(call.is_pending());

        let output = executor.execute_call(&ctx, &mut call).await.unwrap();
        
        assert!(call.is_completed());
        assert!(!output.is_error);
    }
}