1use std::collections::HashMap;
2
3use agent_shell_parser::parse::types::Word;
4use serde::de::Deserializer;
5use serde::{Deserialize, Serialize};
6
7pub const MAX_SUBCOMMAND_DEPTH: usize = 4;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
21#[serde(rename_all = "kebab-case")]
22pub enum Effect {
23 ReadOnly,
24 Mutating,
25 Unknown,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct CommandKnowledge {
34 pub name: String,
35 pub effect: Effect,
37 #[serde(default)]
38 pub subcommands: SubcommandMap,
39 #[serde(default)]
40 pub flags: FlagSchema,
41 #[serde(default)]
42 pub env_gates: Vec<EnvGate>,
43 #[serde(default)]
44 pub paths: PathSpec,
45 #[serde(default)]
46 pub properties: CommandProperties,
47}
48
49#[derive(Debug, Clone, Default, Serialize)]
58pub struct SubcommandMap {
59 entries: HashMap<String, SubcommandEntry>,
60}
61
62impl<'de> Deserialize<'de> for SubcommandMap {
63 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64 where
65 D: Deserializer<'de>,
66 {
67 #[derive(Deserialize)]
71 struct SubcommandMapRepr {
72 #[serde(default)]
73 entries: HashMap<String, SubcommandEntry>,
74 }
75
76 let repr = SubcommandMapRepr::deserialize(deserializer)?;
77 for key in repr.entries.keys() {
78 if key.split_whitespace().count() > MAX_SUBCOMMAND_DEPTH {
79 return Err(serde::de::Error::custom(format!(
80 "subcommand pattern '{}' exceeds MAX_SUBCOMMAND_DEPTH ({})",
81 key, MAX_SUBCOMMAND_DEPTH
82 )));
83 }
84 }
85 Ok(SubcommandMap {
86 entries: repr.entries,
87 })
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct SubcommandEntry {
95 pub effect: Effect,
96 #[serde(default)]
97 pub flags: FlagSchema,
98 #[serde(default)]
99 pub env_gates: Vec<EnvGate>,
100 #[serde(default)]
101 pub paths: PathSpec,
102 #[serde(default)]
103 pub subcommands: SubcommandMap,
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct FlagSchema {
114 #[serde(default)]
116 pub skip_arg: Vec<String>,
117 #[serde(default)]
119 pub skip_solo: Vec<String>,
120 #[serde(default)]
123 pub escalation: Vec<String>,
124 #[serde(default)]
126 pub path: Vec<String>,
127}
128
129impl FlagSchema {
130 pub fn extend(&mut self, other: FlagSchema) {
135 self.skip_arg.extend(other.skip_arg);
136 self.skip_solo.extend(other.skip_solo);
137 self.escalation.extend(other.escalation);
138 self.path.extend(other.path);
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
148#[serde(tag = "type", rename_all = "kebab-case")]
149pub enum EnvGate {
150 Grant {
152 var: String,
153 value: String,
154 unlocks: Effect,
155 },
156 Require { var: String, value: String },
158}
159
160#[derive(Debug, Clone, Default, Serialize, Deserialize)]
163pub struct PathSpec {
164 #[serde(default)]
165 pub positionals: PathPositionals,
166 #[serde(default)]
167 pub flags: Vec<String>,
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172#[serde(rename_all = "kebab-case")]
173pub enum PathPositionals {
174 #[default]
175 None,
176 All,
177 Tail(usize),
178 Last,
179}
180
181#[derive(Debug, Clone, Default, Serialize, Deserialize)]
182pub struct CommandProperties {
183 #[serde(default)]
184 pub version_flag: Option<String>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct WrapperKnowledge {
191 pub name: String,
192 pub floor_effect: Effect,
194 #[serde(default)]
195 pub clears_env: bool,
196 #[serde(default)]
197 pub escalates_privilege: bool,
198}
199
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
203pub struct KnowledgeBase {
204 #[serde(default)]
205 pub commands: HashMap<String, CommandKnowledge>,
206 #[serde(default)]
207 pub wrappers: HashMap<String, WrapperKnowledge>,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
214pub struct CommandInfo {
215 pub effect: Effect,
216 pub subcommand: Option<String>,
217 pub has_escalation_flags: bool,
218 pub affected_paths: Vec<Word>,
219 pub env_gates: Vec<EnvGate>,
220 pub wrapper: Option<WrapperInfo>,
221}
222
223#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct WrapperInfo {
226 pub name: String,
227 pub floor_effect: Effect,
228 pub clears_env: bool,
229 pub escalates_privilege: bool,
230}
231
232impl SubcommandEntry {
233 #[cfg(test)]
235 pub fn with_effect(effect: Effect) -> Self {
236 Self {
237 effect,
238 flags: FlagSchema::default(),
239 env_gates: vec![],
240 paths: PathSpec::default(),
241 subcommands: SubcommandMap::new(),
242 }
243 }
244}
245
246impl CommandKnowledge {
247 #[cfg(test)]
249 pub fn simple(name: impl Into<String>, effect: Effect) -> Self {
250 let name = name.into();
251 Self {
252 name,
253 effect,
254 subcommands: SubcommandMap::new(),
255 flags: FlagSchema::default(),
256 env_gates: vec![],
257 paths: PathSpec::default(),
258 properties: CommandProperties::default(),
259 }
260 }
261}
262
263impl SubcommandMap {
264 #[must_use = "returns an empty SubcommandMap"]
265 pub fn new() -> Self {
266 Self {
267 entries: HashMap::new(),
268 }
269 }
270
271 pub fn insert(&mut self, pattern: impl Into<String>, entry: SubcommandEntry) {
280 let pattern = pattern.into();
281 debug_assert!(
282 pattern.split_whitespace().count() <= MAX_SUBCOMMAND_DEPTH,
283 "subcommand pattern '{}' exceeds MAX_SUBCOMMAND_DEPTH ({})",
284 pattern,
285 MAX_SUBCOMMAND_DEPTH,
286 );
287 self.entries.insert(pattern, entry);
288 }
289
290 #[must_use = "returns the entry if found"]
291 pub fn get(&self, pattern: &str) -> Option<&SubcommandEntry> {
292 self.entries.get(pattern)
293 }
294
295 #[must_use = "returns whether the map has entries"]
296 pub fn is_empty(&self) -> bool {
297 self.entries.is_empty()
298 }
299
300 pub fn iter(&self) -> impl Iterator<Item = (&str, &SubcommandEntry)> {
301 self.entries.iter().map(|(k, v)| (k.as_str(), v))
302 }
303
304 pub fn extend(&mut self, other: SubcommandMap) {
305 for (pattern, entry) in other.entries {
306 self.insert(pattern, entry);
307 }
308 }
309
310 pub fn remove(&mut self, pattern: &str) {
311 self.entries.remove(pattern);
312 }
313
314 #[must_use = "returns the number of entries in the map"]
315 pub fn len(&self) -> usize {
316 self.entries.len()
317 }
318
319 #[must_use = "returns the best-matching entry and how many words it consumed"]
320 pub fn longest_match(&self, words: &[&Word]) -> Option<(&SubcommandEntry, usize)> {
321 let max_depth = words.len().min(MAX_SUBCOMMAND_DEPTH);
322 for depth in (1..=max_depth).rev() {
323 let pattern: String = words[..depth]
324 .iter()
325 .map(|w| w.as_str())
326 .collect::<Vec<_>>()
327 .join(" ");
328 if let Some(entry) = self.entries.get(&pattern) {
329 return Some((entry, depth));
330 }
331 }
332 None
333 }
334}
335
336impl<'a> IntoIterator for &'a SubcommandMap {
337 type Item = (&'a str, &'a SubcommandEntry);
338 type IntoIter = std::iter::Map<
339 std::collections::hash_map::Iter<'a, String, SubcommandEntry>,
340 fn((&'a String, &'a SubcommandEntry)) -> (&'a str, &'a SubcommandEntry),
341 >;
342
343 fn into_iter(self) -> Self::IntoIter {
344 self.entries.iter().map(|(k, v)| (k.as_str(), v))
345 }
346}
347
348impl CommandInfo {
349 #[must_use = "returns a default Unknown classification"]
350 pub fn unknown() -> Self {
351 Self {
352 effect: Effect::Unknown,
353 subcommand: None,
354 has_escalation_flags: false,
355 affected_paths: vec![],
356 env_gates: vec![],
357 wrapper: None,
358 }
359 }
360}
361
362#[cfg(test)]
363#[path = "types_tests.rs"]
364mod types_tests;
365
366#[cfg(test)]
367#[path = "types_proptest.rs"]
368mod types_proptest;