use std::collections::HashSet;
use std::sync::Arc;
use crate::executor::ToolPolicy;
pub use crate::executor::AllowAll;
pub struct AllowList(HashSet<String>);
impl AllowList {
pub fn new<I, S>(tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self(tools.into_iter().map(Into::into).collect())
}
}
impl ToolPolicy for AllowList {
fn is_allowed(&self, name: &str) -> bool {
self.0.contains(name)
}
}
pub struct IntersectPolicy {
pub left: Arc<dyn ToolPolicy>,
pub right: Arc<dyn ToolPolicy>,
}
impl ToolPolicy for IntersectPolicy {
fn is_allowed(&self, name: &str) -> bool {
self.left.is_allowed(name) && self.right.is_allowed(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DenyNamed(HashSet<String>);
impl DenyNamed {
fn new<I, S>(names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self(names.into_iter().map(Into::into).collect())
}
}
impl ToolPolicy for DenyNamed {
fn is_allowed(&self, name: &str) -> bool {
!self.0.contains(name)
}
}
#[test]
fn allow_list_only_allows_named_tools() {
let policy = AllowList::new(["read", "grep"]);
assert!(policy.is_allowed("read"));
assert!(!policy.is_allowed("write"));
}
#[test]
fn allow_list_empty_denies_everything() {
let policy = AllowList::new(std::iter::empty::<String>());
assert!(!policy.is_allowed("read"));
assert!(!policy.is_allowed(""));
}
#[test]
fn intersect_policy_is_deny_monotonic() {
let parent: Arc<dyn ToolPolicy> = Arc::new(AllowList::new(["read", "write"]));
let child: Arc<dyn ToolPolicy> = Arc::new(AllowList::new(["read", "grep"]));
let policy = IntersectPolicy {
left: parent,
right: child,
};
assert!(policy.is_allowed("read"));
assert!(!policy.is_allowed("write"));
assert!(!policy.is_allowed("grep"));
}
#[test]
fn intersect_with_allow_all_reduces_to_other_side() {
let parent: Arc<dyn ToolPolicy> = Arc::new(AllowAll);
let child: Arc<dyn ToolPolicy> = Arc::new(AllowList::new(["read"]));
let policy = IntersectPolicy {
left: parent,
right: child,
};
assert!(policy.is_allowed("read"));
assert!(!policy.is_allowed("write"));
assert!(!policy.is_allowed("anything-else"));
}
#[test]
fn intersect_with_deny_named_subtracts_from_allow_list() {
let parent: Arc<dyn ToolPolicy> = Arc::new(AllowList::new(["read", "write", "bash"]));
let child: Arc<dyn ToolPolicy> = Arc::new(DenyNamed::new(["bash"]));
let policy = IntersectPolicy {
left: parent,
right: child,
};
assert!(policy.is_allowed("read"));
assert!(policy.is_allowed("write"));
assert!(!policy.is_allowed("bash"));
assert!(!policy.is_allowed("grep"));
}
}