1use std::collections::HashMap;
2
3use agent_shell_parser::parse::types::Word;
4use serde::de::Deserializer;
5use serde::{Deserialize, Serialize};
6
7pub const MAX_SUBCOMMAND_DEPTH: usize = 4;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
21#[serde(rename_all = "kebab-case")]
22pub enum Effect {
23 ReadOnly,
24 Mutating,
25 Destructive,
26 Unknown,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct CommandKnowledge {
35 pub name: String,
36 pub effect: Effect,
38 #[serde(default)]
39 pub subcommands: SubcommandMap,
40 #[serde(default)]
41 pub flags: FlagSchema,
42 #[serde(default)]
43 pub env_gates: Vec<EnvGate>,
44 #[serde(default)]
45 pub paths: PathSpec,
46 #[serde(default)]
47 pub properties: CommandProperties,
48}
49
50#[derive(Debug, Clone, Default, Serialize)]
59pub struct SubcommandMap {
60 entries: HashMap<String, SubcommandEntry>,
61}
62
63impl<'de> Deserialize<'de> for SubcommandMap {
64 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65 where
66 D: Deserializer<'de>,
67 {
68 #[derive(Deserialize)]
72 struct SubcommandMapRepr {
73 #[serde(default)]
74 entries: HashMap<String, SubcommandEntry>,
75 }
76
77 let repr = SubcommandMapRepr::deserialize(deserializer)?;
78 for key in repr.entries.keys() {
79 if key.split_whitespace().count() > MAX_SUBCOMMAND_DEPTH {
80 return Err(serde::de::Error::custom(format!(
81 "subcommand pattern '{}' exceeds MAX_SUBCOMMAND_DEPTH ({})",
82 key, MAX_SUBCOMMAND_DEPTH
83 )));
84 }
85 }
86 Ok(SubcommandMap {
87 entries: repr.entries,
88 })
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SubcommandEntry {
96 pub effect: Effect,
97 #[serde(default)]
98 pub flags: FlagSchema,
99 #[serde(default)]
100 pub env_gates: Vec<EnvGate>,
101 #[serde(default)]
102 pub paths: PathSpec,
103 #[serde(default)]
104 pub subcommands: SubcommandMap,
105}
106
107#[derive(Debug, Clone, Default, Serialize, Deserialize)]
114pub struct FlagSchema {
115 #[serde(default)]
117 pub skip_arg: Vec<String>,
118 #[serde(default)]
120 pub skip_solo: Vec<String>,
121 #[serde(default)]
124 pub escalation: Vec<String>,
125 #[serde(default)]
127 pub path: Vec<String>,
128}
129
130impl FlagSchema {
131 pub fn extend(&mut self, other: FlagSchema) {
136 self.skip_arg.extend(other.skip_arg);
137 self.skip_solo.extend(other.skip_solo);
138 self.escalation.extend(other.escalation);
139 self.path.extend(other.path);
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
149#[serde(tag = "type", rename_all = "kebab-case")]
150pub enum EnvGate {
151 Grant {
153 var: String,
154 value: String,
155 unlocks: Effect,
156 },
157 Require { var: String, value: String },
159}
160
161#[derive(Debug, Clone, Default, Serialize, Deserialize)]
164pub struct PathSpec {
165 #[serde(default)]
166 pub positionals: PathPositionals,
167 #[serde(default)]
168 pub flags: Vec<String>,
169}
170
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
173#[serde(rename_all = "kebab-case")]
174pub enum PathPositionals {
175 #[default]
176 None,
177 All,
178 Tail(usize),
179 Last,
180}
181
182#[derive(Debug, Clone, Default, Serialize, Deserialize)]
183pub struct CommandProperties {
184 #[serde(default)]
185 pub version_flag: Option<String>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct WrapperKnowledge {
192 pub name: String,
193 pub floor_effect: Effect,
195 #[serde(default)]
196 pub clears_env: bool,
197 #[serde(default)]
198 pub escalates_privilege: bool,
199}
200
201#[derive(Debug, Clone, Default, Serialize, Deserialize)]
204pub struct KnowledgeBase {
205 #[serde(default)]
206 pub commands: HashMap<String, CommandKnowledge>,
207 #[serde(default)]
208 pub wrappers: HashMap<String, WrapperKnowledge>,
209}
210
211#[derive(Debug, Clone, PartialEq, Eq)]
215pub struct CommandInfo {
216 pub effect: Effect,
217 pub subcommand: Option<String>,
218 pub has_escalation_flags: bool,
219 pub affected_paths: Vec<Word>,
220 pub env_gates: Vec<EnvGate>,
221 pub wrapper: Option<WrapperInfo>,
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
226pub struct WrapperInfo {
227 pub name: String,
228 pub floor_effect: Effect,
229 pub clears_env: bool,
230 pub escalates_privilege: bool,
231}
232
233impl SubcommandEntry {
234 #[cfg(test)]
236 pub fn with_effect(effect: Effect) -> Self {
237 Self {
238 effect,
239 flags: FlagSchema::default(),
240 env_gates: vec![],
241 paths: PathSpec::default(),
242 subcommands: SubcommandMap::new(),
243 }
244 }
245}
246
247impl CommandKnowledge {
248 #[cfg(test)]
250 pub fn simple(name: impl Into<String>, effect: Effect) -> Self {
251 let name = name.into();
252 Self {
253 name,
254 effect,
255 subcommands: SubcommandMap::new(),
256 flags: FlagSchema::default(),
257 env_gates: vec![],
258 paths: PathSpec::default(),
259 properties: CommandProperties::default(),
260 }
261 }
262}
263
264impl SubcommandMap {
265 #[must_use = "returns an empty SubcommandMap"]
266 pub fn new() -> Self {
267 Self {
268 entries: HashMap::new(),
269 }
270 }
271
272 pub fn insert(&mut self, pattern: impl Into<String>, entry: SubcommandEntry) {
281 let pattern = pattern.into();
282 debug_assert!(
283 pattern.split_whitespace().count() <= MAX_SUBCOMMAND_DEPTH,
284 "subcommand pattern '{}' exceeds MAX_SUBCOMMAND_DEPTH ({})",
285 pattern,
286 MAX_SUBCOMMAND_DEPTH,
287 );
288 self.entries.insert(pattern, entry);
289 }
290
291 #[must_use = "returns the entry if found"]
292 pub fn get(&self, pattern: &str) -> Option<&SubcommandEntry> {
293 self.entries.get(pattern)
294 }
295
296 #[must_use = "returns whether the map has entries"]
297 pub fn is_empty(&self) -> bool {
298 self.entries.is_empty()
299 }
300
301 pub fn iter(&self) -> impl Iterator<Item = (&str, &SubcommandEntry)> {
302 self.entries.iter().map(|(k, v)| (k.as_str(), v))
303 }
304
305 pub fn extend(&mut self, other: SubcommandMap) {
306 for (pattern, entry) in other.entries {
307 self.insert(pattern, entry);
308 }
309 }
310
311 pub fn remove(&mut self, pattern: &str) {
312 self.entries.remove(pattern);
313 }
314
315 #[must_use = "returns the number of entries in the map"]
316 pub fn len(&self) -> usize {
317 self.entries.len()
318 }
319
320 #[must_use = "returns the best-matching entry and how many words it consumed"]
321 pub fn longest_match(&self, words: &[&Word]) -> Option<(&SubcommandEntry, usize)> {
322 let max_depth = words.len().min(MAX_SUBCOMMAND_DEPTH);
323 for depth in (1..=max_depth).rev() {
324 let pattern: String = words[..depth]
325 .iter()
326 .map(|w| w.as_str())
327 .collect::<Vec<_>>()
328 .join(" ");
329 if let Some(entry) = self.entries.get(&pattern) {
330 return Some((entry, depth));
331 }
332 }
333 None
334 }
335}
336
337impl<'a> IntoIterator for &'a SubcommandMap {
338 type Item = (&'a str, &'a SubcommandEntry);
339 type IntoIter = std::iter::Map<
340 std::collections::hash_map::Iter<'a, String, SubcommandEntry>,
341 fn((&'a String, &'a SubcommandEntry)) -> (&'a str, &'a SubcommandEntry),
342 >;
343
344 fn into_iter(self) -> Self::IntoIter {
345 self.entries.iter().map(|(k, v)| (k.as_str(), v))
346 }
347}
348
349impl CommandInfo {
350 #[must_use = "returns a default Unknown classification"]
351 pub fn unknown() -> Self {
352 Self {
353 effect: Effect::Unknown,
354 subcommand: None,
355 has_escalation_flags: false,
356 affected_paths: vec![],
357 env_gates: vec![],
358 wrapper: None,
359 }
360 }
361}
362
363#[cfg(test)]
364#[path = "types_tests.rs"]
365mod types_tests;
366
367#[cfg(test)]
368#[path = "types_proptest.rs"]
369mod types_proptest;