Skip to main content

cersei_tools/tool_primitives/
bash_safety.rs

1//! Tree-sitter based bash command safety analysis.
2//!
3//! Parses bash commands into ASTs and validates them for safety before execution.
4//! Detects dangerous constructs like command substitution, process substitution,
5//! variable expansion in dangerous contexts, and destructive operations.
6
7use tree_sitter::{Parser, Tree};
8
9/// Risk level of a bash command.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub enum BashRiskLevel {
12    /// Safe: read-only commands, navigation, inspection.
13    Safe,
14    /// Moderate: writes files, runs builds, modifies state.
15    Moderate,
16    /// High: destructive operations, network access, privilege escalation.
17    High,
18    /// Forbidden: never auto-approve (rm -rf /, sudo, etc.).
19    Forbidden,
20}
21
22/// Result of analyzing a bash command.
23#[derive(Debug, Clone)]
24pub struct BashAnalysis {
25    pub risk: BashRiskLevel,
26    pub reasons: Vec<String>,
27    /// File paths that the command reads from.
28    pub read_paths: Vec<String>,
29    /// File paths that the command writes to.
30    pub write_paths: Vec<String>,
31    /// Commands detected in the input.
32    pub commands: Vec<String>,
33}
34
35/// Parse a bash command string into a tree-sitter AST.
36pub fn parse_bash(source: &str) -> Option<Tree> {
37    let mut parser = Parser::new();
38    let lang = tree_sitter_bash::LANGUAGE;
39    parser.set_language(&lang.into()).ok()?;
40    parser.parse(source, None)
41}
42
43/// Analyze a bash command for safety.
44pub fn analyze_command(source: &str) -> BashAnalysis {
45    let mut analysis = BashAnalysis {
46        risk: BashRiskLevel::Safe,
47        reasons: Vec::new(),
48        read_paths: Vec::new(),
49        write_paths: Vec::new(),
50        commands: Vec::new(),
51    };
52
53    let tree = match parse_bash(source) {
54        Some(t) => t,
55        None => {
56            analysis.risk = BashRiskLevel::High;
57            analysis.reasons.push("Failed to parse command".into());
58            return analysis;
59        }
60    };
61
62    let root = tree.root_node();
63    if root.has_error() {
64        analysis.risk = BashRiskLevel::Moderate;
65        analysis.reasons.push("Command has parse errors".into());
66    }
67
68    // Walk the AST
69    let mut cursor = root.walk();
70    let mut stack = vec![root];
71    let bytes = source.as_bytes();
72
73    while let Some(node) = stack.pop() {
74        let kind = node.kind();
75
76        // Check for dangerous constructs
77        match kind {
78            // Command substitution: $(cmd) or `cmd`
79            "command_substitution" => {
80                raise(
81                    &mut analysis,
82                    BashRiskLevel::Moderate,
83                    "command substitution detected",
84                );
85            }
86            // Process substitution: <(cmd) or >(cmd)
87            "process_substitution" => {
88                raise(
89                    &mut analysis,
90                    BashRiskLevel::Moderate,
91                    "process substitution detected",
92                );
93            }
94            // Redirections: could overwrite files
95            "file_redirect" | "heredoc_redirect" => {
96                raise(
97                    &mut analysis,
98                    BashRiskLevel::Moderate,
99                    "file redirection detected",
100                );
101                // Try to extract target path
102                if let Some(dest) = node.child_by_field_name("destination") {
103                    if let Ok(path) = dest.utf8_text(bytes) {
104                        analysis.write_paths.push(path.to_string());
105                    }
106                }
107            }
108            // Pipeline: moderate risk (data flows between processes)
109            "pipeline" => {
110                raise(&mut analysis, BashRiskLevel::Moderate, "pipeline detected");
111            }
112            // Extract command names
113            "command" => {
114                if let Some(name_node) = node.child_by_field_name("name") {
115                    if let Ok(cmd_name) = name_node.utf8_text(bytes) {
116                        analysis.commands.push(cmd_name.to_string());
117                        classify_command(cmd_name, &mut analysis, &node, bytes);
118                    }
119                }
120            }
121            _ => {}
122        }
123
124        // Push children for traversal
125        for i in 0..node.child_count() {
126            if let Some(child) = node.child(i) {
127                stack.push(child);
128            }
129        }
130    }
131
132    // If no commands detected, it's likely safe (empty or just comments)
133    if analysis.commands.is_empty() && analysis.risk == BashRiskLevel::Safe {
134        analysis.risk = BashRiskLevel::Safe;
135    }
136
137    analysis
138}
139
140/// Classify a specific command by name.
141fn classify_command(
142    name: &str,
143    analysis: &mut BashAnalysis,
144    node: &tree_sitter::Node,
145    bytes: &[u8],
146) {
147    match name {
148        // ── Forbidden ──
149        "sudo" | "doas" | "su" => {
150            raise(analysis, BashRiskLevel::Forbidden, "privilege escalation");
151        }
152
153        // ── High risk: destructive ──
154        "rm" => {
155            // Check for -rf or dangerous flags
156            let args = extract_arguments(node, bytes);
157            if args
158                .iter()
159                .any(|a| a.contains("rf") || a == "/" || a == "/*")
160            {
161                raise(
162                    analysis,
163                    BashRiskLevel::Forbidden,
164                    "rm -rf or root deletion",
165                );
166            } else {
167                raise(analysis, BashRiskLevel::High, "file deletion (rm)");
168            }
169            for arg in &args {
170                if !arg.starts_with('-') {
171                    analysis.write_paths.push(arg.clone());
172                }
173            }
174        }
175        "chmod" | "chown" | "chgrp" => {
176            raise(
177                analysis,
178                BashRiskLevel::High,
179                &format!("permission change ({name})"),
180            );
181        }
182        "kill" | "killall" | "pkill" => {
183            raise(analysis, BashRiskLevel::High, "process termination");
184        }
185        "dd" | "mkfs" | "fdisk" | "mount" | "umount" => {
186            raise(
187                analysis,
188                BashRiskLevel::Forbidden,
189                &format!("disk operation ({name})"),
190            );
191        }
192        "curl" | "wget" => {
193            raise(analysis, BashRiskLevel::High, "network download");
194        }
195        "ssh" | "scp" | "rsync" => {
196            raise(analysis, BashRiskLevel::High, "remote access");
197        }
198
199        // ── Moderate risk: writes ──
200        "cp" | "mv" | "install" => {
201            raise(
202                analysis,
203                BashRiskLevel::Moderate,
204                &format!("file operation ({name})"),
205            );
206            for arg in extract_arguments(node, bytes) {
207                if !arg.starts_with('-') {
208                    analysis.write_paths.push(arg);
209                }
210            }
211        }
212        "mkdir" | "rmdir" | "touch" => {
213            raise(
214                analysis,
215                BashRiskLevel::Moderate,
216                &format!("directory/file creation ({name})"),
217            );
218        }
219        "git" => {
220            let args = extract_arguments(node, bytes);
221            let subcommand = args.first().map(|s| s.as_str()).unwrap_or("");
222            match subcommand {
223                "push" | "reset" | "checkout" | "clean" | "rebase" => {
224                    raise(analysis, BashRiskLevel::High, &format!("git {subcommand}"));
225                }
226                "status" | "log" | "diff" | "branch" | "show" | "blame" | "stash" => {
227                    // Read-only git commands are safe
228                }
229                _ => {
230                    raise(
231                        analysis,
232                        BashRiskLevel::Moderate,
233                        &format!("git {subcommand}"),
234                    );
235                }
236            }
237        }
238        "npm" | "yarn" | "pnpm" | "pip" | "cargo" => {
239            let args = extract_arguments(node, bytes);
240            let subcommand = args.first().map(|s| s.as_str()).unwrap_or("");
241            match subcommand {
242                "install" | "add" | "remove" | "uninstall" | "publish" => {
243                    raise(
244                        analysis,
245                        BashRiskLevel::Moderate,
246                        &format!("{name} {subcommand}"),
247                    );
248                }
249                "run" | "exec" | "test" | "build" | "check" | "clippy" | "fmt" => {
250                    raise(
251                        analysis,
252                        BashRiskLevel::Moderate,
253                        &format!("{name} {subcommand}"),
254                    );
255                }
256                _ => {}
257            }
258        }
259
260        // ── Safe: read-only ──
261        "ls" | "cat" | "head" | "tail" | "less" | "more" | "wc" | "file" | "stat" | "find"
262        | "grep" | "rg" | "ag" | "fd" | "tree" | "du" | "df" | "echo" | "printf" | "date"
263        | "whoami" | "hostname" | "uname" | "env" | "printenv" | "which" | "type" | "command"
264        | "pwd" | "cd" | "pushd" | "popd" | "true" | "false" | "test" | "expr" | "seq" | "sort"
265        | "uniq" | "tr" | "cut" | "awk" | "sed" | "jq" | "yq" | "xargs" | "tee" => {
266            // These are generally safe (read-only or formatting)
267        }
268
269        // ── Unknown: moderate by default ──
270        _ => {
271            raise(
272                analysis,
273                BashRiskLevel::Moderate,
274                &format!("unknown command: {name}"),
275            );
276        }
277    }
278}
279
280/// Extract command arguments from AST.
281fn extract_arguments(node: &tree_sitter::Node, bytes: &[u8]) -> Vec<String> {
282    let mut args = Vec::new();
283    let mut cursor = node.walk();
284
285    for child in node.children(&mut cursor) {
286        match child.kind() {
287            "word" | "string" | "raw_string" | "number" | "concatenation" => {
288                if let Ok(text) = child.utf8_text(bytes) {
289                    // Skip the command name (first word)
290                    if child.start_byte() > node.child(0).map(|c| c.end_byte()).unwrap_or(0) {
291                        args.push(text.trim_matches(|c| c == '"' || c == '\'').to_string());
292                    }
293                }
294            }
295            _ => {}
296        }
297    }
298
299    args
300}
301
302/// Raise the risk level if the new level is higher.
303fn raise(analysis: &mut BashAnalysis, level: BashRiskLevel, reason: &str) {
304    if level > analysis.risk {
305        analysis.risk = level;
306    }
307    analysis.reasons.push(reason.to_string());
308}
309
310/// Quick check: is a command safe to auto-approve?
311pub fn is_safe(source: &str) -> bool {
312    analyze_command(source).risk <= BashRiskLevel::Safe
313}
314
315/// Quick check: should a command be blocked?
316pub fn is_forbidden(source: &str) -> bool {
317    analyze_command(source).risk >= BashRiskLevel::Forbidden
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_safe_commands() {
326        assert!(is_safe("ls -la"));
327        assert!(is_safe("cat README.md"));
328        assert!(is_safe("grep -r 'TODO' src/"));
329        assert!(is_safe("pwd"));
330        assert!(is_safe("echo hello"));
331    }
332
333    #[test]
334    fn test_moderate_commands() {
335        let a = analyze_command("mkdir -p /tmp/test");
336        assert_eq!(a.risk, BashRiskLevel::Moderate);
337
338        let a = analyze_command("cargo build");
339        assert_eq!(a.risk, BashRiskLevel::Moderate);
340
341        let a = analyze_command("cp file1.txt file2.txt");
342        assert_eq!(a.risk, BashRiskLevel::Moderate);
343    }
344
345    #[test]
346    fn test_high_risk_commands() {
347        let a = analyze_command("rm important_file.txt");
348        assert_eq!(a.risk, BashRiskLevel::High);
349
350        let a = analyze_command("chmod 777 /tmp/file");
351        assert_eq!(a.risk, BashRiskLevel::High);
352
353        let a = analyze_command("curl https://example.com/script.sh");
354        assert_eq!(a.risk, BashRiskLevel::High);
355    }
356
357    #[test]
358    fn test_forbidden_commands() {
359        assert!(is_forbidden("sudo rm -rf /"));
360        assert!(is_forbidden("rm -rf /"));
361        assert!(is_forbidden("dd if=/dev/zero of=/dev/sda"));
362    }
363
364    #[test]
365    fn test_git_classification() {
366        let a = analyze_command("git status");
367        assert_eq!(a.risk, BashRiskLevel::Safe);
368
369        let a = analyze_command("git log --oneline");
370        assert_eq!(a.risk, BashRiskLevel::Safe);
371
372        let a = analyze_command("git push origin main");
373        assert_eq!(a.risk, BashRiskLevel::High);
374
375        let a = analyze_command("git add .");
376        assert_eq!(a.risk, BashRiskLevel::Moderate);
377    }
378
379    #[test]
380    fn test_pipeline_detection() {
381        let a = analyze_command("cat file | grep pattern");
382        assert!(a.risk >= BashRiskLevel::Moderate);
383        assert!(a.reasons.iter().any(|r| r.contains("pipeline")));
384    }
385
386    #[test]
387    fn test_command_extraction() {
388        let a = analyze_command("ls -la && echo done && cat file.txt");
389        assert!(a.commands.contains(&"ls".to_string()));
390        assert!(a.commands.contains(&"echo".to_string()));
391        assert!(a.commands.contains(&"cat".to_string()));
392    }
393
394    #[test]
395    fn test_parse_bash() {
396        let tree = parse_bash("echo hello world");
397        assert!(tree.is_some());
398        let tree = tree.unwrap();
399        assert!(!tree.root_node().has_error());
400    }
401}