1use std::collections::HashSet;
9
10use crate::mcp::classify::{classify, Class, ToolMeta, ALL_GROUPS};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Profile {
14 All,
15 Read,
16 Safe,
17}
18
19impl Profile {
20 fn parse(s: &str) -> Option<Self> {
21 match s {
22 "all" => Some(Profile::All),
23 "read" => Some(Profile::Read),
24 "safe" => Some(Profile::Safe),
25 _ => None,
26 }
27 }
28
29 fn allows_class(self, class: Class) -> bool {
30 match (self, class) {
31 (Profile::All, _) => true,
32 (Profile::Read, Class::Read) => true,
33 (Profile::Read, _) => false,
34 (Profile::Safe, Class::Destructive) => false,
35 (Profile::Safe, _) => true,
36 }
37 }
38
39 pub fn as_str(self) -> &'static str {
40 match self {
41 Profile::All => "all",
42 Profile::Read => "read",
43 Profile::Safe => "safe",
44 }
45 }
46}
47
48#[derive(Debug, Default, Clone)]
50pub struct RawFilter {
51 pub profile: Option<String>,
52 pub read_only: bool,
53 pub groups: Option<Vec<String>>,
54 pub exclude_groups: Option<Vec<String>>,
55 pub tools: Option<Vec<String>>,
56 pub exclude_tools: Option<Vec<String>>,
57}
58
59#[derive(Debug)]
60pub enum FilterError {
61 UnknownProfile {
62 name: String,
63 },
64 UnknownGroup {
65 name: String,
66 valid: Vec<&'static str>,
67 },
68 UnknownTool {
69 name: String,
70 suggestion: Option<String>,
71 },
72 ConflictingProfile {
73 profile: String,
74 },
75 ToolExcludedByProfile {
76 tool: String,
77 profile: String,
78 },
79 EmptyFilter,
80}
81
82impl std::fmt::Display for FilterError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 FilterError::UnknownProfile { name } => {
86 write!(f, "unknown --profile: {} (valid: all, read, safe)", name)
87 }
88 FilterError::UnknownGroup { name, valid } => {
89 write!(f, "unknown group: {} (valid: {})", name, valid.join(", "))
90 }
91 FilterError::UnknownTool { name, suggestion } => match suggestion {
92 Some(s) => write!(f, "unknown tool: {} (did you mean {}?)", name, s),
93 None => write!(f, "unknown tool: {}", name),
94 },
95 FilterError::ConflictingProfile { profile } => {
96 write!(
97 f,
98 "conflicting profile flags: --read-only and --profile {}",
99 profile
100 )
101 }
102 FilterError::ToolExcludedByProfile { tool, profile } => write!(
103 f,
104 "tool {} is excluded by profile={}; drop --profile or remove {} from --tools",
105 tool, profile, tool
106 ),
107 FilterError::EmptyFilter => {
108 write!(
109 f,
110 "filter pipeline produced an empty tool set; nothing to expose"
111 )
112 }
113 }
114 }
115}
116
117impl std::error::Error for FilterError {}
118
119#[derive(Debug)]
120pub struct Filter {
121 pub profile: Profile,
122 pub groups: Option<Vec<String>>,
123 pub exclude_groups: Option<Vec<String>>,
124 allowed: HashSet<String>,
125}
126
127impl Filter {
128 pub fn allows(&self, tool_name: &str) -> bool {
129 self.allowed.contains(tool_name)
130 }
131
132 pub fn allowed_count(&self) -> usize {
133 self.allowed.len()
134 }
135
136 pub fn resolve(raw: RawFilter) -> Result<Self, FilterError> {
137 let profile = match (raw.profile.as_deref(), raw.read_only) {
139 (None, false) => Profile::All,
140 (None, true) => Profile::Read,
141 (Some(p), false) => {
142 Profile::parse(p).ok_or_else(|| FilterError::UnknownProfile { name: p.into() })?
143 }
144 (Some("read"), true) => Profile::Read,
145 (Some(p), true) => return Err(FilterError::ConflictingProfile { profile: p.into() }),
146 };
147 let profile_label = profile.as_str();
148
149 for groups in [&raw.groups, &raw.exclude_groups].into_iter().flatten() {
151 for g in groups {
152 if !ALL_GROUPS.contains(&g.as_str()) {
153 return Err(FilterError::UnknownGroup {
154 name: g.clone(),
155 valid: ALL_GROUPS.to_vec(),
156 });
157 }
158 }
159 }
160
161 let all_names: Vec<(String, ToolMeta)> = crate::mcp::tool_list()
163 .as_array()
164 .expect("tool_list returns array")
165 .iter()
166 .filter_map(|t| {
167 t.get("name")
168 .and_then(|v| v.as_str())
169 .and_then(|n| classify(n).map(|m| (n.to_string(), m)))
170 })
171 .collect();
172 let known_names: HashSet<&str> = all_names.iter().map(|(n, _)| n.as_str()).collect();
173
174 for tools in [&raw.tools, &raw.exclude_tools].into_iter().flatten() {
176 for t in tools {
177 if !known_names.contains(t.as_str()) {
178 return Err(FilterError::UnknownTool {
179 name: t.clone(),
180 suggestion: closest_name(t, &known_names),
181 });
182 }
183 }
184 }
185
186 if let Some(tools) = &raw.tools {
188 for t in tools {
189 if let Some(meta) = all_names.iter().find(|(n, _)| n == t).map(|(_, m)| m) {
190 if !profile.allows_class(meta.class) {
191 return Err(FilterError::ToolExcludedByProfile {
192 tool: t.clone(),
193 profile: profile_label.into(),
194 });
195 }
196 }
197 }
198 }
199
200 let mut allowed: HashSet<String> = all_names
202 .iter()
203 .filter(|(_, m)| profile.allows_class(m.class))
204 .map(|(n, _)| n.clone())
205 .collect();
206
207 if let Some(groups) = &raw.groups {
208 allowed.retain(|n| {
209 let g = all_names
210 .iter()
211 .find(|(name, _)| name == n)
212 .map(|(_, m)| m.group)
213 .unwrap_or("");
214 groups.iter().any(|wanted| wanted == g)
215 });
216 }
217
218 if let Some(excl) = &raw.exclude_groups {
219 allowed.retain(|n| {
220 let g = all_names
221 .iter()
222 .find(|(name, _)| name == n)
223 .map(|(_, m)| m.group)
224 .unwrap_or("");
225 !excl.iter().any(|bad| bad == g)
226 });
227 }
228
229 if let Some(tools) = &raw.tools {
230 let wanted: HashSet<&str> = tools.iter().map(String::as_str).collect();
231 allowed.retain(|n| wanted.contains(n.as_str()));
232 }
233
234 if let Some(excl) = &raw.exclude_tools {
235 for t in excl {
236 allowed.remove(t);
237 }
238 }
239
240 if allowed.is_empty() {
241 return Err(FilterError::EmptyFilter);
242 }
243
244 Ok(Filter {
245 profile,
246 groups: raw.groups,
247 exclude_groups: raw.exclude_groups,
248 allowed,
249 })
250 }
251}
252
253fn closest_name(needle: &str, haystack: &HashSet<&str>) -> Option<String> {
254 haystack
255 .iter()
256 .map(|c| (c, levenshtein(needle, c)))
257 .min_by_key(|&(_, d)| d)
258 .filter(|&(_, d)| d <= 3)
259 .map(|(s, _)| s.to_string())
260}
261
262fn levenshtein(a: &str, b: &str) -> usize {
263 let (a, b) = (a.as_bytes(), b.as_bytes());
264 let (n, m) = (a.len(), b.len());
265 if n == 0 {
266 return m;
267 }
268 if m == 0 {
269 return n;
270 }
271 let mut prev: Vec<usize> = (0..=m).collect();
272 let mut curr = vec![0usize; m + 1];
273 for i in 1..=n {
274 curr[0] = i;
275 for j in 1..=m {
276 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
277 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
278 }
279 std::mem::swap(&mut prev, &mut curr);
280 }
281 prev[m]
282}