use anyhow::{anyhow, Result};
use serde_json::json;
use crate::engine::Engine;
use crate::{decide, Adjustments, BurstDetector, Decision, WorkspaceContext};
#[derive(Debug, Clone)]
pub struct CheckCmdReport {
pub command_line: String,
pub decision: Decision,
pub primary: Option<PrimaryFinding>,
}
#[derive(Debug, Clone)]
pub struct PrimaryFinding {
pub rule_id: String,
pub severity: String,
pub reason: String,
pub safer_alternative: Option<String>,
}
impl CheckCmdReport {
pub fn exit_code(&self) -> u8 {
match &self.decision {
Decision::Allow => 0,
Decision::Warn { .. } => 0,
Decision::Block { .. } => 1,
Decision::Approval { .. } | Decision::IdentityVerification { .. } => 2,
}
}
}
pub fn run(engine: &Engine, argv: &[String]) -> Result<CheckCmdReport> {
if argv.is_empty() {
return Err(anyhow!(
"--check-cmd requires at least the command name after `--`"
));
}
let cmd_line = reassemble_command_line(argv);
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
let workspace = WorkspaceContext::probe_at(&engine.policy, &cwd);
let burst = BurstDetector::new(engine.policy.burst_detector.clone());
let adj = Adjustments {
workspace_is_prod: workspace.is_prod,
burst_in_progress: burst.in_burst(),
..Default::default()
};
let canonical = json!({"name": "shell", "arguments": {"command": cmd_line}});
let eval = engine.evaluate("shell", &canonical, adj);
let decision = decide(&eval);
let primary = eval
.matches
.iter()
.max_by(|a, b| a.severity.cmp(&b.severity).then(a.points.cmp(&b.points)))
.map(|m| PrimaryFinding {
rule_id: m.rule_id.clone(),
severity: format!("{:?}", m.severity),
reason: m.reason.clone(),
safer_alternative: m.safer_alternative.clone(),
});
Ok(CheckCmdReport {
command_line: cmd_line,
decision,
primary,
})
}
fn reassemble_command_line(argv: &[String]) -> String {
let mut out = String::new();
for (i, arg) in argv.iter().enumerate() {
if i > 0 {
out.push(' ');
}
if needs_quoting(arg) {
out.push('\'');
out.push_str(&arg.replace('\'', "'\\''"));
out.push('\'');
} else {
out.push_str(arg);
}
}
out
}
fn needs_quoting(arg: &str) -> bool {
if arg.is_empty() {
return true;
}
arg.chars().any(|c| {
c.is_whitespace()
|| c == '"'
|| c == '\''
|| c == '`'
|| c == '$'
|| c == '\\'
|| c == ';'
|| c == '|'
|| c == '&'
|| c == '<'
|| c == '>'
|| c == '('
|| c == ')'
})
}
pub fn refusal_banner(report: &CheckCmdReport) -> String {
let mut out = String::new();
let decision_label = match &report.decision {
Decision::Block { .. } => "BLOCKED",
Decision::Approval { .. } => "APPROVAL-REQUIRED",
Decision::IdentityVerification { .. } => "IDENTITY-REQUIRED",
Decision::Warn { .. } => "WARN",
Decision::Allow => "ALLOW",
};
out.push_str(&format!("[aperion-shield/check-cmd] {} -- ", decision_label));
out.push_str(&format!("`{}`\n", short_command(&report.command_line)));
if let Some(p) = &report.primary {
out.push_str(&format!(
" rule : {} (severity={})\n",
p.rule_id, p.severity
));
out.push_str(&format!(" reason : {}\n", p.reason));
if let Some(sa) = &p.safer_alternative {
out.push_str(&format!(" suggest : {}\n", sa));
}
}
match &report.decision {
Decision::Approval { .. } | Decision::IdentityVerification { .. } => {
out.push_str(
" note : approvals require an MCP-mediated invocation (this shim cannot prompt)\n",
);
}
_ => {}
}
out.push_str("\nbypass options for a single invocation:\n");
out.push_str(" SHIELD_SHIMS_DISABLE=1 <command> ... (env override, one-shot)\n");
out.push_str(" aperion-shield --uninstall-shims (remove all shims)\n");
out
}
fn short_command(cmd: &str) -> String {
const CAP: usize = 200;
if cmd.len() <= CAP {
cmd.to_string()
} else {
let mut s = cmd.chars().take(CAP).collect::<String>();
s.push_str(" …");
s
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_argv_is_an_operational_error() {
let engine = Engine::builtin_default();
let err = run(&engine, &[]).expect_err("empty argv should error");
assert!(err.to_string().contains("at least the command name"));
}
#[test]
fn reassembly_preserves_simple_argv() {
let v = vec!["aws".to_string(), "s3".to_string(), "ls".to_string()];
assert_eq!(reassemble_command_line(&v), "aws s3 ls");
}
#[test]
fn reassembly_quotes_args_with_spaces_and_metacharacters() {
let v = vec![
"psql".to_string(),
"-c".to_string(),
"DROP TABLE users;".to_string(),
];
let line = reassemble_command_line(&v);
assert!(line.contains("'DROP TABLE users;'"), "got: {}", line);
}
#[test]
fn reassembly_escapes_embedded_single_quotes() {
let v = vec!["sh".to_string(), "-c".to_string(), "echo 'hi'".to_string()];
let line = reassemble_command_line(&v);
assert!(line.contains("'echo '\\''hi'\\'''"), "got: {}", line);
}
#[test]
fn needs_quoting_picks_up_shell_metacharacters() {
assert!(needs_quoting("a b"));
assert!(needs_quoting("a;b"));
assert!(needs_quoting("a|b"));
assert!(needs_quoting("a>b"));
assert!(needs_quoting("`whoami`"));
assert!(!needs_quoting("aws"));
assert!(!needs_quoting("--recursive"));
assert!(!needs_quoting("s3://prod-bucket"));
}
#[test]
fn exit_code_for_allow_is_zero() {
let report = CheckCmdReport {
command_line: "aws s3 ls".into(),
decision: Decision::Allow,
primary: None,
};
assert_eq!(report.exit_code(), 0);
}
#[test]
fn run_with_innocuous_command_returns_allow() {
let engine = Engine::builtin_default();
let report = run(&engine, &["aws".to_string(), "s3".to_string(), "ls".to_string()])
.expect("run");
assert!(matches!(
report.decision,
Decision::Allow | Decision::Warn { .. }
));
assert_eq!(report.exit_code(), 0);
}
#[test]
fn refusal_banner_includes_command_rule_and_bypass_note() {
use crate::Severity;
let report = CheckCmdReport {
command_line: "rm -rf /".into(),
decision: Decision::Block {
rule_id: "fs.rm_root".into(),
severity: Severity::Critical,
reason: "rm -rf / is non-recoverable".into(),
safer_alternative: Some("rm -rf <specific-path>".into()),
contributing_rules: vec!["fs.rm_root".into()],
},
primary: Some(PrimaryFinding {
rule_id: "fs.rm_root".into(),
severity: "Critical".into(),
reason: "rm -rf / is non-recoverable".into(),
safer_alternative: Some("rm -rf <specific-path>".into()),
}),
};
let banner = refusal_banner(&report);
assert!(banner.contains("BLOCKED"));
assert!(banner.contains("rm -rf /"));
assert!(banner.contains("fs.rm_root"));
assert!(banner.contains("SHIELD_SHIMS_DISABLE"));
assert!(banner.contains("aperion-shield --uninstall-shims"));
}
#[test]
fn short_command_truncates_long_lines() {
let long = "a".repeat(500);
let s = short_command(&long);
assert!(s.len() < long.len());
assert!(s.ends_with(" …"));
}
}