1use std::collections::HashMap;
33
34#[derive(Debug, Clone)]
36pub struct ToolSearchConfig {
37 pub threshold: f32,
40 pub max_tools: usize,
43 pub always_include_builtins: bool,
46 pub enabled: bool,
49}
50
51impl Default for ToolSearchConfig {
52 fn default() -> Self {
53 Self {
54 threshold: 0.3,
55 max_tools: 20,
56 always_include_builtins: true,
57 enabled: true,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64struct ToolEntry {
65 name: String,
67 #[allow(dead_code)]
69 description: String,
70 keywords: Vec<String>,
72 is_builtin: bool,
74}
75
76#[derive(Debug, Clone)]
78pub struct ToolMatch {
79 pub name: String,
81 pub score: f32,
83 pub is_builtin: bool,
85}
86
87#[derive(Clone)]
89pub struct ToolIndex {
90 config: ToolSearchConfig,
91 entries: HashMap<String, ToolEntry>,
92}
93
94impl ToolIndex {
95 pub fn new(config: ToolSearchConfig) -> Self {
97 Self {
98 config,
99 entries: HashMap::new(),
100 }
101 }
102
103 pub fn add(&mut self, name: &str, description: &str, extra_keywords: &[&str]) {
105 let is_builtin = !name.starts_with("mcp__");
106
107 let mut keywords: Vec<String> = Vec::new();
109
110 for part in name.split("__").flat_map(|s| s.split('_')) {
112 if part.len() >= 2 {
113 keywords.push(part.to_lowercase());
114 }
115 }
116
117 for word in description.split_whitespace() {
119 let clean = word
120 .trim_matches(|c: char| !c.is_alphanumeric())
121 .to_lowercase();
122 if clean.len() >= 3 {
123 keywords.push(clean);
124 }
125 }
126
127 for kw in extra_keywords {
129 keywords.push(kw.to_lowercase());
130 }
131
132 self.entries.insert(
133 name.to_string(),
134 ToolEntry {
135 name: name.to_string(),
136 description: description.to_string(),
137 keywords,
138 is_builtin,
139 },
140 );
141 }
142
143 pub fn remove(&mut self, name: &str) -> bool {
145 self.entries.remove(name).is_some()
146 }
147
148 pub fn search(&self, query: &str, max_results: usize) -> Vec<ToolMatch> {
153 if !self.config.enabled {
154 return self
156 .entries
157 .values()
158 .map(|e| ToolMatch {
159 name: e.name.clone(),
160 score: 1.0,
161 is_builtin: e.is_builtin,
162 })
163 .collect();
164 }
165
166 let query_tokens = tokenize(query);
167 if query_tokens.is_empty() {
168 return self
170 .entries
171 .values()
172 .filter(|e| e.is_builtin)
173 .map(|e| ToolMatch {
174 name: e.name.clone(),
175 score: 1.0,
176 is_builtin: true,
177 })
178 .collect();
179 }
180
181 let mut matches: Vec<ToolMatch> = self
182 .entries
183 .values()
184 .map(|entry| {
185 let score = compute_relevance(&query_tokens, entry);
186 ToolMatch {
187 name: entry.name.clone(),
188 score,
189 is_builtin: entry.is_builtin,
190 }
191 })
192 .filter(|m| {
193 if self.config.always_include_builtins && m.is_builtin {
194 true
195 } else {
196 m.score >= self.config.threshold
197 }
198 })
199 .collect();
200
201 matches.sort_by(|a, b| {
203 b.score
204 .partial_cmp(&a.score)
205 .unwrap_or(std::cmp::Ordering::Equal)
206 });
207
208 let limit = max_results.min(self.config.max_tools);
210 matches.truncate(limit);
211
212 matches
213 }
214
215 pub fn len(&self) -> usize {
217 self.entries.len()
218 }
219
220 pub fn is_empty(&self) -> bool {
222 self.entries.is_empty()
223 }
224
225 pub fn tool_names(&self) -> Vec<&str> {
227 self.entries.keys().map(|s| s.as_str()).collect()
228 }
229}
230
231fn tokenize(text: &str) -> Vec<String> {
233 text.split_whitespace()
234 .map(|w| {
235 w.trim_matches(|c: char| !c.is_alphanumeric())
236 .to_lowercase()
237 })
238 .filter(|w| w.len() >= 2)
239 .collect()
240}
241
242fn compute_relevance(query_tokens: &[String], entry: &ToolEntry) -> f32 {
244 if query_tokens.is_empty() || entry.keywords.is_empty() {
245 return 0.0;
246 }
247
248 let mut matched = 0u32;
249 let mut partial = 0u32;
250
251 for qt in query_tokens {
252 if entry.keywords.iter().any(|kw| kw == qt) {
254 matched += 2;
255 }
256 else if entry
259 .keywords
260 .iter()
261 .any(|kw| kw.contains(qt.as_str()) || qt.contains(kw.as_str()))
262 || entry.name.to_lowercase().contains(qt.as_str())
263 {
264 partial += 1;
265 }
266 }
267
268 let total_score = (matched as f32 * 1.0) + (partial as f32 * 0.5);
269 let max_possible = query_tokens.len() as f32 * 2.0;
270
271 (total_score / max_possible).min(1.0)
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 fn build_index() -> ToolIndex {
279 let mut index = ToolIndex::new(ToolSearchConfig::default());
280 index.add(
282 "bash",
283 "Execute shell commands",
284 &["shell", "terminal", "run"],
285 );
286 index.add("read", "Read file contents", &["file", "open", "cat"]);
287 index.add(
288 "write",
289 "Write content to a file",
290 &["file", "save", "create"],
291 );
292 index.add(
293 "edit",
294 "Edit a file with search and replace",
295 &["modify", "change", "replace"],
296 );
297 index.add(
298 "grep",
299 "Search file contents",
300 &["search", "find", "pattern"],
301 );
302 index.add("glob", "Find files by pattern", &["find", "files", "match"]);
303 index.add(
305 "mcp__github__create_issue",
306 "Create a GitHub issue",
307 &["github", "issue", "bug", "ticket"],
308 );
309 index.add(
310 "mcp__github__list_prs",
311 "List pull requests",
312 &["github", "pull", "request", "pr"],
313 );
314 index.add(
315 "mcp__postgres__query",
316 "Execute a SQL query against PostgreSQL",
317 &["sql", "database", "postgres", "db"],
318 );
319 index.add(
320 "mcp__fetch__fetch",
321 "Fetch a URL and return its content",
322 &["http", "url", "web", "download"],
323 );
324 index.add(
325 "mcp__sentry__get_issues",
326 "Get issues from Sentry",
327 &["sentry", "error", "monitoring", "crash"],
328 );
329 index
330 }
331
332 #[test]
333 fn test_search_github() {
334 let index = build_index();
335 let matches = index.search("create a bug report on GitHub", 10);
336 let names: Vec<&str> = matches.iter().map(|m| m.name.as_str()).collect();
337 assert!(names.contains(&"mcp__github__create_issue"));
338 }
339
340 #[test]
341 fn test_search_database() {
342 let index = build_index();
343 let matches = index.search("run a SQL query on the database", 10);
344 let names: Vec<&str> = matches.iter().map(|m| m.name.as_str()).collect();
345 assert!(names.contains(&"mcp__postgres__query"));
346 }
347
348 #[test]
349 fn test_search_web() {
350 let index = build_index();
351 let matches = index.search("fetch the URL content", 10);
352 let mcp_matches: Vec<&str> = matches
353 .iter()
354 .filter(|m| !m.is_builtin)
355 .map(|m| m.name.as_str())
356 .collect();
357 assert!(mcp_matches.contains(&"mcp__fetch__fetch"));
358 }
359
360 #[test]
361 fn test_builtins_always_included() {
362 let index = build_index();
363 let matches = index.search("create a GitHub issue", 20);
364 let builtins: Vec<&str> = matches
365 .iter()
366 .filter(|m| m.is_builtin)
367 .map(|m| m.name.as_str())
368 .collect();
369 assert!(builtins.contains(&"bash"));
371 assert!(builtins.contains(&"read"));
372 }
373
374 #[test]
375 fn test_empty_query_returns_builtins() {
376 let index = build_index();
377 let matches = index.search("", 20);
378 assert!(matches.iter().all(|m| m.is_builtin));
379 }
380
381 #[test]
382 fn test_disabled_returns_all() {
383 let mut index = build_index();
384 index.config.enabled = false;
385 let matches = index.search("anything", 100);
386 assert_eq!(matches.len(), index.len());
387 }
388
389 #[test]
390 fn test_max_results_limit() {
391 let index = build_index();
392 let matches = index.search("file search", 3);
393 assert!(matches.len() <= 3);
394 }
395
396 #[test]
397 fn test_remove_tool() {
398 let mut index = build_index();
399 let before = index.len();
400 assert!(index.remove("mcp__sentry__get_issues"));
401 assert_eq!(index.len(), before - 1);
402 assert!(!index.remove("nonexistent"));
403 }
404
405 #[test]
406 fn test_threshold_filtering() {
407 let config = ToolSearchConfig {
408 threshold: 0.9,
409 always_include_builtins: false,
410 ..Default::default()
411 };
412 let mut index = ToolIndex::new(config);
413 index.add("mcp__foo__bar", "Completely unrelated tool", &["xyz"]);
414 let matches = index.search("github issue", 10);
415 assert!(matches.is_empty());
417 }
418}