use std::collections::HashSet;
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct ToolFilter {
pub allow_tools: Option<HashSet<String>>,
pub deny_tools: Option<HashSet<String>>,
}
impl ToolFilter {
pub fn allow(tools: Vec<String>) -> Self {
Self {
allow_tools: Some(tools.into_iter().collect()),
deny_tools: None,
}
}
pub fn deny(tools: Vec<String>) -> Self {
Self {
allow_tools: None,
deny_tools: Some(tools.into_iter().collect()),
}
}
pub fn is_allowed(&self, tool_name: &str) -> bool {
if let Some(ref allow_list) = self.allow_tools {
return allow_list.contains(tool_name);
}
if let Some(ref deny_list) = self.deny_tools {
return !deny_list.contains(tool_name);
}
true
}
pub fn is_enabled(&self) -> bool {
self.allow_tools.is_some() || self.deny_tools.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allow_filter() {
let filter = ToolFilter::allow(vec!["tool1".to_string(), "tool2".to_string()]);
assert!(filter.is_allowed("tool1"));
assert!(filter.is_allowed("tool2"));
assert!(!filter.is_allowed("tool3"));
}
#[test]
fn test_deny_filter() {
let filter = ToolFilter::deny(vec!["tool1".to_string()]);
assert!(!filter.is_allowed("tool1"));
assert!(filter.is_allowed("tool2"));
assert!(filter.is_allowed("tool3"));
}
#[test]
fn test_no_filter() {
let filter = ToolFilter::default();
assert!(filter.is_allowed("any_tool"));
assert!(!filter.is_enabled());
}
}