hehe-tools 0.0.1

Tool system and built-in tools for hehe AI Agent framework
Documentation
use crate::error::{Result, ToolError};
use crate::traits::{Tool, ToolOutput};
use async_trait::async_trait;
use hehe_core::{Context, ToolDefinition, ToolParameter};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::process::Command;
use tokio::time::timeout;

pub struct ExecuteShellTool {
    def: ToolDefinition,
    default_timeout: Duration,
}

impl ExecuteShellTool {
    pub fn new() -> Self {
        let def = ToolDefinition::new("execute_shell", "Execute a shell command")
            .with_required_param(
                "command",
                ToolParameter::string().with_description("The shell command to execute"),
            )
            .with_param(
                "working_dir",
                ToolParameter::string().with_description("Working directory for the command"),
            )
            .with_param(
                "timeout_ms",
                ToolParameter::integer()
                    .with_description("Timeout in milliseconds (default: 60000)")
                    .with_default(Value::Number(60000.into())),
            )
            .with_param(
                "env",
                ToolParameter::object().with_description("Environment variables to set"),
            )
            .dangerous();
        Self {
            def,
            default_timeout: Duration::from_secs(60),
        }
    }

    pub fn with_timeout(mut self, timeout: Duration) -> Self {
        self.default_timeout = timeout;
        self
    }
}

impl Default for ExecuteShellTool {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Deserialize)]
struct ExecuteShellInput {
    command: String,
    working_dir: Option<String>,
    timeout_ms: Option<u64>,
    env: Option<HashMap<String, String>>,
}

#[derive(Serialize, Deserialize)]
struct ShellOutput {
    exit_code: Option<i32>,
    stdout: String,
    stderr: String,
    success: bool,
}

#[async_trait]
impl Tool for ExecuteShellTool {
    fn definition(&self) -> &ToolDefinition {
        &self.def
    }

    async fn execute(&self, ctx: &Context, input: Value) -> Result<ToolOutput> {
        let input: ExecuteShellInput = serde_json::from_value(input)?;

        if ctx.is_cancelled() {
            return Err(ToolError::Cancelled);
        }

        let shell = if cfg!(target_os = "windows") {
            ("cmd", "/C")
        } else {
            ("sh", "-c")
        };

        let mut cmd = Command::new(shell.0);
        cmd.arg(shell.1).arg(&input.command);

        if let Some(dir) = &input.working_dir {
            cmd.current_dir(dir);
        }

        if let Some(env) = &input.env {
            for (key, value) in env {
                cmd.env(key, value);
            }
        }

        cmd.stdout(Stdio::piped())
            .stderr(Stdio::piped())
            .stdin(Stdio::null());

        let timeout_duration = input
            .timeout_ms
            .map(Duration::from_millis)
            .unwrap_or(self.default_timeout);

        let result = timeout(timeout_duration, async {
            let mut child = cmd.spawn()?;

            let mut stdout = String::new();
            let mut stderr = String::new();

            if let Some(mut stdout_handle) = child.stdout.take() {
                stdout_handle.read_to_string(&mut stdout).await?;
            }
            if let Some(mut stderr_handle) = child.stderr.take() {
                stderr_handle.read_to_string(&mut stderr).await?;
            }

            let status = child.wait().await?;

            Ok::<_, std::io::Error>((status, stdout, stderr))
        })
        .await;

        match result {
            Ok(Ok((status, stdout, stderr))) => {
                let exit_code = status.code();
                let output = ShellOutput {
                    exit_code,
                    stdout,
                    stderr,
                    success: status.success(),
                };

                if status.success() {
                    ToolOutput::json(&output)?
                        .with_metadata("command", &input.command)
                        .with_metadata("exit_code", exit_code.unwrap_or(-1));
                    Ok(ToolOutput::json(&output)?)
                } else {
                    Ok(ToolOutput::json(&output)?.with_metadata("is_error", true))
                }
            }
            Ok(Err(e)) => Ok(ToolOutput::error(format!("Failed to execute command: {}", e))),
            Err(_) => Ok(ToolOutput::error(format!(
                "Command timed out after {}ms",
                timeout_duration.as_millis()
            ))),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_execute_shell_echo() {
        let tool = ExecuteShellTool::new();
        let ctx = Context::new();
        let input = serde_json::json!({
            "command": "echo hello"
        });

        let output = tool.execute(&ctx, input).await.unwrap();
        assert!(!output.is_error);

        let result: ShellOutput = serde_json::from_str(&output.content).unwrap();
        assert!(result.success);
        assert!(result.stdout.contains("hello"));
    }

    #[tokio::test]
    async fn test_execute_shell_with_env() {
        let tool = ExecuteShellTool::new();
        let ctx = Context::new();
        let input = serde_json::json!({
            "command": if cfg!(target_os = "windows") { "echo %TEST_VAR%" } else { "echo $TEST_VAR" },
            "env": {
                "TEST_VAR": "test_value"
            }
        });

        let output = tool.execute(&ctx, input).await.unwrap();
        let result: ShellOutput = serde_json::from_str(&output.content).unwrap();
        assert!(result.stdout.contains("test_value"));
    }

    #[tokio::test]
    async fn test_execute_shell_failure() {
        let tool = ExecuteShellTool::new();
        let ctx = Context::new();
        let input = serde_json::json!({
            "command": "exit 1"
        });

        let output = tool.execute(&ctx, input).await.unwrap();
        let result: ShellOutput = serde_json::from_str(&output.content).unwrap();
        assert!(!result.success);
        assert_eq!(result.exit_code, Some(1));
    }

    #[tokio::test]
    async fn test_execute_shell_timeout() {
        let tool = ExecuteShellTool::new();
        let ctx = Context::new();
        let input = serde_json::json!({
            "command": "sleep 10",
            "timeout_ms": 100
        });

        let output = tool.execute(&ctx, input).await.unwrap();
        assert!(output.content.contains("timed out"));
    }
}