Skip to main content

agent_sdk/tools/
command_tools.rs

1use std::path::PathBuf;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use crate::error::{SdkError, SdkResult};
7use crate::traits::tool::{Tool, ToolDefinition};
8
9pub struct RunCommandTool {
10    pub work_dir: PathBuf,
11    pub allowed_commands: Vec<String>,
12}
13
14impl RunCommandTool {
15    pub fn with_defaults(work_dir: PathBuf) -> Self {
16        Self {
17            work_dir,
18            allowed_commands: vec![],
19        }
20    }
21
22    pub fn with_commands(work_dir: PathBuf, allowed: Vec<String>) -> Self {
23        Self {
24            work_dir,
25            allowed_commands: allowed,
26        }
27    }
28}
29
30#[async_trait]
31impl Tool for RunCommandTool {
32    fn definition(&self) -> ToolDefinition {
33        let desc = if self.allowed_commands.is_empty() {
34            "Execute any shell command in the working directory.".to_string()
35        } else {
36            format!(
37                "Execute a shell command in the working directory. Allowed commands: {}",
38                self.allowed_commands.join(", ")
39            )
40        };
41        ToolDefinition {
42            name: "run_command".to_string(),
43            description: desc,
44            parameters: json!({
45                "type": "object",
46                "properties": {
47                    "command": { "type": "string", "description": "The command to execute" },
48                    "timeout_secs": { "type": "integer", "description": "Timeout in seconds (default: 30)" }
49                },
50                "required": ["command"]
51            }),
52        }
53    }
54
55    async fn execute(&self, arguments: serde_json::Value) -> SdkResult<serde_json::Value> {
56        let command = arguments["command"]
57            .as_str()
58            .ok_or_else(|| SdkError::ToolExecution {
59                tool_name: "run_command".to_string(),
60                message: "Missing 'command' argument".to_string(),
61            })?;
62
63        let timeout_secs = arguments["timeout_secs"].as_u64().unwrap_or(30);
64
65        // Check whitelist (empty = allow all)
66        if !self.allowed_commands.is_empty() {
67            let executable = command.split_whitespace().next().unwrap_or("");
68            if !self.allowed_commands.iter().any(|c| c == executable) {
69                return Ok(json!({
70                    "error": format!(
71                        "Command '{}' is not allowed. Allowed: {}",
72                        executable,
73                        self.allowed_commands.join(", ")
74                    )
75                }));
76            }
77        }
78
79        let result = tokio::time::timeout(
80            std::time::Duration::from_secs(timeout_secs),
81            tokio::process::Command::new("sh")
82                .arg("-c")
83                .arg(command)
84                .current_dir(&self.work_dir)
85                .output(),
86        )
87        .await;
88
89        match result {
90            Ok(Ok(output)) => {
91                let stdout = String::from_utf8_lossy(&output.stdout);
92                let stderr = String::from_utf8_lossy(&output.stderr);
93                let exit_code = output.status.code().unwrap_or(-1);
94
95                let max_len = 4000;
96                let stdout_truncated = if stdout.len() > max_len {
97                    format!("{}... (truncated, {} total bytes)", &stdout[..max_len], stdout.len())
98                } else {
99                    stdout.to_string()
100                };
101                let stderr_truncated = if stderr.len() > max_len {
102                    format!("{}... (truncated, {} total bytes)", &stderr[..max_len], stderr.len())
103                } else {
104                    stderr.to_string()
105                };
106
107                Ok(json!({
108                    "exit_code": exit_code,
109                    "stdout": stdout_truncated,
110                    "stderr": stderr_truncated
111                }))
112            }
113            Ok(Err(e)) => Ok(json!({ "error": format!("Failed to execute command: {}", e) })),
114            Err(_) => Ok(json!({ "error": format!("Command timed out after {}s", timeout_secs) })),
115        }
116    }
117}