use async_trait::async_trait;
use serde_json::{json, Value};
use std::path::PathBuf;
use std::process::Stdio;
use tokio::process::Command;
use crate::error::{AgnoError, Result};
use crate::tool::{Tool, ToolRegistry};
#[derive(Clone)]
pub struct ShellConfig {
pub base_dir: Option<PathBuf>,
pub max_output_lines: usize,
pub timeout_secs: u64,
pub blocked_commands: Vec<String>,
}
impl Default for ShellConfig {
fn default() -> Self {
Self {
base_dir: None,
max_output_lines: 100,
timeout_secs: 30,
blocked_commands: vec![
"rm -rf /".into(),
"rm -rf /*".into(),
"mkfs".into(),
"dd if=".into(),
":(){:|:&};:".into(), ],
}
}
}
pub fn shell_toolkit(config: ShellConfig) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(RunShellCommandTool { config });
registry
}
struct RunShellCommandTool {
config: ShellConfig,
}
#[async_trait]
impl Tool for RunShellCommandTool {
fn name(&self) -> &str {
"run_shell_command"
}
fn description(&self) -> &str {
"Execute a shell command. Expects {\"command\": string} or {\"args\": [string array]}. Returns stdout or error."
}
async fn call(&self, input: Value) -> Result<Value> {
let (program, args): (String, Vec<String>) = if let Some(cmd) = input.get("command").and_then(Value::as_str) {
#[cfg(unix)]
{
("sh".into(), vec!["-c".into(), cmd.into()])
}
#[cfg(windows)]
{
("cmd".into(), vec!["/C".into(), cmd.into()])
}
} else if let Some(args) = input.get("args").and_then(Value::as_array) {
let args: Vec<String> = args
.iter()
.filter_map(Value::as_str)
.map(String::from)
.collect();
if args.is_empty() {
return Err(AgnoError::Protocol(
"empty `args` for run_shell_command".into(),
));
}
let program = args[0].clone();
let remaining = args.into_iter().skip(1).collect();
(program, remaining)
} else {
return Err(AgnoError::Protocol(
"missing `command` or `args` for run_shell_command".into(),
));
};
let full_command = format!("{} {}", program, args.join(" "));
for blocked in &self.config.blocked_commands {
if full_command.contains(blocked) {
return Ok(json!({
"error": format!("Command blocked for safety: contains '{}'", blocked),
"exit_code": -1
}));
}
}
let mut cmd = Command::new(&program);
cmd.args(&args)
.stdout(Stdio::piped())
.stderr(Stdio::piped());
if let Some(ref base_dir) = self.config.base_dir {
cmd.current_dir(base_dir);
}
let output = tokio::time::timeout(
std::time::Duration::from_secs(self.config.timeout_secs),
cmd.output(),
)
.await
.map_err(|_| AgnoError::ToolInvocation {
name: "run_shell_command".into(),
source: "Command timed out".into(),
})?
.map_err(|e| AgnoError::ToolInvocation {
name: "run_shell_command".into(),
source: Box::new(e),
})?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout_lines: Vec<&str> = stdout.lines().collect();
let truncated_stdout: String = if stdout_lines.len() > self.config.max_output_lines {
stdout_lines[stdout_lines.len() - self.config.max_output_lines..]
.join("\n")
} else {
stdout.to_string()
};
if output.status.success() {
Ok(json!({
"stdout": truncated_stdout,
"exit_code": output.status.code().unwrap_or(0)
}))
} else {
Ok(json!({
"error": stderr.to_string(),
"stdout": truncated_stdout,
"exit_code": output.status.code().unwrap_or(-1)
}))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_echo_command() {
let config = ShellConfig::default();
let registry = shell_toolkit(config);
let shell = registry.get("run_shell_command").unwrap();
let result = shell
.call(json!({"command": "echo hello"}))
.await
.unwrap();
assert!(result["stdout"].as_str().unwrap().contains("hello"));
assert_eq!(result["exit_code"], 0);
}
#[tokio::test]
async fn test_blocked_command() {
let config = ShellConfig::default();
let registry = shell_toolkit(config);
let shell = registry.get("run_shell_command").unwrap();
let result = shell
.call(json!({"command": "rm -rf /"}))
.await
.unwrap();
assert!(result["error"].as_str().unwrap().contains("blocked"));
}
}