usage/spec/
flag.rs

1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::context::ParsingContext;
11use crate::spec::helpers::NodeHelper;
12use crate::spec::is_false;
13use crate::{string, SpecArg, SpecChoices};
14
15#[derive(Debug, Default, Clone, Serialize)]
16pub struct SpecFlag {
17    pub name: String,
18    pub usage: String,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub help: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub help_long: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub help_md: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub help_first_line: Option<String>,
27    pub short: Vec<char>,
28    pub long: Vec<String>,
29    #[serde(skip_serializing_if = "is_false")]
30    pub required: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub deprecated: Option<String>,
33    #[serde(skip_serializing_if = "is_false")]
34    pub var: bool,
35    pub hide: bool,
36    pub global: bool,
37    #[serde(skip_serializing_if = "is_false")]
38    pub count: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub arg: Option<SpecArg>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub default: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub negate: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub env: Option<String>,
47}
48
49impl SpecFlag {
50    pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
51        let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
52        for (k, v) in node.props() {
53            match k {
54                "help" => flag.help = Some(v.ensure_string()?),
55                "long_help" => flag.help_long = Some(v.ensure_string()?),
56                "help_long" => flag.help_long = Some(v.ensure_string()?),
57                "help_md" => flag.help_md = Some(v.ensure_string()?),
58                "required" => flag.required = v.ensure_bool()?,
59                "var" => flag.var = v.ensure_bool()?,
60                "hide" => flag.hide = v.ensure_bool()?,
61                "deprecated" => {
62                    flag.deprecated = match v.value.as_bool() {
63                        Some(true) => Some("deprecated".into()),
64                        Some(false) => None,
65                        None => Some(v.ensure_string()?),
66                    }
67                }
68                "global" => flag.global = v.ensure_bool()?,
69                "count" => flag.count = v.ensure_bool()?,
70                "default" => {
71                    // Support both string and boolean defaults
72                    flag.default = match v.value.as_bool() {
73                        Some(b) => Some(b.to_string()),
74                        None => v.ensure_string().map(Some)?,
75                    }
76                }
77                "negate" => flag.negate = v.ensure_string().map(Some)?,
78                "env" => flag.env = v.ensure_string().map(Some)?,
79                k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
80            }
81        }
82        if flag.default.is_some() {
83            flag.required = false;
84        }
85        for child in node.children() {
86            match child.name() {
87                "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
88                "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
89                "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
90                "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
91                "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
92                "required" => flag.required = child.arg(0)?.ensure_bool()?,
93                "var" => flag.var = child.arg(0)?.ensure_bool()?,
94                "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
95                "deprecated" => {
96                    flag.deprecated = match child.arg(0)?.ensure_bool() {
97                        Ok(true) => Some("deprecated".into()),
98                        Ok(false) => None,
99                        _ => Some(child.arg(0)?.ensure_string()?),
100                    }
101                }
102                "global" => flag.global = child.arg(0)?.ensure_bool()?,
103                "count" => flag.count = child.arg(0)?.ensure_bool()?,
104                "default" => {
105                    // Support both string and boolean defaults
106                    let arg = child.arg(0)?;
107                    flag.default = match arg.value.as_bool() {
108                        Some(b) => Some(b.to_string()),
109                        None => arg.ensure_string().map(Some)?,
110                    }
111                }
112                "env" => flag.env = child.arg(0)?.ensure_string().map(Some)?,
113                "choices" => {
114                    if let Some(arg) = &mut flag.arg {
115                        arg.choices = Some(SpecChoices::parse(ctx, &child)?);
116                    } else {
117                        bail_parse!(
118                            ctx,
119                            child.node.name().span(),
120                            "flag must have value to have choices"
121                        )
122                    }
123                }
124                k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
125            }
126        }
127        flag.usage = flag.usage();
128        flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
129        Ok(flag)
130    }
131    pub fn usage(&self) -> String {
132        let mut parts = vec![];
133        let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
134        if name != self.name {
135            parts.push(format!("{}:", self.name));
136        }
137        if let Some(short) = self.short.first() {
138            parts.push(format!("-{short}"));
139        }
140        if let Some(long) = self.long.first() {
141            parts.push(format!("--{long}"));
142        }
143        let mut out = parts.join(" ");
144        if self.var {
145            out = format!("{out}…");
146        }
147        if let Some(arg) = &self.arg {
148            out = format!("{} {}", out, arg.usage());
149        }
150        out
151    }
152}
153
154impl From<&SpecFlag> for KdlNode {
155    fn from(flag: &SpecFlag) -> KdlNode {
156        let mut node = KdlNode::new("flag");
157        let name = flag
158            .short
159            .iter()
160            .map(|c| format!("-{c}"))
161            .chain(flag.long.iter().map(|s| format!("--{s}")))
162            .collect_vec()
163            .join(" ");
164        node.push(KdlEntry::new(name));
165        if let Some(desc) = &flag.help {
166            node.push(KdlEntry::new_prop("help", desc.clone()));
167        }
168        if let Some(desc) = &flag.help_long {
169            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
170            let mut node = KdlNode::new("long_help");
171            node.entries_mut().push(KdlEntry::new(desc.clone()));
172            children.nodes_mut().push(node);
173        }
174        if let Some(desc) = &flag.help_md {
175            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
176            let mut node = KdlNode::new("help_md");
177            node.entries_mut().push(KdlEntry::new(desc.clone()));
178            children.nodes_mut().push(node);
179        }
180        if flag.required {
181            node.push(KdlEntry::new_prop("required", true));
182        }
183        if flag.var {
184            node.push(KdlEntry::new_prop("var", true));
185        }
186        if flag.hide {
187            node.push(KdlEntry::new_prop("hide", true));
188        }
189        if flag.global {
190            node.push(KdlEntry::new_prop("global", true));
191        }
192        if flag.count {
193            node.push(KdlEntry::new_prop("count", true));
194        }
195        if let Some(negate) = &flag.negate {
196            node.push(KdlEntry::new_prop("negate", negate.clone()));
197        }
198        if let Some(env) = &flag.env {
199            node.push(KdlEntry::new_prop("env", env.clone()));
200        }
201        if let Some(deprecated) = &flag.deprecated {
202            node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
203        }
204        if let Some(arg) = &flag.arg {
205            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
206            children.nodes_mut().push(arg.into());
207        }
208        node
209    }
210}
211
212impl FromStr for SpecFlag {
213    type Err = UsageErr;
214    fn from_str(input: &str) -> Result<Self> {
215        let mut flag = Self::default();
216        let input = input.replace("...", "…").replace("…", " … ");
217        for part in input.split_whitespace() {
218            if let Some(name) = part.strip_suffix(':') {
219                flag.name = name.to_string();
220            } else if let Some(long) = part.strip_prefix("--") {
221                flag.long.push(long.to_string());
222            } else if let Some(short) = part.strip_prefix('-') {
223                if short.len() != 1 {
224                    return Err(InvalidFlag(
225                        short.to_string(),
226                        (0, input.len()).into(),
227                        input.to_string(),
228                    ));
229                }
230                flag.short.push(short.chars().next().unwrap());
231            } else if part == "…" {
232                if let Some(arg) = &mut flag.arg {
233                    arg.var = true;
234                } else {
235                    flag.var = true;
236                }
237            } else if part.starts_with('<') && part.ends_with('>')
238                || part.starts_with('[') && part.ends_with(']')
239            {
240                flag.arg = Some(part.to_string().parse()?);
241            } else {
242                return Err(InvalidFlag(
243                    part.to_string(),
244                    (0, input.len()).into(),
245                    input.to_string(),
246                ));
247            }
248        }
249        if flag.name.is_empty() {
250            flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
251        }
252        flag.usage = flag.usage();
253        Ok(flag)
254    }
255}
256
257#[cfg(feature = "clap")]
258impl From<&clap::Arg> for SpecFlag {
259    fn from(c: &clap::Arg) -> Self {
260        let required = c.is_required_set();
261        let help = c.get_help().map(|s| s.to_string());
262        let help_long = c.get_long_help().map(|s| s.to_string());
263        let help_first_line = help.as_ref().map(|s| string::first_line(s));
264        let hide = c.is_hide_set();
265        let var = matches!(
266            c.get_action(),
267            clap::ArgAction::Count | clap::ArgAction::Append
268        );
269        let default = c
270            .get_default_values()
271            .first()
272            .map(|s| s.to_string_lossy().to_string());
273        let short = c.get_short_and_visible_aliases().unwrap_or_default();
274        let long = c
275            .get_long_and_visible_aliases()
276            .unwrap_or_default()
277            .into_iter()
278            .map(|s| s.to_string())
279            .collect::<Vec<_>>();
280        let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
281        let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
282            let mut arg = SpecArg::from(
283                c.get_value_names()
284                    .map(|s| s.iter().map(|s| s.to_string()).join(" "))
285                    .unwrap_or(name.clone())
286                    .as_str(),
287            );
288
289            let choices = c
290                .get_possible_values()
291                .iter()
292                .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
293                .collect::<Vec<_>>();
294            if !choices.is_empty() {
295                arg.choices = Some(SpecChoices { choices });
296            }
297
298            Some(arg)
299        } else {
300            None
301        };
302        Self {
303            name,
304            usage: "".into(),
305            short,
306            long,
307            required,
308            help,
309            help_long,
310            help_md: None,
311            help_first_line,
312            var,
313            hide,
314            global: c.is_global_set(),
315            arg,
316            count: matches!(c.get_action(), clap::ArgAction::Count),
317            default,
318            deprecated: None,
319            negate: None,
320            env: None,
321        }
322    }
323}
324
325// #[cfg(feature = "clap")]
326// impl From<&SpecFlag> for clap::Arg {
327//     fn from(flag: &SpecFlag) -> Self {
328//         let mut a = clap::Arg::new(&flag.name);
329//         if let Some(desc) = &flag.help {
330//             a = a.help(desc);
331//         }
332//         if flag.required {
333//             a = a.required(true);
334//         }
335//         if let Some(arg) = &flag.arg {
336//             a = a.value_name(&arg.name);
337//             if arg.var {
338//                 a = a.action(clap::ArgAction::Append)
339//             } else {
340//                 a = a.action(clap::ArgAction::Set)
341//             }
342//         } else {
343//             a = a.action(clap::ArgAction::SetTrue)
344//         }
345//         // let mut a = clap::Arg::new(&flag.name)
346//         //     .required(flag.required)
347//         //     .action(clap::ArgAction::SetTrue);
348//         if let Some(short) = flag.short.first() {
349//             a = a.short(*short);
350//         }
351//         if let Some(long) = flag.long.first() {
352//             a = a.long(long);
353//         }
354//         for short in flag.short.iter().skip(1) {
355//             a = a.visible_short_alias(*short);
356//         }
357//         for long in flag.long.iter().skip(1) {
358//             a = a.visible_alias(long);
359//         }
360//         // cmd = cmd.arg(a);
361//         // if flag.multiple {
362//         //     a = a.multiple(true);
363//         // }
364//         // if flag.hide {
365//         //     a = a.hide_possible_values(true);
366//         // }
367//         a
368//     }
369// }
370
371impl Display for SpecFlag {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        write!(f, "{}", self.usage())
374    }
375}
376impl PartialEq for SpecFlag {
377    fn eq(&self, other: &Self) -> bool {
378        self.name == other.name
379    }
380}
381impl Eq for SpecFlag {}
382impl Hash for SpecFlag {
383    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
384        self.name.hash(state);
385    }
386}
387
388fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
389    long.first()
390        .map(|s| s.to_string())
391        .or_else(|| short.first().map(|c| c.to_string()))
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::Spec;
398    use insta::assert_snapshot;
399
400    #[test]
401    fn from_str() {
402        assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
403        assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
404        assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
405        assert_snapshot!("-f --flag…".parse::<SpecFlag>().unwrap(), @"-f --flag…");
406        assert_snapshot!("-f --flag …".parse::<SpecFlag>().unwrap(), @"-f --flag…");
407        assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
408        assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
409        assert_snapshot!("-f --flag… <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag… <arg>");
410        assert_snapshot!("-f --flag <arg>…".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>…");
411        assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
412        assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
413    }
414
415    #[test]
416    fn test_flag_with_env() {
417        let spec = Spec::parse(
418            &Default::default(),
419            r#"
420flag "--color" env="MYCLI_COLOR" help="Enable color output"
421flag "--verbose" env="MYCLI_VERBOSE"
422            "#,
423        )
424        .unwrap();
425
426        assert_snapshot!(spec, @r#"
427        flag --color help="Enable color output" env=MYCLI_COLOR
428        flag --verbose env=MYCLI_VERBOSE
429        "#);
430
431        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
432        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
433
434        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
435        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
436    }
437
438    #[test]
439    fn test_flag_with_env_child_node() {
440        let spec = Spec::parse(
441            &Default::default(),
442            r#"
443flag "--color" help="Enable color output" {
444    env "MYCLI_COLOR"
445}
446flag "--verbose" {
447    env "MYCLI_VERBOSE"
448}
449            "#,
450        )
451        .unwrap();
452
453        assert_snapshot!(spec, @r#"
454        flag --color help="Enable color output" env=MYCLI_COLOR
455        flag --verbose env=MYCLI_VERBOSE
456        "#);
457
458        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
459        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
460
461        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
462        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
463    }
464
465    #[test]
466    fn test_flag_with_boolean_defaults() {
467        let spec = Spec::parse(
468            &Default::default(),
469            r#"
470flag "--color" default=#true
471flag "--verbose" default=#false
472flag "--debug" default="true"
473flag "--quiet" default="false"
474            "#,
475        )
476        .unwrap();
477
478        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
479        assert_eq!(color_flag.default, Some("true".to_string()));
480
481        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
482        assert_eq!(verbose_flag.default, Some("false".to_string()));
483
484        let debug_flag = spec.cmd.flags.iter().find(|f| f.name == "debug").unwrap();
485        assert_eq!(debug_flag.default, Some("true".to_string()));
486
487        let quiet_flag = spec.cmd.flags.iter().find(|f| f.name == "quiet").unwrap();
488        assert_eq!(quiet_flag.default, Some("false".to_string()));
489    }
490
491    #[test]
492    fn test_flag_with_boolean_defaults_child_node() {
493        let spec = Spec::parse(
494            &Default::default(),
495            r#"
496flag "--color" {
497    default #true
498}
499flag "--verbose" {
500    default #false
501}
502            "#,
503        )
504        .unwrap();
505
506        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
507        assert_eq!(color_flag.default, Some("true".to_string()));
508
509        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
510        assert_eq!(verbose_flag.default, Some("false".to_string()));
511    }
512}