use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CapabilitySet {
pub allowed_tools: HashSet<String>,
pub denied_tools: HashSet<String>,
pub allowed_state_keys: HashSet<String>,
pub max_actions: Option<u32>,
}
impl CapabilitySet {
pub fn new() -> Self {
Self::default()
}
pub fn allow_tool(mut self, tool: &str) -> Self {
self.allowed_tools.insert(tool.to_string());
self
}
pub fn deny_tool(mut self, tool: &str) -> Self {
self.denied_tools.insert(tool.to_string());
self
}
pub fn allow_state_key(mut self, key: &str) -> Self {
self.allowed_state_keys.insert(key.to_string());
self
}
pub fn with_max_actions(mut self, max: u32) -> Self {
self.max_actions = Some(max);
self
}
pub fn tool_allowed(&self, tool: &str) -> bool {
if self.denied_tools.contains(tool) {
return false;
}
if self.allowed_tools.is_empty() {
return true;
}
self.allowed_tools.contains(tool)
}
pub fn state_key_allowed(&self, key: &str) -> bool {
if self.allowed_state_keys.is_empty() {
return true;
}
self.allowed_state_keys.contains(key)
}
pub fn actions_within_budget(&self, count: u32) -> bool {
match self.max_actions {
None => true,
Some(max) => count <= max,
}
}
}