1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::context::ParsingContext;
11use crate::spec::helpers::NodeHelper;
12use crate::spec::is_false;
13use crate::{string, SpecArg, SpecChoices};
14
15#[derive(Debug, Default, Clone, Serialize)]
16pub struct SpecFlag {
17 pub name: String,
18 pub usage: String,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub help: Option<String>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub help_long: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub help_md: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub help_first_line: Option<String>,
27 pub short: Vec<char>,
28 pub long: Vec<String>,
29 #[serde(skip_serializing_if = "is_false")]
30 pub required: bool,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub deprecated: Option<String>,
33 #[serde(skip_serializing_if = "is_false")]
34 pub var: bool,
35 pub hide: bool,
36 pub global: bool,
37 #[serde(skip_serializing_if = "is_false")]
38 pub count: bool,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub arg: Option<SpecArg>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub default: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub negate: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub env: Option<String>,
47}
48
49impl SpecFlag {
50 pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
51 let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
52 for (k, v) in node.props() {
53 match k {
54 "help" => flag.help = Some(v.ensure_string()?),
55 "long_help" => flag.help_long = Some(v.ensure_string()?),
56 "help_long" => flag.help_long = Some(v.ensure_string()?),
57 "help_md" => flag.help_md = Some(v.ensure_string()?),
58 "required" => flag.required = v.ensure_bool()?,
59 "var" => flag.var = v.ensure_bool()?,
60 "hide" => flag.hide = v.ensure_bool()?,
61 "deprecated" => {
62 flag.deprecated = match v.value.as_bool() {
63 Some(true) => Some("deprecated".into()),
64 Some(false) => None,
65 None => Some(v.ensure_string()?),
66 }
67 }
68 "global" => flag.global = v.ensure_bool()?,
69 "count" => flag.count = v.ensure_bool()?,
70 "default" => {
71 flag.default = match v.value.as_bool() {
73 Some(b) => Some(b.to_string()),
74 None => v.ensure_string().map(Some)?,
75 }
76 }
77 "negate" => flag.negate = v.ensure_string().map(Some)?,
78 "env" => flag.env = v.ensure_string().map(Some)?,
79 k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
80 }
81 }
82 if flag.default.is_some() {
83 flag.required = false;
84 }
85 for child in node.children() {
86 match child.name() {
87 "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
88 "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
89 "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
90 "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
91 "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
92 "required" => flag.required = child.arg(0)?.ensure_bool()?,
93 "var" => flag.var = child.arg(0)?.ensure_bool()?,
94 "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
95 "deprecated" => {
96 flag.deprecated = match child.arg(0)?.ensure_bool() {
97 Ok(true) => Some("deprecated".into()),
98 Ok(false) => None,
99 _ => Some(child.arg(0)?.ensure_string()?),
100 }
101 }
102 "global" => flag.global = child.arg(0)?.ensure_bool()?,
103 "count" => flag.count = child.arg(0)?.ensure_bool()?,
104 "default" => {
105 let arg = child.arg(0)?;
107 flag.default = match arg.value.as_bool() {
108 Some(b) => Some(b.to_string()),
109 None => arg.ensure_string().map(Some)?,
110 }
111 }
112 "env" => flag.env = child.arg(0)?.ensure_string().map(Some)?,
113 "choices" => {
114 if let Some(arg) = &mut flag.arg {
115 arg.choices = Some(SpecChoices::parse(ctx, &child)?);
116 } else {
117 bail_parse!(
118 ctx,
119 child.node.name().span(),
120 "flag must have value to have choices"
121 )
122 }
123 }
124 k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
125 }
126 }
127 flag.usage = flag.usage();
128 flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
129 Ok(flag)
130 }
131 pub fn usage(&self) -> String {
132 let mut parts = vec![];
133 let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
134 if name != self.name {
135 parts.push(format!("{}:", self.name));
136 }
137 if let Some(short) = self.short.first() {
138 parts.push(format!("-{short}"));
139 }
140 if let Some(long) = self.long.first() {
141 parts.push(format!("--{long}"));
142 }
143 let mut out = parts.join(" ");
144 if self.var {
145 out = format!("{out}…");
146 }
147 if let Some(arg) = &self.arg {
148 out = format!("{} {}", out, arg.usage());
149 }
150 out
151 }
152}
153
154impl From<&SpecFlag> for KdlNode {
155 fn from(flag: &SpecFlag) -> KdlNode {
156 let mut node = KdlNode::new("flag");
157 let name = flag
158 .short
159 .iter()
160 .map(|c| format!("-{c}"))
161 .chain(flag.long.iter().map(|s| format!("--{s}")))
162 .collect_vec()
163 .join(" ");
164 node.push(KdlEntry::new(name));
165 if let Some(desc) = &flag.help {
166 node.push(KdlEntry::new_prop("help", desc.clone()));
167 }
168 if let Some(desc) = &flag.help_long {
169 let children = node.children_mut().get_or_insert_with(KdlDocument::new);
170 let mut node = KdlNode::new("long_help");
171 node.entries_mut().push(KdlEntry::new(desc.clone()));
172 children.nodes_mut().push(node);
173 }
174 if let Some(desc) = &flag.help_md {
175 let children = node.children_mut().get_or_insert_with(KdlDocument::new);
176 let mut node = KdlNode::new("help_md");
177 node.entries_mut().push(KdlEntry::new(desc.clone()));
178 children.nodes_mut().push(node);
179 }
180 if flag.required {
181 node.push(KdlEntry::new_prop("required", true));
182 }
183 if flag.var {
184 node.push(KdlEntry::new_prop("var", true));
185 }
186 if flag.hide {
187 node.push(KdlEntry::new_prop("hide", true));
188 }
189 if flag.global {
190 node.push(KdlEntry::new_prop("global", true));
191 }
192 if flag.count {
193 node.push(KdlEntry::new_prop("count", true));
194 }
195 if let Some(negate) = &flag.negate {
196 node.push(KdlEntry::new_prop("negate", negate.clone()));
197 }
198 if let Some(env) = &flag.env {
199 node.push(KdlEntry::new_prop("env", env.clone()));
200 }
201 if let Some(deprecated) = &flag.deprecated {
202 node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
203 }
204 if let Some(arg) = &flag.arg {
205 let children = node.children_mut().get_or_insert_with(KdlDocument::new);
206 children.nodes_mut().push(arg.into());
207 }
208 node
209 }
210}
211
212impl FromStr for SpecFlag {
213 type Err = UsageErr;
214 fn from_str(input: &str) -> Result<Self> {
215 let mut flag = Self::default();
216 let input = input.replace("...", "…").replace("…", " … ");
217 for part in input.split_whitespace() {
218 if let Some(name) = part.strip_suffix(':') {
219 flag.name = name.to_string();
220 } else if let Some(long) = part.strip_prefix("--") {
221 flag.long.push(long.to_string());
222 } else if let Some(short) = part.strip_prefix('-') {
223 if short.len() != 1 {
224 return Err(InvalidFlag(
225 short.to_string(),
226 (0, input.len()).into(),
227 input.to_string(),
228 ));
229 }
230 flag.short.push(short.chars().next().unwrap());
231 } else if part == "…" {
232 if let Some(arg) = &mut flag.arg {
233 arg.var = true;
234 } else {
235 flag.var = true;
236 }
237 } else if part.starts_with('<') && part.ends_with('>')
238 || part.starts_with('[') && part.ends_with(']')
239 {
240 flag.arg = Some(part.to_string().parse()?);
241 } else {
242 return Err(InvalidFlag(
243 part.to_string(),
244 (0, input.len()).into(),
245 input.to_string(),
246 ));
247 }
248 }
249 if flag.name.is_empty() {
250 flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
251 }
252 flag.usage = flag.usage();
253 Ok(flag)
254 }
255}
256
257#[cfg(feature = "clap")]
258impl From<&clap::Arg> for SpecFlag {
259 fn from(c: &clap::Arg) -> Self {
260 let required = c.is_required_set();
261 let help = c.get_help().map(|s| s.to_string());
262 let help_long = c.get_long_help().map(|s| s.to_string());
263 let help_first_line = help.as_ref().map(|s| string::first_line(s));
264 let hide = c.is_hide_set();
265 let var = matches!(
266 c.get_action(),
267 clap::ArgAction::Count | clap::ArgAction::Append
268 );
269 let default = c
270 .get_default_values()
271 .first()
272 .map(|s| s.to_string_lossy().to_string());
273 let short = c.get_short_and_visible_aliases().unwrap_or_default();
274 let long = c
275 .get_long_and_visible_aliases()
276 .unwrap_or_default()
277 .into_iter()
278 .map(|s| s.to_string())
279 .collect::<Vec<_>>();
280 let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
281 let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
282 let mut arg = SpecArg::from(
283 c.get_value_names()
284 .map(|s| s.iter().map(|s| s.to_string()).join(" "))
285 .unwrap_or(name.clone())
286 .as_str(),
287 );
288
289 let choices = c
290 .get_possible_values()
291 .iter()
292 .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
293 .collect::<Vec<_>>();
294 if !choices.is_empty() {
295 arg.choices = Some(SpecChoices { choices });
296 }
297
298 Some(arg)
299 } else {
300 None
301 };
302 Self {
303 name,
304 usage: "".into(),
305 short,
306 long,
307 required,
308 help,
309 help_long,
310 help_md: None,
311 help_first_line,
312 var,
313 hide,
314 global: c.is_global_set(),
315 arg,
316 count: matches!(c.get_action(), clap::ArgAction::Count),
317 default,
318 deprecated: None,
319 negate: None,
320 env: None,
321 }
322 }
323}
324
325impl Display for SpecFlag {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 write!(f, "{}", self.usage())
374 }
375}
376impl PartialEq for SpecFlag {
377 fn eq(&self, other: &Self) -> bool {
378 self.name == other.name
379 }
380}
381impl Eq for SpecFlag {}
382impl Hash for SpecFlag {
383 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
384 self.name.hash(state);
385 }
386}
387
388fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
389 long.first()
390 .map(|s| s.to_string())
391 .or_else(|| short.first().map(|c| c.to_string()))
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::Spec;
398 use insta::assert_snapshot;
399
400 #[test]
401 fn from_str() {
402 assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
403 assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
404 assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
405 assert_snapshot!("-f --flag…".parse::<SpecFlag>().unwrap(), @"-f --flag…");
406 assert_snapshot!("-f --flag …".parse::<SpecFlag>().unwrap(), @"-f --flag…");
407 assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
408 assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
409 assert_snapshot!("-f --flag… <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag… <arg>");
410 assert_snapshot!("-f --flag <arg>…".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>…");
411 assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
412 assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
413 }
414
415 #[test]
416 fn test_flag_with_env() {
417 let spec = Spec::parse(
418 &Default::default(),
419 r#"
420flag "--color" env="MYCLI_COLOR" help="Enable color output"
421flag "--verbose" env="MYCLI_VERBOSE"
422 "#,
423 )
424 .unwrap();
425
426 assert_snapshot!(spec, @r#"
427 flag --color help="Enable color output" env=MYCLI_COLOR
428 flag --verbose env=MYCLI_VERBOSE
429 "#);
430
431 let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
432 assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
433
434 let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
435 assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
436 }
437
438 #[test]
439 fn test_flag_with_env_child_node() {
440 let spec = Spec::parse(
441 &Default::default(),
442 r#"
443flag "--color" help="Enable color output" {
444 env "MYCLI_COLOR"
445}
446flag "--verbose" {
447 env "MYCLI_VERBOSE"
448}
449 "#,
450 )
451 .unwrap();
452
453 assert_snapshot!(spec, @r#"
454 flag --color help="Enable color output" env=MYCLI_COLOR
455 flag --verbose env=MYCLI_VERBOSE
456 "#);
457
458 let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
459 assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
460
461 let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
462 assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
463 }
464
465 #[test]
466 fn test_flag_with_boolean_defaults() {
467 let spec = Spec::parse(
468 &Default::default(),
469 r#"
470flag "--color" default=#true
471flag "--verbose" default=#false
472flag "--debug" default="true"
473flag "--quiet" default="false"
474 "#,
475 )
476 .unwrap();
477
478 let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
479 assert_eq!(color_flag.default, Some("true".to_string()));
480
481 let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
482 assert_eq!(verbose_flag.default, Some("false".to_string()));
483
484 let debug_flag = spec.cmd.flags.iter().find(|f| f.name == "debug").unwrap();
485 assert_eq!(debug_flag.default, Some("true".to_string()));
486
487 let quiet_flag = spec.cmd.flags.iter().find(|f| f.name == "quiet").unwrap();
488 assert_eq!(quiet_flag.default, Some("false".to_string()));
489 }
490
491 #[test]
492 fn test_flag_with_boolean_defaults_child_node() {
493 let spec = Spec::parse(
494 &Default::default(),
495 r#"
496flag "--color" {
497 default #true
498}
499flag "--verbose" {
500 default #false
501}
502 "#,
503 )
504 .unwrap();
505
506 let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
507 assert_eq!(color_flag.default, Some("true".to_string()));
508
509 let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
510 assert_eq!(verbose_flag.default, Some("false".to_string()));
511 }
512}