agcodex_execpolicy/
arg_resolver.rs

1use serde::Serialize;
2
3use crate::arg_matcher::ArgMatcher;
4use crate::arg_matcher::ArgMatcherCardinality;
5use crate::error::Error;
6use crate::error::Result;
7use crate::valid_exec::MatchedArg;
8
9#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
10pub struct PositionalArg {
11    pub index: usize,
12    pub value: String,
13}
14
15pub fn resolve_observed_args_with_patterns(
16    program: &str,
17    args: Vec<PositionalArg>,
18    arg_patterns: &Vec<ArgMatcher>,
19) -> Result<Vec<MatchedArg>> {
20    // Naive matching implementation. Among `arg_patterns`, there is allowed to
21    // be at most one vararg pattern. Assuming `arg_patterns` is non-empty, we
22    // end up with either:
23    //
24    // - all `arg_patterns` in `prefix_patterns`
25    // - `arg_patterns` split across `prefix_patterns` (which could be empty),
26    //   one `vararg_pattern`, and `suffix_patterns` (which could also empty).
27    //
28    // From there, we start by matching everything in `prefix_patterns`.
29    // Then we calculate how many positional args should be matched by
30    // `suffix_patterns` and use that to determine how many args are left to
31    // be matched by `vararg_pattern` (which could be zero).
32    //
33    // After associating positional args with `vararg_pattern`, we match the
34    // `suffix_patterns` with the remaining args.
35    let ParitionedArgs {
36        num_prefix_args,
37        num_suffix_args,
38        prefix_patterns,
39        suffix_patterns,
40        vararg_pattern,
41    } = partition_args(program, arg_patterns)?;
42
43    let mut matched_args = Vec::<MatchedArg>::new();
44
45    let prefix = get_range_checked(&args, 0..num_prefix_args)?;
46    let mut prefix_arg_index = 0;
47    for pattern in prefix_patterns {
48        let n = pattern
49            .cardinality()
50            .is_exact()
51            .ok_or(Error::InternalInvariantViolation {
52                message: "expected exact cardinality".to_string(),
53            })?;
54        for positional_arg in &prefix[prefix_arg_index..prefix_arg_index + n] {
55            let matched_arg = MatchedArg::new(
56                positional_arg.index,
57                pattern.arg_type(),
58                &positional_arg.value.clone(),
59            )?;
60            matched_args.push(matched_arg);
61        }
62        prefix_arg_index += n;
63    }
64
65    if num_suffix_args > args.len() {
66        return Err(Error::NotEnoughArgs {
67            program: program.to_string(),
68            args,
69            arg_patterns: arg_patterns.clone(),
70        });
71    }
72
73    let initial_suffix_args_index = args.len() - num_suffix_args;
74    if prefix_arg_index > initial_suffix_args_index {
75        return Err(Error::PrefixOverlapsSuffix {});
76    }
77
78    if let Some(pattern) = vararg_pattern {
79        let vararg = get_range_checked(&args, prefix_arg_index..initial_suffix_args_index)?;
80        match pattern.cardinality() {
81            ArgMatcherCardinality::One => {
82                return Err(Error::InternalInvariantViolation {
83                    message: "vararg pattern should not have cardinality of one".to_string(),
84                });
85            }
86            ArgMatcherCardinality::AtLeastOne => {
87                if vararg.is_empty() {
88                    return Err(Error::VarargMatcherDidNotMatchAnything {
89                        program: program.to_string(),
90                        matcher: pattern,
91                    });
92                } else {
93                    for positional_arg in vararg {
94                        let matched_arg = MatchedArg::new(
95                            positional_arg.index,
96                            pattern.arg_type(),
97                            &positional_arg.value.clone(),
98                        )?;
99                        matched_args.push(matched_arg);
100                    }
101                }
102            }
103            ArgMatcherCardinality::ZeroOrMore => {
104                for positional_arg in vararg {
105                    let matched_arg = MatchedArg::new(
106                        positional_arg.index,
107                        pattern.arg_type(),
108                        &positional_arg.value.clone(),
109                    )?;
110                    matched_args.push(matched_arg);
111                }
112            }
113        }
114    }
115
116    let suffix = get_range_checked(&args, initial_suffix_args_index..args.len())?;
117    let mut suffix_arg_index = 0;
118    for pattern in suffix_patterns {
119        let n = pattern
120            .cardinality()
121            .is_exact()
122            .ok_or(Error::InternalInvariantViolation {
123                message: "expected exact cardinality".to_string(),
124            })?;
125        for positional_arg in &suffix[suffix_arg_index..suffix_arg_index + n] {
126            let matched_arg = MatchedArg::new(
127                positional_arg.index,
128                pattern.arg_type(),
129                &positional_arg.value.clone(),
130            )?;
131            matched_args.push(matched_arg);
132        }
133        suffix_arg_index += n;
134    }
135
136    if matched_args.len() < args.len() {
137        let extra_args = get_range_checked(&args, matched_args.len()..args.len())?;
138        Err(Error::UnexpectedArguments {
139            program: program.to_string(),
140            args: extra_args.to_vec(),
141        })
142    } else {
143        Ok(matched_args)
144    }
145}
146
147#[derive(Default)]
148struct ParitionedArgs {
149    num_prefix_args: usize,
150    num_suffix_args: usize,
151    prefix_patterns: Vec<ArgMatcher>,
152    suffix_patterns: Vec<ArgMatcher>,
153    vararg_pattern: Option<ArgMatcher>,
154}
155
156fn partition_args(program: &str, arg_patterns: &Vec<ArgMatcher>) -> Result<ParitionedArgs> {
157    let mut in_prefix = true;
158    let mut partitioned_args = ParitionedArgs::default();
159
160    for pattern in arg_patterns {
161        match pattern.cardinality().is_exact() {
162            Some(n) => {
163                if in_prefix {
164                    partitioned_args.prefix_patterns.push(pattern.clone());
165                    partitioned_args.num_prefix_args += n;
166                } else {
167                    partitioned_args.suffix_patterns.push(pattern.clone());
168                    partitioned_args.num_suffix_args += n;
169                }
170            }
171            None => match partitioned_args.vararg_pattern {
172                None => {
173                    partitioned_args.vararg_pattern = Some(pattern.clone());
174                    in_prefix = false;
175                }
176                Some(existing_pattern) => {
177                    return Err(Error::MultipleVarargPatterns {
178                        program: program.to_string(),
179                        first: existing_pattern,
180                        second: pattern.clone(),
181                    });
182                }
183            },
184        }
185    }
186
187    Ok(partitioned_args)
188}
189
190fn get_range_checked<T>(vec: &[T], range: std::ops::Range<usize>) -> Result<&[T]> {
191    if range.start > range.end {
192        Err(Error::RangeStartExceedsEnd {
193            start: range.start,
194            end: range.end,
195        })
196    } else if range.end > vec.len() {
197        Err(Error::RangeEndOutOfBounds {
198            end: range.end,
199            len: vec.len(),
200        })
201    } else {
202        Ok(&vec[range])
203    }
204}