1use crate::shell_parse::{self, AnnotatedChar, QuoteContext};
2use agentzero_core::{Tool, ToolContext, ToolResult};
3use anyhow::{anyhow, Context};
4use async_trait::async_trait;
5use std::process::Stdio;
6use tokio::io::{AsyncRead, AsyncReadExt};
7use tokio::process::Command;
8
9const DEFAULT_MAX_SHELL_ARGS: usize = 32;
10const DEFAULT_MAX_ARG_LENGTH: usize = 4096;
11const DEFAULT_MAX_OUTPUT_BYTES: usize = 65536;
12const DEFAULT_FORBIDDEN_CHARS: &str = ";&|><$`\n\r";
13
14#[derive(Debug, Clone)]
19pub struct ShellCommandPolicy {
20 pub always_forbidden: Vec<char>,
22 pub forbidden_unquoted: Vec<char>,
24}
25
26impl Default for ShellCommandPolicy {
27 fn default() -> Self {
28 Self {
29 always_forbidden: vec!['`', '\0'],
30 forbidden_unquoted: vec![';', '&', '|', '>', '<', '$', '\n', '\r'],
31 }
32 }
33}
34
35impl ShellCommandPolicy {
36 pub fn from_legacy_forbidden_chars(chars: &str) -> Self {
38 let always: Vec<char> = chars.chars().filter(|c| *c == '`' || *c == '\0').collect();
39 let unquoted: Vec<char> = chars.chars().filter(|c| *c != '`' && *c != '\0').collect();
40 Self {
41 always_forbidden: always,
42 forbidden_unquoted: unquoted,
43 }
44 }
45
46 pub fn validate_token(&self, chars: &[AnnotatedChar]) -> anyhow::Result<()> {
48 for ac in chars {
49 if self.always_forbidden.contains(&ac.ch) {
50 anyhow::bail!(
51 "shell argument contains always-forbidden character: {:?}",
52 ac.ch
53 );
54 }
55 if ac.context == QuoteContext::Unquoted && self.forbidden_unquoted.contains(&ac.ch) {
56 anyhow::bail!(
57 "shell argument contains unquoted forbidden metacharacter: {:?}",
58 ac.ch
59 );
60 }
61 }
62 Ok(())
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct ShellPolicy {
68 pub allowed_commands: Vec<String>,
69 pub max_args: usize,
70 pub max_arg_length: usize,
71 pub max_output_bytes: usize,
72 pub forbidden_chars: String,
73 pub command_policy: Option<ShellCommandPolicy>,
76}
77
78impl ShellPolicy {
79 pub fn default_with_commands(allowed_commands: Vec<String>) -> Self {
80 Self {
81 allowed_commands,
82 max_args: DEFAULT_MAX_SHELL_ARGS,
83 max_arg_length: DEFAULT_MAX_ARG_LENGTH,
84 max_output_bytes: DEFAULT_MAX_OUTPUT_BYTES,
85 forbidden_chars: DEFAULT_FORBIDDEN_CHARS.to_string(),
86 command_policy: Some(ShellCommandPolicy::default()),
87 }
88 }
89}
90
91pub struct ShellTool {
92 policy: ShellPolicy,
93}
94
95impl ShellTool {
96 pub fn new(policy: ShellPolicy) -> Self {
97 Self { policy }
98 }
99
100 fn parse_and_validate(
102 policy: &ShellPolicy,
103 input: &str,
104 ) -> anyhow::Result<(String, Vec<String>)> {
105 if policy.command_policy.is_some() {
106 Self::parse_context_aware(policy, input)
107 } else {
108 Self::parse_legacy(policy, input)
109 }
110 }
111
112 fn parse_context_aware(
114 policy: &ShellPolicy,
115 input: &str,
116 ) -> anyhow::Result<(String, Vec<String>)> {
117 let tokens = shell_parse::tokenize(input)?;
118 let annotated = shell_parse::tokenize_annotated(input)?;
119
120 if tokens.is_empty() {
121 return Err(anyhow!("command is required"));
122 }
123
124 let command_name = tokens[0].text.clone();
125 let args: Vec<String> = tokens[1..].iter().map(|t| t.text.clone()).collect();
126
127 if args.len() > policy.max_args {
128 return Err(anyhow!("too many shell arguments"));
129 }
130
131 let cmd_policy = policy
133 .command_policy
134 .as_ref()
135 .expect("command_policy must be Some in context-aware mode");
136 for (i, token) in tokens.iter().enumerate().skip(1) {
137 if token.text.is_empty() {
138 return Err(anyhow!("empty shell argument is not allowed"));
139 }
140 if token.text.len() > policy.max_arg_length {
141 return Err(anyhow!("shell argument exceeds max length"));
142 }
143 cmd_policy.validate_token(&annotated[i])?;
144 }
145
146 Ok((command_name, args))
147 }
148
149 fn parse_legacy(policy: &ShellPolicy, input: &str) -> anyhow::Result<(String, Vec<String>)> {
151 let mut parts = input.split_whitespace();
152 let command_name = parts
153 .next()
154 .ok_or_else(|| anyhow!("command is required"))?
155 .to_string();
156 let args: Vec<String> = parts.map(ToString::to_string).collect();
157
158 if args.len() > policy.max_args {
159 return Err(anyhow!("too many shell arguments"));
160 }
161 for arg in &args {
162 if arg.is_empty() {
163 return Err(anyhow!("empty shell argument is not allowed"));
164 }
165 if arg.len() > policy.max_arg_length {
166 return Err(anyhow!("shell argument exceeds max length"));
167 }
168 if arg.chars().any(|c| policy.forbidden_chars.contains(c)) {
169 return Err(anyhow!(
170 "shell argument contains forbidden shell metacharacters"
171 ));
172 }
173 }
174
175 Ok((command_name, args))
176 }
177
178 async fn read_limited<R>(mut reader: R, max_bytes: usize) -> anyhow::Result<(Vec<u8>, bool)>
179 where
180 R: AsyncRead + Unpin,
181 {
182 let mut bytes = Vec::new();
183 let mut limited = (&mut reader).take((max_bytes + 1) as u64);
184 limited
185 .read_to_end(&mut bytes)
186 .await
187 .context("failed to capture command output")?;
188
189 let truncated = bytes.len() > max_bytes;
190 if truncated {
191 bytes.truncate(max_bytes);
192 }
193
194 Ok((bytes, truncated))
195 }
196
197 fn render_stream(name: &str, bytes: &[u8], truncated: bool, max_bytes: usize) -> String {
198 let mut out = format!("{name}:\n{}", String::from_utf8_lossy(bytes));
199 if truncated {
200 out.push_str(&format!("\n<truncated at {max_bytes} bytes>"));
201 }
202 out
203 }
204}
205
206#[async_trait]
207impl Tool for ShellTool {
208 fn name(&self) -> &'static str {
209 "shell"
210 }
211
212 fn description(&self) -> &'static str {
213 "Execute a shell command from the allowlist. Input is the full command line. Returns stdout, stderr, and exit code."
214 }
215
216 fn input_schema(&self) -> Option<serde_json::Value> {
217 Some(serde_json::json!({
218 "type": "object",
219 "properties": {
220 "command": {
221 "type": "string",
222 "description": "The shell command to execute (e.g. \"ls -la\", \"cargo build\")"
223 }
224 },
225 "required": ["command"]
226 }))
227 }
228
229 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
230 let (command_name, args) = Self::parse_and_validate(&self.policy, input)?;
231 if !self
232 .policy
233 .allowed_commands
234 .iter()
235 .any(|c| c == &command_name)
236 {
237 return Err(anyhow!("command is not in allowlist"));
238 }
239
240 let mut child = Command::new(&command_name)
241 .args(&args)
242 .stdout(Stdio::piped())
243 .stderr(Stdio::piped())
244 .spawn()
245 .context("shell command failed to execute")?;
246
247 let stdout_reader = child
248 .stdout
249 .take()
250 .ok_or_else(|| anyhow!("shell command did not provide stdout pipe"))?;
251 let stderr_reader = child
252 .stderr
253 .take()
254 .ok_or_else(|| anyhow!("shell command did not provide stderr pipe"))?;
255
256 let stdout_task = tokio::spawn(Self::read_limited(
257 stdout_reader,
258 self.policy.max_output_bytes,
259 ));
260 let stderr_task = tokio::spawn(Self::read_limited(
261 stderr_reader,
262 self.policy.max_output_bytes,
263 ));
264
265 let status = child.wait().await.context("shell command failed to run")?;
266 let (stdout, stdout_truncated) = stdout_task
267 .await
268 .context("failed joining stdout capture task")??;
269 let (stderr, stderr_truncated) = stderr_task
270 .await
271 .context("failed joining stderr capture task")??;
272
273 Ok(ToolResult {
274 output: format!(
275 "status={}\n{}\n{}",
276 status,
277 Self::render_stream(
278 "stdout",
279 &stdout,
280 stdout_truncated,
281 self.policy.max_output_bytes
282 ),
283 Self::render_stream(
284 "stderr",
285 &stderr,
286 stderr_truncated,
287 self.policy.max_output_bytes
288 )
289 ),
290 })
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use agentzero_core::{Tool, ToolContext};
298
299 fn echo_tool() -> ShellTool {
300 ShellTool::new(ShellPolicy::default_with_commands(vec!["echo".to_string()]))
301 }
302
303 fn ctx() -> ToolContext {
304 ToolContext::new(".".to_string())
305 }
306
307 #[tokio::test]
308 async fn shell_allows_allowlisted_command() {
309 let result = echo_tool()
310 .execute("echo hello", &ctx())
311 .await
312 .expect("shell should succeed");
313 assert!(result.output.contains("stdout:\nhello"));
314 }
315
316 #[tokio::test]
317 async fn shell_rejects_unquoted_metacharacters() {
318 let result = echo_tool().execute("echo hello;uname", &ctx()).await;
319 assert!(result.is_err());
320 let msg = result.unwrap_err().to_string();
321 assert!(msg.contains("unquoted forbidden metacharacter"));
322 }
323
324 #[tokio::test]
325 async fn shell_rejects_non_allowlisted_command() {
326 let result = echo_tool().execute("pwd", &ctx()).await;
327 assert!(result.is_err());
328 assert!(result
329 .unwrap_err()
330 .to_string()
331 .contains("command is not in allowlist"));
332 }
333
334 #[tokio::test]
335 async fn shell_truncates_stdout_to_policy_limit() {
336 let mut policy = ShellPolicy::default_with_commands(vec!["echo".to_string()]);
337 policy.max_output_bytes = 8;
338 let tool = ShellTool::new(policy);
339 let result = tool
340 .execute("echo 1234567890", &ctx())
341 .await
342 .expect("shell should succeed");
343 assert!(result.output.contains("stdout:\n12345678"));
344 assert!(result.output.contains("<truncated at 8 bytes>"));
345 }
346
347 #[tokio::test]
350 async fn policy_allows_single_quoted_semicolon() {
351 let result = echo_tool()
352 .execute("echo 'hello;world'", &ctx())
353 .await
354 .expect("quoted semicolon should be allowed");
355 assert!(result.output.contains("hello;world"));
356 }
357
358 #[tokio::test]
359 async fn policy_allows_double_quoted_semicolon() {
360 let result = echo_tool()
361 .execute(r#"echo "hello;world""#, &ctx())
362 .await
363 .expect("quoted semicolon should be allowed");
364 assert!(result.output.contains("hello;world"));
365 }
366
367 #[tokio::test]
368 async fn policy_blocks_backtick_always() {
369 let result = echo_tool().execute("echo '`uname`'", &ctx()).await;
370 assert!(result.is_err());
371 assert!(result.unwrap_err().to_string().contains("always-forbidden"));
372 }
373
374 #[tokio::test]
375 async fn policy_blocks_unquoted_dollar() {
376 let result = echo_tool().execute("echo $HOME", &ctx()).await;
377 assert!(result.is_err());
378 assert!(result
379 .unwrap_err()
380 .to_string()
381 .contains("unquoted forbidden metacharacter"));
382 }
383
384 #[tokio::test]
385 async fn policy_allows_dollar_in_single_quotes() {
386 let result = echo_tool()
387 .execute("echo '$HOME'", &ctx())
388 .await
389 .expect("dollar in single quotes should be allowed");
390 assert!(result.output.contains("$HOME"));
391 }
392
393 #[tokio::test]
394 async fn legacy_mode_flat_check() {
395 let mut policy = ShellPolicy::default_with_commands(vec!["echo".to_string()]);
396 policy.command_policy = None; let tool = ShellTool::new(policy);
398 let result = tool.execute("echo hello;uname", &ctx()).await;
399 assert!(result.is_err());
400 assert!(result
401 .unwrap_err()
402 .to_string()
403 .contains("forbidden shell metacharacters"));
404 }
405
406 #[tokio::test]
407 async fn shell_quoted_argument_with_spaces() {
408 let result = echo_tool()
409 .execute("echo 'hello world'", &ctx())
410 .await
411 .expect("quoted spaces should work");
412 assert!(result.output.contains("hello world"));
413 }
414}