ricecoder_research/
search_engine.rs1use crate::models::{SearchResult, SymbolKind};
4use crate::semantic_index::SemanticIndex;
5use regex::Regex;
6use std::collections::HashMap;
7
8pub struct SearchEngine {
10 index: SemanticIndex,
11}
12
13#[derive(Debug, Clone)]
15pub struct SearchOptions {
16 pub max_results: usize,
18 pub kind_filter: Option<SymbolKind>,
20 pub case_sensitive: bool,
22}
23
24impl Default for SearchOptions {
25 fn default() -> Self {
26 SearchOptions {
27 max_results: 100,
28 kind_filter: None,
29 case_sensitive: false,
30 }
31 }
32}
33
34impl SearchEngine {
35 pub fn new(index: SemanticIndex) -> Self {
37 SearchEngine { index }
38 }
39
40 pub fn search_by_name(&self, query: &str, options: &SearchOptions) -> Vec<SearchResult> {
42 let mut results = self.index.search_by_name(query);
43
44 if let Some(kind) = options.kind_filter {
46 results.retain(|r| r.symbol.kind == kind);
47 }
48
49 results.truncate(options.max_results);
51
52 results
53 }
54
55 pub fn search_by_kind(&self, kind: SymbolKind, options: &SearchOptions) -> Vec<SearchResult> {
57 let symbols = self.index.search_by_kind(kind);
58
59 let mut results: Vec<SearchResult> = symbols
60 .into_iter()
61 .map(|symbol| SearchResult {
62 symbol: symbol.clone(),
63 relevance: 1.0,
64 context: None,
65 })
66 .collect();
67
68 results.truncate(options.max_results);
70
71 results
72 }
73
74 pub fn full_text_search(&self, query: &str, options: &SearchOptions) -> Vec<SearchResult> {
76 let query_lower = if options.case_sensitive {
77 query.to_string()
78 } else {
79 query.to_lowercase()
80 };
81
82 let mut results = Vec::new();
83
84 for symbol in self.index.all_symbols() {
85 let symbol_name = if options.case_sensitive {
87 symbol.name.clone()
88 } else {
89 symbol.name.to_lowercase()
90 };
91
92 if symbol_name.contains(&query_lower) {
93 let relevance = if symbol_name == query_lower {
95 1.0
96 } else if symbol_name.starts_with(&query_lower) {
97 0.8
98 } else {
99 0.5
100 };
101
102 if let Some(kind) = options.kind_filter {
104 if symbol.kind != kind {
105 continue;
106 }
107 }
108
109 results.push(SearchResult {
110 symbol: symbol.clone(),
111 relevance,
112 context: None,
113 });
114 }
115 }
116
117 results.sort_by(|a, b| {
119 b.relevance
120 .partial_cmp(&a.relevance)
121 .unwrap_or(std::cmp::Ordering::Equal)
122 });
123
124 results.truncate(options.max_results);
126
127 results
128 }
129
130 pub fn regex_search(&self, pattern: &str, options: &SearchOptions) -> Vec<SearchResult> {
132 let regex = match Regex::new(pattern) {
133 Ok(r) => r,
134 Err(_) => return Vec::new(),
135 };
136
137 let mut results = Vec::new();
138
139 for symbol in self.index.all_symbols() {
140 if regex.is_match(&symbol.name) {
141 if let Some(kind) = options.kind_filter {
143 if symbol.kind != kind {
144 continue;
145 }
146 }
147
148 results.push(SearchResult {
149 symbol: symbol.clone(),
150 relevance: 0.7,
151 context: None,
152 });
153 }
154 }
155
156 results.truncate(options.max_results);
158
159 results
160 }
161
162 pub fn all_symbols(&self) -> Vec<SearchResult> {
164 self.index
165 .all_symbols()
166 .into_iter()
167 .map(|symbol| SearchResult {
168 symbol: symbol.clone(),
169 relevance: 1.0,
170 context: None,
171 })
172 .collect()
173 }
174
175 pub fn get_statistics(&self) -> SearchStatistics {
177 let all_symbols = self.index.all_symbols();
178 let mut kind_counts: HashMap<String, usize> = HashMap::new();
179
180 for symbol in &all_symbols {
181 let kind_str = format!("{:?}", symbol.kind);
182 *kind_counts.entry(kind_str).or_insert(0) += 1;
183 }
184
185 SearchStatistics {
186 total_symbols: self.index.symbol_count(),
187 total_references: self.index.reference_count(),
188 kind_distribution: kind_counts,
189 }
190 }
191}
192
193#[derive(Debug, Clone)]
195pub struct SearchStatistics {
196 pub total_symbols: usize,
198 pub total_references: usize,
200 pub kind_distribution: HashMap<String, usize>,
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::models::Symbol;
208 use std::path::PathBuf;
209
210 fn create_test_symbol(id: &str, name: &str, kind: SymbolKind) -> Symbol {
211 Symbol {
212 id: id.to_string(),
213 name: name.to_string(),
214 kind,
215 file: PathBuf::from("test.rs"),
216 line: 1,
217 column: 1,
218 references: Vec::new(),
219 }
220 }
221
222 fn create_test_index() -> SemanticIndex {
223 let mut index = SemanticIndex::new();
224 index.add_symbol(create_test_symbol(
225 "sym1",
226 "my_function",
227 SymbolKind::Function,
228 ));
229 index.add_symbol(create_test_symbol("sym2", "MyClass", SymbolKind::Class));
230 index.add_symbol(create_test_symbol(
231 "sym3",
232 "my_constant",
233 SymbolKind::Constant,
234 ));
235 index
236 }
237
238 #[test]
239 fn test_search_by_name() {
240 let index = create_test_index();
241 let engine = SearchEngine::new(index);
242 let options = SearchOptions::default();
243
244 let results = engine.search_by_name("my_", &options);
245 assert!(!results.is_empty());
246 }
247
248 #[test]
249 fn test_search_by_kind() {
250 let index = create_test_index();
251 let engine = SearchEngine::new(index);
252 let options = SearchOptions::default();
253
254 let results = engine.search_by_kind(SymbolKind::Function, &options);
255 assert_eq!(results.len(), 1);
256 assert_eq!(results[0].symbol.name, "my_function");
257 }
258
259 #[test]
260 fn test_full_text_search() {
261 let index = create_test_index();
262 let engine = SearchEngine::new(index);
263 let options = SearchOptions::default();
264
265 let results = engine.full_text_search("function", &options);
266 assert!(!results.is_empty());
267 }
268
269 #[test]
270 fn test_full_text_search_case_insensitive() {
271 let index = create_test_index();
272 let engine = SearchEngine::new(index);
273 let options = SearchOptions {
274 case_sensitive: false,
275 ..Default::default()
276 };
277
278 let results = engine.full_text_search("MY_FUNCTION", &options);
279 assert!(!results.is_empty());
280 }
281
282 #[test]
283 fn test_full_text_search_case_sensitive() {
284 let index = create_test_index();
285 let engine = SearchEngine::new(index);
286 let options = SearchOptions {
287 case_sensitive: true,
288 ..Default::default()
289 };
290
291 let results = engine.full_text_search("MY_FUNCTION", &options);
292 assert!(results.is_empty());
293 }
294
295 #[test]
296 fn test_regex_search() {
297 let index = create_test_index();
298 let engine = SearchEngine::new(index);
299 let options = SearchOptions::default();
300
301 let results = engine.regex_search("my_.*", &options);
302 assert!(!results.is_empty());
303 }
304
305 #[test]
306 fn test_regex_search_invalid_pattern() {
307 let index = create_test_index();
308 let engine = SearchEngine::new(index);
309 let options = SearchOptions::default();
310
311 let results = engine.regex_search("[invalid(", &options);
312 assert!(results.is_empty());
313 }
314
315 #[test]
316 fn test_search_with_kind_filter() {
317 let index = create_test_index();
318 let engine = SearchEngine::new(index);
319 let options = SearchOptions {
320 kind_filter: Some(SymbolKind::Function),
321 ..Default::default()
322 };
323
324 let results = engine.full_text_search("my", &options);
325 assert_eq!(results.len(), 1);
326 assert_eq!(results[0].symbol.kind, SymbolKind::Function);
327 }
328
329 #[test]
330 fn test_search_max_results() {
331 let index = create_test_index();
332 let engine = SearchEngine::new(index);
333 let options = SearchOptions {
334 max_results: 1,
335 ..Default::default()
336 };
337
338 let results = engine.full_text_search("my", &options);
339 assert_eq!(results.len(), 1);
340 }
341
342 #[test]
343 fn test_all_symbols() {
344 let index = create_test_index();
345 let engine = SearchEngine::new(index);
346
347 let results = engine.all_symbols();
348 assert_eq!(results.len(), 3);
349 }
350
351 #[test]
352 fn test_get_statistics() {
353 let index = create_test_index();
354 let engine = SearchEngine::new(index);
355
356 let stats = engine.get_statistics();
357 assert_eq!(stats.total_symbols, 3);
358 assert!(stats.kind_distribution.contains_key("Function"));
359 assert!(stats.kind_distribution.contains_key("Class"));
360 assert!(stats.kind_distribution.contains_key("Constant"));
361 }
362}