Skip to main content

agentzero_tools/
shell.rs

1use crate::shell_parse::{self, AnnotatedChar, QuoteContext};
2use agentzero_core::{Tool, ToolContext, ToolResult};
3use anyhow::{anyhow, Context};
4use async_trait::async_trait;
5use std::process::Stdio;
6use tokio::io::{AsyncRead, AsyncReadExt};
7use tokio::process::Command;
8
9const DEFAULT_MAX_SHELL_ARGS: usize = 32;
10const DEFAULT_MAX_ARG_LENGTH: usize = 4096;
11const DEFAULT_MAX_OUTPUT_BYTES: usize = 65536;
12const DEFAULT_FORBIDDEN_CHARS: &str = ";&|><$`\n\r";
13
14/// Context-aware shell command policy.
15///
16/// Replaces the flat `forbidden_chars` with structured classification that
17/// respects quoting context.
18#[derive(Debug, Clone)]
19pub struct ShellCommandPolicy {
20    /// Characters that are ALWAYS forbidden, even inside quotes.
21    pub always_forbidden: Vec<char>,
22    /// Characters forbidden only when they appear unquoted.
23    pub forbidden_unquoted: Vec<char>,
24}
25
26impl Default for ShellCommandPolicy {
27    fn default() -> Self {
28        Self {
29            always_forbidden: vec!['`', '\0'],
30            forbidden_unquoted: vec![';', '&', '|', '>', '<', '$', '\n', '\r'],
31        }
32    }
33}
34
35impl ShellCommandPolicy {
36    /// Build from the legacy flat `forbidden_chars` string.
37    pub fn from_legacy_forbidden_chars(chars: &str) -> Self {
38        let always: Vec<char> = chars.chars().filter(|c| *c == '`' || *c == '\0').collect();
39        let unquoted: Vec<char> = chars.chars().filter(|c| *c != '`' && *c != '\0').collect();
40        Self {
41            always_forbidden: always,
42            forbidden_unquoted: unquoted,
43        }
44    }
45
46    /// Validate annotated characters from a single token.
47    pub fn validate_token(&self, chars: &[AnnotatedChar]) -> anyhow::Result<()> {
48        for ac in chars {
49            if self.always_forbidden.contains(&ac.ch) {
50                anyhow::bail!(
51                    "shell argument contains always-forbidden character: {:?}",
52                    ac.ch
53                );
54            }
55            if ac.context == QuoteContext::Unquoted && self.forbidden_unquoted.contains(&ac.ch) {
56                anyhow::bail!(
57                    "shell argument contains unquoted forbidden metacharacter: {:?}",
58                    ac.ch
59                );
60            }
61        }
62        Ok(())
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct ShellPolicy {
68    pub allowed_commands: Vec<String>,
69    pub max_args: usize,
70    pub max_arg_length: usize,
71    pub max_output_bytes: usize,
72    pub forbidden_chars: String,
73    /// Context-aware policy. When `Some`, uses quote-aware validation.
74    /// When `None`, falls back to legacy flat `forbidden_chars` check.
75    pub command_policy: Option<ShellCommandPolicy>,
76}
77
78impl ShellPolicy {
79    pub fn default_with_commands(allowed_commands: Vec<String>) -> Self {
80        Self {
81            allowed_commands,
82            max_args: DEFAULT_MAX_SHELL_ARGS,
83            max_arg_length: DEFAULT_MAX_ARG_LENGTH,
84            max_output_bytes: DEFAULT_MAX_OUTPUT_BYTES,
85            forbidden_chars: DEFAULT_FORBIDDEN_CHARS.to_string(),
86            command_policy: Some(ShellCommandPolicy::default()),
87        }
88    }
89}
90
91pub struct ShellTool {
92    policy: ShellPolicy,
93}
94
95impl ShellTool {
96    pub fn new(policy: ShellPolicy) -> Self {
97        Self { policy }
98    }
99
100    /// Parse and validate a shell command input using context-aware or legacy mode.
101    fn parse_and_validate(
102        policy: &ShellPolicy,
103        input: &str,
104    ) -> anyhow::Result<(String, Vec<String>)> {
105        if policy.command_policy.is_some() {
106            Self::parse_context_aware(policy, input)
107        } else {
108            Self::parse_legacy(policy, input)
109        }
110    }
111
112    /// Context-aware parsing: uses quote-aware tokenizer and structured policy.
113    fn parse_context_aware(
114        policy: &ShellPolicy,
115        input: &str,
116    ) -> anyhow::Result<(String, Vec<String>)> {
117        let tokens = shell_parse::tokenize(input)?;
118        let annotated = shell_parse::tokenize_annotated(input)?;
119
120        if tokens.is_empty() {
121            return Err(anyhow!("command is required"));
122        }
123
124        let command_name = tokens[0].text.clone();
125        let args: Vec<String> = tokens[1..].iter().map(|t| t.text.clone()).collect();
126
127        if args.len() > policy.max_args {
128            return Err(anyhow!("too many shell arguments"));
129        }
130
131        // SAFETY: parse_context_aware is only called when command_policy.is_some()
132        let cmd_policy = policy
133            .command_policy
134            .as_ref()
135            .expect("command_policy must be Some in context-aware mode");
136        for (i, token) in tokens.iter().enumerate().skip(1) {
137            if token.text.is_empty() {
138                return Err(anyhow!("empty shell argument is not allowed"));
139            }
140            if token.text.len() > policy.max_arg_length {
141                return Err(anyhow!("shell argument exceeds max length"));
142            }
143            cmd_policy.validate_token(&annotated[i])?;
144        }
145
146        Ok((command_name, args))
147    }
148
149    /// Legacy parsing: flat whitespace split and flat forbidden_chars check.
150    fn parse_legacy(policy: &ShellPolicy, input: &str) -> anyhow::Result<(String, Vec<String>)> {
151        let mut parts = input.split_whitespace();
152        let command_name = parts
153            .next()
154            .ok_or_else(|| anyhow!("command is required"))?
155            .to_string();
156        let args: Vec<String> = parts.map(ToString::to_string).collect();
157
158        if args.len() > policy.max_args {
159            return Err(anyhow!("too many shell arguments"));
160        }
161        for arg in &args {
162            if arg.is_empty() {
163                return Err(anyhow!("empty shell argument is not allowed"));
164            }
165            if arg.len() > policy.max_arg_length {
166                return Err(anyhow!("shell argument exceeds max length"));
167            }
168            if arg.chars().any(|c| policy.forbidden_chars.contains(c)) {
169                return Err(anyhow!(
170                    "shell argument contains forbidden shell metacharacters"
171                ));
172            }
173        }
174
175        Ok((command_name, args))
176    }
177
178    async fn read_limited<R>(mut reader: R, max_bytes: usize) -> anyhow::Result<(Vec<u8>, bool)>
179    where
180        R: AsyncRead + Unpin,
181    {
182        let mut bytes = Vec::new();
183        let mut limited = (&mut reader).take((max_bytes + 1) as u64);
184        limited
185            .read_to_end(&mut bytes)
186            .await
187            .context("failed to capture command output")?;
188
189        let truncated = bytes.len() > max_bytes;
190        if truncated {
191            bytes.truncate(max_bytes);
192        }
193
194        Ok((bytes, truncated))
195    }
196
197    fn render_stream(name: &str, bytes: &[u8], truncated: bool, max_bytes: usize) -> String {
198        let mut out = format!("{name}:\n{}", String::from_utf8_lossy(bytes));
199        if truncated {
200            out.push_str(&format!("\n<truncated at {max_bytes} bytes>"));
201        }
202        out
203    }
204}
205
206#[async_trait]
207impl Tool for ShellTool {
208    fn name(&self) -> &'static str {
209        "shell"
210    }
211
212    fn description(&self) -> &'static str {
213        "Execute a shell command from the allowlist. Input is the full command line. Returns stdout, stderr, and exit code."
214    }
215
216    fn input_schema(&self) -> Option<serde_json::Value> {
217        Some(serde_json::json!({
218            "type": "object",
219            "properties": {
220                "command": {
221                    "type": "string",
222                    "description": "The shell command to execute (e.g. \"ls -la\", \"cargo build\")"
223                }
224            },
225            "required": ["command"]
226        }))
227    }
228
229    async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
230        let (command_name, args) = Self::parse_and_validate(&self.policy, input)?;
231        if !self
232            .policy
233            .allowed_commands
234            .iter()
235            .any(|c| c == &command_name)
236        {
237            return Err(anyhow!("command is not in allowlist"));
238        }
239
240        let mut child = Command::new(&command_name)
241            .args(&args)
242            .stdout(Stdio::piped())
243            .stderr(Stdio::piped())
244            .spawn()
245            .context("shell command failed to execute")?;
246
247        let stdout_reader = child
248            .stdout
249            .take()
250            .ok_or_else(|| anyhow!("shell command did not provide stdout pipe"))?;
251        let stderr_reader = child
252            .stderr
253            .take()
254            .ok_or_else(|| anyhow!("shell command did not provide stderr pipe"))?;
255
256        let stdout_task = tokio::spawn(Self::read_limited(
257            stdout_reader,
258            self.policy.max_output_bytes,
259        ));
260        let stderr_task = tokio::spawn(Self::read_limited(
261            stderr_reader,
262            self.policy.max_output_bytes,
263        ));
264
265        let status = child.wait().await.context("shell command failed to run")?;
266        let (stdout, stdout_truncated) = stdout_task
267            .await
268            .context("failed joining stdout capture task")??;
269        let (stderr, stderr_truncated) = stderr_task
270            .await
271            .context("failed joining stderr capture task")??;
272
273        Ok(ToolResult {
274            output: format!(
275                "status={}\n{}\n{}",
276                status,
277                Self::render_stream(
278                    "stdout",
279                    &stdout,
280                    stdout_truncated,
281                    self.policy.max_output_bytes
282                ),
283                Self::render_stream(
284                    "stderr",
285                    &stderr,
286                    stderr_truncated,
287                    self.policy.max_output_bytes
288                )
289            ),
290        })
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use agentzero_core::{Tool, ToolContext};
298
299    fn echo_tool() -> ShellTool {
300        ShellTool::new(ShellPolicy::default_with_commands(vec!["echo".to_string()]))
301    }
302
303    fn ctx() -> ToolContext {
304        ToolContext::new(".".to_string())
305    }
306
307    #[tokio::test]
308    async fn shell_allows_allowlisted_command() {
309        let result = echo_tool()
310            .execute("echo hello", &ctx())
311            .await
312            .expect("shell should succeed");
313        assert!(result.output.contains("stdout:\nhello"));
314    }
315
316    #[tokio::test]
317    async fn shell_rejects_unquoted_metacharacters() {
318        let result = echo_tool().execute("echo hello;uname", &ctx()).await;
319        assert!(result.is_err());
320        let msg = result.unwrap_err().to_string();
321        assert!(msg.contains("unquoted forbidden metacharacter"));
322    }
323
324    #[tokio::test]
325    async fn shell_rejects_non_allowlisted_command() {
326        let result = echo_tool().execute("pwd", &ctx()).await;
327        assert!(result.is_err());
328        assert!(result
329            .unwrap_err()
330            .to_string()
331            .contains("command is not in allowlist"));
332    }
333
334    #[tokio::test]
335    async fn shell_truncates_stdout_to_policy_limit() {
336        let mut policy = ShellPolicy::default_with_commands(vec!["echo".to_string()]);
337        policy.max_output_bytes = 8;
338        let tool = ShellTool::new(policy);
339        let result = tool
340            .execute("echo 1234567890", &ctx())
341            .await
342            .expect("shell should succeed");
343        assert!(result.output.contains("stdout:\n12345678"));
344        assert!(result.output.contains("<truncated at 8 bytes>"));
345    }
346
347    // B7: Context-aware policy tests
348
349    #[tokio::test]
350    async fn policy_allows_single_quoted_semicolon() {
351        let result = echo_tool()
352            .execute("echo 'hello;world'", &ctx())
353            .await
354            .expect("quoted semicolon should be allowed");
355        assert!(result.output.contains("hello;world"));
356    }
357
358    #[tokio::test]
359    async fn policy_allows_double_quoted_semicolon() {
360        let result = echo_tool()
361            .execute(r#"echo "hello;world""#, &ctx())
362            .await
363            .expect("quoted semicolon should be allowed");
364        assert!(result.output.contains("hello;world"));
365    }
366
367    #[tokio::test]
368    async fn policy_blocks_backtick_always() {
369        let result = echo_tool().execute("echo '`uname`'", &ctx()).await;
370        assert!(result.is_err());
371        assert!(result.unwrap_err().to_string().contains("always-forbidden"));
372    }
373
374    #[tokio::test]
375    async fn policy_blocks_unquoted_dollar() {
376        let result = echo_tool().execute("echo $HOME", &ctx()).await;
377        assert!(result.is_err());
378        assert!(result
379            .unwrap_err()
380            .to_string()
381            .contains("unquoted forbidden metacharacter"));
382    }
383
384    #[tokio::test]
385    async fn policy_allows_dollar_in_single_quotes() {
386        let result = echo_tool()
387            .execute("echo '$HOME'", &ctx())
388            .await
389            .expect("dollar in single quotes should be allowed");
390        assert!(result.output.contains("$HOME"));
391    }
392
393    #[tokio::test]
394    async fn legacy_mode_flat_check() {
395        let mut policy = ShellPolicy::default_with_commands(vec!["echo".to_string()]);
396        policy.command_policy = None; // disable context-aware
397        let tool = ShellTool::new(policy);
398        let result = tool.execute("echo hello;uname", &ctx()).await;
399        assert!(result.is_err());
400        assert!(result
401            .unwrap_err()
402            .to_string()
403            .contains("forbidden shell metacharacters"));
404    }
405
406    #[tokio::test]
407    async fn shell_quoted_argument_with_spaces() {
408        let result = echo_tool()
409            .execute("echo 'hello world'", &ctx())
410            .await
411            .expect("quoted spaces should work");
412        assert!(result.output.contains("hello world"));
413    }
414}