mcp_common/
tool_filter.rs

1//! 工具过滤器
2//!
3//! 提供白名单和黑名单两种过滤模式
4
5use std::collections::HashSet;
6
7/// 工具过滤配置
8#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
9pub struct ToolFilter {
10    /// 白名单(只允许这些工具)
11    pub allow_tools: Option<HashSet<String>>,
12    /// 黑名单(排除这些工具)
13    pub deny_tools: Option<HashSet<String>>,
14}
15
16impl ToolFilter {
17    /// 创建白名单过滤器
18    pub fn allow(tools: Vec<String>) -> Self {
19        Self {
20            allow_tools: Some(tools.into_iter().collect()),
21            deny_tools: None,
22        }
23    }
24
25    /// 创建黑名单过滤器
26    pub fn deny(tools: Vec<String>) -> Self {
27        Self {
28            allow_tools: None,
29            deny_tools: Some(tools.into_iter().collect()),
30        }
31    }
32
33    /// 检查工具是否被允许
34    pub fn is_allowed(&self, tool_name: &str) -> bool {
35        // 白名单模式:只有在白名单中的工具才被允许
36        if let Some(ref allow_list) = self.allow_tools {
37            return allow_list.contains(tool_name);
38        }
39        // 黑名单模式:不在黑名单中的工具都被允许
40        if let Some(ref deny_list) = self.deny_tools {
41            return !deny_list.contains(tool_name);
42        }
43        // 无过滤:全部允许
44        true
45    }
46
47    /// 检查是否启用了过滤
48    pub fn is_enabled(&self) -> bool {
49        self.allow_tools.is_some() || self.deny_tools.is_some()
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn test_allow_filter() {
59        let filter = ToolFilter::allow(vec!["tool1".to_string(), "tool2".to_string()]);
60        assert!(filter.is_allowed("tool1"));
61        assert!(filter.is_allowed("tool2"));
62        assert!(!filter.is_allowed("tool3"));
63    }
64
65    #[test]
66    fn test_deny_filter() {
67        let filter = ToolFilter::deny(vec!["tool1".to_string()]);
68        assert!(!filter.is_allowed("tool1"));
69        assert!(filter.is_allowed("tool2"));
70        assert!(filter.is_allowed("tool3"));
71    }
72
73    #[test]
74    fn test_no_filter() {
75        let filter = ToolFilter::default();
76        assert!(filter.is_allowed("any_tool"));
77        assert!(!filter.is_enabled());
78    }
79}