1use std::path::PathBuf;
5use std::time::{Duration, Instant};
6
7use tokio::process::Command;
8use tokio_util::sync::CancellationToken;
9
10use schemars::JsonSchema;
11use serde::Deserialize;
12
13use std::sync::Arc;
14
15use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
16use crate::config::ShellConfig;
17use crate::executor::{
18 ClaimSource, FilterStats, ToolCall, ToolError, ToolEvent, ToolEventTx, ToolExecutor, ToolOutput,
19};
20use crate::filter::{OutputFilterRegistry, sanitize_output};
21use crate::permissions::{PermissionAction, PermissionPolicy};
22
23mod transaction;
24use transaction::{TransactionSnapshot, affected_paths, build_scope_matchers, is_write_command};
25
26const DEFAULT_BLOCKED: &[&str] = &[
27 "rm -rf /", "sudo", "mkfs", "dd if=", "curl", "wget", "nc ", "ncat", "netcat", "shutdown",
28 "reboot", "halt",
29];
30
31pub const DEFAULT_BLOCKED_COMMANDS: &[&str] = DEFAULT_BLOCKED;
36
37pub const SHELL_INTERPRETERS: &[&str] =
39 &["bash", "sh", "zsh", "fish", "dash", "ksh", "csh", "tcsh"];
40
41const SUBSHELL_METACHARS: &[&str] = &["$(", "`", "<(", ">("];
45
46#[must_use]
54pub fn check_blocklist(command: &str, blocklist: &[String]) -> Option<String> {
55 let lower = command.to_lowercase();
56 for meta in SUBSHELL_METACHARS {
58 if lower.contains(meta) {
59 return Some((*meta).to_owned());
60 }
61 }
62 let cleaned = strip_shell_escapes(&lower);
63 let commands = tokenize_commands(&cleaned);
64 for blocked in blocklist {
65 for cmd_tokens in &commands {
66 if tokens_match_pattern(cmd_tokens, blocked) {
67 return Some(blocked.clone());
68 }
69 }
70 }
71 None
72}
73
74#[must_use]
79pub fn effective_shell_command<'a>(binary: &str, args: &'a [String]) -> Option<&'a str> {
80 let base = binary.rsplit('/').next().unwrap_or(binary);
81 if !SHELL_INTERPRETERS.contains(&base) {
82 return None;
83 }
84 let pos = args.iter().position(|a| a == "-c")?;
86 args.get(pos + 1).map(String::as_str)
87}
88
89const NETWORK_COMMANDS: &[&str] = &["curl", "wget", "nc ", "ncat", "netcat"];
90
91#[derive(Deserialize, JsonSchema)]
92pub(crate) struct BashParams {
93 command: String,
95}
96
97#[derive(Debug)]
99pub struct ShellExecutor {
100 timeout: Duration,
101 blocked_commands: Vec<String>,
102 allowed_paths: Vec<PathBuf>,
103 confirm_patterns: Vec<String>,
104 env_blocklist: Vec<String>,
105 audit_logger: Option<Arc<AuditLogger>>,
106 tool_event_tx: Option<ToolEventTx>,
107 permission_policy: Option<PermissionPolicy>,
108 output_filter_registry: Option<OutputFilterRegistry>,
109 cancel_token: Option<CancellationToken>,
110 skill_env: std::sync::RwLock<Option<std::collections::HashMap<String, String>>>,
111 transactional: bool,
112 auto_rollback: bool,
113 auto_rollback_exit_codes: Vec<i32>,
114 snapshot_required: bool,
115 max_snapshot_bytes: u64,
116 transaction_scope_matchers: Vec<globset::GlobMatcher>,
117}
118
119impl ShellExecutor {
120 #[must_use]
121 pub fn new(config: &ShellConfig) -> Self {
122 let allowed: Vec<String> = config
123 .allowed_commands
124 .iter()
125 .map(|s| s.to_lowercase())
126 .collect();
127
128 let mut blocked: Vec<String> = DEFAULT_BLOCKED
129 .iter()
130 .filter(|s| !allowed.contains(&s.to_lowercase()))
131 .map(|s| (*s).to_owned())
132 .collect();
133 blocked.extend(config.blocked_commands.iter().map(|s| s.to_lowercase()));
134
135 if !config.allow_network {
136 for cmd in NETWORK_COMMANDS {
137 let lower = cmd.to_lowercase();
138 if !blocked.contains(&lower) {
139 blocked.push(lower);
140 }
141 }
142 }
143
144 blocked.sort();
145 blocked.dedup();
146
147 let allowed_paths = if config.allowed_paths.is_empty() {
148 vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
149 } else {
150 config.allowed_paths.iter().map(PathBuf::from).collect()
151 };
152
153 Self {
154 timeout: Duration::from_secs(config.timeout),
155 blocked_commands: blocked,
156 allowed_paths,
157 confirm_patterns: config.confirm_patterns.clone(),
158 env_blocklist: config.env_blocklist.clone(),
159 audit_logger: None,
160 tool_event_tx: None,
161 permission_policy: None,
162 output_filter_registry: None,
163 cancel_token: None,
164 skill_env: std::sync::RwLock::new(None),
165 transactional: config.transactional,
166 auto_rollback: config.auto_rollback,
167 auto_rollback_exit_codes: config.auto_rollback_exit_codes.clone(),
168 snapshot_required: config.snapshot_required,
169 max_snapshot_bytes: config.max_snapshot_bytes,
170 transaction_scope_matchers: build_scope_matchers(&config.transaction_scope),
171 }
172 }
173
174 pub fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
176 match self.skill_env.write() {
177 Ok(mut guard) => *guard = env,
178 Err(e) => tracing::error!("skill_env RwLock poisoned: {e}"),
179 }
180 }
181
182 #[must_use]
183 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
184 self.audit_logger = Some(logger);
185 self
186 }
187
188 #[must_use]
189 pub fn with_tool_event_tx(mut self, tx: ToolEventTx) -> Self {
190 self.tool_event_tx = Some(tx);
191 self
192 }
193
194 #[must_use]
195 pub fn with_permissions(mut self, policy: PermissionPolicy) -> Self {
196 self.permission_policy = Some(policy);
197 self
198 }
199
200 #[must_use]
201 pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
202 self.cancel_token = Some(token);
203 self
204 }
205
206 #[must_use]
207 pub fn with_output_filters(mut self, registry: OutputFilterRegistry) -> Self {
208 self.output_filter_registry = Some(registry);
209 self
210 }
211
212 pub async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
218 self.execute_inner(response, true).await
219 }
220
221 async fn execute_inner(
222 &self,
223 response: &str,
224 skip_confirm: bool,
225 ) -> Result<Option<ToolOutput>, ToolError> {
226 let blocks = extract_bash_blocks(response);
227 if blocks.is_empty() {
228 return Ok(None);
229 }
230
231 let mut outputs = Vec::with_capacity(blocks.len());
232 let mut cumulative_filter_stats: Option<FilterStats> = None;
233 #[allow(clippy::cast_possible_truncation)]
234 let blocks_executed = blocks.len() as u32;
235
236 for block in &blocks {
237 let (output_line, per_block_stats) = self.execute_block(block, skip_confirm).await?;
238 if let Some(fs) = per_block_stats {
239 let stats = cumulative_filter_stats.get_or_insert_with(FilterStats::default);
240 stats.raw_chars += fs.raw_chars;
241 stats.filtered_chars += fs.filtered_chars;
242 stats.raw_lines += fs.raw_lines;
243 stats.filtered_lines += fs.filtered_lines;
244 stats.confidence = Some(match (stats.confidence, fs.confidence) {
245 (Some(prev), Some(cur)) => crate::filter::worse_confidence(prev, cur),
246 (Some(prev), None) => prev,
247 (None, Some(cur)) => cur,
248 (None, None) => unreachable!(),
249 });
250 if stats.command.is_none() {
251 stats.command = fs.command;
252 }
253 if stats.kept_lines.is_empty() && !fs.kept_lines.is_empty() {
254 stats.kept_lines = fs.kept_lines;
255 }
256 }
257 outputs.push(output_line);
258 }
259
260 Ok(Some(ToolOutput {
261 tool_name: "bash".to_owned(),
262 summary: outputs.join("\n\n"),
263 blocks_executed,
264 filter_stats: cumulative_filter_stats,
265 diff: None,
266 streamed: self.tool_event_tx.is_some(),
267 terminal_id: None,
268 locations: None,
269 raw_response: None,
270 claim_source: Some(ClaimSource::Shell),
271 }))
272 }
273
274 #[allow(clippy::too_many_lines)]
275 async fn execute_block(
276 &self,
277 block: &str,
278 skip_confirm: bool,
279 ) -> Result<(String, Option<FilterStats>), ToolError> {
280 self.check_permissions(block, skip_confirm).await?;
281 self.validate_sandbox(block)?;
282
283 let mut snapshot_warning: Option<String> = None;
285 let snapshot = if self.transactional && is_write_command(block) {
286 let paths = affected_paths(block, &self.transaction_scope_matchers);
287 if paths.is_empty() {
288 None
289 } else {
290 match TransactionSnapshot::capture(&paths, self.max_snapshot_bytes) {
291 Ok(snap) => {
292 tracing::debug!(
293 files = snap.file_count(),
294 bytes = snap.total_bytes(),
295 "transaction snapshot captured"
296 );
297 Some(snap)
298 }
299 Err(e) if self.snapshot_required => {
300 return Err(ToolError::SnapshotFailed {
301 reason: e.to_string(),
302 });
303 }
304 Err(e) => {
305 tracing::warn!(err = %e, "transaction snapshot failed, proceeding without rollback");
306 snapshot_warning =
307 Some(format!("[warn] snapshot failed: {e}; rollback unavailable"));
308 None
309 }
310 }
311 }
312 } else {
313 None
314 };
315
316 if let Some(ref tx) = self.tool_event_tx {
317 let _ = tx.send(ToolEvent::Started {
318 tool_name: "bash".to_owned(),
319 command: block.to_owned(),
320 });
321 }
322
323 let start = Instant::now();
324 let skill_env_snapshot: Option<std::collections::HashMap<String, String>> =
325 self.skill_env.read().ok().and_then(|g| g.clone());
326 let (out, exit_code) = execute_bash(
327 block,
328 self.timeout,
329 self.tool_event_tx.as_ref(),
330 self.cancel_token.as_ref(),
331 skill_env_snapshot.as_ref(),
332 &self.env_blocklist,
333 )
334 .await;
335 if exit_code == 130
336 && self
337 .cancel_token
338 .as_ref()
339 .is_some_and(CancellationToken::is_cancelled)
340 {
341 return Err(ToolError::Cancelled);
342 }
343 #[allow(clippy::cast_possible_truncation)]
344 let duration_ms = start.elapsed().as_millis() as u64;
345
346 if let Some(snap) = snapshot {
348 let should_rollback = self.auto_rollback
349 && if self.auto_rollback_exit_codes.is_empty() {
350 exit_code >= 2
351 } else {
352 self.auto_rollback_exit_codes.contains(&exit_code)
353 };
354 if should_rollback {
355 match snap.rollback() {
356 Ok(report) => {
357 tracing::info!(
358 restored = report.restored_count,
359 deleted = report.deleted_count,
360 "transaction rollback completed"
361 );
362 self.log_audit(
363 block,
364 AuditResult::Rollback {
365 restored: report.restored_count,
366 deleted: report.deleted_count,
367 },
368 duration_ms,
369 None,
370 )
371 .await;
372 if let Some(ref tx) = self.tool_event_tx {
373 let _ = tx.send(ToolEvent::Rollback {
374 tool_name: "bash".to_owned(),
375 command: block.to_owned(),
376 restored_count: report.restored_count,
377 deleted_count: report.deleted_count,
378 });
379 }
380 }
381 Err(e) => {
382 tracing::error!(err = %e, "transaction rollback failed");
383 }
384 }
385 }
386 }
388
389 let is_timeout = out.contains("[error] command timed out");
390 let audit_result = if is_timeout {
391 AuditResult::Timeout
392 } else if out.contains("[error]") || out.contains("[stderr]") {
393 AuditResult::Error {
394 message: out.clone(),
395 }
396 } else {
397 AuditResult::Success
398 };
399 self.log_audit(block, audit_result, duration_ms, None).await;
400
401 if is_timeout {
402 self.emit_completed(block, &out, false, None);
403 return Err(ToolError::Timeout {
404 timeout_secs: self.timeout.as_secs(),
405 });
406 }
407
408 if let Some(category) = classify_shell_exit(exit_code, &out) {
409 self.emit_completed(block, &out, false, None);
410 return Err(ToolError::Shell {
411 exit_code,
412 category,
413 message: out.lines().take(3).collect::<Vec<_>>().join("; "),
414 });
415 }
416
417 let sanitized = sanitize_output(&out);
418 let mut per_block_stats: Option<FilterStats> = None;
419 let filtered = if let Some(ref registry) = self.output_filter_registry {
420 match registry.apply(block, &sanitized, exit_code) {
421 Some(fr) => {
422 tracing::debug!(
423 command = block,
424 raw = fr.raw_chars,
425 filtered = fr.filtered_chars,
426 savings_pct = fr.savings_pct(),
427 "output filter applied"
428 );
429 per_block_stats = Some(FilterStats {
430 raw_chars: fr.raw_chars,
431 filtered_chars: fr.filtered_chars,
432 raw_lines: fr.raw_lines,
433 filtered_lines: fr.filtered_lines,
434 confidence: Some(fr.confidence),
435 command: Some(block.to_owned()),
436 kept_lines: fr.kept_lines.clone(),
437 });
438 fr.output
439 }
440 None => sanitized,
441 }
442 } else {
443 sanitized
444 };
445
446 self.emit_completed(
447 block,
448 &out,
449 !out.contains("[error]"),
450 per_block_stats.clone(),
451 );
452
453 let output_line = if let Some(warn) = snapshot_warning {
454 format!("{warn}\n$ {block}\n{filtered}")
455 } else {
456 format!("$ {block}\n{filtered}")
457 };
458 Ok((output_line, per_block_stats))
459 }
460
461 fn emit_completed(
462 &self,
463 command: &str,
464 output: &str,
465 success: bool,
466 filter_stats: Option<FilterStats>,
467 ) {
468 if let Some(ref tx) = self.tool_event_tx {
469 let _ = tx.send(ToolEvent::Completed {
470 tool_name: "bash".to_owned(),
471 command: command.to_owned(),
472 output: output.to_owned(),
473 success,
474 filter_stats,
475 diff: None,
476 });
477 }
478 }
479
480 async fn check_permissions(&self, block: &str, skip_confirm: bool) -> Result<(), ToolError> {
482 if let Some(blocked) = self.find_blocked_command(block) {
485 let err = ToolError::Blocked {
486 command: blocked.to_owned(),
487 };
488 self.log_audit(
489 block,
490 AuditResult::Blocked {
491 reason: format!("blocked command: {blocked}"),
492 },
493 0,
494 Some(&err),
495 )
496 .await;
497 return Err(err);
498 }
499
500 if let Some(ref policy) = self.permission_policy {
501 match policy.check("bash", block) {
502 PermissionAction::Deny => {
503 let err = ToolError::Blocked {
504 command: block.to_owned(),
505 };
506 self.log_audit(
507 block,
508 AuditResult::Blocked {
509 reason: "denied by permission policy".to_owned(),
510 },
511 0,
512 Some(&err),
513 )
514 .await;
515 return Err(err);
516 }
517 PermissionAction::Ask if !skip_confirm => {
518 return Err(ToolError::ConfirmationRequired {
519 command: block.to_owned(),
520 });
521 }
522 _ => {}
523 }
524 } else if !skip_confirm && let Some(pattern) = self.find_confirm_command(block) {
525 return Err(ToolError::ConfirmationRequired {
526 command: pattern.to_owned(),
527 });
528 }
529
530 Ok(())
531 }
532
533 fn validate_sandbox(&self, code: &str) -> Result<(), ToolError> {
534 let cwd = std::env::current_dir().unwrap_or_default();
535
536 for token in extract_paths(code) {
537 if has_traversal(&token) {
538 return Err(ToolError::SandboxViolation { path: token });
539 }
540
541 let path = if token.starts_with('/') {
542 PathBuf::from(&token)
543 } else {
544 cwd.join(&token)
545 };
546 let canonical = path
547 .canonicalize()
548 .or_else(|_| std::path::absolute(&path))
549 .unwrap_or(path);
550 if !self
551 .allowed_paths
552 .iter()
553 .any(|allowed| canonical.starts_with(allowed))
554 {
555 return Err(ToolError::SandboxViolation {
556 path: canonical.display().to_string(),
557 });
558 }
559 }
560 Ok(())
561 }
562
563 fn find_blocked_command(&self, code: &str) -> Option<&str> {
598 let cleaned = strip_shell_escapes(&code.to_lowercase());
599 let commands = tokenize_commands(&cleaned);
600 for blocked in &self.blocked_commands {
601 for cmd_tokens in &commands {
602 if tokens_match_pattern(cmd_tokens, blocked) {
603 return Some(blocked.as_str());
604 }
605 }
606 }
607 for inner in extract_subshell_contents(&cleaned) {
609 let inner_commands = tokenize_commands(&inner);
610 for blocked in &self.blocked_commands {
611 for cmd_tokens in &inner_commands {
612 if tokens_match_pattern(cmd_tokens, blocked) {
613 return Some(blocked.as_str());
614 }
615 }
616 }
617 }
618 None
619 }
620
621 fn find_confirm_command(&self, code: &str) -> Option<&str> {
622 let normalized = code.to_lowercase();
623 for pattern in &self.confirm_patterns {
624 if normalized.contains(pattern.as_str()) {
625 return Some(pattern.as_str());
626 }
627 }
628 None
629 }
630
631 async fn log_audit(
632 &self,
633 command: &str,
634 result: AuditResult,
635 duration_ms: u64,
636 error: Option<&ToolError>,
637 ) {
638 if let Some(ref logger) = self.audit_logger {
639 let (error_category, error_domain, error_phase) =
640 error.map_or((None, None, None), |e| {
641 let cat = e.category();
642 (
643 Some(cat.label().to_owned()),
644 Some(cat.domain().label().to_owned()),
645 Some(cat.phase().label().to_owned()),
646 )
647 });
648 let entry = AuditEntry {
649 timestamp: chrono_now(),
650 tool: "shell".into(),
651 command: command.into(),
652 result,
653 duration_ms,
654 error_category,
655 error_domain,
656 error_phase,
657 claim_source: Some(ClaimSource::Shell),
658 mcp_server_id: None,
659 injection_flagged: false,
660 embedding_anomalous: false,
661 cross_boundary_mcp_to_acp: false,
662 adversarial_policy_decision: None,
663 };
664 logger.log(&entry).await;
665 }
666 }
667}
668
669impl ToolExecutor for ShellExecutor {
670 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
671 self.execute_inner(response, false).await
672 }
673
674 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
675 use crate::registry::{InvocationHint, ToolDef};
676 vec![ToolDef {
677 id: "bash".into(),
678 description: "Execute a shell command and return stdout/stderr.\n\nParameters: command (string, required) - shell command to run\nReturns: stdout and stderr combined, prefixed with exit code\nErrors: Blocked if command matches security policy; Timeout after configured seconds; SandboxViolation if path outside allowed dirs\nExample: {\"command\": \"ls -la /tmp\"}".into(),
679 schema: schemars::schema_for!(BashParams),
680 invocation: InvocationHint::FencedBlock("bash"),
681 }]
682 }
683
684 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
685 if call.tool_id != "bash" {
686 return Ok(None);
687 }
688 let params: BashParams = crate::executor::deserialize_params(&call.params)?;
689 if params.command.is_empty() {
690 return Ok(None);
691 }
692 let command = ¶ms.command;
693 let synthetic = format!("```bash\n{command}\n```");
695 self.execute_inner(&synthetic, false).await
696 }
697
698 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
699 ShellExecutor::set_skill_env(self, env);
700 }
701}
702
703pub(crate) fn strip_shell_escapes(input: &str) -> String {
707 let mut out = String::with_capacity(input.len());
708 let bytes = input.as_bytes();
709 let mut i = 0;
710 while i < bytes.len() {
711 if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'\'' {
713 let mut j = i + 2; let mut decoded = String::new();
715 let mut valid = false;
716 while j < bytes.len() && bytes[j] != b'\'' {
717 if bytes[j] == b'\\' && j + 1 < bytes.len() {
718 let next = bytes[j + 1];
719 if next == b'x' && j + 3 < bytes.len() {
720 let hi = (bytes[j + 2] as char).to_digit(16);
722 let lo = (bytes[j + 3] as char).to_digit(16);
723 if let (Some(h), Some(l)) = (hi, lo) {
724 #[allow(clippy::cast_possible_truncation)]
725 let byte = ((h << 4) | l) as u8;
726 decoded.push(byte as char);
727 j += 4;
728 valid = true;
729 continue;
730 }
731 } else if next.is_ascii_digit() {
732 let mut val = u32::from(next - b'0');
734 let mut len = 2; if j + 2 < bytes.len() && bytes[j + 2].is_ascii_digit() {
736 val = val * 8 + u32::from(bytes[j + 2] - b'0');
737 len = 3;
738 if j + 3 < bytes.len() && bytes[j + 3].is_ascii_digit() {
739 val = val * 8 + u32::from(bytes[j + 3] - b'0');
740 len = 4;
741 }
742 }
743 #[allow(clippy::cast_possible_truncation)]
744 decoded.push((val & 0xFF) as u8 as char);
745 j += len;
746 valid = true;
747 continue;
748 }
749 decoded.push(next as char);
751 j += 2;
752 } else {
753 decoded.push(bytes[j] as char);
754 j += 1;
755 }
756 }
757 if j < bytes.len() && bytes[j] == b'\'' && valid {
758 out.push_str(&decoded);
759 i = j + 1;
760 continue;
761 }
762 }
764 if bytes[i] == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'\n' {
766 i += 2;
767 continue;
768 }
769 if bytes[i] == b'\\' && i + 1 < bytes.len() && bytes[i + 1] != b'\n' {
771 i += 1;
772 out.push(bytes[i] as char);
773 i += 1;
774 continue;
775 }
776 if bytes[i] == b'"' || bytes[i] == b'\'' {
778 let quote = bytes[i];
779 i += 1;
780 while i < bytes.len() && bytes[i] != quote {
781 out.push(bytes[i] as char);
782 i += 1;
783 }
784 if i < bytes.len() {
785 i += 1; }
787 continue;
788 }
789 out.push(bytes[i] as char);
790 i += 1;
791 }
792 out
793}
794
795pub(crate) fn extract_subshell_contents(s: &str) -> Vec<String> {
805 let mut results = Vec::new();
806 let chars: Vec<char> = s.chars().collect();
807 let len = chars.len();
808 let mut i = 0;
809
810 while i < len {
811 if chars[i] == '`' {
813 let start = i + 1;
814 let mut j = start;
815 while j < len && chars[j] != '`' {
816 j += 1;
817 }
818 if j < len {
819 results.push(chars[start..j].iter().collect());
820 }
821 i = j + 1;
822 continue;
823 }
824
825 let next_is_open_paren = i + 1 < len && chars[i + 1] == '(';
827 let is_paren_subshell = next_is_open_paren && matches!(chars[i], '$' | '<' | '>');
828
829 if is_paren_subshell {
830 let start = i + 2;
831 let mut depth: usize = 1;
832 let mut j = start;
833 while j < len && depth > 0 {
834 match chars[j] {
835 '(' => depth += 1,
836 ')' => depth -= 1,
837 _ => {}
838 }
839 if depth > 0 {
840 j += 1;
841 } else {
842 break;
843 }
844 }
845 if depth == 0 {
846 results.push(chars[start..j].iter().collect());
847 }
848 i = j + 1;
849 continue;
850 }
851
852 i += 1;
853 }
854
855 results
856}
857
858pub(crate) fn tokenize_commands(normalized: &str) -> Vec<Vec<String>> {
861 let replaced = normalized.replace("||", "\n").replace("&&", "\n");
863 replaced
864 .split([';', '|', '\n'])
865 .map(|seg| {
866 seg.split_whitespace()
867 .map(str::to_owned)
868 .collect::<Vec<String>>()
869 })
870 .filter(|tokens| !tokens.is_empty())
871 .collect()
872}
873
874const TRANSPARENT_PREFIXES: &[&str] = &["env", "command", "exec", "nice", "nohup", "time", "xargs"];
877
878fn cmd_basename(tok: &str) -> &str {
880 tok.rsplit('/').next().unwrap_or(tok)
881}
882
883pub(crate) fn tokens_match_pattern(tokens: &[String], pattern: &str) -> bool {
890 if tokens.is_empty() || pattern.is_empty() {
891 return false;
892 }
893 let pattern = pattern.trim();
894 let pattern_tokens: Vec<&str> = pattern.split_whitespace().collect();
895 if pattern_tokens.is_empty() {
896 return false;
897 }
898
899 let start = tokens
901 .iter()
902 .position(|t| !TRANSPARENT_PREFIXES.contains(&cmd_basename(t)))
903 .unwrap_or(0);
904 let effective = &tokens[start..];
905 if effective.is_empty() {
906 return false;
907 }
908
909 if pattern_tokens.len() == 1 {
910 let pat = pattern_tokens[0];
911 let base = cmd_basename(&effective[0]);
912 base == pat || base.starts_with(&format!("{pat}."))
914 } else {
915 let n = pattern_tokens.len().min(effective.len());
917 let mut parts: Vec<&str> = vec![cmd_basename(&effective[0])];
918 parts.extend(effective[1..n].iter().map(String::as_str));
919 let joined = parts.join(" ");
920 if joined.starts_with(pattern) {
921 return true;
922 }
923 if effective.len() > n {
924 let mut parts2: Vec<&str> = vec![cmd_basename(&effective[0])];
925 parts2.extend(effective[1..=n].iter().map(String::as_str));
926 parts2.join(" ").starts_with(pattern)
927 } else {
928 false
929 }
930 }
931}
932
933fn extract_paths(code: &str) -> Vec<String> {
934 let mut result = Vec::new();
935
936 let mut tokens: Vec<String> = Vec::new();
938 let mut current = String::new();
939 let mut chars = code.chars().peekable();
940 while let Some(c) = chars.next() {
941 match c {
942 '"' | '\'' => {
943 let quote = c;
944 while let Some(&nc) = chars.peek() {
945 if nc == quote {
946 chars.next();
947 break;
948 }
949 current.push(chars.next().unwrap());
950 }
951 }
952 c if c.is_whitespace() || matches!(c, ';' | '|' | '&') => {
953 if !current.is_empty() {
954 tokens.push(std::mem::take(&mut current));
955 }
956 }
957 _ => current.push(c),
958 }
959 }
960 if !current.is_empty() {
961 tokens.push(current);
962 }
963
964 for token in tokens {
965 let trimmed = token.trim_end_matches([';', '&', '|']).to_owned();
966 if trimmed.is_empty() {
967 continue;
968 }
969 if trimmed.starts_with('/')
970 || trimmed.starts_with("./")
971 || trimmed.starts_with("../")
972 || trimmed == ".."
973 {
974 result.push(trimmed);
975 }
976 }
977 result
978}
979
980fn classify_shell_exit(
986 exit_code: i32,
987 output: &str,
988) -> Option<crate::error_taxonomy::ToolErrorCategory> {
989 use crate::error_taxonomy::ToolErrorCategory;
990 match exit_code {
991 126 => Some(ToolErrorCategory::PolicyBlocked),
993 127 => Some(ToolErrorCategory::PermanentFailure),
995 _ => {
996 let lower = output.to_lowercase();
997 if lower.contains("permission denied") {
998 Some(ToolErrorCategory::PolicyBlocked)
999 } else if lower.contains("no such file or directory") {
1000 Some(ToolErrorCategory::PermanentFailure)
1001 } else {
1002 None
1003 }
1004 }
1005 }
1006}
1007
1008fn has_traversal(path: &str) -> bool {
1009 path.split('/').any(|seg| seg == "..")
1010}
1011
1012fn extract_bash_blocks(text: &str) -> Vec<&str> {
1013 crate::executor::extract_fenced_blocks(text, "bash")
1014}
1015
1016async fn kill_process_tree(child: &mut tokio::process::Child) {
1020 #[cfg(unix)]
1021 if let Some(pid) = child.id() {
1022 let _ = Command::new("pkill")
1023 .args(["-KILL", "-P", &pid.to_string()])
1024 .status()
1025 .await;
1026 }
1027 let _ = child.kill().await;
1028}
1029
1030async fn execute_bash(
1031 code: &str,
1032 timeout: Duration,
1033 event_tx: Option<&ToolEventTx>,
1034 cancel_token: Option<&CancellationToken>,
1035 extra_env: Option<&std::collections::HashMap<String, String>>,
1036 env_blocklist: &[String],
1037) -> (String, i32) {
1038 use std::process::Stdio;
1039 use tokio::io::{AsyncBufReadExt, BufReader};
1040
1041 let timeout_secs = timeout.as_secs();
1042
1043 let mut cmd = Command::new("bash");
1044 cmd.arg("-c")
1045 .arg(code)
1046 .stdout(Stdio::piped())
1047 .stderr(Stdio::piped());
1048
1049 for (key, _) in std::env::vars() {
1050 if env_blocklist
1051 .iter()
1052 .any(|prefix| key.starts_with(prefix.as_str()))
1053 {
1054 cmd.env_remove(&key);
1055 }
1056 }
1057
1058 if let Some(env) = extra_env {
1059 cmd.envs(env);
1060 }
1061 let child_result = cmd.spawn();
1062
1063 let mut child = match child_result {
1064 Ok(c) => c,
1065 Err(e) => return (format!("[error] {e}"), 1),
1066 };
1067
1068 let stdout = child.stdout.take().expect("stdout piped");
1069 let stderr = child.stderr.take().expect("stderr piped");
1070
1071 let (line_tx, mut line_rx) = tokio::sync::mpsc::channel::<String>(64);
1072
1073 let stdout_tx = line_tx.clone();
1074 tokio::spawn(async move {
1075 let mut reader = BufReader::new(stdout);
1076 let mut buf = String::new();
1077 while reader.read_line(&mut buf).await.unwrap_or(0) > 0 {
1078 let _ = stdout_tx.send(buf.clone()).await;
1079 buf.clear();
1080 }
1081 });
1082
1083 tokio::spawn(async move {
1084 let mut reader = BufReader::new(stderr);
1085 let mut buf = String::new();
1086 while reader.read_line(&mut buf).await.unwrap_or(0) > 0 {
1087 let _ = line_tx.send(format!("[stderr] {buf}")).await;
1088 buf.clear();
1089 }
1090 });
1091
1092 let mut combined = String::new();
1093 let deadline = tokio::time::Instant::now() + timeout;
1094
1095 loop {
1096 tokio::select! {
1097 line = line_rx.recv() => {
1098 match line {
1099 Some(chunk) => {
1100 if let Some(tx) = event_tx {
1101 let _ = tx.send(ToolEvent::OutputChunk {
1102 tool_name: "bash".to_owned(),
1103 command: code.to_owned(),
1104 chunk: chunk.clone(),
1105 });
1106 }
1107 combined.push_str(&chunk);
1108 }
1109 None => break,
1110 }
1111 }
1112 () = tokio::time::sleep_until(deadline) => {
1113 kill_process_tree(&mut child).await;
1114 return (format!("[error] command timed out after {timeout_secs}s"), 1);
1115 }
1116 () = async {
1117 match cancel_token {
1118 Some(t) => t.cancelled().await,
1119 None => std::future::pending().await,
1120 }
1121 } => {
1122 kill_process_tree(&mut child).await;
1123 return ("[cancelled] operation aborted".to_string(), 130);
1124 }
1125 }
1126 }
1127
1128 let status = child.wait().await;
1129 let exit_code = status.ok().and_then(|s| s.code()).unwrap_or(1);
1130
1131 if combined.is_empty() {
1132 ("(no output)".to_string(), exit_code)
1133 } else {
1134 (combined, exit_code)
1135 }
1136}
1137
1138#[cfg(test)]
1139mod tests;