1use std::collections::HashSet;
2
3use crate::{actions::{apply_check, apply_reducer}, errors::Error, machine::{Machine, MachineContext}, shared::{Arg, ERROR_NODE_ID, HELP_COMMAND_INDEX, INITIAL_NODE_ID}};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum Token {
7 Path {
8 segment_index: usize,
9 },
10
11 Positional {
12 segment_index: usize,
13 },
14
15 Option {
16 segment_index: usize,
17 slice: Option<(usize, usize)>,
18 option: String,
19 },
20
21 Assign {
22 segment_index: usize,
23 slice: (usize, usize),
24 },
25
26 Value {
27 segment_index: usize,
28 slice: Option<(usize, usize)>,
29 },
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
33pub enum OptionValue {
34 None,
35 Array(Vec<String>),
36 Bool(bool),
37 String(String),
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
41pub enum Positional {
42 Required(String),
43 Optional(String),
44 Rest(String),
45}
46
47#[derive(Debug, Default, Clone, PartialEq, Eq)]
48pub struct RunState {
49 pub candidate_index: usize,
50 pub required_options: Vec<String>,
51 pub error_message: Option<Error>,
52 pub ignore_options: bool,
53 pub is_help: bool,
54 pub options: Vec<(String, OptionValue)>,
55 pub path: Vec<String>,
56 pub positionals: Vec<Positional>,
57 pub remainder: Option<String>,
58 pub selected_index: Option<usize>,
59 pub tokens: Vec<Token>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63struct RunBranch {
64 node_id: usize,
65 state: RunState,
66}
67
68impl RunBranch {
69 pub fn apply_transition(&self, transition: &crate::transition::Transition, context: &MachineContext, segment: &Arg, segment_index: usize) -> RunBranch {
70 RunBranch {
71 node_id: transition.to,
72 state: apply_reducer(&transition.reducer, context, &self.state, segment, segment_index),
73 }
74 }
75}
76
77fn trim_smaller_branches(branches: &mut Vec<RunBranch>) {
78 let max_path_size = branches.iter()
79 .map(|b| b.state.path.len())
80 .max()
81 .unwrap();
82
83 branches.retain(|b| b.state.path.len() == max_path_size);
84}
85
86fn select_best_state(_input: &[String], mut states: Vec<RunState>) -> Result<RunState, Error> {
87 states.retain(|s| {
88 s.selected_index.is_some()
89 });
90
91 if states.is_empty() {
92 panic!("No terminal states found");
93 }
94
95 states.retain(|s| {
96 s.selected_index == Some(HELP_COMMAND_INDEX) || s.required_options.iter().all(|o| s.options.iter().any(|(name, _)| name == o))
97 });
98
99 if states.is_empty() {
100 return Err(Error::InternalError);
101 }
102
103 let max_path_size = states.iter()
104 .map(|s| s.path.len())
105 .max()
106 .unwrap();
107
108 states.retain(|s| {
109 s.path.len() == max_path_size
110 });
111
112 fn get_fill_score(state: &RunState) -> usize {
113 let option_scope = state.options.len();
114
115 let positional_score = state.positionals.iter()
116 .filter(|mode| matches!(mode, Positional::Required(_)))
117 .count();
118
119 option_scope + positional_score
120 }
121
122 let best_fill_score = states.iter()
123 .map(get_fill_score)
124 .max()
125 .unwrap();
126
127 states.retain(|s| {
128 get_fill_score(s) == best_fill_score
129 });
130
131 let mut aggregated_states
132 = aggregate_help_states(states.into_iter());
133
134 if aggregated_states.len() > 1 {
135 let candidate_commands = aggregated_states.iter()
136 .map(|s| s.selected_index.unwrap())
137 .collect::<Vec<_>>();
138
139 return Err(Error::AmbiguousSyntax(candidate_commands));
140 }
141
142 Ok(std::mem::take(aggregated_states.first_mut().unwrap()))
143}
144
145fn find_common_prefix<'t, I>(mut it: I) -> Vec<String> where I: Iterator<Item = &'t Vec<String>> {
146 let mut common_prefix
147 = it.next().unwrap().clone();
148
149 for path in it {
150 if path.len() < common_prefix.len() {
151 common_prefix.resize(path.len(), Default::default());
152 }
153
154 let diff = common_prefix.iter()
155 .zip(path.iter())
156 .position(|(a, b)| a != b);
157
158 if let Some(diff) = diff {
159 common_prefix.resize(diff, Default::default());
160 }
161 }
162
163 common_prefix
164}
165
166fn aggregate_help_states<I>(it: I) -> Vec<RunState> where I: Iterator<Item = RunState> {
167 let (helps, mut not_helps)
168 = it.partition::<Vec<_>, _>(|s| s.selected_index == Some(HELP_COMMAND_INDEX));
169
170 if !helps.is_empty() {
171 let options = helps.iter()
172 .flat_map(|s| s.options.iter())
173 .cloned()
174 .collect();
175
176 not_helps.push(RunState {
177 selected_index: Some(HELP_COMMAND_INDEX),
178 path: find_common_prefix(helps.iter().map(|s| &s.path)),
179 options,
180 ..Default::default()
181 });
182 }
183
184 not_helps
185}
186
187fn extract_error_from_branches(_input: &[String], mut branches: Vec<RunBranch>, is_next: bool) -> Error {
188 if is_next {
189 if let Some(lead) = branches.pop() {
190 if let Some(Error::CommandError(usize, command_error)) = lead.state.error_message {
191 if branches.iter().all(|b| match &b.state.error_message {
192 Some(Error::CommandError(_, other_error)) => other_error == &command_error,
193 _ => false,
194 }) {
195 return Error::CommandError(usize, command_error);
196 }
197 }
198 }
199 }
200
201 let candidate_indices = branches.iter()
202 .filter(|b| b.node_id != ERROR_NODE_ID)
203 .map(|b| b.state.candidate_index)
204 .collect::<HashSet<_>>()
205 .into_iter()
206 .collect::<Vec<_>>();
207
208 Error::NotFound(candidate_indices)
209}
210
211fn run_machine_internal(machine: &Machine, input: &[String], partial: bool) -> Result<Vec<RunBranch>, Error> {
212 let mut args = vec![Arg::StartOfInput];
213
214 args.extend(input.iter().map(|s| {
215 Arg::User(s.to_string())
216 }));
217
218 args.push(match partial {
219 true => Arg::EndOfPartialInput,
220 false => Arg::EndOfInput,
221 });
222
223 let mut branches = vec![RunBranch {
224 node_id: INITIAL_NODE_ID,
225 state: RunState::default(),
226 }];
227
228 for (t, arg) in args.iter().enumerate() {
229 let is_eoi = arg == &Arg::EndOfInput || arg == &Arg::EndOfPartialInput;
230 let mut next_branches = vec![];
231
232 for branch in &branches {
233 if branch.node_id == ERROR_NODE_ID {
234 next_branches.push(branch.clone());
235 continue;
236 }
237
238 let node = &machine.nodes[branch.node_id];
239 let context = &machine.contexts[node.context];
240
241 let has_exact_match = node.statics.contains_key(arg);
242 if !partial || t < args.len() - 1 || has_exact_match {
243 if has_exact_match {
244 for transition in &node.statics[arg] {
245 next_branches.push(branch.apply_transition(transition, context, arg, t.wrapping_sub(1)));
246 }
247 }
248 } else {
249 for candidate in machine.nodes[branch.node_id].statics.keys() {
250 if !candidate.starts_with(arg) {
251 continue;
252 }
253
254 for transition in &node.statics[candidate] {
255 next_branches.push(branch.apply_transition(transition, context, arg, t - 1));
256 }
257 }
258 }
259
260 if !is_eoi {
261 for (check, transition) in &node.dynamics {
262 if apply_check(check, context, &branch.state, arg, t - 1) {
263 next_branches.push(branch.apply_transition(transition, context, arg, t - 1));
264 }
265 }
266 }
267 }
268
269 if next_branches.is_empty() && is_eoi && input.len() == 1 {
270 return Ok(vec![RunBranch {
271 node_id: INITIAL_NODE_ID,
272 state: RunState {
273 selected_index: Some(HELP_COMMAND_INDEX),
274 ..RunState::default()
275 },
276 }]);
277 }
278
279 if next_branches.is_empty() {
280 return Err(extract_error_from_branches(input, branches, false));
281 }
282
283 if next_branches.iter().all(|b| b.node_id == ERROR_NODE_ID) {
284 return Err(extract_error_from_branches(input, next_branches, true));
285 }
286
287 branches = next_branches;
288 trim_smaller_branches(&mut branches);
289 }
290
291 Ok(branches)
292}
293
294pub fn run_machine(machine: &Machine, input: &[String]) -> Result<RunState, Error> {
295 let branches = run_machine_internal(machine, input, false)?;
296
297 let states = branches.into_iter()
298 .map(|b| b.state)
299 .collect();
300
301 select_best_state(input, states)
302}
303
304pub fn run_partial_machine(machine: &Machine, input: &[String]) -> Result<RunState, Error> {
305 let branches = run_machine_internal(machine, input, true)?;
306
307 let states = branches.into_iter()
308 .map(|b| b.state)
309 .collect();
310
311 select_best_state(input, states)
312}