Skip to main content

safe_chains/
command.rs

1use crate::parse::{has_flag, Token};
2use crate::policy::{self, FlagPolicy};
3#[cfg(test)]
4use crate::policy::FlagStyle;
5
6pub type CheckFn = fn(&[Token]) -> bool;
7
8pub enum SubDef {
9    Policy {
10        name: &'static str,
11        policy: &'static FlagPolicy,
12    },
13    Nested {
14        name: &'static str,
15        subs: &'static [SubDef],
16    },
17    Guarded {
18        name: &'static str,
19        guard_short: Option<&'static str>,
20        guard_long: &'static str,
21        policy: &'static FlagPolicy,
22    },
23    Custom {
24        name: &'static str,
25        check: CheckFn,
26        doc: &'static str,
27        test_suffix: Option<&'static str>,
28    },
29    Delegation {
30        name: &'static str,
31        skip: usize,
32        doc: &'static str,
33    },
34}
35
36pub struct CommandDef {
37    pub name: &'static str,
38    pub subs: &'static [SubDef],
39    pub bare_flags: &'static [&'static str],
40    pub help_eligible: bool,
41    pub url: &'static str,
42    pub aliases: &'static [&'static str],
43}
44
45impl SubDef {
46    pub fn name(&self) -> &'static str {
47        match self {
48            Self::Policy { name, .. }
49            | Self::Nested { name, .. }
50            | Self::Guarded { name, .. }
51            | Self::Custom { name, .. }
52            | Self::Delegation { name, .. } => name,
53        }
54    }
55
56    pub fn check(&self, tokens: &[Token]) -> bool {
57        match self {
58            Self::Policy { policy, .. } => {
59                if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
60                    return true;
61                }
62                policy::check(tokens, policy)
63            }
64            Self::Nested { subs, .. } => {
65                if tokens.len() < 2 {
66                    return false;
67                }
68                let sub = tokens[1].as_str();
69                if tokens.len() == 2 && (sub == "--help" || sub == "-h") {
70                    return true;
71                }
72                subs.iter()
73                    .any(|s| s.name() == sub && s.check(&tokens[1..]))
74            }
75            Self::Guarded {
76                guard_short,
77                guard_long,
78                policy,
79                ..
80            } => {
81                if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
82                    return true;
83                }
84                has_flag(tokens, *guard_short, Some(guard_long))
85                    && policy::check(tokens, policy)
86            }
87            Self::Custom { check: f, .. } => {
88                if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
89                    return true;
90                }
91                f(tokens)
92            }
93            Self::Delegation { skip, .. } => {
94                if tokens.len() <= *skip {
95                    return false;
96                }
97                let inner = shell_words::join(tokens[*skip..].iter().map(|t| t.as_str()));
98                crate::is_safe_command(&inner)
99            }
100        }
101    }
102}
103
104impl CommandDef {
105    pub fn opencode_patterns(&self) -> Vec<String> {
106        let mut patterns = Vec::new();
107        let names: Vec<&str> = std::iter::once(self.name)
108            .chain(self.aliases.iter().copied())
109            .collect();
110        for name in &names {
111            for sub in self.subs {
112                sub_opencode_patterns(name, sub, &mut patterns);
113            }
114        }
115        patterns
116    }
117
118    pub fn check(&self, tokens: &[Token]) -> bool {
119        if tokens.len() < 2 {
120            return false;
121        }
122        let arg = tokens[1].as_str();
123        if self.help_eligible && tokens.len() == 2 && matches!(arg, "--help" | "-h" | "--version" | "-V") {
124            return true;
125        }
126        if tokens.len() == 2 && self.bare_flags.contains(&arg) {
127            return true;
128        }
129        self.subs
130            .iter()
131            .find(|s| s.name() == arg)
132            .is_some_and(|s| s.check(&tokens[1..]))
133    }
134
135    pub fn dispatch(
136        &self,
137        cmd: &str,
138        tokens: &[Token],
139    ) -> Option<bool> {
140        if cmd == self.name || self.aliases.contains(&cmd) {
141            Some(self.check(tokens))
142        } else {
143            None
144        }
145    }
146
147    pub fn to_doc(&self) -> crate::docs::CommandDoc {
148        let mut lines = Vec::new();
149
150        if !self.bare_flags.is_empty() {
151            lines.push(format!("- Info flags: {}", self.bare_flags.join(", ")));
152        }
153
154        let mut sub_lines: Vec<String> = Vec::new();
155        for sub in self.subs {
156            sub_doc_line(sub, "", &mut sub_lines);
157        }
158        sub_lines.sort();
159        lines.extend(sub_lines);
160
161        let mut doc = crate::docs::CommandDoc::handler(self.name, self.url, lines.join("\n"));
162        doc.aliases = self.aliases.iter().map(|a| a.to_string()).collect();
163        doc
164    }
165}
166
167pub struct FlatDef {
168    pub name: &'static str,
169    pub policy: &'static FlagPolicy,
170    pub help_eligible: bool,
171    pub url: &'static str,
172    pub aliases: &'static [&'static str],
173}
174
175impl FlatDef {
176    pub fn opencode_patterns(&self) -> Vec<String> {
177        let mut patterns = Vec::new();
178        let names: Vec<&str> = std::iter::once(self.name)
179            .chain(self.aliases.iter().copied())
180            .collect();
181        for name in names {
182            patterns.push(name.to_string());
183            patterns.push(format!("{name} *"));
184        }
185        patterns
186    }
187
188    pub fn dispatch(&self, cmd: &str, tokens: &[Token]) -> Option<bool> {
189        if cmd == self.name || self.aliases.contains(&cmd) {
190            if self.help_eligible
191                && tokens.len() == 2
192                && matches!(tokens[1].as_str(), "--help" | "-h" | "--version" | "-V")
193            {
194                return Some(true);
195            }
196            Some(policy::check(tokens, self.policy))
197        } else {
198            None
199        }
200    }
201
202    pub fn to_doc(&self) -> crate::docs::CommandDoc {
203        let mut doc = crate::docs::CommandDoc::handler(self.name, self.url, self.policy.describe());
204        doc.aliases = self.aliases.iter().map(|a| a.to_string()).collect();
205        doc
206    }
207}
208
209#[cfg(test)]
210impl FlatDef {
211    pub fn auto_test_reject_unknown(&self) {
212        if self.policy.flag_style == FlagStyle::Positional {
213            return;
214        }
215        let test = format!("{} --xyzzy-unknown-42", self.name);
216        assert!(
217            !crate::is_safe_command(&test),
218            "{}: accepted unknown flag: {test}",
219            self.name,
220        );
221        for alias in self.aliases {
222            let test = format!("{alias} --xyzzy-unknown-42");
223            assert!(
224                !crate::is_safe_command(&test),
225                "{alias}: alias accepted unknown flag: {test}",
226            );
227        }
228    }
229}
230
231fn sub_opencode_patterns(prefix: &str, sub: &SubDef, out: &mut Vec<String>) {
232    match sub {
233        SubDef::Policy { name, .. } => {
234            out.push(format!("{prefix} {name}"));
235            out.push(format!("{prefix} {name} *"));
236        }
237        SubDef::Nested { name, subs } => {
238            let path = format!("{prefix} {name}");
239            for s in *subs {
240                sub_opencode_patterns(&path, s, out);
241            }
242        }
243        SubDef::Guarded {
244            name, guard_long, ..
245        } => {
246            out.push(format!("{prefix} {name} {guard_long}"));
247            out.push(format!("{prefix} {name} {guard_long} *"));
248        }
249        SubDef::Custom { name, .. } => {
250            out.push(format!("{prefix} {name}"));
251            out.push(format!("{prefix} {name} *"));
252        }
253        SubDef::Delegation { .. } => {}
254    }
255}
256
257fn sub_doc_line(sub: &SubDef, prefix: &str, out: &mut Vec<String>) {
258    match sub {
259        SubDef::Policy { name, policy } => {
260            let summary = policy.flag_summary();
261            let label = if prefix.is_empty() {
262                (*name).to_string()
263            } else {
264                format!("{prefix} {name}")
265            };
266            if summary.is_empty() {
267                out.push(format!("- **{label}**"));
268            } else {
269                out.push(format!("- **{label}**: {summary}"));
270            }
271        }
272        SubDef::Nested { name, subs } => {
273            let path = if prefix.is_empty() {
274                (*name).to_string()
275            } else {
276                format!("{prefix} {name}")
277            };
278            for s in *subs {
279                sub_doc_line(s, &path, out);
280            }
281        }
282        SubDef::Guarded {
283            name,
284            guard_long,
285            policy,
286            ..
287        } => {
288            let summary = policy.flag_summary();
289            let label = if prefix.is_empty() {
290                (*name).to_string()
291            } else {
292                format!("{prefix} {name}")
293            };
294            if summary.is_empty() {
295                out.push(format!("- **{label}** (requires {guard_long})"));
296            } else {
297                out.push(format!("- **{label}** (requires {guard_long}): {summary}"));
298            }
299        }
300        SubDef::Custom { name, doc, .. } => {
301            if !doc.is_empty() && doc.trim().is_empty() {
302                return;
303            }
304            let label = if prefix.is_empty() {
305                (*name).to_string()
306            } else {
307                format!("{prefix} {name}")
308            };
309            if doc.is_empty() {
310                out.push(format!("- **{label}**"));
311            } else {
312                out.push(format!("- **{label}**: {doc}"));
313            }
314        }
315        SubDef::Delegation { name, doc, .. } => {
316            if doc.is_empty() {
317                return;
318            }
319            let label = if prefix.is_empty() {
320                (*name).to_string()
321            } else {
322                format!("{prefix} {name}")
323            };
324            out.push(format!("- **{label}**: {doc}"));
325        }
326    }
327}
328
329#[cfg(test)]
330impl CommandDef {
331    pub fn auto_test_reject_unknown(&self) {
332        let mut failures = Vec::new();
333
334        assert!(
335            !crate::is_safe_command(self.name),
336            "{}: accepted bare invocation",
337            self.name,
338        );
339
340        let test = format!("{} xyzzy-unknown-42", self.name);
341        assert!(
342            !crate::is_safe_command(&test),
343            "{}: accepted unknown subcommand: {test}",
344            self.name,
345        );
346
347        for sub in self.subs {
348            auto_test_sub(self.name, sub, &mut failures);
349        }
350        assert!(
351            failures.is_empty(),
352            "{}: unknown flags/subcommands accepted:\n{}",
353            self.name,
354            failures.join("\n"),
355        );
356    }
357}
358
359#[cfg(test)]
360fn auto_test_sub(prefix: &str, sub: &SubDef, failures: &mut Vec<String>) {
361    const UNKNOWN: &str = "--xyzzy-unknown-42";
362
363    match sub {
364        SubDef::Policy { name, policy } => {
365            if policy.flag_style == FlagStyle::Positional {
366                return;
367            }
368            let test = format!("{prefix} {name} {UNKNOWN}");
369            if crate::is_safe_command(&test) {
370                failures.push(format!("{prefix} {name}: accepted unknown flag: {test}"));
371            }
372        }
373        SubDef::Nested { name, subs } => {
374            let path = format!("{prefix} {name}");
375            let test = format!("{path} xyzzy-unknown-42");
376            if crate::is_safe_command(&test) {
377                failures.push(format!("{path}: accepted unknown subcommand: {test}"));
378            }
379            for s in *subs {
380                auto_test_sub(&path, s, failures);
381            }
382        }
383        SubDef::Guarded {
384            name, guard_long, ..
385        } => {
386            let test = format!("{prefix} {name} {guard_long} {UNKNOWN}");
387            if crate::is_safe_command(&test) {
388                failures.push(format!("{prefix} {name}: accepted unknown flag: {test}"));
389            }
390        }
391        SubDef::Custom {
392            name, test_suffix, ..
393        } => {
394            if let Some(suffix) = test_suffix {
395                let test = format!("{prefix} {name} {suffix} {UNKNOWN}");
396                if crate::is_safe_command(&test) {
397                    failures.push(format!(
398                        "{prefix} {name}: accepted unknown flag: {test}"
399                    ));
400                }
401            }
402        }
403        SubDef::Delegation { .. } => {}
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::parse::WordSet;
411    use crate::policy::FlagStyle;
412
413    fn toks(words: &[&str]) -> Vec<Token> {
414        words.iter().map(|s| Token::from_test(s)).collect()
415    }
416
417
418    static TEST_POLICY: FlagPolicy = FlagPolicy {
419        standalone: WordSet::new(&["--verbose", "-v"]),
420        valued: WordSet::new(&["--output", "-o"]),
421        bare: true,
422        max_positional: None,
423        flag_style: FlagStyle::Strict,
424    };
425
426    static SIMPLE_CMD: CommandDef = CommandDef {
427        name: "mycmd",
428        subs: &[SubDef::Policy {
429            name: "build",
430            policy: &TEST_POLICY,
431        }],
432        bare_flags: &["--info"],
433        help_eligible: true,
434        url: "",
435        aliases: &[],
436    };
437
438    #[test]
439    fn bare_rejected() {
440        assert!(!SIMPLE_CMD.check(&toks(&["mycmd"])));
441    }
442
443    #[test]
444    fn bare_flag_accepted() {
445        assert!(SIMPLE_CMD.check(&toks(&["mycmd", "--info"])));
446    }
447
448    #[test]
449    fn bare_flag_with_extra_rejected() {
450        assert!(!SIMPLE_CMD.check(&toks(&["mycmd", "--info", "extra"])));
451    }
452
453    #[test]
454    fn policy_sub_bare() {
455        assert!(SIMPLE_CMD.check(&toks(&["mycmd", "build"])));
456    }
457
458    #[test]
459    fn policy_sub_with_flag() {
460        assert!(SIMPLE_CMD.check(&toks(&["mycmd", "build", "--verbose"])));
461    }
462
463    #[test]
464    fn policy_sub_unknown_flag() {
465        assert!(!SIMPLE_CMD.check(&toks(&["mycmd", "build", "--bad"])));
466    }
467
468    #[test]
469    fn unknown_sub_rejected() {
470        assert!(!SIMPLE_CMD.check(&toks(&["mycmd", "deploy"])));
471    }
472
473    #[test]
474    fn dispatch_matches() {
475        assert_eq!(
476            SIMPLE_CMD.dispatch("mycmd", &toks(&["mycmd", "build"])),
477            Some(true)
478        );
479    }
480
481    #[test]
482    fn dispatch_no_match() {
483        assert_eq!(
484            SIMPLE_CMD.dispatch("other", &toks(&["other", "build"])),
485            None
486        );
487    }
488
489    static NESTED_CMD: CommandDef = CommandDef {
490        name: "nested",
491        subs: &[SubDef::Nested {
492            name: "package",
493            subs: &[SubDef::Policy {
494                name: "describe",
495                policy: &TEST_POLICY,
496            }],
497        }],
498        bare_flags: &[],
499        help_eligible: false,
500        url: "",
501        aliases: &[],
502    };
503
504    #[test]
505    fn nested_sub() {
506        assert!(NESTED_CMD.check(&toks(&["nested", "package", "describe"])));
507    }
508
509    #[test]
510    fn nested_sub_with_flag() {
511        assert!(NESTED_CMD.check(
512            &toks(&["nested", "package", "describe", "--verbose"]),
513        ));
514    }
515
516    #[test]
517    fn nested_bare_rejected() {
518        assert!(!NESTED_CMD.check(&toks(&["nested", "package"])));
519    }
520
521    #[test]
522    fn nested_unknown_sub_rejected() {
523        assert!(!NESTED_CMD.check(&toks(&["nested", "package", "deploy"])));
524    }
525
526    static GUARDED_POLICY: FlagPolicy = FlagPolicy {
527        standalone: WordSet::new(&["--all", "--check"]),
528        valued: WordSet::new(&[]),
529        bare: false,
530        max_positional: None,
531        flag_style: FlagStyle::Strict,
532    };
533
534    static GUARDED_CMD: CommandDef = CommandDef {
535        name: "guarded",
536        subs: &[SubDef::Guarded {
537            name: "fmt",
538            guard_short: None,
539            guard_long: "--check",
540            policy: &GUARDED_POLICY,
541        }],
542        bare_flags: &[],
543        help_eligible: false,
544        url: "",
545        aliases: &[],
546    };
547
548    #[test]
549    fn guarded_with_guard() {
550        assert!(GUARDED_CMD.check(&toks(&["guarded", "fmt", "--check"])));
551    }
552
553    #[test]
554    fn guarded_without_guard() {
555        assert!(!GUARDED_CMD.check(&toks(&["guarded", "fmt"])));
556    }
557
558    #[test]
559    fn guarded_with_guard_and_flag() {
560        assert!(GUARDED_CMD.check(
561            &toks(&["guarded", "fmt", "--check", "--all"]),
562        ));
563    }
564
565    static DELEGATION_CMD: CommandDef = CommandDef {
566        name: "runner",
567        subs: &[SubDef::Delegation {
568            name: "run",
569            skip: 2,
570            doc: "run delegates to inner command.",
571        }],
572        bare_flags: &[],
573        help_eligible: false,
574        url: "",
575        aliases: &[],
576    };
577
578    #[test]
579    fn delegation_safe_inner() {
580        assert!(DELEGATION_CMD.check(
581            &toks(&["runner", "run", "stable", "echo", "hello"]),
582        ));
583    }
584
585    #[test]
586    fn delegation_unsafe_inner() {
587        assert!(!DELEGATION_CMD.check(
588            &toks(&["runner", "run", "stable", "rm", "-rf"]),
589        ));
590    }
591
592    #[test]
593    fn delegation_no_inner() {
594        assert!(!DELEGATION_CMD.check(
595            &toks(&["runner", "run", "stable"]),
596        ));
597    }
598
599    fn custom_check(tokens: &[Token]) -> bool {
600        tokens.len() >= 2 && tokens[1] == "safe"
601    }
602
603    static CUSTOM_CMD: CommandDef = CommandDef {
604        name: "custom",
605        subs: &[SubDef::Custom {
606            name: "special",
607            check: custom_check,
608            doc: "special (safe only).",
609            test_suffix: Some("safe"),
610        }],
611        bare_flags: &[],
612        help_eligible: false,
613        url: "",
614        aliases: &[],
615    };
616
617    #[test]
618    fn custom_passes() {
619        assert!(CUSTOM_CMD.check(&toks(&["custom", "special", "safe"])));
620    }
621
622    #[test]
623    fn custom_fails() {
624        assert!(!CUSTOM_CMD.check(&toks(&["custom", "special", "bad"])));
625    }
626
627    #[test]
628    fn doc_simple() {
629        let doc = SIMPLE_CMD.to_doc();
630        assert_eq!(doc.name, "mycmd");
631        assert_eq!(
632            doc.description,
633            "- Info flags: --info\n- **build**: Flags: --verbose, -v. Valued: --output, -o"
634        );
635    }
636
637    #[test]
638    fn doc_nested() {
639        let doc = NESTED_CMD.to_doc();
640        assert_eq!(
641            doc.description,
642            "- **package describe**: Flags: --verbose, -v. Valued: --output, -o"
643        );
644    }
645
646    #[test]
647    fn doc_guarded() {
648        let doc = GUARDED_CMD.to_doc();
649        assert_eq!(
650            doc.description,
651            "- **fmt** (requires --check): Flags: --all, --check"
652        );
653    }
654
655    #[test]
656    fn doc_delegation() {
657        let doc = DELEGATION_CMD.to_doc();
658        assert_eq!(doc.description, "- **run**: run delegates to inner command.");
659    }
660
661    #[test]
662    fn doc_custom() {
663        let doc = CUSTOM_CMD.to_doc();
664        assert_eq!(doc.description, "- **special**: special (safe only).");
665    }
666
667    #[test]
668    fn opencode_patterns_simple() {
669        let patterns = SIMPLE_CMD.opencode_patterns();
670        assert!(patterns.contains(&"mycmd build".to_string()));
671        assert!(patterns.contains(&"mycmd build *".to_string()));
672    }
673
674    #[test]
675    fn opencode_patterns_nested() {
676        let patterns = NESTED_CMD.opencode_patterns();
677        assert!(patterns.contains(&"nested package describe".to_string()));
678        assert!(patterns.contains(&"nested package describe *".to_string()));
679        assert!(!patterns.iter().any(|p| p == "nested package"));
680    }
681
682    #[test]
683    fn opencode_patterns_guarded() {
684        let patterns = GUARDED_CMD.opencode_patterns();
685        assert!(patterns.contains(&"guarded fmt --check".to_string()));
686        assert!(patterns.contains(&"guarded fmt --check *".to_string()));
687        assert!(!patterns.iter().any(|p| p == "guarded fmt"));
688    }
689
690    #[test]
691    fn opencode_patterns_delegation_skipped() {
692        let patterns = DELEGATION_CMD.opencode_patterns();
693        assert!(patterns.is_empty());
694    }
695
696    #[test]
697    fn opencode_patterns_custom() {
698        let patterns = CUSTOM_CMD.opencode_patterns();
699        assert!(patterns.contains(&"custom special".to_string()));
700        assert!(patterns.contains(&"custom special *".to_string()));
701    }
702
703    #[test]
704    fn opencode_patterns_aliases() {
705        static ALIASED: CommandDef = CommandDef {
706            name: "primary",
707            subs: &[SubDef::Policy {
708                name: "list",
709                policy: &TEST_POLICY,
710            }],
711            bare_flags: &[],
712            help_eligible: false,
713            url: "",
714            aliases: &["alt"],
715        };
716        let patterns = ALIASED.opencode_patterns();
717        assert!(patterns.contains(&"primary list".to_string()));
718        assert!(patterns.contains(&"alt list".to_string()));
719        assert!(patterns.contains(&"alt list *".to_string()));
720    }
721
722    #[test]
723    fn flat_def_opencode_patterns() {
724        static FLAT: FlatDef = FlatDef {
725            name: "grep",
726            policy: &TEST_POLICY,
727            help_eligible: true,
728            url: "",
729            aliases: &["rg"],
730        };
731        let patterns = FLAT.opencode_patterns();
732        assert_eq!(patterns, vec!["grep", "grep *", "rg", "rg *"]);
733    }
734}