1use agent_shell_parser::parse::types::Word;
2
3use crate::types::{
4 CommandInfo, Effect, FlagSchema, KnowledgeBase, PathPositionals, PathSpec, SubcommandEntry,
5 WrapperInfo,
6};
7
8#[must_use = "classification result contains the effect, subcommand, and paths — ignoring it skips policy evaluation"]
14pub fn classify(base_command: &Word, words: &[Word], kb: &KnowledgeBase) -> CommandInfo {
15 let Some(knowledge) = kb.commands.get(base_command.as_str()) else {
16 return check_wrapper(base_command, kb);
17 };
18
19 let suffix = format!("/{base_command}");
20 let cmd_idx = words
21 .iter()
22 .position(|w| w == base_command || w.ends_with(suffix.as_str()))
23 .map(|i| i + 1)
24 .unwrap_or(0);
25
26 let remaining: Vec<&Word> = skip_flags(&words[cmd_idx..], &knowledge.flags);
27
28 let (effect, subcommand, sub_depth, sub_flags, sub_env_gates, sub_paths) =
29 match knowledge.subcommands.longest_match(&remaining) {
30 Some((entry, depth)) => {
31 let sub_pattern: String = remaining[..depth]
32 .iter()
33 .map(|w| w.as_str())
34 .collect::<Vec<_>>()
35 .join(" ");
36 let (e, s, d, f, g, p) =
37 resolve_subcommand(entry, &remaining[depth..], &sub_pattern, depth, &[]);
38 (e, s, d, f, g, p)
39 }
40 None => (
41 knowledge.effect,
42 None,
43 0,
44 &knowledge.flags,
45 vec![],
46 &knowledge.paths,
47 ),
48 };
49
50 let merged_flags = merge_flag_schemas(&knowledge.flags, sub_flags);
51 let has_escalation_flags = check_escalation_flags(words, &merged_flags);
52 let positionals_after_sub: Vec<&Word> = remaining.iter().skip(sub_depth).copied().collect();
53 let affected_paths = extract_paths(words, &positionals_after_sub, sub_paths, &merged_flags);
54
55 let env_gates = {
56 let mut gates = knowledge.env_gates.clone();
57 if subcommand.is_some() {
58 gates.extend(sub_env_gates);
59 }
60 gates
61 };
62
63 CommandInfo {
64 effect,
65 subcommand,
66 has_escalation_flags,
67 affected_paths,
68 env_gates,
69 wrapper: None,
70 }
71}
72
73fn resolve_subcommand<'a>(
74 entry: &'a SubcommandEntry,
75 remaining: &[&Word],
76 pattern: &str,
77 accumulated_depth: usize,
78 accumulated_gates: &[crate::types::EnvGate],
79) -> (
80 Effect,
81 Option<String>,
82 usize,
83 &'a FlagSchema,
84 Vec<crate::types::EnvGate>,
85 &'a PathSpec,
86) {
87 let mut gates_so_far: Vec<crate::types::EnvGate> = accumulated_gates.to_vec();
89 gates_so_far.extend(entry.env_gates.iter().cloned());
90
91 if !entry.subcommands.is_empty() {
92 let remaining_owned: Vec<Word> = remaining.iter().map(|w| Word::from(w.as_str())).collect();
93 let inner_remaining: Vec<&Word> = skip_flags(&remaining_owned, &entry.flags);
94 if let Some((inner_entry, inner_depth)) = entry.subcommands.longest_match(&inner_remaining)
95 {
96 let inner_pattern: String = inner_remaining[..inner_depth]
97 .iter()
98 .map(|w| w.as_str())
99 .collect::<Vec<_>>()
100 .join(" ");
101 let full_pattern = format!("{} {}", pattern, inner_pattern);
102 let total_depth = accumulated_depth + inner_depth;
103 return resolve_subcommand(
104 inner_entry,
105 &inner_remaining[inner_depth..],
106 &full_pattern,
107 total_depth,
108 &gates_so_far,
109 );
110 }
111 }
112
113 (
114 entry.effect,
115 Some(pattern.to_string()),
116 accumulated_depth,
117 &entry.flags,
118 gates_so_far,
119 &entry.paths,
120 )
121}
122
123fn check_wrapper(base_command: &Word, kb: &KnowledgeBase) -> CommandInfo {
124 if let Some(wrapper) = kb.wrappers.get(base_command.as_str()) {
125 CommandInfo {
126 effect: Effect::Unknown,
127 subcommand: None,
128 has_escalation_flags: false,
129 affected_paths: vec![],
130 env_gates: vec![],
131 wrapper: Some(WrapperInfo {
132 name: wrapper.name.clone(),
133 floor_effect: wrapper.floor_effect,
134 clears_env: wrapper.clears_env,
135 escalates_privilege: wrapper.escalates_privilege,
136 }),
137 }
138 } else {
139 CommandInfo::unknown()
140 }
141}
142
143fn skip_flags<'a>(words: &'a [Word], schema: &FlagSchema) -> Vec<&'a Word> {
145 let mut result = Vec::new();
146 let mut i = 0;
147 while i < words.len() {
148 let w = &words[i];
149 if schema.skip_arg.iter().any(|f| w == f) {
150 i += 2;
151 continue;
152 }
153 if w.starts_with('-') && w.contains('=') {
154 if let Some((flag_part, _)) = w.split_once('=') {
155 if schema.skip_arg.iter().any(|f| f == flag_part) {
156 i += 1;
157 continue;
158 }
159 }
160 }
161 if w.as_str() == "--" {
162 result.extend(words[i + 1..].iter());
163 break;
164 }
165 if schema.skip_solo.iter().any(|f| w == f) {
166 i += 1;
167 continue;
168 }
169 if w.starts_with('-') {
170 i += 1;
171 continue;
172 }
173 result.push(w);
174 i += 1;
175 }
176 result
177}
178
179fn check_escalation_flags(words: &[Word], schema: &FlagSchema) -> bool {
180 words.iter().any(|w| {
181 schema.escalation.iter().any(|f| {
182 w == f
183 || w.split_once('=')
184 .is_some_and(|(prefix, _)| prefix == f.as_str())
185 })
186 })
187}
188
189fn merge_flag_schemas(parent: &FlagSchema, child: &FlagSchema) -> FlagSchema {
190 FlagSchema {
191 skip_arg: parent
192 .skip_arg
193 .iter()
194 .chain(child.skip_arg.iter())
195 .cloned()
196 .collect(),
197 skip_solo: parent
198 .skip_solo
199 .iter()
200 .chain(child.skip_solo.iter())
201 .cloned()
202 .collect(),
203 escalation: parent
204 .escalation
205 .iter()
206 .chain(child.escalation.iter())
207 .cloned()
208 .collect(),
209 path: parent
210 .path
211 .iter()
212 .chain(child.path.iter())
213 .cloned()
214 .collect(),
215 }
216}
217
218fn extract_paths(
225 words: &[Word],
226 positionals: &[&Word],
227 path_spec: &PathSpec,
228 flag_schema: &FlagSchema,
229) -> Vec<Word> {
230 let mut paths = Vec::new();
231
232 for flag in &flag_schema.path {
233 let prefix = format!("{flag}=");
234 let mut i = 0;
235 while i < words.len() {
236 if words[i] == flag.as_str() && i + 1 < words.len() {
237 paths.push(words[i + 1].clone());
238 i += 2;
239 continue;
240 }
241 if let Some(value) = words[i].strip_prefix(prefix.as_str()) {
242 paths.push(Word::from(value));
243 }
244 i += 1;
245 }
246 }
247
248 match &path_spec.positionals {
249 PathPositionals::None => {}
250 PathPositionals::All => {
251 paths.extend(positionals.iter().map(|p| (*p).clone()));
252 }
253 PathPositionals::Tail(skip) => {
254 paths.extend(positionals.iter().skip(*skip).map(|p| (*p).clone()));
255 }
256 PathPositionals::Last => {
257 if let Some(last) = positionals.last() {
258 paths.push((*last).clone());
259 }
260 }
261 }
262
263 paths
264}
265
266#[cfg(test)]
267#[path = "lookup_tests.rs"]
268mod lookup_tests;
269
270#[cfg(test)]
271#[path = "classify_proptest.rs"]
272mod classify_proptest;