use std::collections::HashSet;
use std::ffi::OsString;
use crate::cli::Cli;
pub fn inject_default_subcommand<I>(args: I) -> Vec<OsString>
where
I: IntoIterator<Item = OsString>,
{
let args: Vec<OsString> = args.into_iter().collect();
if args.len() <= 1 {
return args;
}
let cmd = <Cli as clap::CommandFactory>::command();
let mut known: Vec<String> = cmd
.get_subcommands()
.flat_map(|sc| {
std::iter::once(sc.get_name().to_string()).chain(sc.get_all_aliases().map(String::from))
})
.collect();
known.push(String::from("help"));
let top_level_flags: HashSet<String> = cmd
.get_arguments()
.filter(|a| !a.is_positional())
.flat_map(|a| {
let mut names = Vec::new();
if let Some(l) = a.get_long() {
names.push(format!("--{l}"));
}
if let Some(s) = a.get_short() {
names.push(format!("-{s}"));
}
names
})
.chain(
["--help", "-h", "--version", "-V"]
.into_iter()
.map(String::from),
)
.collect();
let mut all_value_flags: Vec<(Option<String>, Option<char>)> = Vec::new();
let mut collect_value = |c: &clap::Command| {
for arg in c.get_arguments().filter(|a| !a.is_positional()) {
if matches!(
arg.get_action(),
clap::ArgAction::Set | clap::ArgAction::Append
) {
all_value_flags.push((arg.get_long().map(String::from), arg.get_short()));
}
}
};
collect_value(&cmd);
for sc in cmd.get_subcommands() {
collect_value(sc);
}
let base_form = |token: &str| -> Option<String> {
if let Some(rest) = token.strip_prefix("--") {
let name = rest.split('=').next().unwrap_or(rest);
return Some(format!("--{name}"));
}
if let Some(rest) = token.strip_prefix('-') {
return rest.chars().next().map(|c| format!("-{c}"));
}
None
};
let consumes_next = |token: &str| -> bool {
if token.starts_with("--") && token.contains('=') {
return false;
}
if token.starts_with('-') && !token.starts_with("--") && token.len() > 2 {
return false;
}
if let Some(rest) = token.strip_prefix("--") {
return all_value_flags
.iter()
.any(|(l, _)| l.as_deref() == Some(rest));
}
if let Some(rest) = token.strip_prefix('-') {
if let Some(c) = rest.chars().next().filter(|_| rest.len() == 1) {
return all_value_flags.iter().any(|(_, s)| *s == Some(c));
}
}
false
};
let inject_check = |args: Vec<OsString>| -> Vec<OsString> {
let mut injected = Vec::with_capacity(args.len() + 1);
injected.push(args[0].clone());
injected.push(OsString::from("check"));
injected.extend(args.into_iter().skip(1));
injected
};
let mut i = 1;
let mut saw_subcommand_flag = false;
while i < args.len() {
let token = args[i].to_string_lossy();
if token == "--" {
return if i + 1 >= args.len() {
args
} else {
inject_check(args)
};
}
if token.starts_with('-') {
if let Some(base) = base_form(&token) {
if !top_level_flags.contains(&base) {
saw_subcommand_flag = true;
}
}
i += if consumes_next(&token) { 2 } else { 1 };
continue;
}
return if known.iter().any(|k| k == &*token) {
args
} else {
inject_check(args)
};
}
if saw_subcommand_flag {
return inject_check(args);
}
args
}
#[cfg(test)]
mod tests {
use super::inject_default_subcommand;
use std::ffi::OsString;
fn args(a: &[&str]) -> Vec<OsString> {
a.iter().map(OsString::from).collect()
}
fn names(v: Vec<OsString>) -> Vec<String> {
v.into_iter()
.map(|s| s.to_string_lossy().into_owned())
.collect()
}
#[test]
fn bare_invocation_is_untouched() {
let out = inject_default_subcommand(args(&["anc"]));
assert_eq!(names(out), vec!["anc"]);
}
#[test]
fn dot_path_gets_check_injected() {
let out = inject_default_subcommand(args(&["anc", "."]));
assert_eq!(names(out), vec!["anc", "check", "."]);
}
#[test]
fn global_short_flag_before_path_gets_check_injected_in_canonical_position() {
let out = inject_default_subcommand(args(&["anc", "-q", "."]));
assert_eq!(names(out), vec!["anc", "check", "-q", "."]);
}
#[test]
fn global_long_flag_before_path_gets_check_injected() {
let out = inject_default_subcommand(args(&["anc", "--quiet", "."]));
assert_eq!(names(out), vec!["anc", "check", "--quiet", "."]);
}
#[test]
fn explicit_check_subcommand_is_untouched() {
let out = inject_default_subcommand(args(&["anc", "check", "."]));
assert_eq!(names(out), vec!["anc", "check", "."]);
}
#[test]
fn explicit_completions_subcommand_is_untouched() {
let out = inject_default_subcommand(args(&["anc", "completions", "bash"]));
assert_eq!(names(out), vec!["anc", "completions", "bash"]);
}
#[test]
fn help_flag_alone_is_untouched() {
let out = inject_default_subcommand(args(&["anc", "--help"]));
assert_eq!(names(out), vec!["anc", "--help"]);
}
#[test]
fn version_flag_alone_is_untouched() {
let out = inject_default_subcommand(args(&["anc", "--version"]));
assert_eq!(names(out), vec!["anc", "--version"]);
}
#[test]
fn quiet_flag_alone_is_untouched() {
let out = inject_default_subcommand(args(&["anc", "-q"]));
assert_eq!(names(out), vec!["anc", "-q"]);
}
#[test]
fn help_subcommand_passes_through() {
let out = inject_default_subcommand(args(&["anc", "help"]));
assert_eq!(names(out), vec!["anc", "help"]);
}
#[test]
fn help_subcommand_with_target_passes_through() {
let out = inject_default_subcommand(args(&["anc", "help", "check"]));
assert_eq!(names(out), vec!["anc", "help", "check"]);
}
#[test]
fn command_flag_value_matching_subcommand_name_is_paired() {
let out = inject_default_subcommand(args(&["anc", "--command", "check"]));
assert_eq!(names(out), vec!["anc", "check", "--command", "check"]);
}
#[test]
fn command_flag_with_no_positional_injects_check() {
let out = inject_default_subcommand(args(&["anc", "--command", "rg"]));
assert_eq!(names(out), vec!["anc", "check", "--command", "rg"]);
}
#[test]
fn output_flag_with_no_positional_injects_check() {
let out = inject_default_subcommand(args(&["anc", "--output", "json", "--source"]));
assert_eq!(
names(out),
vec!["anc", "check", "--output", "json", "--source"]
);
}
#[test]
fn equals_form_value_flag_is_recognized_as_subcommand_scoped() {
let out = inject_default_subcommand(args(&["anc", "--output=json", "--source"]));
assert_eq!(
names(out),
vec!["anc", "check", "--output=json", "--source"]
);
}
#[test]
fn principle_value_flag_pairs_with_numeric_value() {
let out = inject_default_subcommand(args(&["anc", "--principle", "4"]));
assert_eq!(names(out), vec!["anc", "check", "--principle", "4"]);
}
#[test]
fn double_dash_separator_injects_check_before_separator() {
let out = inject_default_subcommand(args(&["anc", "--", "."]));
assert_eq!(names(out), vec!["anc", "check", "--", "."]);
}
#[test]
fn double_dash_alone_passes_through() {
let out = inject_default_subcommand(args(&["anc", "--"]));
assert_eq!(names(out), vec!["anc", "--"]);
}
#[test]
fn directory_path_gets_check_injected() {
let out = inject_default_subcommand(args(&["anc", "/some/dir"]));
assert_eq!(names(out), vec!["anc", "check", "/some/dir"]);
}
#[test]
fn trailing_flags_pass_through() {
let out = inject_default_subcommand(args(&["anc", ".", "--output", "json"]));
assert_eq!(names(out), vec!["anc", "check", ".", "--output", "json"]);
}
}