usage/spec/
flag.rs

1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::builder::SpecFlagBuilder;
11use crate::spec::context::ParsingContext;
12use crate::spec::helpers::NodeHelper;
13use crate::spec::is_false;
14use crate::{string, SpecArg, SpecChoices};
15
16#[derive(Debug, Default, Clone, Serialize)]
17pub struct SpecFlag {
18    pub name: String,
19    pub usage: String,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub help: Option<String>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub help_long: Option<String>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub help_md: Option<String>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub help_first_line: Option<String>,
28    pub short: Vec<char>,
29    pub long: Vec<String>,
30    #[serde(skip_serializing_if = "is_false")]
31    pub required: bool,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub deprecated: Option<String>,
34    #[serde(skip_serializing_if = "is_false")]
35    pub var: bool,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub var_min: Option<usize>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub var_max: Option<usize>,
40    pub hide: bool,
41    pub global: bool,
42    #[serde(skip_serializing_if = "is_false")]
43    pub count: bool,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub arg: Option<SpecArg>,
46    #[serde(skip_serializing_if = "Vec::is_empty")]
47    pub default: Vec<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub negate: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub env: Option<String>,
52}
53
54impl SpecFlag {
55    /// Create a new builder for SpecFlag
56    pub fn builder() -> SpecFlagBuilder {
57        SpecFlagBuilder::new()
58    }
59
60    pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
61        let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
62        for (k, v) in node.props() {
63            match k {
64                "help" => flag.help = Some(v.ensure_string()?),
65                "long_help" => flag.help_long = Some(v.ensure_string()?),
66                "help_long" => flag.help_long = Some(v.ensure_string()?),
67                "help_md" => flag.help_md = Some(v.ensure_string()?),
68                "required" => flag.required = v.ensure_bool()?,
69                "var" => flag.var = v.ensure_bool()?,
70                "var_min" => flag.var_min = v.ensure_usize().map(Some)?,
71                "var_max" => flag.var_max = v.ensure_usize().map(Some)?,
72                "hide" => flag.hide = v.ensure_bool()?,
73                "deprecated" => {
74                    flag.deprecated = match v.value.as_bool() {
75                        Some(true) => Some("deprecated".into()),
76                        Some(false) => None,
77                        None => Some(v.ensure_string()?),
78                    }
79                }
80                "global" => flag.global = v.ensure_bool()?,
81                "count" => flag.count = v.ensure_bool()?,
82                "default" => {
83                    // Support both string and boolean defaults
84                    let default_value = match v.value.as_bool() {
85                        Some(b) => b.to_string(),
86                        None => v.ensure_string()?,
87                    };
88                    flag.default = vec![default_value];
89                }
90                "negate" => flag.negate = v.ensure_string().map(Some)?,
91                "env" => flag.env = v.ensure_string().map(Some)?,
92                k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
93            }
94        }
95        if !flag.default.is_empty() {
96            flag.required = false;
97        }
98        for child in node.children() {
99            match child.name() {
100                "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
101                "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
102                "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
103                "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
104                "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
105                "required" => flag.required = child.arg(0)?.ensure_bool()?,
106                "var" => flag.var = child.arg(0)?.ensure_bool()?,
107                "var_min" => flag.var_min = child.arg(0)?.ensure_usize().map(Some)?,
108                "var_max" => flag.var_max = child.arg(0)?.ensure_usize().map(Some)?,
109                "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
110                "deprecated" => {
111                    flag.deprecated = match child.arg(0)?.ensure_bool() {
112                        Ok(true) => Some("deprecated".into()),
113                        Ok(false) => None,
114                        _ => Some(child.arg(0)?.ensure_string()?),
115                    }
116                }
117                "global" => flag.global = child.arg(0)?.ensure_bool()?,
118                "count" => flag.count = child.arg(0)?.ensure_bool()?,
119                "default" => {
120                    // Support both single value and multiple values
121                    // default "bar"            -> vec!["bar"]
122                    // default #true            -> vec!["true"]
123                    // default { "xyz"; "bar" } -> vec!["xyz", "bar"]
124                    let children = child.children();
125                    if children.is_empty() {
126                        // Single value: default "bar" or default #true
127                        let arg = child.arg(0)?;
128                        let default_value = match arg.value.as_bool() {
129                            Some(b) => b.to_string(),
130                            None => arg.ensure_string()?,
131                        };
132                        flag.default = vec![default_value];
133                    } else {
134                        // Multiple values from children: default { "xyz"; "bar" }
135                        // In KDL, these are child nodes where the string is the node name
136                        flag.default = children.iter().map(|c| c.name().to_string()).collect();
137                    }
138                }
139                "env" => flag.env = child.arg(0)?.ensure_string().map(Some)?,
140                "choices" => {
141                    if let Some(arg) = &mut flag.arg {
142                        arg.choices = Some(SpecChoices::parse(ctx, &child)?);
143                    } else {
144                        bail_parse!(
145                            ctx,
146                            child.node.name().span(),
147                            "flag must have value to have choices"
148                        )
149                    }
150                }
151                k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
152            }
153        }
154        flag.usage = flag.usage();
155        flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
156        Ok(flag)
157    }
158    pub fn usage(&self) -> String {
159        let mut parts = vec![];
160        let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
161        if name != self.name {
162            parts.push(format!("{}:", self.name));
163        }
164        if let Some(short) = self.short.first() {
165            parts.push(format!("-{short}"));
166        }
167        if let Some(long) = self.long.first() {
168            parts.push(format!("--{long}"));
169        }
170        let mut out = parts.join(" ");
171        if self.var {
172            out = format!("{out}…");
173        }
174        if let Some(arg) = &self.arg {
175            out = format!("{} {}", out, arg.usage());
176        }
177        out
178    }
179}
180
181impl From<&SpecFlag> for KdlNode {
182    fn from(flag: &SpecFlag) -> KdlNode {
183        let mut node = KdlNode::new("flag");
184        let name = flag
185            .short
186            .iter()
187            .map(|c| format!("-{c}"))
188            .chain(flag.long.iter().map(|s| format!("--{s}")))
189            .collect_vec()
190            .join(" ");
191        node.push(KdlEntry::new(name));
192        if let Some(desc) = &flag.help {
193            node.push(KdlEntry::new_prop("help", desc.clone()));
194        }
195        if let Some(desc) = &flag.help_long {
196            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
197            let mut node = KdlNode::new("long_help");
198            node.entries_mut().push(KdlEntry::new(desc.clone()));
199            children.nodes_mut().push(node);
200        }
201        if let Some(desc) = &flag.help_md {
202            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
203            let mut node = KdlNode::new("help_md");
204            node.entries_mut().push(KdlEntry::new(desc.clone()));
205            children.nodes_mut().push(node);
206        }
207        if flag.required {
208            node.push(KdlEntry::new_prop("required", true));
209        }
210        if flag.var {
211            node.push(KdlEntry::new_prop("var", true));
212        }
213        if let Some(var_min) = flag.var_min {
214            node.push(KdlEntry::new_prop("var_min", var_min as i128));
215        }
216        if let Some(var_max) = flag.var_max {
217            node.push(KdlEntry::new_prop("var_max", var_max as i128));
218        }
219        if flag.hide {
220            node.push(KdlEntry::new_prop("hide", true));
221        }
222        if flag.global {
223            node.push(KdlEntry::new_prop("global", true));
224        }
225        if flag.count {
226            node.push(KdlEntry::new_prop("count", true));
227        }
228        if let Some(negate) = &flag.negate {
229            node.push(KdlEntry::new_prop("negate", negate.clone()));
230        }
231        if let Some(env) = &flag.env {
232            node.push(KdlEntry::new_prop("env", env.clone()));
233        }
234        if let Some(deprecated) = &flag.deprecated {
235            node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
236        }
237        // Serialize default values
238        if !flag.default.is_empty() {
239            if flag.default.len() == 1 {
240                // Single value: use property default="bar"
241                node.push(KdlEntry::new_prop("default", flag.default[0].clone()));
242            } else {
243                // Multiple values: use child node default { "xyz"; "bar" }
244                let children = node.children_mut().get_or_insert_with(KdlDocument::new);
245                let mut default_node = KdlNode::new("default");
246                let default_children = default_node
247                    .children_mut()
248                    .get_or_insert_with(KdlDocument::new);
249                for val in &flag.default {
250                    default_children
251                        .nodes_mut()
252                        .push(KdlNode::new(val.as_str()));
253                }
254                children.nodes_mut().push(default_node);
255            }
256        }
257        if let Some(arg) = &flag.arg {
258            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
259            children.nodes_mut().push(arg.into());
260        }
261        node
262    }
263}
264
265impl FromStr for SpecFlag {
266    type Err = UsageErr;
267    fn from_str(input: &str) -> Result<Self> {
268        let mut flag = Self::default();
269        let input = input.replace("...", "…").replace("…", " … ");
270        for part in input.split_whitespace() {
271            if let Some(name) = part.strip_suffix(':') {
272                flag.name = name.to_string();
273            } else if let Some(long) = part.strip_prefix("--") {
274                flag.long.push(long.to_string());
275            } else if let Some(short) = part.strip_prefix('-') {
276                if short.len() != 1 {
277                    return Err(InvalidFlag(
278                        short.to_string(),
279                        (0, input.len()).into(),
280                        input.to_string(),
281                    ));
282                }
283                flag.short.push(short.chars().next().unwrap());
284            } else if part == "…" {
285                if let Some(arg) = &mut flag.arg {
286                    arg.var = true;
287                } else {
288                    flag.var = true;
289                }
290            } else if part.starts_with('<') && part.ends_with('>')
291                || part.starts_with('[') && part.ends_with(']')
292            {
293                flag.arg = Some(part.to_string().parse()?);
294            } else {
295                return Err(InvalidFlag(
296                    part.to_string(),
297                    (0, input.len()).into(),
298                    input.to_string(),
299                ));
300            }
301        }
302        if flag.name.is_empty() {
303            flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
304        }
305        flag.usage = flag.usage();
306        Ok(flag)
307    }
308}
309
310#[cfg(feature = "clap")]
311impl From<&clap::Arg> for SpecFlag {
312    fn from(c: &clap::Arg) -> Self {
313        let required = c.is_required_set();
314        let help = c.get_help().map(|s| s.to_string());
315        let help_long = c.get_long_help().map(|s| s.to_string());
316        let help_first_line = help.as_ref().map(|s| string::first_line(s));
317        let hide = c.is_hide_set();
318        let var = matches!(
319            c.get_action(),
320            clap::ArgAction::Count | clap::ArgAction::Append
321        );
322        let default: Vec<String> = c
323            .get_default_values()
324            .iter()
325            .map(|s| s.to_string_lossy().to_string())
326            .collect();
327        let short = c.get_short_and_visible_aliases().unwrap_or_default();
328        let long = c
329            .get_long_and_visible_aliases()
330            .unwrap_or_default()
331            .into_iter()
332            .map(|s| s.to_string())
333            .collect::<Vec<_>>();
334        let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
335        let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
336            let mut arg = SpecArg::from(
337                c.get_value_names()
338                    .map(|s| s.iter().map(|s| s.to_string()).join(" "))
339                    .unwrap_or(name.clone())
340                    .as_str(),
341            );
342
343            let choices = c
344                .get_possible_values()
345                .iter()
346                .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
347                .collect::<Vec<_>>();
348            if !choices.is_empty() {
349                arg.choices = Some(SpecChoices { choices });
350            }
351
352            Some(arg)
353        } else {
354            None
355        };
356        Self {
357            name,
358            usage: "".into(),
359            short,
360            long,
361            required,
362            help,
363            help_long,
364            help_md: None,
365            help_first_line,
366            var,
367            var_min: None,
368            var_max: None,
369            hide,
370            global: c.is_global_set(),
371            arg,
372            count: matches!(c.get_action(), clap::ArgAction::Count),
373            default,
374            deprecated: None,
375            negate: None,
376            env: None,
377        }
378    }
379}
380
381// #[cfg(feature = "clap")]
382// impl From<&SpecFlag> for clap::Arg {
383//     fn from(flag: &SpecFlag) -> Self {
384//         let mut a = clap::Arg::new(&flag.name);
385//         if let Some(desc) = &flag.help {
386//             a = a.help(desc);
387//         }
388//         if flag.required {
389//             a = a.required(true);
390//         }
391//         if let Some(arg) = &flag.arg {
392//             a = a.value_name(&arg.name);
393//             if arg.var {
394//                 a = a.action(clap::ArgAction::Append)
395//             } else {
396//                 a = a.action(clap::ArgAction::Set)
397//             }
398//         } else {
399//             a = a.action(clap::ArgAction::SetTrue)
400//         }
401//         // let mut a = clap::Arg::new(&flag.name)
402//         //     .required(flag.required)
403//         //     .action(clap::ArgAction::SetTrue);
404//         if let Some(short) = flag.short.first() {
405//             a = a.short(*short);
406//         }
407//         if let Some(long) = flag.long.first() {
408//             a = a.long(long);
409//         }
410//         for short in flag.short.iter().skip(1) {
411//             a = a.visible_short_alias(*short);
412//         }
413//         for long in flag.long.iter().skip(1) {
414//             a = a.visible_alias(long);
415//         }
416//         // cmd = cmd.arg(a);
417//         // if flag.multiple {
418//         //     a = a.multiple(true);
419//         // }
420//         // if flag.hide {
421//         //     a = a.hide_possible_values(true);
422//         // }
423//         a
424//     }
425// }
426
427impl Display for SpecFlag {
428    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429        write!(f, "{}", self.usage())
430    }
431}
432impl PartialEq for SpecFlag {
433    fn eq(&self, other: &Self) -> bool {
434        self.name == other.name
435    }
436}
437impl Eq for SpecFlag {}
438impl Hash for SpecFlag {
439    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
440        self.name.hash(state);
441    }
442}
443
444fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
445    long.first()
446        .map(|s| s.to_string())
447        .or_else(|| short.first().map(|c| c.to_string()))
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use crate::Spec;
454    use insta::assert_snapshot;
455
456    #[test]
457    fn from_str() {
458        assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
459        assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
460        assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
461        assert_snapshot!("-f --flag…".parse::<SpecFlag>().unwrap(), @"-f --flag…");
462        assert_snapshot!("-f --flag …".parse::<SpecFlag>().unwrap(), @"-f --flag…");
463        assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
464        assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
465        assert_snapshot!("-f --flag… <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag… <arg>");
466        assert_snapshot!("-f --flag <arg>…".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>…");
467        assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
468        assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
469    }
470
471    #[test]
472    fn test_flag_with_env() {
473        let spec = Spec::parse(
474            &Default::default(),
475            r#"
476flag "--color" env="MYCLI_COLOR" help="Enable color output"
477flag "--verbose" env="MYCLI_VERBOSE"
478            "#,
479        )
480        .unwrap();
481
482        assert_snapshot!(spec, @r#"
483        flag --color help="Enable color output" env=MYCLI_COLOR
484        flag --verbose env=MYCLI_VERBOSE
485        "#);
486
487        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
488        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
489
490        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
491        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
492    }
493
494    #[test]
495    fn test_flag_with_env_child_node() {
496        let spec = Spec::parse(
497            &Default::default(),
498            r#"
499flag "--color" help="Enable color output" {
500    env "MYCLI_COLOR"
501}
502flag "--verbose" {
503    env "MYCLI_VERBOSE"
504}
505            "#,
506        )
507        .unwrap();
508
509        assert_snapshot!(spec, @r#"
510        flag --color help="Enable color output" env=MYCLI_COLOR
511        flag --verbose env=MYCLI_VERBOSE
512        "#);
513
514        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
515        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
516
517        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
518        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
519    }
520
521    #[test]
522    fn test_flag_with_boolean_defaults() {
523        let spec = Spec::parse(
524            &Default::default(),
525            r#"
526flag "--color" default=#true
527flag "--verbose" default=#false
528flag "--debug" default="true"
529flag "--quiet" default="false"
530            "#,
531        )
532        .unwrap();
533
534        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
535        assert_eq!(color_flag.default, vec!["true".to_string()]);
536
537        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
538        assert_eq!(verbose_flag.default, vec!["false".to_string()]);
539
540        let debug_flag = spec.cmd.flags.iter().find(|f| f.name == "debug").unwrap();
541        assert_eq!(debug_flag.default, vec!["true".to_string()]);
542
543        let quiet_flag = spec.cmd.flags.iter().find(|f| f.name == "quiet").unwrap();
544        assert_eq!(quiet_flag.default, vec!["false".to_string()]);
545    }
546
547    #[test]
548    fn test_flag_with_boolean_defaults_child_node() {
549        let spec = Spec::parse(
550            &Default::default(),
551            r#"
552flag "--color" {
553    default #true
554}
555flag "--verbose" {
556    default #false
557}
558            "#,
559        )
560        .unwrap();
561
562        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
563        assert_eq!(color_flag.default, vec!["true".to_string()]);
564
565        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
566        assert_eq!(verbose_flag.default, vec!["false".to_string()]);
567    }
568
569    #[test]
570    fn test_flag_with_single_default() {
571        let spec = Spec::parse(
572            &Default::default(),
573            r#"
574flag "--foo <foo>" var=#true default="bar"
575            "#,
576        )
577        .unwrap();
578
579        let flag = spec.cmd.flags.iter().find(|f| f.name == "foo").unwrap();
580        assert!(flag.var);
581        assert_eq!(flag.default, vec!["bar".to_string()]);
582    }
583
584    #[test]
585    fn test_flag_with_multiple_defaults_child_node() {
586        let spec = Spec::parse(
587            &Default::default(),
588            r#"
589flag "--foo <foo>" var=#true {
590    default {
591        "xyz"
592        "bar"
593    }
594}
595            "#,
596        )
597        .unwrap();
598
599        let flag = spec.cmd.flags.iter().find(|f| f.name == "foo").unwrap();
600        assert!(flag.var);
601        assert_eq!(flag.default, vec!["xyz".to_string(), "bar".to_string()]);
602    }
603
604    #[test]
605    fn test_flag_with_single_default_child_node() {
606        let spec = Spec::parse(
607            &Default::default(),
608            r#"
609flag "--foo <foo>" var=#true {
610    default "bar"
611}
612            "#,
613        )
614        .unwrap();
615
616        let flag = spec.cmd.flags.iter().find(|f| f.name == "foo").unwrap();
617        assert!(flag.var);
618        assert_eq!(flag.default, vec!["bar".to_string()]);
619    }
620
621    #[test]
622    fn test_flag_default_serialization_single() {
623        let spec = Spec::parse(
624            &Default::default(),
625            r#"
626flag "--foo <foo>" default="bar"
627            "#,
628        )
629        .unwrap();
630
631        // When serialized, single default should use property format
632        let output = spec.to_string();
633        assert!(output.contains("default=bar") || output.contains(r#"default="bar""#));
634    }
635
636    #[test]
637    fn test_flag_default_serialization_multiple() {
638        let spec = Spec::parse(
639            &Default::default(),
640            r#"
641flag "--foo <foo>" var=#true {
642    default {
643        "xyz"
644        "bar"
645    }
646}
647            "#,
648        )
649        .unwrap();
650
651        // When serialized, multiple defaults should use child node format
652        let output = spec.to_string();
653        // The output should contain a default block with children
654        assert!(output.contains("default {"));
655    }
656}