Skip to main content

clap_builder/parser/
validator.rs

1// Internal
2use crate::INTERNAL_ERROR_MSG;
3use crate::builder::StyledStr;
4use crate::builder::{Arg, ArgGroup, ArgPredicate, Command, PossibleValue};
5use crate::error::{Error, Result as ClapResult};
6use crate::output::Usage;
7use crate::parser::ArgMatcher;
8use crate::util::ChildGraph;
9use crate::util::FlatMap;
10use crate::util::FlatSet;
11use crate::util::Id;
12
13pub(crate) struct Validator<'cmd> {
14    cmd: &'cmd Command,
15    required: ChildGraph<Id>,
16}
17
18impl<'cmd> Validator<'cmd> {
19    pub(crate) fn new(cmd: &'cmd Command) -> Self {
20        let required = cmd.required_graph();
21        Validator { cmd, required }
22    }
23
24    pub(crate) fn validate(&mut self, matcher: &mut ArgMatcher) -> ClapResult<()> {
25        debug!("Validator::validate");
26        let conflicts = Conflicts::with_args(self.cmd, matcher);
27        let has_subcmd = matcher.subcommand_name().is_some();
28
29        if !has_subcmd && self.cmd.is_arg_required_else_help_set() {
30            let num_user_values = matcher
31                .args()
32                .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
33                .count();
34            if num_user_values == 0 {
35                let message = self.cmd.write_help_err(false);
36                return Err(Error::display_help_error(self.cmd, message));
37            }
38        }
39        if !has_subcmd && self.cmd.is_subcommand_required_set() {
40            let bn = self.cmd.get_bin_name_fallback();
41            return Err(Error::missing_subcommand(
42                self.cmd,
43                bn.to_string(),
44                self.cmd
45                    .all_subcommand_names()
46                    .map(|s| s.to_owned())
47                    .collect::<Vec<_>>(),
48                Usage::new(self.cmd)
49                    .required(&self.required)
50                    .create_usage_with_title(&[]),
51            ));
52        }
53
54        ok!(self.validate_conflicts(matcher, &conflicts));
55        if !(self.cmd.is_subcommand_negates_reqs_set() && has_subcmd) {
56            ok!(self.validate_required(matcher, &conflicts));
57        }
58
59        Ok(())
60    }
61
62    fn validate_conflicts(
63        &mut self,
64        matcher: &ArgMatcher,
65        conflicts: &Conflicts,
66    ) -> ClapResult<()> {
67        debug!("Validator::validate_conflicts");
68
69        ok!(self.validate_exclusive(matcher));
70
71        for (arg_id, _) in matcher
72            .args()
73            .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
74            .filter(|(arg_id, _)| self.cmd.find(arg_id).is_some())
75        {
76            debug!("Validator::validate_conflicts::iter: id={arg_id:?}");
77            let conflicts = conflicts.gather_conflicts(self.cmd, arg_id);
78            ok!(self.build_conflict_err(arg_id, &conflicts, matcher));
79        }
80
81        Ok(())
82    }
83
84    fn validate_exclusive(&self, matcher: &ArgMatcher) -> ClapResult<()> {
85        debug!("Validator::validate_exclusive");
86        let args_count = matcher
87            .args()
88            .filter(|(arg_id, matched)| {
89                matched.check_explicit(&ArgPredicate::IsPresent)
90                    // Avoid including our own groups by checking none of them.  If a group is present, the
91                    // args for the group will be.
92                    && self.cmd.find(arg_id).is_some()
93            })
94            .count();
95        if args_count <= 1 {
96            // Nothing present to conflict with
97            return Ok(());
98        }
99
100        matcher
101            .args()
102            .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
103            .find_map(|(id, _)| {
104                debug!("Validator::validate_exclusive:iter:{id:?}");
105                self.cmd
106                    .find(id)
107                    // Find `arg`s which are exclusive but also appear with other args.
108                    .filter(|&arg| arg.is_exclusive_set() && args_count > 1)
109            })
110            .map(|arg| {
111                // Throw an error for the first conflict found.
112                Err(Error::argument_conflict(
113                    self.cmd,
114                    arg.to_string(),
115                    Vec::new(),
116                    Usage::new(self.cmd)
117                        .required(&self.required)
118                        .create_usage_with_title(&[]),
119                ))
120            })
121            .unwrap_or(Ok(()))
122    }
123
124    fn build_conflict_err(
125        &self,
126        name: &Id,
127        conflict_ids: &[Id],
128        matcher: &ArgMatcher,
129    ) -> ClapResult<()> {
130        if conflict_ids.is_empty() {
131            return Ok(());
132        }
133
134        debug!("Validator::build_conflict_err: name={name:?}");
135        let conflict_ids = conflict_ids
136            .iter()
137            .flat_map(|c_id| {
138                if self.cmd.find_group(c_id).is_some() {
139                    self.cmd.unroll_args_in_group(c_id)
140                } else {
141                    vec![c_id.clone()]
142                }
143            })
144            .collect::<FlatSet<_>>()
145            .into_vec();
146        let conflicts = conflict_ids
147            .iter()
148            .map(|c_id| {
149                let c_arg = self.cmd.find(c_id).expect(INTERNAL_ERROR_MSG);
150                c_arg.to_string()
151            })
152            .collect();
153
154        let former_arg = self.cmd.find(name).expect(INTERNAL_ERROR_MSG);
155        let usg = self.build_conflict_err_usage(matcher, &conflict_ids);
156        Err(Error::argument_conflict(
157            self.cmd,
158            former_arg.to_string(),
159            conflicts,
160            usg,
161        ))
162    }
163
164    fn build_conflict_err_usage(
165        &self,
166        matcher: &ArgMatcher,
167        conflicting_keys: &[Id],
168    ) -> Option<StyledStr> {
169        let used_filtered: Vec<Id> = matcher
170            .args()
171            .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
172            .map(|(n, _)| n)
173            .filter(|n| {
174                // Filter out the args we don't want to specify.
175                self.cmd
176                    .find(n)
177                    .map(|a| !a.is_hide_set())
178                    .unwrap_or_default()
179            })
180            .filter(|key| !conflicting_keys.contains(key))
181            .cloned()
182            .collect();
183        let required: Vec<Id> = used_filtered
184            .iter()
185            .filter_map(|key| self.cmd.find(key))
186            .flat_map(|arg| arg.requires.iter().map(|item| &item.1))
187            .filter(|key| !used_filtered.contains(key) && !conflicting_keys.contains(key))
188            .chain(used_filtered.iter())
189            .cloned()
190            .collect();
191        Usage::new(self.cmd)
192            .required(&self.required)
193            .create_usage_with_title(&required)
194    }
195
196    fn gather_requires(&mut self, matcher: &ArgMatcher) {
197        debug!("Validator::gather_requires");
198        for (name, matched) in matcher
199            .args()
200            .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
201        {
202            debug!("Validator::gather_requires:iter:{name:?}");
203            if let Some(arg) = self.cmd.find(name) {
204                let is_relevant = |(val, req_arg): &(ArgPredicate, Id)| -> Option<Id> {
205                    let required = matched.check_explicit(val);
206                    required.then(|| req_arg.clone())
207                };
208
209                for req in self.cmd.unroll_arg_requires(is_relevant, arg.get_id()) {
210                    self.required.insert(req);
211                }
212            } else if let Some(g) = self.cmd.find_group(name) {
213                debug!("Validator::gather_requires:iter:{name:?}:group");
214                for r in &g.requires {
215                    self.required.insert(r.clone());
216                }
217            }
218        }
219    }
220
221    fn validate_required(&mut self, matcher: &ArgMatcher, conflicts: &Conflicts) -> ClapResult<()> {
222        debug!("Validator::validate_required: required={:?}", self.required);
223        self.gather_requires(matcher);
224
225        let mut missing_required = Vec::new();
226        let mut highest_index = 0;
227
228        let is_exclusive_present = matcher
229            .args()
230            .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
231            .any(|(id, _)| {
232                self.cmd
233                    .find(id)
234                    .map(|arg| arg.is_exclusive_set())
235                    .unwrap_or_default()
236            });
237        debug!("Validator::validate_required: is_exclusive_present={is_exclusive_present}");
238
239        for arg_or_group in self
240            .required
241            .iter()
242            .filter(|r| !matcher.check_explicit(r, &ArgPredicate::IsPresent))
243        {
244            debug!("Validator::validate_required:iter:aog={arg_or_group:?}");
245            if let Some(arg) = self.cmd.find(arg_or_group) {
246                debug!("Validator::validate_required:iter: This is an arg");
247                if !is_exclusive_present && !self.is_missing_required_ok(arg, conflicts) {
248                    debug!(
249                        "Validator::validate_required:iter: Missing {:?}",
250                        arg.get_id()
251                    );
252                    missing_required.push(arg.get_id().clone());
253                    if !arg.is_last_set() {
254                        highest_index = highest_index.max(arg.get_index().unwrap_or(0));
255                    }
256                }
257            } else if let Some(group) = self.cmd.find_group(arg_or_group) {
258                debug!("Validator::validate_required:iter: This is a group");
259                if !self
260                    .cmd
261                    .unroll_args_in_group(&group.id)
262                    .iter()
263                    .any(|a| matcher.check_explicit(a, &ArgPredicate::IsPresent))
264                {
265                    debug!(
266                        "Validator::validate_required:iter: Missing {:?}",
267                        group.get_id()
268                    );
269                    missing_required.push(group.get_id().clone());
270                }
271            }
272        }
273
274        // Validate the conditionally required args
275        for a in self
276            .cmd
277            .get_arguments()
278            .filter(|a| !matcher.check_explicit(a.get_id(), &ArgPredicate::IsPresent))
279        {
280            let mut required = false;
281
282            for (other, val) in &a.r_ifs {
283                if matcher.check_explicit(other, &ArgPredicate::Equals(val.into())) {
284                    debug!(
285                        "Validator::validate_required:iter: Missing {:?}",
286                        a.get_id()
287                    );
288                    required = true;
289                }
290            }
291
292            let match_all = a.r_ifs_all.iter().all(|(other, val)| {
293                matcher.check_explicit(other, &ArgPredicate::Equals(val.into()))
294            });
295            if match_all && !a.r_ifs_all.is_empty() {
296                debug!(
297                    "Validator::validate_required:iter: Missing {:?}",
298                    a.get_id()
299                );
300                required = true;
301            }
302
303            if (!a.r_unless.is_empty() || !a.r_unless_all.is_empty())
304                && self.fails_arg_required_unless(a, matcher)
305            {
306                debug!(
307                    "Validator::validate_required:iter: Missing {:?}",
308                    a.get_id()
309                );
310                required = true;
311            }
312
313            if !is_exclusive_present && required {
314                missing_required.push(a.get_id().clone());
315                if !a.is_last_set() {
316                    highest_index = highest_index.max(a.get_index().unwrap_or(0));
317                }
318            }
319        }
320
321        // For display purposes, include all of the preceding positional arguments
322        if !self.cmd.is_allow_missing_positional_set() {
323            for pos in self
324                .cmd
325                .get_positionals()
326                .filter(|a| !matcher.check_explicit(a.get_id(), &ArgPredicate::IsPresent))
327            {
328                if pos.get_index() < Some(highest_index) {
329                    debug!(
330                        "Validator::validate_required:iter: Missing {:?}",
331                        pos.get_id()
332                    );
333                    missing_required.push(pos.get_id().clone());
334                }
335            }
336        }
337
338        if !missing_required.is_empty() {
339            ok!(self.missing_required_error(matcher, missing_required));
340        }
341
342        Ok(())
343    }
344
345    fn is_missing_required_ok(&self, a: &Arg, conflicts: &Conflicts) -> bool {
346        debug!("Validator::is_missing_required_ok: {}", a.get_id());
347        if !conflicts.gather_conflicts(self.cmd, a.get_id()).is_empty() {
348            debug!("Validator::is_missing_required_ok: true (self)");
349            return true;
350        }
351        for group_id in self.cmd.groups_for_arg(a.get_id()) {
352            if !conflicts.gather_conflicts(self.cmd, &group_id).is_empty() {
353                debug!("Validator::is_missing_required_ok: true ({group_id})");
354                return true;
355            }
356        }
357        false
358    }
359
360    // Failing a required unless means, the arg's "unless" wasn't present, and neither were they
361    fn fails_arg_required_unless(&self, a: &Arg, matcher: &ArgMatcher) -> bool {
362        debug!("Validator::fails_arg_required_unless: a={:?}", a.get_id());
363        let exists = |id| matcher.check_explicit(id, &ArgPredicate::IsPresent);
364
365        (a.r_unless_all.is_empty() || !a.r_unless_all.iter().all(exists))
366            && !a.r_unless.iter().any(exists)
367    }
368
369    // `req_args`: an arg to include in the error even if not used
370    fn missing_required_error(
371        &self,
372        matcher: &ArgMatcher,
373        raw_req_args: Vec<Id>,
374    ) -> ClapResult<()> {
375        debug!("Validator::missing_required_error; incl={raw_req_args:?}");
376        debug!(
377            "Validator::missing_required_error: reqs={:?}",
378            self.required
379        );
380
381        let usg = Usage::new(self.cmd).required(&self.required);
382
383        let req_args = {
384            #[cfg(feature = "usage")]
385            {
386                usg.get_required_usage_from(&raw_req_args, Some(matcher), true)
387                    .into_iter()
388                    .map(|s| s.to_string())
389                    .collect::<Vec<_>>()
390            }
391
392            #[cfg(not(feature = "usage"))]
393            {
394                raw_req_args
395                    .iter()
396                    .map(|id| {
397                        if let Some(arg) = self.cmd.find(id) {
398                            arg.to_string()
399                        } else if let Some(_group) = self.cmd.find_group(id) {
400                            self.cmd.format_group(id).to_string()
401                        } else {
402                            debug_assert!(false, "id={id:?} is unknown");
403                            "".to_owned()
404                        }
405                    })
406                    .collect::<FlatSet<_>>()
407                    .into_iter()
408                    .collect::<Vec<_>>()
409            }
410        };
411
412        debug!("Validator::missing_required_error: req_args={req_args:#?}");
413
414        let used: Vec<Id> = matcher
415            .args()
416            .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
417            .map(|(n, _)| n)
418            .filter(|n| {
419                // Filter out the args we don't want to specify.
420                self.cmd
421                    .find(n)
422                    .map(|a| !a.is_hide_set())
423                    .unwrap_or_default()
424            })
425            .cloned()
426            .chain(raw_req_args)
427            .collect();
428
429        Err(Error::missing_required_argument(
430            self.cmd,
431            req_args,
432            usg.create_usage_with_title(&used),
433        ))
434    }
435}
436
437#[derive(Default, Clone, Debug)]
438struct Conflicts {
439    potential: FlatMap<Id, Vec<Id>>,
440}
441
442impl Conflicts {
443    fn with_args(cmd: &Command, matcher: &ArgMatcher) -> Self {
444        let mut potential = FlatMap::new();
445        potential.extend_unchecked(
446            matcher
447                .args()
448                .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
449                .map(|(id, _)| {
450                    let conf = gather_direct_conflicts(cmd, id);
451                    (id.clone(), conf)
452                }),
453        );
454        Self { potential }
455    }
456
457    fn gather_conflicts(&self, cmd: &Command, arg_id: &Id) -> Vec<Id> {
458        debug!("Conflicts::gather_conflicts: arg={arg_id:?}");
459        let mut conflicts = Vec::new();
460
461        let arg_id_conflicts_storage;
462        let arg_id_conflicts = if let Some(arg_id_conflicts) = self.get_direct_conflicts(arg_id) {
463            arg_id_conflicts
464        } else {
465            // `is_missing_required_ok` is a case where we check not-present args for conflicts
466            arg_id_conflicts_storage = gather_direct_conflicts(cmd, arg_id);
467            &arg_id_conflicts_storage
468        };
469        for (other_arg_id, other_arg_id_conflicts) in self.potential.iter() {
470            if arg_id == other_arg_id {
471                continue;
472            }
473
474            if arg_id_conflicts.contains(other_arg_id) {
475                conflicts.push(other_arg_id.clone());
476            }
477            if other_arg_id_conflicts.contains(arg_id) {
478                conflicts.push(other_arg_id.clone());
479            }
480        }
481
482        debug!("Conflicts::gather_conflicts: conflicts={conflicts:?}");
483        conflicts
484    }
485
486    fn get_direct_conflicts(&self, arg_id: &Id) -> Option<&[Id]> {
487        self.potential.get(arg_id).map(Vec::as_slice)
488    }
489}
490
491fn gather_direct_conflicts(cmd: &Command, id: &Id) -> Vec<Id> {
492    let conf = if let Some(arg) = cmd.find(id) {
493        gather_arg_direct_conflicts(cmd, arg)
494    } else if let Some(group) = cmd.find_group(id) {
495        gather_group_direct_conflicts(group)
496    } else {
497        debug_assert!(false, "id={id:?} is unknown");
498        Vec::new()
499    };
500    debug!("Conflicts::gather_direct_conflicts id={id:?}, conflicts={conf:?}",);
501    conf
502}
503
504fn gather_arg_direct_conflicts(cmd: &Command, arg: &Arg) -> Vec<Id> {
505    let mut conf = arg.blacklist.clone();
506    for group_id in cmd.groups_for_arg(arg.get_id()) {
507        let group = cmd.find_group(&group_id).expect(INTERNAL_ERROR_MSG);
508        conf.extend(group.conflicts.iter().cloned());
509        if !group.multiple {
510            for member_id in &group.args {
511                if member_id != arg.get_id() {
512                    conf.push(member_id.clone());
513                }
514            }
515        }
516    }
517
518    // Overrides are implicitly conflicts
519    conf.extend(arg.overrides.iter().cloned());
520
521    conf
522}
523
524fn gather_group_direct_conflicts(group: &ArgGroup) -> Vec<Id> {
525    group.conflicts.clone()
526}
527
528pub(crate) fn get_possible_values_cli(a: &Arg) -> Vec<PossibleValue> {
529    if !a.is_takes_value_set() {
530        vec![]
531    } else {
532        a.get_value_parser()
533            .possible_values()
534            .map(|pvs| pvs.collect())
535            .unwrap_or_default()
536    }
537}