use std::sync::Arc;
use regex::Regex;
use crate::tool::AgentTool;
#[derive(Debug, Clone)]
pub enum ToolPattern {
Exact(String),
Glob(String),
Regex(Regex),
}
impl ToolPattern {
#[must_use]
pub fn parse(pattern: &str) -> Self {
if pattern.starts_with('^') || pattern.ends_with('$') {
Regex::new(pattern).map_or_else(|_| Self::Exact(pattern.to_string()), Self::Regex)
} else if pattern.contains('*') || pattern.contains('?') {
Self::Glob(pattern.to_string())
} else {
Self::Exact(pattern.to_string())
}
}
#[must_use]
pub fn matches(&self, name: &str) -> bool {
match self {
Self::Exact(pat) => name == pat,
Self::Glob(pat) => glob_matches(pat, name),
Self::Regex(re) => re.is_match(name),
}
}
}
fn glob_matches(pattern: &str, text: &str) -> bool {
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
let mut pattern_idx = 0;
let mut text_idx = 0;
let mut star_idx = None;
let mut match_after_star = 0;
while text_idx < text_chars.len() {
if pattern_idx < pattern_chars.len()
&& (pattern_chars[pattern_idx] == '?'
|| pattern_chars[pattern_idx] == text_chars[text_idx])
{
pattern_idx += 1;
text_idx += 1;
continue;
}
if pattern_idx < pattern_chars.len() && pattern_chars[pattern_idx] == '*' {
star_idx = Some(pattern_idx);
pattern_idx += 1;
match_after_star = text_idx;
continue;
}
if let Some(star) = star_idx {
pattern_idx = star + 1;
match_after_star += 1;
text_idx = match_after_star;
continue;
}
return false;
}
while pattern_idx < pattern_chars.len() && pattern_chars[pattern_idx] == '*' {
pattern_idx += 1;
}
pattern_idx == pattern_chars.len()
}
#[derive(Debug, Clone, Default)]
pub struct ToolFilter {
allowed: Vec<ToolPattern>,
rejected: Vec<ToolPattern>,
}
impl ToolFilter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_allowed(mut self, patterns: Vec<ToolPattern>) -> Self {
self.allowed = patterns;
self
}
#[must_use]
pub fn with_rejected(mut self, patterns: Vec<ToolPattern>) -> Self {
self.rejected = patterns;
self
}
#[must_use]
pub fn is_allowed(&self, name: &str) -> bool {
if self.rejected.iter().any(|p| p.matches(name)) {
return false;
}
if self.allowed.is_empty() {
return true;
}
self.allowed.iter().any(|p| p.matches(name))
}
#[must_use]
pub fn filter_tools(&self, tools: Vec<Arc<dyn AgentTool>>) -> Vec<Arc<dyn AgentTool>> {
tools
.into_iter()
.filter(|t| self.is_allowed(t.name()))
.collect()
}
}
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ToolFilter>();
assert_send_sync::<ToolPattern>();
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exact_pattern_matches() {
let pat = ToolPattern::parse("bash");
assert!(pat.matches("bash"));
assert!(!pat.matches("read_file"));
}
#[test]
fn glob_pattern_matches() {
let pat = ToolPattern::parse("read_*");
assert!(pat.matches("read_file"));
assert!(pat.matches("read_secret"));
assert!(!pat.matches("write_file"));
}
#[test]
fn glob_question_mark_matches_single_char() {
let pat = ToolPattern::parse("tool_?");
assert!(pat.matches("tool_a"));
assert!(!pat.matches("tool_ab"));
}
#[test]
fn glob_star_backtracks_without_regex() {
let pat = ToolPattern::parse("read_*_file");
assert!(pat.matches("read_secret_file"));
assert!(pat.matches("read_very_secret_file"));
assert!(!pat.matches("read_secret_dir"));
}
#[test]
fn glob_handles_unicode_chars() {
let pat = ToolPattern::parse("t?ol_*");
assert!(pat.matches("t🦀ol_alpha"));
assert!(!pat.matches("tool"));
}
#[test]
fn regex_pattern_matches() {
let pat = ToolPattern::parse("^file_.*$");
assert!(pat.matches("file_read"));
assert!(pat.matches("file_write"));
assert!(!pat.matches("bash"));
}
#[test]
fn rejected_takes_precedence() {
let filter = ToolFilter::new()
.with_allowed(vec![ToolPattern::parse("read_*")])
.with_rejected(vec![ToolPattern::parse("read_secret")]);
assert!(filter.is_allowed("read_file"));
assert!(!filter.is_allowed("read_secret"));
}
#[test]
fn empty_filter_allows_all() {
let filter = ToolFilter::new();
assert!(filter.is_allowed("anything"));
assert!(filter.is_allowed("bash"));
}
#[test]
fn allowed_only_restricts_to_matching() {
let filter = ToolFilter::new().with_allowed(vec![ToolPattern::parse("bash")]);
assert!(filter.is_allowed("bash"));
assert!(!filter.is_allowed("read_file"));
}
#[test]
fn rejected_only_excludes_matching() {
let filter = ToolFilter::new().with_rejected(vec![ToolPattern::parse("bash")]);
assert!(!filter.is_allowed("bash"));
assert!(filter.is_allowed("read_file"));
}
#[test]
fn invalid_regex_falls_back_to_exact() {
let pat = ToolPattern::parse("^[invalid");
assert!(pat.matches("^[invalid"));
}
#[test]
fn parse_detects_pattern_type() {
assert!(matches!(ToolPattern::parse("exact"), ToolPattern::Exact(_)));
assert!(matches!(ToolPattern::parse("glob_*"), ToolPattern::Glob(_)));
assert!(matches!(
ToolPattern::parse("^regex$"),
ToolPattern::Regex(_)
));
}
}