mcp_common/
tool_filter.rs1use std::collections::HashSet;
6
7#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
9pub struct ToolFilter {
10 pub allow_tools: Option<HashSet<String>>,
12 pub deny_tools: Option<HashSet<String>>,
14}
15
16impl ToolFilter {
17 pub fn allow(tools: Vec<String>) -> Self {
19 Self {
20 allow_tools: Some(tools.into_iter().collect()),
21 deny_tools: None,
22 }
23 }
24
25 pub fn deny(tools: Vec<String>) -> Self {
27 Self {
28 allow_tools: None,
29 deny_tools: Some(tools.into_iter().collect()),
30 }
31 }
32
33 pub fn is_allowed(&self, tool_name: &str) -> bool {
35 if let Some(ref allow_list) = self.allow_tools {
37 return allow_list.contains(tool_name);
38 }
39 if let Some(ref deny_list) = self.deny_tools {
41 return !deny_list.contains(tool_name);
42 }
43 true
45 }
46
47 pub fn is_enabled(&self) -> bool {
49 self.allow_tools.is_some() || self.deny_tools.is_some()
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn test_allow_filter() {
59 let filter = ToolFilter::allow(vec!["tool1".to_string(), "tool2".to_string()]);
60 assert!(filter.is_allowed("tool1"));
61 assert!(filter.is_allowed("tool2"));
62 assert!(!filter.is_allowed("tool3"));
63 }
64
65 #[test]
66 fn test_deny_filter() {
67 let filter = ToolFilter::deny(vec!["tool1".to_string()]);
68 assert!(!filter.is_allowed("tool1"));
69 assert!(filter.is_allowed("tool2"));
70 assert!(filter.is_allowed("tool3"));
71 }
72
73 #[test]
74 fn test_no_filter() {
75 let filter = ToolFilter::default();
76 assert!(filter.is_allowed("any_tool"));
77 assert!(!filter.is_enabled());
78 }
79}