clipanion_core/
runner.rs

1use std::collections::HashSet;
2
3use crate::{actions::{apply_check, apply_reducer}, errors::Error, machine::{Machine, MachineContext}, shared::{Arg, ERROR_NODE_ID, HELP_COMMAND_INDEX, INITIAL_NODE_ID}};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum Token {
7    Path {
8        segment_index: usize,
9    },
10
11    Positional {
12        segment_index: usize,
13    },
14
15    Option {
16        segment_index: usize,
17        slice: Option<(usize, usize)>,
18        option: String,
19    },
20
21    Assign {
22        segment_index: usize,
23        slice: (usize, usize),
24    },
25
26    Value {
27        segment_index: usize,
28        slice: Option<(usize, usize)>,
29    },
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
33pub enum OptionValue {
34    None,
35    Array(Vec<String>),
36    Bool(bool),
37    String(String),
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
41pub enum Positional {
42    Required(String),
43    Optional(String),
44    Rest(String),
45}
46
47#[derive(Debug, Default, Clone, PartialEq, Eq)]
48pub struct RunState {
49    pub candidate_index: usize,
50    pub required_options: Vec<String>,
51    pub error_message: Option<Error>,
52    pub ignore_options: bool,
53    pub is_help: bool,
54    pub options: Vec<(String, OptionValue)>,
55    pub path: Vec<String>,
56    pub positionals: Vec<Positional>,
57    pub remainder: Option<String>,
58    pub selected_index: Option<usize>,
59    pub tokens: Vec<Token>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63struct RunBranch {
64    node_id: usize,
65    state: RunState,
66}
67
68impl RunBranch {
69    pub fn apply_transition(&self, transition: &crate::transition::Transition, context: &MachineContext, segment: &Arg, segment_index: usize) -> RunBranch {
70        RunBranch {
71            node_id: transition.to,
72            state: apply_reducer(&transition.reducer, context, &self.state, segment, segment_index),
73        }
74    }
75}
76
77fn trim_smaller_branches(branches: &mut Vec<RunBranch>) {
78    let max_path_size = branches.iter()
79        .map(|b| b.state.path.len())
80        .max()
81        .unwrap();
82
83    branches.retain(|b| b.state.path.len() == max_path_size);
84}
85
86fn select_best_state(_input: &[String], mut states: Vec<RunState>) -> Result<RunState, Error> {
87    states.retain(|s| {
88        s.selected_index.is_some()
89    });
90
91    if states.is_empty() {
92        panic!("No terminal states found");
93    }
94
95    states.retain(|s| {
96        s.selected_index == Some(HELP_COMMAND_INDEX) || s.required_options.iter().all(|o| s.options.iter().any(|(name, _)| name == o))
97    });
98
99    if states.is_empty() {
100        return Err(Error::InternalError);
101    }
102
103    let max_path_size = states.iter()
104        .map(|s| s.path.len())
105        .max()
106        .unwrap();
107
108    states.retain(|s| {
109        s.path.len() == max_path_size
110    });
111
112    fn get_fill_score(state: &RunState) -> usize {
113        let option_scope = state.options.len();
114
115        let positional_score = state.positionals.iter()
116            .filter(|mode| matches!(mode, Positional::Required(_)))
117            .count();
118
119        option_scope + positional_score
120    }
121
122    let best_fill_score = states.iter()
123        .map(get_fill_score)
124        .max()
125        .unwrap();
126
127    states.retain(|s| {
128        get_fill_score(s) == best_fill_score
129    });
130
131    let mut aggregated_states
132        = aggregate_help_states(states.into_iter());
133
134    if aggregated_states.len() > 1 {
135        let candidate_commands = aggregated_states.iter()
136            .map(|s| s.selected_index.unwrap())
137            .collect::<Vec<_>>();
138
139        return Err(Error::AmbiguousSyntax(candidate_commands));
140    }
141
142    Ok(std::mem::take(aggregated_states.first_mut().unwrap()))
143}
144
145fn find_common_prefix<'t, I>(mut it: I) -> Vec<String> where I: Iterator<Item = &'t Vec<String>> {
146    let mut common_prefix
147        = it.next().unwrap().clone();
148
149    for path in it {
150        if path.len() < common_prefix.len() {
151            common_prefix.resize(path.len(), Default::default());
152        }
153
154        let diff = common_prefix.iter()
155            .zip(path.iter())
156            .position(|(a, b)| a != b);
157
158        if let Some(diff) = diff {
159            common_prefix.resize(diff, Default::default());
160        }
161    }
162
163    common_prefix
164}
165
166fn aggregate_help_states<I>(it: I) -> Vec<RunState> where I: Iterator<Item = RunState> {
167    let (helps, mut not_helps)
168        = it.partition::<Vec<_>, _>(|s| s.selected_index == Some(HELP_COMMAND_INDEX));
169
170    if !helps.is_empty() {
171        let options = helps.iter()
172            .flat_map(|s| s.options.iter())
173            .cloned()
174            .collect();
175
176        not_helps.push(RunState {
177            selected_index: Some(HELP_COMMAND_INDEX),
178            path: find_common_prefix(helps.iter().map(|s| &s.path)),
179            options,
180            ..Default::default()
181        });
182    }
183
184    not_helps
185}
186
187fn extract_error_from_branches(_input: &[String], mut branches: Vec<RunBranch>, is_next: bool) -> Error {
188    if is_next {
189        if let Some(lead) = branches.pop() {
190            if let Some(Error::CommandError(usize, command_error)) = lead.state.error_message {
191                if branches.iter().all(|b| match &b.state.error_message {
192                    Some(Error::CommandError(_, other_error)) => other_error == &command_error,
193                    _ => false,
194                }) {
195                    return Error::CommandError(usize, command_error);
196                }
197            }
198        }
199    }
200
201    let candidate_indices = branches.iter()
202        .filter(|b| b.node_id != ERROR_NODE_ID)
203        .map(|b| b.state.candidate_index)
204        .collect::<HashSet<_>>()
205        .into_iter()
206        .collect::<Vec<_>>();
207
208    Error::NotFound(candidate_indices)
209}
210
211fn run_machine_internal(machine: &Machine, input: &[String], partial: bool) -> Result<Vec<RunBranch>, Error> {
212    let mut args = vec![Arg::StartOfInput];
213
214    args.extend(input.iter().map(|s| {
215        Arg::User(s.to_string())
216    }));
217
218    args.push(match partial {
219        true => Arg::EndOfPartialInput,
220        false => Arg::EndOfInput,
221    });
222
223    let mut branches = vec![RunBranch {
224        node_id: INITIAL_NODE_ID,
225        state: RunState::default(),
226    }];
227
228    for (t, arg) in args.iter().enumerate() {
229        let is_eoi = arg == &Arg::EndOfInput || arg == &Arg::EndOfPartialInput;
230        let mut next_branches = vec![];
231
232        for branch in &branches {
233            if branch.node_id == ERROR_NODE_ID {
234                next_branches.push(branch.clone());
235                continue;
236            }
237
238            let node = &machine.nodes[branch.node_id];
239            let context = &machine.contexts[node.context];
240
241            let has_exact_match = node.statics.contains_key(arg);
242            if !partial || t < args.len() - 1 || has_exact_match {
243                if has_exact_match {
244                    for transition in &node.statics[arg] {
245                        next_branches.push(branch.apply_transition(transition, context, arg, t.wrapping_sub(1)));
246                    }
247                }
248            } else {
249                for candidate in machine.nodes[branch.node_id].statics.keys() {
250                    if !candidate.starts_with(arg) {
251                        continue;
252                    }
253
254                    for transition in &node.statics[candidate] {
255                        next_branches.push(branch.apply_transition(transition, context, arg, t - 1));
256                    }
257                }
258            }
259
260            if !is_eoi {
261                for (check, transition) in &node.dynamics {
262                    if apply_check(check, context, &branch.state, arg, t - 1) {
263                        next_branches.push(branch.apply_transition(transition, context, arg, t - 1));
264                    }
265                }
266            }
267        }
268
269        if next_branches.is_empty() && is_eoi && input.len() == 1 {
270            return Ok(vec![RunBranch {
271                node_id: INITIAL_NODE_ID,
272                state: RunState {
273                    selected_index: Some(HELP_COMMAND_INDEX),
274                    ..RunState::default()
275                },
276            }]);
277        }
278
279        if next_branches.is_empty() {
280            return Err(extract_error_from_branches(input, branches, false));
281        }
282
283        if next_branches.iter().all(|b| b.node_id == ERROR_NODE_ID) {
284            return Err(extract_error_from_branches(input, next_branches, true));
285        }
286
287        branches = next_branches;
288        trim_smaller_branches(&mut branches);
289    }
290
291    Ok(branches)
292}
293
294pub fn run_machine(machine: &Machine, input: &[String]) -> Result<RunState, Error> {
295    let branches = run_machine_internal(machine, input, false)?;
296
297    let states = branches.into_iter()
298        .map(|b| b.state)
299        .collect();
300
301    select_best_state(input, states)
302}
303
304pub fn run_partial_machine(machine: &Machine, input: &[String]) -> Result<RunState, Error> {
305    let branches = run_machine_internal(machine, input, true)?;
306
307    let states = branches.into_iter()
308        .map(|b| b.state)
309        .collect();
310
311    select_best_state(input, states)
312}