use tree_sitter::{Parser, Tree};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum BashRiskLevel {
Safe,
Moderate,
High,
Forbidden,
}
#[derive(Debug, Clone)]
pub struct BashAnalysis {
pub risk: BashRiskLevel,
pub reasons: Vec<String>,
pub read_paths: Vec<String>,
pub write_paths: Vec<String>,
pub commands: Vec<String>,
}
pub fn parse_bash(source: &str) -> Option<Tree> {
let mut parser = Parser::new();
let lang = tree_sitter_bash::LANGUAGE;
parser.set_language(&lang.into()).ok()?;
parser.parse(source, None)
}
pub fn analyze_command(source: &str) -> BashAnalysis {
let mut analysis = BashAnalysis {
risk: BashRiskLevel::Safe,
reasons: Vec::new(),
read_paths: Vec::new(),
write_paths: Vec::new(),
commands: Vec::new(),
};
let tree = match parse_bash(source) {
Some(t) => t,
None => {
analysis.risk = BashRiskLevel::High;
analysis.reasons.push("Failed to parse command".into());
return analysis;
}
};
let root = tree.root_node();
if root.has_error() {
analysis.risk = BashRiskLevel::Moderate;
analysis.reasons.push("Command has parse errors".into());
}
let mut cursor = root.walk();
let mut stack = vec![root];
let bytes = source.as_bytes();
while let Some(node) = stack.pop() {
let kind = node.kind();
match kind {
"command_substitution" => {
raise(
&mut analysis,
BashRiskLevel::Moderate,
"command substitution detected",
);
}
"process_substitution" => {
raise(
&mut analysis,
BashRiskLevel::Moderate,
"process substitution detected",
);
}
"file_redirect" | "heredoc_redirect" => {
raise(
&mut analysis,
BashRiskLevel::Moderate,
"file redirection detected",
);
if let Some(dest) = node.child_by_field_name("destination") {
if let Ok(path) = dest.utf8_text(bytes) {
analysis.write_paths.push(path.to_string());
}
}
}
"pipeline" => {
raise(&mut analysis, BashRiskLevel::Moderate, "pipeline detected");
}
"command" => {
if let Some(name_node) = node.child_by_field_name("name") {
if let Ok(cmd_name) = name_node.utf8_text(bytes) {
analysis.commands.push(cmd_name.to_string());
classify_command(cmd_name, &mut analysis, &node, bytes);
}
}
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
stack.push(child);
}
}
}
if analysis.commands.is_empty() && analysis.risk == BashRiskLevel::Safe {
analysis.risk = BashRiskLevel::Safe;
}
analysis
}
fn classify_command(
name: &str,
analysis: &mut BashAnalysis,
node: &tree_sitter::Node,
bytes: &[u8],
) {
match name {
"sudo" | "doas" | "su" => {
raise(analysis, BashRiskLevel::Forbidden, "privilege escalation");
}
"rm" => {
let args = extract_arguments(node, bytes);
if args
.iter()
.any(|a| a.contains("rf") || a == "/" || a == "/*")
{
raise(
analysis,
BashRiskLevel::Forbidden,
"rm -rf or root deletion",
);
} else {
raise(analysis, BashRiskLevel::High, "file deletion (rm)");
}
for arg in &args {
if !arg.starts_with('-') {
analysis.write_paths.push(arg.clone());
}
}
}
"chmod" | "chown" | "chgrp" => {
raise(
analysis,
BashRiskLevel::High,
&format!("permission change ({name})"),
);
}
"kill" | "killall" | "pkill" => {
raise(analysis, BashRiskLevel::High, "process termination");
}
"dd" | "mkfs" | "fdisk" | "mount" | "umount" => {
raise(
analysis,
BashRiskLevel::Forbidden,
&format!("disk operation ({name})"),
);
}
"curl" | "wget" => {
raise(analysis, BashRiskLevel::High, "network download");
}
"ssh" | "scp" | "rsync" => {
raise(analysis, BashRiskLevel::High, "remote access");
}
"cp" | "mv" | "install" => {
raise(
analysis,
BashRiskLevel::Moderate,
&format!("file operation ({name})"),
);
for arg in extract_arguments(node, bytes) {
if !arg.starts_with('-') {
analysis.write_paths.push(arg);
}
}
}
"mkdir" | "rmdir" | "touch" => {
raise(
analysis,
BashRiskLevel::Moderate,
&format!("directory/file creation ({name})"),
);
}
"git" => {
let args = extract_arguments(node, bytes);
let subcommand = args.first().map(|s| s.as_str()).unwrap_or("");
match subcommand {
"push" | "reset" | "checkout" | "clean" | "rebase" => {
raise(analysis, BashRiskLevel::High, &format!("git {subcommand}"));
}
"status" | "log" | "diff" | "branch" | "show" | "blame" | "stash" => {
}
_ => {
raise(
analysis,
BashRiskLevel::Moderate,
&format!("git {subcommand}"),
);
}
}
}
"npm" | "yarn" | "pnpm" | "pip" | "cargo" => {
let args = extract_arguments(node, bytes);
let subcommand = args.first().map(|s| s.as_str()).unwrap_or("");
match subcommand {
"install" | "add" | "remove" | "uninstall" | "publish" => {
raise(
analysis,
BashRiskLevel::Moderate,
&format!("{name} {subcommand}"),
);
}
"run" | "exec" | "test" | "build" | "check" | "clippy" | "fmt" => {
raise(
analysis,
BashRiskLevel::Moderate,
&format!("{name} {subcommand}"),
);
}
_ => {}
}
}
"ls" | "cat" | "head" | "tail" | "less" | "more" | "wc" | "file" | "stat" | "find"
| "grep" | "rg" | "ag" | "fd" | "tree" | "du" | "df" | "echo" | "printf" | "date"
| "whoami" | "hostname" | "uname" | "env" | "printenv" | "which" | "type" | "command"
| "pwd" | "cd" | "pushd" | "popd" | "true" | "false" | "test" | "expr" | "seq" | "sort"
| "uniq" | "tr" | "cut" | "awk" | "sed" | "jq" | "yq" | "xargs" | "tee" => {
}
_ => {
raise(
analysis,
BashRiskLevel::Moderate,
&format!("unknown command: {name}"),
);
}
}
}
fn extract_arguments(node: &tree_sitter::Node, bytes: &[u8]) -> Vec<String> {
let mut args = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"word" | "string" | "raw_string" | "number" | "concatenation" => {
if let Ok(text) = child.utf8_text(bytes) {
if child.start_byte() > node.child(0).map(|c| c.end_byte()).unwrap_or(0) {
args.push(text.trim_matches(|c| c == '"' || c == '\'').to_string());
}
}
}
_ => {}
}
}
args
}
fn raise(analysis: &mut BashAnalysis, level: BashRiskLevel, reason: &str) {
if level > analysis.risk {
analysis.risk = level;
}
analysis.reasons.push(reason.to_string());
}
pub fn is_safe(source: &str) -> bool {
analyze_command(source).risk <= BashRiskLevel::Safe
}
pub fn is_forbidden(source: &str) -> bool {
analyze_command(source).risk >= BashRiskLevel::Forbidden
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_commands() {
assert!(is_safe("ls -la"));
assert!(is_safe("cat README.md"));
assert!(is_safe("grep -r 'TODO' src/"));
assert!(is_safe("pwd"));
assert!(is_safe("echo hello"));
}
#[test]
fn test_moderate_commands() {
let a = analyze_command("mkdir -p /tmp/test");
assert_eq!(a.risk, BashRiskLevel::Moderate);
let a = analyze_command("cargo build");
assert_eq!(a.risk, BashRiskLevel::Moderate);
let a = analyze_command("cp file1.txt file2.txt");
assert_eq!(a.risk, BashRiskLevel::Moderate);
}
#[test]
fn test_high_risk_commands() {
let a = analyze_command("rm important_file.txt");
assert_eq!(a.risk, BashRiskLevel::High);
let a = analyze_command("chmod 777 /tmp/file");
assert_eq!(a.risk, BashRiskLevel::High);
let a = analyze_command("curl https://example.com/script.sh");
assert_eq!(a.risk, BashRiskLevel::High);
}
#[test]
fn test_forbidden_commands() {
assert!(is_forbidden("sudo rm -rf /"));
assert!(is_forbidden("rm -rf /"));
assert!(is_forbidden("dd if=/dev/zero of=/dev/sda"));
}
#[test]
fn test_git_classification() {
let a = analyze_command("git status");
assert_eq!(a.risk, BashRiskLevel::Safe);
let a = analyze_command("git log --oneline");
assert_eq!(a.risk, BashRiskLevel::Safe);
let a = analyze_command("git push origin main");
assert_eq!(a.risk, BashRiskLevel::High);
let a = analyze_command("git add .");
assert_eq!(a.risk, BashRiskLevel::Moderate);
}
#[test]
fn test_pipeline_detection() {
let a = analyze_command("cat file | grep pattern");
assert!(a.risk >= BashRiskLevel::Moderate);
assert!(a.reasons.iter().any(|r| r.contains("pipeline")));
}
#[test]
fn test_command_extraction() {
let a = analyze_command("ls -la && echo done && cat file.txt");
assert!(a.commands.contains(&"ls".to_string()));
assert!(a.commands.contains(&"echo".to_string()));
assert!(a.commands.contains(&"cat".to_string()));
}
#[test]
fn test_parse_bash() {
let tree = parse_bash("echo hello world");
assert!(tree.is_some());
let tree = tree.unwrap();
assert!(!tree.root_node().has_error());
}
}