Skip to main content

a3s_code_core/
tool_search.rs

1//! Tool Search — Semantic tool matching for dynamic MCP tool loading
2//!
3//! When the MCP ecosystem grows large (100+ tools), injecting all tool
4//! descriptions into the system prompt wastes context. Tool Search selects
5//! only the relevant tools per-turn based on keyword and semantic matching.
6//!
7//! ## How It Works
8//!
9//! 1. All registered tools (builtin + MCP) are indexed with their name,
10//!    description, and parameter schema.
11//! 2. Before each LLM call, the user prompt is matched against the index.
12//! 3. Only tools scoring above the threshold are included in the request.
13//! 4. Builtin tools are always included (they're small and essential).
14//!
15//! ## Usage
16//!
17//! ```rust,no_run
18//! use a3s_code_core::tool_search::{ToolIndex, ToolSearchConfig};
19//!
20//! let config = ToolSearchConfig::default();
21//! let mut index = ToolIndex::new(config);
22//!
23//! // Index tools
24//! index.add("mcp__github__create_issue", "Create a GitHub issue", &["github", "issue", "bug"]);
25//! index.add("mcp__postgres__query", "Run a SQL query", &["sql", "database", "query"]);
26//!
27//! // Search
28//! let matches = index.search("create a bug report on GitHub", 5);
29//! // → ["mcp__github__create_issue"]
30//! ```
31
32use std::collections::HashMap;
33
34/// Configuration for tool search behavior.
35#[derive(Debug, Clone)]
36pub struct ToolSearchConfig {
37    /// Minimum relevance score (0.0–1.0) for a tool to be included.
38    /// Default: 0.3
39    pub threshold: f32,
40    /// Maximum number of MCP tools to include per turn.
41    /// Default: 20
42    pub max_tools: usize,
43    /// Always include builtin tools regardless of score.
44    /// Default: true
45    pub always_include_builtins: bool,
46    /// Enable tool search. When false, all tools are included (legacy behavior).
47    /// Default: true
48    pub enabled: bool,
49}
50
51impl Default for ToolSearchConfig {
52    fn default() -> Self {
53        Self {
54            threshold: 0.3,
55            max_tools: 20,
56            always_include_builtins: true,
57            enabled: true,
58        }
59    }
60}
61
62/// An indexed tool entry.
63#[derive(Debug, Clone)]
64struct ToolEntry {
65    /// Tool name (e.g., "mcp__github__create_issue").
66    name: String,
67    /// Tool description (stored for future semantic search).
68    #[allow(dead_code)]
69    description: String,
70    /// Searchable keywords (extracted from name, description, params).
71    keywords: Vec<String>,
72    /// Whether this is a builtin tool.
73    is_builtin: bool,
74}
75
76/// A scored search result.
77#[derive(Debug, Clone)]
78pub struct ToolMatch {
79    /// Tool name.
80    pub name: String,
81    /// Relevance score (0.0–1.0).
82    pub score: f32,
83    /// Whether this is a builtin tool.
84    pub is_builtin: bool,
85}
86
87/// Index of all registered tools for semantic search.
88#[derive(Clone)]
89pub struct ToolIndex {
90    config: ToolSearchConfig,
91    entries: HashMap<String, ToolEntry>,
92}
93
94impl ToolIndex {
95    /// Create a new tool index with the given configuration.
96    pub fn new(config: ToolSearchConfig) -> Self {
97        Self {
98            config,
99            entries: HashMap::new(),
100        }
101    }
102
103    /// Add a tool to the index.
104    pub fn add(&mut self, name: &str, description: &str, extra_keywords: &[&str]) {
105        let is_builtin = !name.starts_with("mcp__");
106
107        // Extract keywords from name and description
108        let mut keywords: Vec<String> = Vec::new();
109
110        // Split tool name by underscores and double-underscores
111        for part in name.split("__").flat_map(|s| s.split('_')) {
112            if part.len() >= 2 {
113                keywords.push(part.to_lowercase());
114            }
115        }
116
117        // Extract words from description
118        for word in description.split_whitespace() {
119            let clean = word
120                .trim_matches(|c: char| !c.is_alphanumeric())
121                .to_lowercase();
122            if clean.len() >= 3 {
123                keywords.push(clean);
124            }
125        }
126
127        // Add extra keywords
128        for kw in extra_keywords {
129            keywords.push(kw.to_lowercase());
130        }
131
132        self.entries.insert(
133            name.to_string(),
134            ToolEntry {
135                name: name.to_string(),
136                description: description.to_string(),
137                keywords,
138                is_builtin,
139            },
140        );
141    }
142
143    /// Remove a tool from the index.
144    pub fn remove(&mut self, name: &str) -> bool {
145        self.entries.remove(name).is_some()
146    }
147
148    /// Search for tools relevant to the given query.
149    ///
150    /// Returns tools sorted by relevance score (highest first),
151    /// limited to `max_results` entries.
152    pub fn search(&self, query: &str, max_results: usize) -> Vec<ToolMatch> {
153        if !self.config.enabled {
154            // Return all tools when search is disabled
155            return self
156                .entries
157                .values()
158                .map(|e| ToolMatch {
159                    name: e.name.clone(),
160                    score: 1.0,
161                    is_builtin: e.is_builtin,
162                })
163                .collect();
164        }
165
166        let query_tokens = tokenize(query);
167        if query_tokens.is_empty() {
168            // Empty query: return builtins only
169            return self
170                .entries
171                .values()
172                .filter(|e| e.is_builtin)
173                .map(|e| ToolMatch {
174                    name: e.name.clone(),
175                    score: 1.0,
176                    is_builtin: true,
177                })
178                .collect();
179        }
180
181        let mut matches: Vec<ToolMatch> = self
182            .entries
183            .values()
184            .map(|entry| {
185                let score = compute_relevance(&query_tokens, entry);
186                ToolMatch {
187                    name: entry.name.clone(),
188                    score,
189                    is_builtin: entry.is_builtin,
190                }
191            })
192            .filter(|m| {
193                if self.config.always_include_builtins && m.is_builtin {
194                    true
195                } else {
196                    m.score >= self.config.threshold
197                }
198            })
199            .collect();
200
201        // Sort by score descending
202        matches.sort_by(|a, b| {
203            b.score
204                .partial_cmp(&a.score)
205                .unwrap_or(std::cmp::Ordering::Equal)
206        });
207
208        // Limit results
209        let limit = max_results.min(self.config.max_tools);
210        matches.truncate(limit);
211
212        matches
213    }
214
215    /// Number of indexed tools.
216    pub fn len(&self) -> usize {
217        self.entries.len()
218    }
219
220    /// Whether the index is empty.
221    pub fn is_empty(&self) -> bool {
222        self.entries.is_empty()
223    }
224
225    /// Get all tool names in the index.
226    pub fn tool_names(&self) -> Vec<&str> {
227        self.entries.keys().map(|s| s.as_str()).collect()
228    }
229}
230
231/// Tokenize a query string into lowercase words.
232fn tokenize(text: &str) -> Vec<String> {
233    text.split_whitespace()
234        .map(|w| {
235            w.trim_matches(|c: char| !c.is_alphanumeric())
236                .to_lowercase()
237        })
238        .filter(|w| w.len() >= 2)
239        .collect()
240}
241
242/// Compute relevance score between query tokens and a tool entry.
243fn compute_relevance(query_tokens: &[String], entry: &ToolEntry) -> f32 {
244    if query_tokens.is_empty() || entry.keywords.is_empty() {
245        return 0.0;
246    }
247
248    let mut matched = 0u32;
249    let mut partial = 0u32;
250
251    for qt in query_tokens {
252        // Exact keyword match
253        if entry.keywords.iter().any(|kw| kw == qt) {
254            matched += 2;
255        }
256        // Substring match (query token contained in keyword or vice versa)
257        // or tool name contains the token
258        else if entry
259            .keywords
260            .iter()
261            .any(|kw| kw.contains(qt.as_str()) || qt.contains(kw.as_str()))
262            || entry.name.to_lowercase().contains(qt.as_str())
263        {
264            partial += 1;
265        }
266    }
267
268    let total_score = (matched as f32 * 1.0) + (partial as f32 * 0.5);
269    let max_possible = query_tokens.len() as f32 * 2.0;
270
271    (total_score / max_possible).min(1.0)
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    fn build_index() -> ToolIndex {
279        let mut index = ToolIndex::new(ToolSearchConfig::default());
280        // Builtins
281        index.add(
282            "bash",
283            "Execute shell commands",
284            &["shell", "terminal", "run"],
285        );
286        index.add("read", "Read file contents", &["file", "open", "cat"]);
287        index.add(
288            "write",
289            "Write content to a file",
290            &["file", "save", "create"],
291        );
292        index.add(
293            "edit",
294            "Edit a file with search and replace",
295            &["modify", "change", "replace"],
296        );
297        index.add(
298            "grep",
299            "Search file contents",
300            &["search", "find", "pattern"],
301        );
302        index.add("glob", "Find files by pattern", &["find", "files", "match"]);
303        // MCP tools
304        index.add(
305            "mcp__github__create_issue",
306            "Create a GitHub issue",
307            &["github", "issue", "bug", "ticket"],
308        );
309        index.add(
310            "mcp__github__list_prs",
311            "List pull requests",
312            &["github", "pull", "request", "pr"],
313        );
314        index.add(
315            "mcp__postgres__query",
316            "Execute a SQL query against PostgreSQL",
317            &["sql", "database", "postgres", "db"],
318        );
319        index.add(
320            "mcp__fetch__fetch",
321            "Fetch a URL and return its content",
322            &["http", "url", "web", "download"],
323        );
324        index.add(
325            "mcp__sentry__get_issues",
326            "Get issues from Sentry",
327            &["sentry", "error", "monitoring", "crash"],
328        );
329        index
330    }
331
332    #[test]
333    fn test_search_github() {
334        let index = build_index();
335        let matches = index.search("create a bug report on GitHub", 10);
336        let names: Vec<&str> = matches.iter().map(|m| m.name.as_str()).collect();
337        assert!(names.contains(&"mcp__github__create_issue"));
338    }
339
340    #[test]
341    fn test_search_database() {
342        let index = build_index();
343        let matches = index.search("run a SQL query on the database", 10);
344        let names: Vec<&str> = matches.iter().map(|m| m.name.as_str()).collect();
345        assert!(names.contains(&"mcp__postgres__query"));
346    }
347
348    #[test]
349    fn test_search_web() {
350        let index = build_index();
351        let matches = index.search("fetch the URL content", 10);
352        let mcp_matches: Vec<&str> = matches
353            .iter()
354            .filter(|m| !m.is_builtin)
355            .map(|m| m.name.as_str())
356            .collect();
357        assert!(mcp_matches.contains(&"mcp__fetch__fetch"));
358    }
359
360    #[test]
361    fn test_builtins_always_included() {
362        let index = build_index();
363        let matches = index.search("create a GitHub issue", 20);
364        let builtins: Vec<&str> = matches
365            .iter()
366            .filter(|m| m.is_builtin)
367            .map(|m| m.name.as_str())
368            .collect();
369        // All builtins should be present
370        assert!(builtins.contains(&"bash"));
371        assert!(builtins.contains(&"read"));
372    }
373
374    #[test]
375    fn test_empty_query_returns_builtins() {
376        let index = build_index();
377        let matches = index.search("", 20);
378        assert!(matches.iter().all(|m| m.is_builtin));
379    }
380
381    #[test]
382    fn test_disabled_returns_all() {
383        let mut index = build_index();
384        index.config.enabled = false;
385        let matches = index.search("anything", 100);
386        assert_eq!(matches.len(), index.len());
387    }
388
389    #[test]
390    fn test_max_results_limit() {
391        let index = build_index();
392        let matches = index.search("file search", 3);
393        assert!(matches.len() <= 3);
394    }
395
396    #[test]
397    fn test_remove_tool() {
398        let mut index = build_index();
399        let before = index.len();
400        assert!(index.remove("mcp__sentry__get_issues"));
401        assert_eq!(index.len(), before - 1);
402        assert!(!index.remove("nonexistent"));
403    }
404
405    #[test]
406    fn test_threshold_filtering() {
407        let config = ToolSearchConfig {
408            threshold: 0.9,
409            always_include_builtins: false,
410            ..Default::default()
411        };
412        let mut index = ToolIndex::new(config);
413        index.add("mcp__foo__bar", "Completely unrelated tool", &["xyz"]);
414        let matches = index.search("github issue", 10);
415        // High threshold + unrelated tool = no matches
416        assert!(matches.is_empty());
417    }
418}