Skip to main content

agent_shell_parser/parse/
shell.rs

1//! Shell command parsing backed by tree-sitter-bash.
2//!
3//! Public API:
4//!
5//! - [`parse_with_substitutions`] — decomposes a shell command into a
6//!   recursive [`ParsedPipeline`] tree.
7//! - [`has_output_redirection`] — mutation-detection for redirects.
8//! - [`dump_ast`] — diagnostic output.
9//!
10//! The parser uses tree-sitter-bash for a full AST, then walks it to
11//! produce segments joined by operators. Substitutions (`$()`, backticks,
12//! `<()`, `>()`) are recursively parsed into nested pipelines — the
13//! result is a tree that can be evaluated bottom-up (catamorphism).
14//!
15//! # Control flow handling
16//!
17//! Shell keywords (`for`, `if`, `while`, `case`) are grammar structure,
18//! not commands. The walker recurses into their bodies and extracts the
19//! actual commands as segments.
20//!
21//! # Redirection propagation
22//!
23//! When a control flow construct has output redirection
24//! (e.g. `for ... done > file`), it propagates to inner segments via
25//! [`ShellSegment::redirection`].
26
27use super::redirect::detect_redirections;
28use super::subst::{assign_substitutions, build_segments, collect_substitutions};
29use super::types::{ParseError, ParsedPipeline, ShellSegment, Word};
30use super::walk::walk_ast;
31use std::cell::{Cell, RefCell};
32use tree_sitter::{Parser, Tree};
33
34/// Maximum number of tree-sitter parse calls across all recursion levels.
35/// Prevents exponential fan-out DoS (e.g. `echo $(a) $(b) $(c) ...` nested).
36const MAX_TOTAL_PARSES: usize = 512;
37
38/// Maximum input length accepted by the parser (64 KiB).
39const MAX_INPUT_LENGTH: usize = 64 * 1024;
40
41// ---------------------------------------------------------------------------
42// Thread-local parser
43// ---------------------------------------------------------------------------
44
45thread_local! {
46    /// tree-sitter `Parser` is `!Send`, so we use `thread_local!` storage.
47    ///
48    /// # Async safety
49    ///
50    /// The `RefCell` borrow is acquired and released within the synchronous
51    /// `parse_tree()` call — it never crosses an `.await` point. Each
52    /// thread in an async runtime pool gets its own parser instance.
53    /// `parse_tree()` must remain synchronous.
54    static TS_PARSER: RefCell<Parser> = RefCell::new({
55        let mut p = Parser::new();
56        p.set_language(&tree_sitter_bash::LANGUAGE.into())
57            .expect("failed to load bash grammar");
58        p
59    });
60}
61
62fn parse_tree(source: &str, budget: &Cell<usize>) -> Result<Tree, ParseError> {
63    let count = budget.get();
64    if count >= MAX_TOTAL_PARSES {
65        return Err(ParseError);
66    }
67    budget.set(count + 1);
68    TS_PARSER.with(|p| p.borrow_mut().parse(source, None).ok_or(ParseError))
69}
70
71// ---------------------------------------------------------------------------
72// Public API
73// ---------------------------------------------------------------------------
74
75/// Parse a shell command into a recursive pipeline tree.
76///
77/// Substitutions are recursively parsed: `echo $(cmd1 && cmd2)` produces
78/// a segment whose substitution contains a two-segment pipeline. The tree
79/// can be evaluated bottom-up — inner substitutions execute first.
80///
81/// Recursion depth is capped at 32 levels. Deeper nesting produces an
82/// empty pipeline with `has_parse_errors: true`.
83pub fn parse_with_substitutions(command: &str) -> Result<ParsedPipeline, ParseError> {
84    if command.len() > MAX_INPUT_LENGTH {
85        return Ok(ParsedPipeline::empty_with_error());
86    }
87    let budget = Cell::new(0);
88    parse_with_substitutions_impl(command, 0, &budget)
89}
90
91fn parse_with_substitutions_impl(
92    command: &str,
93    depth: usize,
94    budget: &Cell<usize>,
95) -> Result<ParsedPipeline, ParseError> {
96    let tree = parse_tree(command, budget)?;
97    let root = tree.root_node();
98    let source = command.as_bytes();
99    let has_parse_errors = root.has_error();
100
101    let mut raw_substs = Vec::new();
102    collect_substitutions(root, source, &mut raw_substs);
103
104    let walk = walk_ast(root, source);
105
106    let trimmed = command.trim();
107    let is_trivial = walk.segments.len() <= 1
108        && raw_substs.is_empty()
109        && walk
110            .segments
111            .first()
112            .is_none_or(|seg| seg.start == 0 && seg.end >= trimmed.len());
113
114    if is_trivial {
115        let first_seg = walk.segments.first();
116        let redir = first_seg
117            .and_then(|seg| seg.redirection.clone())
118            .or_else(|| detect_redirections(root, source));
119        let words = first_seg.map(|seg| seg.words.clone()).unwrap_or_else(|| {
120            // No segment produced (e.g. empty program) — shlex the trimmed text.
121            shlex::split(trimmed)
122                .unwrap_or_else(|| trimmed.split_whitespace().map(String::from).collect())
123                .into_iter()
124                .map(Word::from)
125                .collect()
126        });
127        return Ok(ParsedPipeline {
128            segments: vec![ShellSegment {
129                command: trimmed.to_string(),
130                words,
131                redirection: redir,
132                substitutions: vec![],
133            }],
134            operators: vec![],
135            structural_substitutions: vec![],
136            has_parse_errors,
137        });
138    }
139
140    let built = build_segments(&walk, command);
141    let (per_segment_subs, structural_subs) =
142        assign_substitutions(&raw_substs, &built, depth, &|inner, d| {
143            parse_with_substitutions_impl(inner, d, budget)
144        });
145
146    let segments: Vec<ShellSegment> = built
147        .into_iter()
148        .zip(per_segment_subs)
149        .map(|(b, subs)| ShellSegment {
150            command: b.command,
151            words: b.words,
152            redirection: b.redirection,
153            substitutions: subs,
154        })
155        .collect();
156
157    Ok(ParsedPipeline {
158        segments,
159        operators: walk.operators,
160        structural_substitutions: structural_subs,
161        has_parse_errors,
162    })
163}
164
165/// Check whether `command` contains output redirection.
166pub fn has_output_redirection(
167    command: &str,
168) -> Result<Option<super::types::Redirection>, ParseError> {
169    let budget = Cell::new(0);
170    let tree = parse_tree(command, &budget)?;
171    Ok(detect_redirections(tree.root_node(), command.as_bytes()))
172}
173
174/// Diagnostic: dump the tree-sitter AST and parsed pipeline.
175///
176/// Sections 1 (AST dump) and 3 (redirection check) share a single
177/// parse tree. Section 2 (pipeline decomposition) calls
178/// [`parse_with_substitutions`] separately — it builds the recursive
179/// pipeline structure from scratch.
180pub fn dump_ast(command: &str) -> Result<String, ParseError> {
181    use std::fmt::Write;
182    let mut out = String::new();
183
184    let budget = Cell::new(0);
185    let tree = parse_tree(command, &budget)?;
186    let root = tree.root_node();
187    let source = command.as_bytes();
188
189    // Section 1: raw AST
190    writeln!(out, "── tree-sitter AST ──").unwrap();
191    fn print_node(out: &mut String, node: tree_sitter::Node, source: &[u8], indent: usize) {
192        let text = node.utf8_text(source).unwrap_or("???");
193        let short: String = text.chars().take(60).collect();
194        let tag = if node.is_named() { "named" } else { "anon" };
195        writeln!(
196            out,
197            "{}{} [{}] {:?}",
198            "  ".repeat(indent),
199            node.kind(),
200            tag,
201            short
202        )
203        .unwrap();
204        let mut cursor = node.walk();
205        for child in node.children(&mut cursor) {
206            print_node(out, child, source, indent + 1);
207        }
208    }
209    print_node(&mut out, root, source, 0);
210
211    // Section 2: parsed pipeline (reuses the public API — separate parse is
212    // unavoidable here since parse_with_substitutions_impl builds from scratch,
213    // but this is a diagnostic function so the cost is acceptable)
214    let pipeline = parse_with_substitutions(command)?;
215    writeln!(out, "\n── parsed pipeline ──").unwrap();
216    if pipeline.has_parse_errors {
217        writeln!(out, "  (parse errors detected — best-effort result)").unwrap();
218    }
219    fn print_pipeline(out: &mut String, p: &ParsedPipeline, indent: usize) {
220        let pad = "  ".repeat(indent);
221        for sub in &p.structural_substitutions {
222            writeln!(
223                out,
224                "{pad}structural subst bytes {}..{}:",
225                sub.start, sub.end
226            )
227            .unwrap();
228            print_pipeline(out, &sub.pipeline, indent + 1);
229        }
230        for (i, seg) in p.segments.iter().enumerate() {
231            let redir = seg
232                .redirection
233                .as_ref()
234                .map(|r| format!(" [{r}]"))
235                .unwrap_or_default();
236            writeln!(out, "{pad}segment {i}: {:?}{redir}", seg.command).unwrap();
237            if !seg.words.is_empty() {
238                writeln!(out, "{pad}  words: {:?}", seg.words).unwrap();
239            }
240            for sub in &seg.substitutions {
241                writeln!(out, "{pad}  subst bytes {}..{}:", sub.start, sub.end).unwrap();
242                print_pipeline(out, &sub.pipeline, indent + 2);
243            }
244            if i < p.operators.len() {
245                writeln!(out, "{pad}operator: {}", p.operators[i]).unwrap();
246            }
247        }
248    }
249    print_pipeline(&mut out, &pipeline, 1);
250
251    // Section 3: redirection check (reuses the tree from section 1)
252    let redir = detect_redirections(root, source);
253    writeln!(out, "\n── output redirection ──").unwrap();
254    match redir {
255        Some(r) => writeln!(out, "  {r}").unwrap(),
256        None => writeln!(out, "  (none)").unwrap(),
257    }
258
259    Ok(out)
260}
261
262#[cfg(test)]
263#[path = "shell_inline_tests.rs"]
264mod shell_inline_tests;
265
266#[cfg(test)]
267#[path = "shell_tests.rs"]
268mod shell_tests;