Skip to main content

agent_command_knowledge/
lookup.rs

1use agent_shell_parser::parse::types::Word;
2
3use crate::types::{
4    CommandInfo, Effect, FlagSchema, KnowledgeBase, PathPositionals, PathSpec, SubcommandEntry,
5    WrapperInfo,
6};
7
8/// Classify a command using the knowledge base.
9///
10/// Takes the base command name and the full word list (including the command
11/// itself). Returns a `CommandInfo` describing the command's effect,
12/// subcommand, escalation flags, affected paths, and env gates.
13#[must_use = "classification result contains the effect, subcommand, and paths — ignoring it skips policy evaluation"]
14pub fn classify(base_command: &Word, words: &[Word], kb: &KnowledgeBase) -> CommandInfo {
15    let Some(knowledge) = kb.commands.get(base_command.as_str()) else {
16        return check_wrapper(base_command, kb);
17    };
18
19    let suffix = format!("/{base_command}");
20    let cmd_idx = words
21        .iter()
22        .position(|w| w == base_command || w.ends_with(suffix.as_str()))
23        .map(|i| i + 1)
24        .unwrap_or(0);
25
26    let remaining: Vec<&Word> = skip_flags(&words[cmd_idx..], &knowledge.flags);
27
28    let (effect, subcommand, sub_depth, sub_flags, sub_env_gates, sub_paths) =
29        match knowledge.subcommands.longest_match(&remaining) {
30            Some((entry, depth)) => {
31                let sub_pattern: String = remaining[..depth]
32                    .iter()
33                    .map(|w| w.as_str())
34                    .collect::<Vec<_>>()
35                    .join(" ");
36                let (e, s, d, f, g, p) =
37                    resolve_subcommand(entry, &remaining[depth..], &sub_pattern, depth, &[]);
38                (e, s, d, f, g, p)
39            }
40            None => (
41                knowledge.effect,
42                None,
43                0,
44                &knowledge.flags,
45                vec![],
46                &knowledge.paths,
47            ),
48        };
49
50    let merged_flags = merge_flag_schemas(&knowledge.flags, sub_flags);
51    let has_escalation_flags = check_escalation_flags(words, &merged_flags);
52    let positionals_after_sub: Vec<&Word> = remaining.iter().skip(sub_depth).copied().collect();
53    let affected_paths = extract_paths(words, &positionals_after_sub, sub_paths, &merged_flags);
54
55    let env_gates = {
56        let mut gates = knowledge.env_gates.clone();
57        if subcommand.is_some() {
58            gates.extend(sub_env_gates);
59        }
60        gates
61    };
62
63    CommandInfo {
64        effect,
65        subcommand,
66        has_escalation_flags,
67        affected_paths,
68        env_gates,
69        wrapper: None,
70    }
71}
72
73fn resolve_subcommand<'a>(
74    entry: &'a SubcommandEntry,
75    remaining: &[&Word],
76    pattern: &str,
77    accumulated_depth: usize,
78    accumulated_gates: &[crate::types::EnvGate],
79) -> (
80    Effect,
81    Option<String>,
82    usize,
83    &'a FlagSchema,
84    Vec<crate::types::EnvGate>,
85    &'a PathSpec,
86) {
87    // Merge gates from this level into the accumulator.
88    let mut gates_so_far: Vec<crate::types::EnvGate> = accumulated_gates.to_vec();
89    gates_so_far.extend(entry.env_gates.iter().cloned());
90
91    if !entry.subcommands.is_empty() {
92        let remaining_owned: Vec<Word> = remaining.iter().map(|w| Word::from(w.as_str())).collect();
93        let inner_remaining: Vec<&Word> = skip_flags(&remaining_owned, &entry.flags);
94        if let Some((inner_entry, inner_depth)) = entry.subcommands.longest_match(&inner_remaining)
95        {
96            let inner_pattern: String = inner_remaining[..inner_depth]
97                .iter()
98                .map(|w| w.as_str())
99                .collect::<Vec<_>>()
100                .join(" ");
101            let full_pattern = format!("{} {}", pattern, inner_pattern);
102            let total_depth = accumulated_depth + inner_depth;
103            return resolve_subcommand(
104                inner_entry,
105                &inner_remaining[inner_depth..],
106                &full_pattern,
107                total_depth,
108                &gates_so_far,
109            );
110        }
111    }
112
113    (
114        entry.effect,
115        Some(pattern.to_string()),
116        accumulated_depth,
117        &entry.flags,
118        gates_so_far,
119        &entry.paths,
120    )
121}
122
123fn check_wrapper(base_command: &Word, kb: &KnowledgeBase) -> CommandInfo {
124    if let Some(wrapper) = kb.wrappers.get(base_command.as_str()) {
125        CommandInfo {
126            effect: Effect::Unknown,
127            subcommand: None,
128            has_escalation_flags: false,
129            affected_paths: vec![],
130            env_gates: vec![],
131            wrapper: Some(WrapperInfo {
132                name: wrapper.name.clone(),
133                floor_effect: wrapper.floor_effect,
134                clears_env: wrapper.clears_env,
135                escalates_privilege: wrapper.escalates_privilege,
136            }),
137        }
138    } else {
139        CommandInfo::unknown()
140    }
141}
142
143/// Skip past flags in a word list to find subcommand words.
144fn skip_flags<'a>(words: &'a [Word], schema: &FlagSchema) -> Vec<&'a Word> {
145    let mut result = Vec::new();
146    let mut i = 0;
147    while i < words.len() {
148        let w = &words[i];
149        if schema.skip_arg.iter().any(|f| w == f) {
150            i += 2;
151            continue;
152        }
153        if w.starts_with('-') && w.contains('=') {
154            if let Some((flag_part, _)) = w.split_once('=') {
155                if schema.skip_arg.iter().any(|f| f == flag_part) {
156                    i += 1;
157                    continue;
158                }
159            }
160        }
161        if w.as_str() == "--" {
162            result.extend(words[i + 1..].iter());
163            break;
164        }
165        if schema.skip_solo.iter().any(|f| w == f) {
166            i += 1;
167            continue;
168        }
169        if w.starts_with('-') {
170            i += 1;
171            continue;
172        }
173        result.push(w);
174        i += 1;
175    }
176    result
177}
178
179fn check_escalation_flags(words: &[Word], schema: &FlagSchema) -> bool {
180    words.iter().any(|w| {
181        schema.escalation.iter().any(|f| {
182            w == f
183                || w.split_once('=')
184                    .is_some_and(|(prefix, _)| prefix == f.as_str())
185        })
186    })
187}
188
189fn merge_flag_schemas(parent: &FlagSchema, child: &FlagSchema) -> FlagSchema {
190    FlagSchema {
191        skip_arg: parent
192            .skip_arg
193            .iter()
194            .chain(child.skip_arg.iter())
195            .cloned()
196            .collect(),
197        skip_solo: parent
198            .skip_solo
199            .iter()
200            .chain(child.skip_solo.iter())
201            .cloned()
202            .collect(),
203        escalation: parent
204            .escalation
205            .iter()
206            .chain(child.escalation.iter())
207            .cloned()
208            .collect(),
209        path: parent
210            .path
211            .iter()
212            .chain(child.path.iter())
213            .cloned()
214            .collect(),
215    }
216}
217
218/// Extract affected paths from a command invocation.
219///
220/// `positionals` is the flag-skipped word list after subcommand consumption —
221/// the same view that `longest_match` operated on, minus the subcommand words.
222/// This ensures positional path extraction agrees with subcommand resolution
223/// about which words are arguments vs flag values.
224fn extract_paths(
225    words: &[Word],
226    positionals: &[&Word],
227    path_spec: &PathSpec,
228    flag_schema: &FlagSchema,
229) -> Vec<Word> {
230    let mut paths = Vec::new();
231
232    for flag in &flag_schema.path {
233        let prefix = format!("{flag}=");
234        let mut i = 0;
235        while i < words.len() {
236            if words[i] == flag.as_str() && i + 1 < words.len() {
237                paths.push(words[i + 1].clone());
238                i += 2;
239                continue;
240            }
241            if let Some(value) = words[i].strip_prefix(prefix.as_str()) {
242                paths.push(Word::from(value));
243            }
244            i += 1;
245        }
246    }
247
248    match &path_spec.positionals {
249        PathPositionals::None => {}
250        PathPositionals::All => {
251            paths.extend(positionals.iter().map(|p| (*p).clone()));
252        }
253        PathPositionals::Tail(skip) => {
254            paths.extend(positionals.iter().skip(*skip).map(|p| (*p).clone()));
255        }
256        PathPositionals::Last => {
257            if let Some(last) = positionals.last() {
258                paths.push((*last).clone());
259            }
260        }
261    }
262
263    paths
264}
265
266#[cfg(test)]
267#[path = "lookup_tests.rs"]
268mod lookup_tests;
269
270#[cfg(test)]
271#[path = "classify_proptest.rs"]
272mod classify_proptest;