aperion_shield/shims/
check_cmd.rs1use anyhow::{anyhow, Result};
37use serde_json::json;
38
39use crate::engine::Engine;
40use crate::{decide, Adjustments, BurstDetector, Decision, WorkspaceContext};
41
42#[derive(Debug, Clone)]
47pub struct CheckCmdReport {
48 pub command_line: String,
51 pub decision: Decision,
53 pub primary: Option<PrimaryFinding>,
55}
56
57#[derive(Debug, Clone)]
58pub struct PrimaryFinding {
59 pub rule_id: String,
60 pub severity: String,
61 pub reason: String,
62 pub safer_alternative: Option<String>,
63}
64
65impl CheckCmdReport {
66 pub fn exit_code(&self) -> u8 {
68 match &self.decision {
69 Decision::Allow => 0,
70 Decision::Warn { .. } => 0,
71 Decision::Block { .. } => 1,
72 Decision::Approval { .. } | Decision::IdentityVerification { .. } => 2,
73 }
74 }
75}
76
77pub fn run(engine: &Engine, argv: &[String]) -> Result<CheckCmdReport> {
88 if argv.is_empty() {
89 return Err(anyhow!(
90 "--check-cmd requires at least the command name after `--`"
91 ));
92 }
93
94 let cmd_line = reassemble_command_line(argv);
95
96 let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
102 let workspace = WorkspaceContext::probe_at(&engine.policy, &cwd);
103 let burst = BurstDetector::new(engine.policy.burst_detector.clone());
104 let adj = Adjustments {
105 workspace_is_prod: workspace.is_prod,
106 burst_in_progress: burst.in_burst(),
107 ..Default::default()
108 };
109
110 let canonical = json!({"name": "shell", "arguments": {"command": cmd_line}});
111 let eval = engine.evaluate("shell", &canonical, adj);
112 let decision = decide(&eval);
113
114 let primary = eval
115 .matches
116 .iter()
117 .max_by(|a, b| a.severity.cmp(&b.severity).then(a.points.cmp(&b.points)))
118 .map(|m| PrimaryFinding {
119 rule_id: m.rule_id.clone(),
120 severity: format!("{:?}", m.severity),
121 reason: m.reason.clone(),
122 safer_alternative: m.safer_alternative.clone(),
123 });
124
125 Ok(CheckCmdReport {
126 command_line: cmd_line,
127 decision,
128 primary,
129 })
130}
131
132fn reassemble_command_line(argv: &[String]) -> String {
142 let mut out = String::new();
143 for (i, arg) in argv.iter().enumerate() {
144 if i > 0 {
145 out.push(' ');
146 }
147 if needs_quoting(arg) {
148 out.push('\'');
149 out.push_str(&arg.replace('\'', "'\\''"));
150 out.push('\'');
151 } else {
152 out.push_str(arg);
153 }
154 }
155 out
156}
157
158fn needs_quoting(arg: &str) -> bool {
159 if arg.is_empty() {
160 return true;
161 }
162 arg.chars().any(|c| {
163 c.is_whitespace()
164 || c == '"'
165 || c == '\''
166 || c == '`'
167 || c == '$'
168 || c == '\\'
169 || c == ';'
170 || c == '|'
171 || c == '&'
172 || c == '<'
173 || c == '>'
174 || c == '('
175 || c == ')'
176 })
177}
178
179pub fn refusal_banner(report: &CheckCmdReport) -> String {
183 let mut out = String::new();
184 let decision_label = match &report.decision {
185 Decision::Block { .. } => "BLOCKED",
186 Decision::Approval { .. } => "APPROVAL-REQUIRED",
187 Decision::IdentityVerification { .. } => "IDENTITY-REQUIRED",
188 Decision::Warn { .. } => "WARN",
189 Decision::Allow => "ALLOW",
190 };
191
192 out.push_str(&format!("[aperion-shield/check-cmd] {} -- ", decision_label));
193 out.push_str(&format!("`{}`\n", short_command(&report.command_line)));
194
195 if let Some(p) = &report.primary {
196 out.push_str(&format!(
197 " rule : {} (severity={})\n",
198 p.rule_id, p.severity
199 ));
200 out.push_str(&format!(" reason : {}\n", p.reason));
201 if let Some(sa) = &p.safer_alternative {
202 out.push_str(&format!(" suggest : {}\n", sa));
203 }
204 }
205
206 match &report.decision {
207 Decision::Approval { .. } | Decision::IdentityVerification { .. } => {
208 out.push_str(
209 " note : approvals require an MCP-mediated invocation (this shim cannot prompt)\n",
210 );
211 }
212 _ => {}
213 }
214
215 out.push_str("\nbypass options for a single invocation:\n");
216 out.push_str(" SHIELD_SHIMS_DISABLE=1 <command> ... (env override, one-shot)\n");
217 out.push_str(" aperion-shield --uninstall-shims (remove all shims)\n");
218 out
219}
220
221fn short_command(cmd: &str) -> String {
224 const CAP: usize = 200;
225 if cmd.len() <= CAP {
226 cmd.to_string()
227 } else {
228 let mut s = cmd.chars().take(CAP).collect::<String>();
229 s.push_str(" …");
230 s
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn empty_argv_is_an_operational_error() {
240 let engine = Engine::builtin_default();
241 let err = run(&engine, &[]).expect_err("empty argv should error");
242 assert!(err.to_string().contains("at least the command name"));
243 }
244
245 #[test]
246 fn reassembly_preserves_simple_argv() {
247 let v = vec!["aws".to_string(), "s3".to_string(), "ls".to_string()];
248 assert_eq!(reassemble_command_line(&v), "aws s3 ls");
249 }
250
251 #[test]
252 fn reassembly_quotes_args_with_spaces_and_metacharacters() {
253 let v = vec![
254 "psql".to_string(),
255 "-c".to_string(),
256 "DROP TABLE users;".to_string(),
257 ];
258 let line = reassemble_command_line(&v);
259 assert!(line.contains("'DROP TABLE users;'"), "got: {}", line);
261 }
262
263 #[test]
264 fn reassembly_escapes_embedded_single_quotes() {
265 let v = vec!["sh".to_string(), "-c".to_string(), "echo 'hi'".to_string()];
266 let line = reassemble_command_line(&v);
267 assert!(line.contains("'echo '\\''hi'\\'''"), "got: {}", line);
268 }
269
270 #[test]
271 fn needs_quoting_picks_up_shell_metacharacters() {
272 assert!(needs_quoting("a b"));
273 assert!(needs_quoting("a;b"));
274 assert!(needs_quoting("a|b"));
275 assert!(needs_quoting("a>b"));
276 assert!(needs_quoting("`whoami`"));
277 assert!(!needs_quoting("aws"));
278 assert!(!needs_quoting("--recursive"));
279 assert!(!needs_quoting("s3://prod-bucket"));
280 }
281
282 #[test]
283 fn exit_code_for_allow_is_zero() {
284 let report = CheckCmdReport {
285 command_line: "aws s3 ls".into(),
286 decision: Decision::Allow,
287 primary: None,
288 };
289 assert_eq!(report.exit_code(), 0);
290 }
291
292 #[test]
293 fn run_with_innocuous_command_returns_allow() {
294 let engine = Engine::builtin_default();
295 let report = run(&engine, &["aws".to_string(), "s3".to_string(), "ls".to_string()])
296 .expect("run");
297 assert!(matches!(
298 report.decision,
299 Decision::Allow | Decision::Warn { .. }
300 ));
301 assert_eq!(report.exit_code(), 0);
302 }
303
304 #[test]
305 fn refusal_banner_includes_command_rule_and_bypass_note() {
306 use crate::Severity;
307 let report = CheckCmdReport {
308 command_line: "rm -rf /".into(),
309 decision: Decision::Block {
310 rule_id: "fs.rm_root".into(),
311 severity: Severity::Critical,
312 reason: "rm -rf / is non-recoverable".into(),
313 safer_alternative: Some("rm -rf <specific-path>".into()),
314 contributing_rules: vec!["fs.rm_root".into()],
315 },
316 primary: Some(PrimaryFinding {
317 rule_id: "fs.rm_root".into(),
318 severity: "Critical".into(),
319 reason: "rm -rf / is non-recoverable".into(),
320 safer_alternative: Some("rm -rf <specific-path>".into()),
321 }),
322 };
323 let banner = refusal_banner(&report);
324 assert!(banner.contains("BLOCKED"));
325 assert!(banner.contains("rm -rf /"));
326 assert!(banner.contains("fs.rm_root"));
327 assert!(banner.contains("SHIELD_SHIMS_DISABLE"));
328 assert!(banner.contains("aperion-shield --uninstall-shims"));
329 }
330
331 #[test]
332 fn short_command_truncates_long_lines() {
333 let long = "a".repeat(500);
334 let s = short_command(&long);
335 assert!(s.len() < long.len());
336 assert!(s.ends_with(" …"));
337 }
338}