cersei_tools/tool_primitives/
bash_safety.rs1use tree_sitter::{Parser, Tree};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub enum BashRiskLevel {
12 Safe,
14 Moderate,
16 High,
18 Forbidden,
20}
21
22#[derive(Debug, Clone)]
24pub struct BashAnalysis {
25 pub risk: BashRiskLevel,
26 pub reasons: Vec<String>,
27 pub read_paths: Vec<String>,
29 pub write_paths: Vec<String>,
31 pub commands: Vec<String>,
33}
34
35pub fn parse_bash(source: &str) -> Option<Tree> {
37 let mut parser = Parser::new();
38 let lang = tree_sitter_bash::LANGUAGE;
39 parser.set_language(&lang.into()).ok()?;
40 parser.parse(source, None)
41}
42
43pub fn analyze_command(source: &str) -> BashAnalysis {
45 let mut analysis = BashAnalysis {
46 risk: BashRiskLevel::Safe,
47 reasons: Vec::new(),
48 read_paths: Vec::new(),
49 write_paths: Vec::new(),
50 commands: Vec::new(),
51 };
52
53 let tree = match parse_bash(source) {
54 Some(t) => t,
55 None => {
56 analysis.risk = BashRiskLevel::High;
57 analysis.reasons.push("Failed to parse command".into());
58 return analysis;
59 }
60 };
61
62 let root = tree.root_node();
63 if root.has_error() {
64 analysis.risk = BashRiskLevel::Moderate;
65 analysis.reasons.push("Command has parse errors".into());
66 }
67
68 let mut cursor = root.walk();
70 let mut stack = vec![root];
71 let bytes = source.as_bytes();
72
73 while let Some(node) = stack.pop() {
74 let kind = node.kind();
75
76 match kind {
78 "command_substitution" => {
80 raise(
81 &mut analysis,
82 BashRiskLevel::Moderate,
83 "command substitution detected",
84 );
85 }
86 "process_substitution" => {
88 raise(
89 &mut analysis,
90 BashRiskLevel::Moderate,
91 "process substitution detected",
92 );
93 }
94 "file_redirect" | "heredoc_redirect" => {
96 raise(
97 &mut analysis,
98 BashRiskLevel::Moderate,
99 "file redirection detected",
100 );
101 if let Some(dest) = node.child_by_field_name("destination") {
103 if let Ok(path) = dest.utf8_text(bytes) {
104 analysis.write_paths.push(path.to_string());
105 }
106 }
107 }
108 "pipeline" => {
110 raise(&mut analysis, BashRiskLevel::Moderate, "pipeline detected");
111 }
112 "command" => {
114 if let Some(name_node) = node.child_by_field_name("name") {
115 if let Ok(cmd_name) = name_node.utf8_text(bytes) {
116 analysis.commands.push(cmd_name.to_string());
117 classify_command(cmd_name, &mut analysis, &node, bytes);
118 }
119 }
120 }
121 _ => {}
122 }
123
124 for i in 0..node.child_count() {
126 if let Some(child) = node.child(i) {
127 stack.push(child);
128 }
129 }
130 }
131
132 if analysis.commands.is_empty() && analysis.risk == BashRiskLevel::Safe {
134 analysis.risk = BashRiskLevel::Safe;
135 }
136
137 analysis
138}
139
140fn classify_command(
142 name: &str,
143 analysis: &mut BashAnalysis,
144 node: &tree_sitter::Node,
145 bytes: &[u8],
146) {
147 match name {
148 "sudo" | "doas" | "su" => {
150 raise(analysis, BashRiskLevel::Forbidden, "privilege escalation");
151 }
152
153 "rm" => {
155 let args = extract_arguments(node, bytes);
157 if args
158 .iter()
159 .any(|a| a.contains("rf") || a == "/" || a == "/*")
160 {
161 raise(
162 analysis,
163 BashRiskLevel::Forbidden,
164 "rm -rf or root deletion",
165 );
166 } else {
167 raise(analysis, BashRiskLevel::High, "file deletion (rm)");
168 }
169 for arg in &args {
170 if !arg.starts_with('-') {
171 analysis.write_paths.push(arg.clone());
172 }
173 }
174 }
175 "chmod" | "chown" | "chgrp" => {
176 raise(
177 analysis,
178 BashRiskLevel::High,
179 &format!("permission change ({name})"),
180 );
181 }
182 "kill" | "killall" | "pkill" => {
183 raise(analysis, BashRiskLevel::High, "process termination");
184 }
185 "dd" | "mkfs" | "fdisk" | "mount" | "umount" => {
186 raise(
187 analysis,
188 BashRiskLevel::Forbidden,
189 &format!("disk operation ({name})"),
190 );
191 }
192 "curl" | "wget" => {
193 raise(analysis, BashRiskLevel::High, "network download");
194 }
195 "ssh" | "scp" | "rsync" => {
196 raise(analysis, BashRiskLevel::High, "remote access");
197 }
198
199 "cp" | "mv" | "install" => {
201 raise(
202 analysis,
203 BashRiskLevel::Moderate,
204 &format!("file operation ({name})"),
205 );
206 for arg in extract_arguments(node, bytes) {
207 if !arg.starts_with('-') {
208 analysis.write_paths.push(arg);
209 }
210 }
211 }
212 "mkdir" | "rmdir" | "touch" => {
213 raise(
214 analysis,
215 BashRiskLevel::Moderate,
216 &format!("directory/file creation ({name})"),
217 );
218 }
219 "git" => {
220 let args = extract_arguments(node, bytes);
221 let subcommand = args.first().map(|s| s.as_str()).unwrap_or("");
222 match subcommand {
223 "push" | "reset" | "checkout" | "clean" | "rebase" => {
224 raise(analysis, BashRiskLevel::High, &format!("git {subcommand}"));
225 }
226 "status" | "log" | "diff" | "branch" | "show" | "blame" | "stash" => {
227 }
229 _ => {
230 raise(
231 analysis,
232 BashRiskLevel::Moderate,
233 &format!("git {subcommand}"),
234 );
235 }
236 }
237 }
238 "npm" | "yarn" | "pnpm" | "pip" | "cargo" => {
239 let args = extract_arguments(node, bytes);
240 let subcommand = args.first().map(|s| s.as_str()).unwrap_or("");
241 match subcommand {
242 "install" | "add" | "remove" | "uninstall" | "publish" => {
243 raise(
244 analysis,
245 BashRiskLevel::Moderate,
246 &format!("{name} {subcommand}"),
247 );
248 }
249 "run" | "exec" | "test" | "build" | "check" | "clippy" | "fmt" => {
250 raise(
251 analysis,
252 BashRiskLevel::Moderate,
253 &format!("{name} {subcommand}"),
254 );
255 }
256 _ => {}
257 }
258 }
259
260 "ls" | "cat" | "head" | "tail" | "less" | "more" | "wc" | "file" | "stat" | "find"
262 | "grep" | "rg" | "ag" | "fd" | "tree" | "du" | "df" | "echo" | "printf" | "date"
263 | "whoami" | "hostname" | "uname" | "env" | "printenv" | "which" | "type" | "command"
264 | "pwd" | "cd" | "pushd" | "popd" | "true" | "false" | "test" | "expr" | "seq" | "sort"
265 | "uniq" | "tr" | "cut" | "awk" | "sed" | "jq" | "yq" | "xargs" | "tee" => {
266 }
268
269 _ => {
271 raise(
272 analysis,
273 BashRiskLevel::Moderate,
274 &format!("unknown command: {name}"),
275 );
276 }
277 }
278}
279
280fn extract_arguments(node: &tree_sitter::Node, bytes: &[u8]) -> Vec<String> {
282 let mut args = Vec::new();
283 let mut cursor = node.walk();
284
285 for child in node.children(&mut cursor) {
286 match child.kind() {
287 "word" | "string" | "raw_string" | "number" | "concatenation" => {
288 if let Ok(text) = child.utf8_text(bytes) {
289 if child.start_byte() > node.child(0).map(|c| c.end_byte()).unwrap_or(0) {
291 args.push(text.trim_matches(|c| c == '"' || c == '\'').to_string());
292 }
293 }
294 }
295 _ => {}
296 }
297 }
298
299 args
300}
301
302fn raise(analysis: &mut BashAnalysis, level: BashRiskLevel, reason: &str) {
304 if level > analysis.risk {
305 analysis.risk = level;
306 }
307 analysis.reasons.push(reason.to_string());
308}
309
310pub fn is_safe(source: &str) -> bool {
312 analyze_command(source).risk <= BashRiskLevel::Safe
313}
314
315pub fn is_forbidden(source: &str) -> bool {
317 analyze_command(source).risk >= BashRiskLevel::Forbidden
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_safe_commands() {
326 assert!(is_safe("ls -la"));
327 assert!(is_safe("cat README.md"));
328 assert!(is_safe("grep -r 'TODO' src/"));
329 assert!(is_safe("pwd"));
330 assert!(is_safe("echo hello"));
331 }
332
333 #[test]
334 fn test_moderate_commands() {
335 let a = analyze_command("mkdir -p /tmp/test");
336 assert_eq!(a.risk, BashRiskLevel::Moderate);
337
338 let a = analyze_command("cargo build");
339 assert_eq!(a.risk, BashRiskLevel::Moderate);
340
341 let a = analyze_command("cp file1.txt file2.txt");
342 assert_eq!(a.risk, BashRiskLevel::Moderate);
343 }
344
345 #[test]
346 fn test_high_risk_commands() {
347 let a = analyze_command("rm important_file.txt");
348 assert_eq!(a.risk, BashRiskLevel::High);
349
350 let a = analyze_command("chmod 777 /tmp/file");
351 assert_eq!(a.risk, BashRiskLevel::High);
352
353 let a = analyze_command("curl https://example.com/script.sh");
354 assert_eq!(a.risk, BashRiskLevel::High);
355 }
356
357 #[test]
358 fn test_forbidden_commands() {
359 assert!(is_forbidden("sudo rm -rf /"));
360 assert!(is_forbidden("rm -rf /"));
361 assert!(is_forbidden("dd if=/dev/zero of=/dev/sda"));
362 }
363
364 #[test]
365 fn test_git_classification() {
366 let a = analyze_command("git status");
367 assert_eq!(a.risk, BashRiskLevel::Safe);
368
369 let a = analyze_command("git log --oneline");
370 assert_eq!(a.risk, BashRiskLevel::Safe);
371
372 let a = analyze_command("git push origin main");
373 assert_eq!(a.risk, BashRiskLevel::High);
374
375 let a = analyze_command("git add .");
376 assert_eq!(a.risk, BashRiskLevel::Moderate);
377 }
378
379 #[test]
380 fn test_pipeline_detection() {
381 let a = analyze_command("cat file | grep pattern");
382 assert!(a.risk >= BashRiskLevel::Moderate);
383 assert!(a.reasons.iter().any(|r| r.contains("pipeline")));
384 }
385
386 #[test]
387 fn test_command_extraction() {
388 let a = analyze_command("ls -la && echo done && cat file.txt");
389 assert!(a.commands.contains(&"ls".to_string()));
390 assert!(a.commands.contains(&"echo".to_string()));
391 assert!(a.commands.contains(&"cat".to_string()));
392 }
393
394 #[test]
395 fn test_parse_bash() {
396 let tree = parse_bash("echo hello world");
397 assert!(tree.is_some());
398 let tree = tree.unwrap();
399 assert!(!tree.root_node().has_error());
400 }
401}