claude_agent/tools/search/
engine.rs1use super::index::{ToolIndex, ToolIndexEntry};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum SearchMode {
7 #[default]
8 Regex,
9 Bm25,
10}
11
12#[derive(Debug, Clone)]
13pub struct SearchHit {
14 pub entry: ToolIndexEntry,
15 pub score: f64,
16}
17
18pub struct SearchEngine {
19 mode: SearchMode,
20}
21
22impl SearchEngine {
23 pub fn new(mode: SearchMode) -> Self {
24 Self { mode }
25 }
26
27 pub fn regex() -> Self {
28 Self::new(SearchMode::Regex)
29 }
30
31 pub fn bm25() -> Self {
32 Self::new(SearchMode::Bm25)
33 }
34
35 pub fn mode(&self) -> SearchMode {
36 self.mode
37 }
38
39 pub fn search(&self, index: &ToolIndex, query: &str, limit: usize) -> Vec<SearchHit> {
40 if query.is_empty() || index.is_empty() {
41 return Vec::new();
42 }
43
44 match self.mode {
45 SearchMode::Regex => self.search_regex(index, query, limit),
46 SearchMode::Bm25 => self.search_bm25(index, query, limit),
47 }
48 }
49
50 fn search_regex(&self, index: &ToolIndex, pattern: &str, limit: usize) -> Vec<SearchHit> {
51 let regex = match regex::Regex::new(pattern) {
52 Ok(r) => r,
53 Err(_) => return Vec::new(),
54 };
55
56 let mut hits: Vec<SearchHit> = index
57 .entries()
58 .iter()
59 .filter_map(|entry| {
60 let text = entry.searchable_text();
61 if regex.is_match(&text) {
62 Some(SearchHit {
63 entry: entry.clone(),
64 score: 1.0,
65 })
66 } else {
67 None
68 }
69 })
70 .collect();
71
72 hits.truncate(limit);
73 hits
74 }
75
76 fn search_bm25(&self, index: &ToolIndex, query: &str, limit: usize) -> Vec<SearchHit> {
77 let query_terms: Vec<&str> = query.split_whitespace().collect();
78 if query_terms.is_empty() {
79 return Vec::new();
80 }
81
82 let avg_doc_len = index
83 .entries()
84 .iter()
85 .map(|e| e.searchable_text().split_whitespace().count())
86 .sum::<usize>() as f64
87 / index.len().max(1) as f64;
88
89 let mut hits: Vec<SearchHit> = index
90 .entries()
91 .iter()
92 .map(|entry| {
93 let score = self.bm25_score(&entry.searchable_text(), &query_terms, avg_doc_len);
94 SearchHit {
95 entry: entry.clone(),
96 score,
97 }
98 })
99 .filter(|hit| hit.score > 0.0)
100 .collect();
101
102 hits.sort_by(|a, b| {
103 b.score
104 .partial_cmp(&a.score)
105 .unwrap_or(std::cmp::Ordering::Equal)
106 });
107 hits.truncate(limit);
108 hits
109 }
110
111 fn bm25_score(&self, text: &str, query_terms: &[&str], avg_doc_len: f64) -> f64 {
112 const K1: f64 = 1.2;
113 const B: f64 = 0.75;
114
115 let text_lower = text.to_lowercase();
116 let words: Vec<&str> = text_lower.split_whitespace().collect();
117 let doc_len = words.len() as f64;
118
119 let mut score = 0.0;
120 for term in query_terms {
121 let term_lower = term.to_lowercase();
122 let tf = words
123 .iter()
124 .filter(|w| w.contains(term_lower.as_str()))
125 .count() as f64;
126
127 if tf > 0.0 {
128 let idf = 1.0; let numerator = tf * (K1 + 1.0);
130 let denominator = tf + K1 * (1.0 - B + B * (doc_len / avg_doc_len.max(1.0)));
131 score += idf * (numerator / denominator);
132 }
133 }
134
135 score
136 }
137}
138
139impl Default for SearchEngine {
140 fn default() -> Self {
141 Self::regex()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::mcp::McpToolDefinition;
149
150 fn make_index() -> ToolIndex {
151 let mut index = ToolIndex::new();
152
153 let tools = [
154 (
155 "weather",
156 "get_weather",
157 "Get current weather for a location",
158 ),
159 ("weather", "get_forecast", "Get weather forecast for days"),
160 ("database", "query", "Execute database query"),
161 ("database", "insert", "Insert data into database"),
162 ("files", "read_file", "Read file contents"),
163 ];
164
165 for (server, name, desc) in tools {
166 let tool = McpToolDefinition {
167 name: name.to_string(),
168 description: desc.to_string(),
169 input_schema: serde_json::json!({"type": "object"}),
170 };
171 index.add(super::super::index::ToolIndexEntry::from_mcp_tool(
172 server, &tool,
173 ));
174 }
175
176 index
177 }
178
179 #[test]
180 fn test_regex_search_simple() {
181 let engine = SearchEngine::regex();
182 let index = make_index();
183
184 let hits = engine.search(&index, "weather", 5);
185 assert_eq!(hits.len(), 2);
186 }
187
188 #[test]
189 fn test_regex_search_pattern() {
190 let engine = SearchEngine::regex();
191 let index = make_index();
192
193 let hits = engine.search(&index, "get_.*", 5);
194 assert_eq!(hits.len(), 2);
195 }
196
197 #[test]
198 fn test_bm25_search() {
199 let engine = SearchEngine::bm25();
200 let index = make_index();
201
202 let hits = engine.search(&index, "weather location", 5);
203 assert!(!hits.is_empty());
204 assert!(hits[0].entry.tool_name.contains("weather"));
205 }
206
207 #[test]
208 fn test_empty_query() {
209 let engine = SearchEngine::regex();
210 let index = make_index();
211
212 let hits = engine.search(&index, "", 5);
213 assert!(hits.is_empty());
214 }
215
216 #[test]
217 fn test_invalid_regex() {
218 let engine = SearchEngine::regex();
219 let index = make_index();
220
221 let hits = engine.search(&index, "[invalid", 5);
222 assert!(hits.is_empty());
223 }
224}