claude_agent/tools/search/
engine.rs

1//! Search engine implementations for tool discovery.
2
3use super::index::{ToolIndex, ToolIndexEntry};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum SearchMode {
7    #[default]
8    Regex,
9    Bm25,
10}
11
12#[derive(Debug, Clone)]
13pub struct SearchHit {
14    pub entry: ToolIndexEntry,
15    pub score: f64,
16}
17
18pub struct SearchEngine {
19    mode: SearchMode,
20}
21
22impl SearchEngine {
23    pub fn new(mode: SearchMode) -> Self {
24        Self { mode }
25    }
26
27    pub fn regex() -> Self {
28        Self::new(SearchMode::Regex)
29    }
30
31    pub fn bm25() -> Self {
32        Self::new(SearchMode::Bm25)
33    }
34
35    pub fn mode(&self) -> SearchMode {
36        self.mode
37    }
38
39    pub fn search(&self, index: &ToolIndex, query: &str, limit: usize) -> Vec<SearchHit> {
40        if query.is_empty() || index.is_empty() {
41            return Vec::new();
42        }
43
44        match self.mode {
45            SearchMode::Regex => self.search_regex(index, query, limit),
46            SearchMode::Bm25 => self.search_bm25(index, query, limit),
47        }
48    }
49
50    fn search_regex(&self, index: &ToolIndex, pattern: &str, limit: usize) -> Vec<SearchHit> {
51        let regex = match regex::Regex::new(pattern) {
52            Ok(r) => r,
53            Err(_) => return Vec::new(),
54        };
55
56        let mut hits: Vec<SearchHit> = index
57            .entries()
58            .iter()
59            .filter_map(|entry| {
60                let text = entry.searchable_text();
61                if regex.is_match(&text) {
62                    Some(SearchHit {
63                        entry: entry.clone(),
64                        score: 1.0,
65                    })
66                } else {
67                    None
68                }
69            })
70            .collect();
71
72        hits.truncate(limit);
73        hits
74    }
75
76    fn search_bm25(&self, index: &ToolIndex, query: &str, limit: usize) -> Vec<SearchHit> {
77        let query_terms: Vec<&str> = query.split_whitespace().collect();
78        if query_terms.is_empty() {
79            return Vec::new();
80        }
81
82        let avg_doc_len = index
83            .entries()
84            .iter()
85            .map(|e| e.searchable_text().split_whitespace().count())
86            .sum::<usize>() as f64
87            / index.len().max(1) as f64;
88
89        let mut hits: Vec<SearchHit> = index
90            .entries()
91            .iter()
92            .map(|entry| {
93                let score = self.bm25_score(&entry.searchable_text(), &query_terms, avg_doc_len);
94                SearchHit {
95                    entry: entry.clone(),
96                    score,
97                }
98            })
99            .filter(|hit| hit.score > 0.0)
100            .collect();
101
102        hits.sort_by(|a, b| {
103            b.score
104                .partial_cmp(&a.score)
105                .unwrap_or(std::cmp::Ordering::Equal)
106        });
107        hits.truncate(limit);
108        hits
109    }
110
111    fn bm25_score(&self, text: &str, query_terms: &[&str], avg_doc_len: f64) -> f64 {
112        const K1: f64 = 1.2;
113        const B: f64 = 0.75;
114
115        let text_lower = text.to_lowercase();
116        let words: Vec<&str> = text_lower.split_whitespace().collect();
117        let doc_len = words.len() as f64;
118
119        let mut score = 0.0;
120        for term in query_terms {
121            let term_lower = term.to_lowercase();
122            let tf = words
123                .iter()
124                .filter(|w| w.contains(term_lower.as_str()))
125                .count() as f64;
126
127            if tf > 0.0 {
128                let idf = 1.0; // Simplified IDF
129                let numerator = tf * (K1 + 1.0);
130                let denominator = tf + K1 * (1.0 - B + B * (doc_len / avg_doc_len.max(1.0)));
131                score += idf * (numerator / denominator);
132            }
133        }
134
135        score
136    }
137}
138
139impl Default for SearchEngine {
140    fn default() -> Self {
141        Self::regex()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::mcp::McpToolDefinition;
149
150    fn make_index() -> ToolIndex {
151        let mut index = ToolIndex::new();
152
153        let tools = [
154            (
155                "weather",
156                "get_weather",
157                "Get current weather for a location",
158            ),
159            ("weather", "get_forecast", "Get weather forecast for days"),
160            ("database", "query", "Execute database query"),
161            ("database", "insert", "Insert data into database"),
162            ("files", "read_file", "Read file contents"),
163        ];
164
165        for (server, name, desc) in tools {
166            let tool = McpToolDefinition {
167                name: name.to_string(),
168                description: desc.to_string(),
169                input_schema: serde_json::json!({"type": "object"}),
170            };
171            index.add(super::super::index::ToolIndexEntry::from_mcp_tool(
172                server, &tool,
173            ));
174        }
175
176        index
177    }
178
179    #[test]
180    fn test_regex_search_simple() {
181        let engine = SearchEngine::regex();
182        let index = make_index();
183
184        let hits = engine.search(&index, "weather", 5);
185        assert_eq!(hits.len(), 2);
186    }
187
188    #[test]
189    fn test_regex_search_pattern() {
190        let engine = SearchEngine::regex();
191        let index = make_index();
192
193        let hits = engine.search(&index, "get_.*", 5);
194        assert_eq!(hits.len(), 2);
195    }
196
197    #[test]
198    fn test_bm25_search() {
199        let engine = SearchEngine::bm25();
200        let index = make_index();
201
202        let hits = engine.search(&index, "weather location", 5);
203        assert!(!hits.is_empty());
204        assert!(hits[0].entry.tool_name.contains("weather"));
205    }
206
207    #[test]
208    fn test_empty_query() {
209        let engine = SearchEngine::regex();
210        let index = make_index();
211
212        let hits = engine.search(&index, "", 5);
213        assert!(hits.is_empty());
214    }
215
216    #[test]
217    fn test_invalid_regex() {
218        let engine = SearchEngine::regex();
219        let index = make_index();
220
221        let hits = engine.search(&index, "[invalid", 5);
222        assert!(hits.is_empty());
223    }
224}