Skip to main content

codewhale_execpolicy/
bash_arity.rs

1//! Bash arity dictionary for command-prefix allow rule matching.
2//!
3//! [`BashArityDict`] maps a command prefix (space-separated, lowercase) to the
4//! number of positional (non-flag) words, *including the base command word*,
5//! that form the canonical prefix.
6//!
7//! ## Invariant
8//!
9//! Flags (tokens starting with `-`) are **never** counted toward arity.
10//! `auto_allow = ["git status"]` must match `git status -s` and
11//! `git status --porcelain`, but **not** `git push`.
12//!
13//! ## Coverage
14//!
15//! 30+ common tools are covered across: git, npm, yarn, pnpm, cargo, docker,
16//! kubectl, go, python/pip, gh, rustup, deno, bun, aws, terraform, make,
17//! and more.
18
19/// Static arity table: `(prefix, arity)`.
20///
21/// Arity is the total number of *positional* tokens (including the base
22/// command) that form the canonical prefix.  For example:
23///
24/// * `("git status", 2)` — 2 positional tokens: `git` + `status`.
25/// * `("npm run", 3)` — 3 positional tokens: `npm` + `run` + `<script>`.
26/// * `("make", 1)` — only the base command, no sub-command.
27pub static BASH_ARITY_TABLE: &[(&str, u8)] = &[
28    // ── git ──────────────────────────────────────────────────────────────────
29    ("git add", 2),
30    ("git am", 2),
31    ("git apply", 2),
32    ("git bisect", 2),
33    ("git blame", 2),
34    ("git branch", 2),
35    ("git cat-file", 2),
36    ("git checkout", 2),
37    ("git cherry-pick", 2),
38    ("git clean", 2),
39    ("git clone", 2),
40    ("git commit", 2),
41    ("git config", 2),
42    ("git describe", 2),
43    ("git diff", 2),
44    ("git fetch", 2),
45    ("git format-patch", 2),
46    ("git grep", 2),
47    ("git init", 2),
48    ("git log", 2),
49    ("git ls-files", 2),
50    ("git merge", 2),
51    ("git mv", 2),
52    ("git notes", 2),
53    ("git pull", 2),
54    ("git push", 2),
55    ("git rebase", 2),
56    ("git reflog", 2),
57    ("git remote", 2),
58    ("git reset", 2),
59    ("git restore", 2),
60    ("git revert", 2),
61    ("git rm", 2),
62    ("git show", 2),
63    ("git stash", 2),
64    ("git status", 2),
65    ("git submodule", 2),
66    ("git switch", 2),
67    ("git tag", 2),
68    ("git worktree", 2),
69    // ── npm ──────────────────────────────────────────────────────────────────
70    ("npm audit", 2),
71    ("npm build", 2),
72    ("npm cache", 2),
73    ("npm ci", 2),
74    ("npm dedupe", 2),
75    ("npm fund", 2),
76    ("npm help", 2),
77    ("npm info", 2),
78    ("npm init", 2),
79    ("npm install", 2),
80    ("npm link", 2),
81    ("npm list", 2),
82    ("npm ls", 2),
83    ("npm outdated", 2),
84    ("npm pack", 2),
85    ("npm prune", 2),
86    ("npm publish", 2),
87    ("npm rebuild", 2),
88    ("npm run", 3),
89    ("npm start", 2),
90    ("npm stop", 2),
91    ("npm test", 2),
92    ("npm uninstall", 2),
93    ("npm update", 2),
94    ("npm version", 2),
95    ("npm view", 2),
96    // ── yarn ─────────────────────────────────────────────────────────────────
97    ("yarn add", 2),
98    ("yarn audit", 2),
99    ("yarn build", 2),
100    ("yarn install", 2),
101    ("yarn run", 3),
102    ("yarn start", 2),
103    ("yarn test", 2),
104    ("yarn upgrade", 2),
105    ("yarn workspace", 3),
106    // ── pnpm ─────────────────────────────────────────────────────────────────
107    ("pnpm add", 2),
108    ("pnpm build", 2),
109    ("pnpm install", 2),
110    ("pnpm run", 3),
111    ("pnpm start", 2),
112    ("pnpm test", 2),
113    ("pnpm update", 2),
114    // ── cargo ────────────────────────────────────────────────────────────────
115    ("cargo add", 2),
116    ("cargo bench", 2),
117    ("cargo build", 2),
118    ("cargo check", 2),
119    ("cargo clean", 2),
120    ("cargo clippy", 2),
121    ("cargo doc", 2),
122    ("cargo fix", 2),
123    ("cargo fmt", 2),
124    ("cargo generate", 2),
125    ("cargo install", 2),
126    ("cargo metadata", 2),
127    ("cargo package", 2),
128    ("cargo publish", 2),
129    ("cargo remove", 2),
130    ("cargo run", 2),
131    ("cargo search", 2),
132    ("cargo test", 2),
133    ("cargo tree", 2),
134    ("cargo uninstall", 2),
135    ("cargo update", 2),
136    ("cargo yank", 2),
137    // ── docker ───────────────────────────────────────────────────────────────
138    ("docker build", 2),
139    ("docker compose", 3),
140    ("docker container", 3),
141    ("docker cp", 2),
142    ("docker exec", 2),
143    ("docker image", 3),
144    ("docker images", 2),
145    ("docker inspect", 2),
146    ("docker kill", 2),
147    ("docker logs", 2),
148    ("docker network", 3),
149    ("docker ps", 2),
150    ("docker pull", 2),
151    ("docker push", 2),
152    ("docker rm", 2),
153    ("docker rmi", 2),
154    ("docker run", 2),
155    ("docker start", 2),
156    ("docker stop", 2),
157    ("docker system", 3),
158    ("docker tag", 2),
159    ("docker volume", 3),
160    // ── kubectl ──────────────────────────────────────────────────────────────
161    ("kubectl apply", 2),
162    ("kubectl create", 3),
163    ("kubectl delete", 3),
164    ("kubectl describe", 3),
165    ("kubectl exec", 2),
166    ("kubectl explain", 2),
167    ("kubectl get", 3),
168    ("kubectl label", 2),
169    ("kubectl logs", 2),
170    ("kubectl patch", 2),
171    ("kubectl port-forward", 2),
172    ("kubectl rollout", 3),
173    ("kubectl scale", 2),
174    ("kubectl set", 2),
175    ("kubectl top", 3),
176    // ── go ───────────────────────────────────────────────────────────────────
177    ("go build", 2),
178    ("go clean", 2),
179    ("go env", 2),
180    ("go fmt", 2),
181    ("go generate", 2),
182    ("go get", 2),
183    ("go install", 2),
184    ("go list", 2),
185    ("go mod", 3),
186    ("go run", 2),
187    ("go test", 2),
188    ("go vet", 2),
189    ("go work", 3),
190    // ── python / pip ─────────────────────────────────────────────────────────
191    ("pip install", 2),
192    ("pip uninstall", 2),
193    ("pip list", 2),
194    ("pip show", 2),
195    ("pip freeze", 2),
196    ("pip3 install", 2),
197    ("pip3 uninstall", 2),
198    ("pip3 list", 2),
199    ("pip3 show", 2),
200    ("python -m", 3),
201    ("python3 -m", 3),
202    // ── make / cmake ─────────────────────────────────────────────────────────
203    ("make", 1),
204    ("cmake", 1),
205    // ── gh (GitHub CLI) ──────────────────────────────────────────────────────
206    ("gh pr", 3),
207    ("gh issue", 3),
208    ("gh repo", 3),
209    ("gh release", 3),
210    ("gh workflow", 3),
211    ("gh run", 3),
212    ("gh secret", 3),
213    // ── rustup ───────────────────────────────────────────────────────────────
214    ("rustup default", 2),
215    ("rustup install", 2),
216    ("rustup show", 2),
217    ("rustup target", 3),
218    ("rustup toolchain", 3),
219    ("rustup update", 2),
220    // ── deno / bun ───────────────────────────────────────────────────────────
221    ("deno run", 2),
222    ("deno test", 2),
223    ("deno fmt", 2),
224    ("deno lint", 2),
225    ("bun add", 2),
226    ("bun build", 2),
227    ("bun install", 2),
228    ("bun run", 3),
229    ("bun test", 2),
230    ("npx", 2),
231    // ── aws CLI ──────────────────────────────────────────────────────────────
232    ("aws s3", 3),
233    ("aws ec2", 3),
234    ("aws iam", 3),
235    ("aws lambda", 3),
236    ("aws cloudformation", 3),
237    ("aws ecs", 3),
238    ("aws eks", 3),
239    ("aws rds", 3),
240    ("aws sts", 3),
241    ("aws configure", 2),
242    // ── terraform ────────────────────────────────────────────────────────────
243    ("terraform init", 2),
244    ("terraform plan", 2),
245    ("terraform apply", 2),
246    ("terraform destroy", 2),
247    ("terraform validate", 2),
248    ("terraform output", 2),
249    ("terraform state", 3),
250    ("terraform workspace", 3),
251    // ── helm ─────────────────────────────────────────────────────────────────
252    ("helm install", 2),
253    ("helm upgrade", 2),
254    ("helm uninstall", 2),
255    ("helm list", 2),
256    ("helm repo", 3),
257    ("helm status", 2),
258    ("helm template", 2),
259];
260
261/// Arity dictionary for bash command-prefix allow rules.
262///
263/// Provides arity-aware prefix extraction so that `auto_allow = ["git status"]`
264/// correctly matches `git status -s` and `git status --porcelain` without
265/// also matching `git push`.
266///
267/// # Example
268///
269/// ```rust
270/// use codewhale_execpolicy::bash_arity::BashArityDict;
271///
272/// let dict = BashArityDict::new();
273/// assert_eq!(dict.classify(&["git", "status", "-s"]),   "git status");
274/// assert_eq!(dict.classify(&["git", "push", "origin"]), "git push");
275/// assert_eq!(dict.classify(&["npm", "run", "dev"]),     "npm run dev");
276/// assert_eq!(dict.classify(&["ls", "-la"]),             "ls");
277/// ```
278#[derive(Debug, Clone)]
279pub struct BashArityDict {
280    /// Internal table sorted longest-prefix-first for greedy matching.
281    entries: Vec<(&'static str, u8)>,
282}
283
284impl BashArityDict {
285    /// Construct a new dictionary pre-loaded with [`BASH_ARITY_TABLE`].
286    #[must_use]
287    pub fn new() -> Self {
288        let mut entries: Vec<(&'static str, u8)> = BASH_ARITY_TABLE.to_vec();
289        // Longest prefix first so greedy matching works correctly.
290        entries.sort_by_key(|entry| std::cmp::Reverse(entry.0.len()));
291        Self { entries }
292    }
293
294    /// Return the canonical command prefix for a slice of command tokens.
295    ///
296    /// # Algorithm
297    ///
298    /// 1. Strip all flag tokens (tokens that start with `-`).
299    /// 2. Build candidates of depth 1..=3 from positional tokens (longest first).
300    /// 3. If a candidate matches a dictionary entry, return `arity` positional
301    ///    tokens joined with spaces.
302    /// 4. If no dictionary entry matches, return the single base command name.
303    #[must_use]
304    pub fn classify(&self, tokens: &[&str]) -> String {
305        if tokens.is_empty() {
306            return String::new();
307        }
308
309        // Collect positional (non-flag) tokens, lowercased.
310        let positional: Vec<String> = tokens
311            .iter()
312            .filter(|t| !t.starts_with('-'))
313            .map(|t| t.to_ascii_lowercase())
314            .collect();
315
316        if positional.is_empty() {
317            return String::new();
318        }
319
320        // Try candidates from longest to shortest (max depth 3).
321        let max_depth = positional.len().min(3);
322        for depth in (1..=max_depth).rev() {
323            let candidate = positional[..depth].join(" ");
324            if let Some(&(_key, arity)) = self
325                .entries
326                .iter()
327                .find(|(key, _)| *key == candidate.as_str())
328            {
329                let take = (arity as usize).min(positional.len());
330                return positional[..take].join(" ");
331            }
332        }
333
334        // No match: return base command name only.
335        positional[0].clone()
336    }
337
338    /// Return `true` if the allow-rule `pattern` (a command prefix string such
339    /// as `"git status"`) matches the concrete command `command`.
340    ///
341    /// Matching is arity-aware:
342    /// - `"git status"` matches `"git status -s"` and `"git status --porcelain"`.
343    /// - `"git status"` does **not** match `"git push origin main"`.
344    /// - Exact string patterns (e.g. `"ls"`) still work as before.
345    ///
346    /// For patterns that are not in the arity table, the function falls back to
347    /// a plain prefix test on the normalised command so that existing exact-match
348    /// rules continue to work unchanged.
349    #[must_use]
350    pub fn allow_rule_matches(&self, pattern: &str, command: &str) -> bool {
351        let pattern_lower = pattern.trim().to_ascii_lowercase();
352        let command_tokens: Vec<&str> = command.split_whitespace().collect();
353
354        // Classify the concrete command through the arity dictionary.
355        let canonical = self.classify(&command_tokens);
356
357        // Primary check: the classified prefix equals the allow-rule pattern.
358        if canonical == pattern_lower {
359            return true;
360        }
361
362        // Fallback: plain normalised prefix match for patterns not in the table
363        // (preserves backward compatibility with exact-match allow rules).
364        let command_lower = command.trim().to_ascii_lowercase();
365        // Normalise whitespace in both sides before comparing.
366        let pattern_norm: String = pattern_lower
367            .split_whitespace()
368            .collect::<Vec<_>>()
369            .join(" ");
370        let command_norm: String = command_lower
371            .split_whitespace()
372            .collect::<Vec<_>>()
373            .join(" ");
374        command_norm == pattern_norm || command_norm.starts_with(&format!("{pattern_norm} "))
375    }
376
377    /// Iterate over all entries in the dictionary.
378    pub fn entries(&self) -> impl Iterator<Item = (&str, u8)> {
379        self.entries.iter().map(|(k, v)| (*k, *v))
380    }
381
382    /// Return the number of entries in the dictionary.
383    #[must_use]
384    pub fn len(&self) -> usize {
385        self.entries.len()
386    }
387
388    /// Return `true` if the dictionary is empty.
389    #[must_use]
390    pub fn is_empty(&self) -> bool {
391        self.entries.is_empty()
392    }
393}
394
395impl Default for BashArityDict {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn dict() -> BashArityDict {
406        BashArityDict::new()
407    }
408
409    // ── classify ─────────────────────────────────────────────────────────────
410
411    #[test]
412    fn classify_git_status_bare() {
413        assert_eq!(dict().classify(&["git", "status"]), "git status");
414    }
415
416    #[test]
417    fn classify_git_status_with_short_flag() {
418        assert_eq!(dict().classify(&["git", "status", "-s"]), "git status");
419    }
420
421    #[test]
422    fn classify_git_status_with_long_flag() {
423        assert_eq!(
424            dict().classify(&["git", "status", "--porcelain"]),
425            "git status"
426        );
427    }
428
429    #[test]
430    fn classify_git_push() {
431        assert_eq!(
432            dict().classify(&["git", "push", "origin", "main"]),
433            "git push"
434        );
435    }
436
437    #[test]
438    fn classify_git_push_force() {
439        assert_eq!(dict().classify(&["git", "push", "--force"]), "git push");
440    }
441
442    #[test]
443    fn classify_npm_run_dev_arity_3() {
444        assert_eq!(dict().classify(&["npm", "run", "dev"]), "npm run dev");
445    }
446
447    #[test]
448    fn classify_npm_install() {
449        assert_eq!(dict().classify(&["npm", "install"]), "npm install");
450    }
451
452    #[test]
453    fn classify_cargo_check_with_flag() {
454        assert_eq!(
455            dict().classify(&["cargo", "check", "--workspace"]),
456            "cargo check"
457        );
458    }
459
460    #[test]
461    fn classify_docker_compose_up_arity_3() {
462        assert_eq!(
463            dict().classify(&["docker", "compose", "up"]),
464            "docker compose up"
465        );
466    }
467
468    #[test]
469    fn classify_kubectl_get_pods_arity_3() {
470        assert_eq!(
471            dict().classify(&["kubectl", "get", "pods"]),
472            "kubectl get pods"
473        );
474    }
475
476    #[test]
477    fn classify_go_mod_tidy_arity_3() {
478        assert_eq!(dict().classify(&["go", "mod", "tidy"]), "go mod tidy");
479    }
480
481    #[test]
482    fn classify_make_no_subcommand() {
483        assert_eq!(dict().classify(&["make", "all"]), "make");
484    }
485
486    #[test]
487    fn classify_aws_s3_arity_3() {
488        assert_eq!(dict().classify(&["aws", "s3", "ls"]), "aws s3 ls");
489    }
490
491    #[test]
492    fn classify_terraform_plan() {
493        assert_eq!(
494            dict().classify(&["terraform", "plan", "-out=tfplan"]),
495            "terraform plan"
496        );
497    }
498
499    #[test]
500    fn classify_unknown_falls_back_to_base() {
501        assert_eq!(dict().classify(&["ls", "-la"]), "ls");
502    }
503
504    #[test]
505    fn classify_empty_returns_empty() {
506        assert_eq!(dict().classify(&[]), "");
507    }
508
509    // ── allow_rule_matches ────────────────────────────────────────────────────
510
511    #[test]
512    fn allow_rule_git_status_matches_with_flag() {
513        assert!(dict().allow_rule_matches("git status", "git status -s"));
514    }
515
516    #[test]
517    fn allow_rule_git_status_matches_porcelain() {
518        assert!(dict().allow_rule_matches("git status", "git status --porcelain"));
519    }
520
521    #[test]
522    fn allow_rule_git_status_does_not_match_push() {
523        assert!(!dict().allow_rule_matches("git status", "git push origin main"));
524    }
525
526    #[test]
527    fn allow_rule_git_status_does_not_match_checkout() {
528        assert!(!dict().allow_rule_matches("git status", "git checkout main"));
529    }
530
531    #[test]
532    fn allow_rule_npm_run_matches_dev() {
533        assert!(dict().allow_rule_matches("npm run dev", "npm run dev"));
534    }
535
536    #[test]
537    fn allow_rule_npm_run_dev_does_not_match_build() {
538        assert!(!dict().allow_rule_matches("npm run dev", "npm run build"));
539    }
540
541    #[test]
542    fn allow_rule_cargo_check_matches_with_flags() {
543        assert!(dict().allow_rule_matches("cargo check", "cargo check --workspace"));
544    }
545
546    #[test]
547    fn allow_rule_exact_match_still_works() {
548        // A pattern not in the arity table falls back to exact/prefix match.
549        assert!(dict().allow_rule_matches("ls", "ls -la"));
550    }
551
552    #[test]
553    fn allow_rule_make_matches_with_target() {
554        assert!(dict().allow_rule_matches("make", "make all"));
555        assert!(dict().allow_rule_matches("make", "make clean"));
556    }
557
558    #[test]
559    fn allow_rule_aws_s3_ls() {
560        assert!(dict().allow_rule_matches("aws s3 ls", "aws s3 ls"));
561        // "aws s3 cp" should not match "aws s3 ls"
562        assert!(!dict().allow_rule_matches("aws s3 ls", "aws s3 cp src dst"));
563    }
564
565    // ── coverage count ────────────────────────────────────────────────────────
566
567    #[test]
568    fn dict_covers_at_least_30_commands() {
569        // The issue requires 30+ common commands covered.
570        assert!(
571            BashArityDict::new().len() >= 30,
572            "expected at least 30 entries, got {}",
573            BashArityDict::new().len()
574        );
575    }
576}