use std::collections::HashMap;
use agent_shell_parser::parse::types::Word;
use serde::de::Deserializer;
use serde::{Deserialize, Serialize};
pub const MAX_SUBCOMMAND_DEPTH: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum Effect {
ReadOnly,
Mutating,
Destructive,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandKnowledge {
pub name: String,
pub effect: Effect,
#[serde(default)]
pub subcommands: SubcommandMap,
#[serde(default)]
pub flags: FlagSchema,
#[serde(default)]
pub env_gates: Vec<EnvGate>,
#[serde(default)]
pub paths: PathSpec,
#[serde(default)]
pub properties: CommandProperties,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct SubcommandMap {
entries: HashMap<String, SubcommandEntry>,
}
impl<'de> Deserialize<'de> for SubcommandMap {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct SubcommandMapRepr {
#[serde(default)]
entries: HashMap<String, SubcommandEntry>,
}
let repr = SubcommandMapRepr::deserialize(deserializer)?;
for key in repr.entries.keys() {
if key.split_whitespace().count() > MAX_SUBCOMMAND_DEPTH {
return Err(serde::de::Error::custom(format!(
"subcommand pattern '{}' exceeds MAX_SUBCOMMAND_DEPTH ({})",
key, MAX_SUBCOMMAND_DEPTH
)));
}
}
Ok(SubcommandMap {
entries: repr.entries,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubcommandEntry {
pub effect: Effect,
#[serde(default)]
pub flags: FlagSchema,
#[serde(default)]
pub env_gates: Vec<EnvGate>,
#[serde(default)]
pub paths: PathSpec,
#[serde(default)]
pub subcommands: SubcommandMap,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FlagSchema {
#[serde(default)]
pub skip_arg: Vec<String>,
#[serde(default)]
pub skip_solo: Vec<String>,
#[serde(default)]
pub escalation: Vec<String>,
#[serde(default)]
pub path: Vec<String>,
}
impl FlagSchema {
pub fn extend(&mut self, other: FlagSchema) {
self.skip_arg.extend(other.skip_arg);
self.skip_solo.extend(other.skip_solo);
self.escalation.extend(other.escalation);
self.path.extend(other.path);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum EnvGate {
Grant {
var: String,
value: String,
unlocks: Effect,
},
Require { var: String, value: String },
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PathSpec {
#[serde(default)]
pub positionals: PathPositionals,
#[serde(default)]
pub flags: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum PathPositionals {
#[default]
None,
All,
Tail(usize),
Last,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CommandProperties {
#[serde(default)]
pub version_flag: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WrapperKnowledge {
pub name: String,
pub floor_effect: Effect,
#[serde(default)]
pub clears_env: bool,
#[serde(default)]
pub escalates_privilege: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct KnowledgeBase {
#[serde(default)]
pub commands: HashMap<String, CommandKnowledge>,
#[serde(default)]
pub wrappers: HashMap<String, WrapperKnowledge>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CommandInfo {
pub effect: Effect,
pub subcommand: Option<String>,
pub has_escalation_flags: bool,
pub affected_paths: Vec<Word>,
pub env_gates: Vec<EnvGate>,
pub wrapper: Option<WrapperInfo>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WrapperInfo {
pub name: String,
pub floor_effect: Effect,
pub clears_env: bool,
pub escalates_privilege: bool,
}
impl SubcommandEntry {
#[cfg(test)]
pub fn with_effect(effect: Effect) -> Self {
Self {
effect,
flags: FlagSchema::default(),
env_gates: vec![],
paths: PathSpec::default(),
subcommands: SubcommandMap::new(),
}
}
}
impl CommandKnowledge {
#[cfg(test)]
pub fn simple(name: impl Into<String>, effect: Effect) -> Self {
let name = name.into();
Self {
name,
effect,
subcommands: SubcommandMap::new(),
flags: FlagSchema::default(),
env_gates: vec![],
paths: PathSpec::default(),
properties: CommandProperties::default(),
}
}
}
impl SubcommandMap {
#[must_use = "returns an empty SubcommandMap"]
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn insert(&mut self, pattern: impl Into<String>, entry: SubcommandEntry) {
let pattern = pattern.into();
debug_assert!(
pattern.split_whitespace().count() <= MAX_SUBCOMMAND_DEPTH,
"subcommand pattern '{}' exceeds MAX_SUBCOMMAND_DEPTH ({})",
pattern,
MAX_SUBCOMMAND_DEPTH,
);
self.entries.insert(pattern, entry);
}
#[must_use = "returns the entry if found"]
pub fn get(&self, pattern: &str) -> Option<&SubcommandEntry> {
self.entries.get(pattern)
}
#[must_use = "returns whether the map has entries"]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &SubcommandEntry)> {
self.entries.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn extend(&mut self, other: SubcommandMap) {
for (pattern, entry) in other.entries {
self.insert(pattern, entry);
}
}
pub fn remove(&mut self, pattern: &str) {
self.entries.remove(pattern);
}
#[must_use = "returns the number of entries in the map"]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use = "returns the best-matching entry and how many words it consumed"]
pub fn longest_match(&self, words: &[&Word]) -> Option<(&SubcommandEntry, usize)> {
let max_depth = words.len().min(MAX_SUBCOMMAND_DEPTH);
for depth in (1..=max_depth).rev() {
let pattern: String = words[..depth]
.iter()
.map(|w| w.as_str())
.collect::<Vec<_>>()
.join(" ");
if let Some(entry) = self.entries.get(&pattern) {
return Some((entry, depth));
}
}
None
}
}
impl<'a> IntoIterator for &'a SubcommandMap {
type Item = (&'a str, &'a SubcommandEntry);
type IntoIter = std::iter::Map<
std::collections::hash_map::Iter<'a, String, SubcommandEntry>,
fn((&'a String, &'a SubcommandEntry)) -> (&'a str, &'a SubcommandEntry),
>;
fn into_iter(self) -> Self::IntoIter {
self.entries.iter().map(|(k, v)| (k.as_str(), v))
}
}
impl CommandInfo {
#[must_use = "returns a default Unknown classification"]
pub fn unknown() -> Self {
Self {
effect: Effect::Unknown,
subcommand: None,
has_escalation_flags: false,
affected_paths: vec![],
env_gates: vec![],
wrapper: None,
}
}
}
#[cfg(test)]
#[path = "types_tests.rs"]
mod types_tests;
#[cfg(test)]
#[path = "types_proptest.rs"]
mod types_proptest;