ricecoder_research/
search_engine.rs

1//! Search engine for semantic and full-text search
2
3use crate::models::{SearchResult, SymbolKind};
4use crate::semantic_index::SemanticIndex;
5use regex::Regex;
6use std::collections::HashMap;
7
8/// Search engine for code search
9pub struct SearchEngine {
10    index: SemanticIndex,
11}
12
13/// Search query options
14#[derive(Debug, Clone)]
15pub struct SearchOptions {
16    /// Maximum number of results to return
17    pub max_results: usize,
18    /// Filter by symbol kind
19    pub kind_filter: Option<SymbolKind>,
20    /// Case-sensitive search
21    pub case_sensitive: bool,
22}
23
24impl Default for SearchOptions {
25    fn default() -> Self {
26        SearchOptions {
27            max_results: 100,
28            kind_filter: None,
29            case_sensitive: false,
30        }
31    }
32}
33
34impl SearchEngine {
35    /// Create a new search engine with an index
36    pub fn new(index: SemanticIndex) -> Self {
37        SearchEngine { index }
38    }
39
40    /// Search for symbols by name
41    pub fn search_by_name(&self, query: &str, options: &SearchOptions) -> Vec<SearchResult> {
42        let mut results = self.index.search_by_name(query);
43
44        // Apply kind filter if specified
45        if let Some(kind) = options.kind_filter {
46            results.retain(|r| r.symbol.kind == kind);
47        }
48
49        // Limit results
50        results.truncate(options.max_results);
51
52        results
53    }
54
55    /// Search for symbols by kind
56    pub fn search_by_kind(&self, kind: SymbolKind, options: &SearchOptions) -> Vec<SearchResult> {
57        let symbols = self.index.search_by_kind(kind);
58
59        let mut results: Vec<SearchResult> = symbols
60            .into_iter()
61            .map(|symbol| SearchResult {
62                symbol: symbol.clone(),
63                relevance: 1.0,
64                context: None,
65            })
66            .collect();
67
68        // Limit results
69        results.truncate(options.max_results);
70
71        results
72    }
73
74    /// Full-text search across all symbols
75    pub fn full_text_search(&self, query: &str, options: &SearchOptions) -> Vec<SearchResult> {
76        let query_lower = if options.case_sensitive {
77            query.to_string()
78        } else {
79            query.to_lowercase()
80        };
81
82        let mut results = Vec::new();
83
84        for symbol in self.index.all_symbols() {
85            // Check if symbol name matches
86            let symbol_name = if options.case_sensitive {
87                symbol.name.clone()
88            } else {
89                symbol.name.to_lowercase()
90            };
91
92            if symbol_name.contains(&query_lower) {
93                // Calculate relevance based on match quality
94                let relevance = if symbol_name == query_lower {
95                    1.0
96                } else if symbol_name.starts_with(&query_lower) {
97                    0.8
98                } else {
99                    0.5
100                };
101
102                // Apply kind filter if specified
103                if let Some(kind) = options.kind_filter {
104                    if symbol.kind != kind {
105                        continue;
106                    }
107                }
108
109                results.push(SearchResult {
110                    symbol: symbol.clone(),
111                    relevance,
112                    context: None,
113                });
114            }
115        }
116
117        // Sort by relevance (descending)
118        results.sort_by(|a, b| {
119            b.relevance
120                .partial_cmp(&a.relevance)
121                .unwrap_or(std::cmp::Ordering::Equal)
122        });
123
124        // Limit results
125        results.truncate(options.max_results);
126
127        results
128    }
129
130    /// Search using a regex pattern
131    pub fn regex_search(&self, pattern: &str, options: &SearchOptions) -> Vec<SearchResult> {
132        let regex = match Regex::new(pattern) {
133            Ok(r) => r,
134            Err(_) => return Vec::new(),
135        };
136
137        let mut results = Vec::new();
138
139        for symbol in self.index.all_symbols() {
140            if regex.is_match(&symbol.name) {
141                // Apply kind filter if specified
142                if let Some(kind) = options.kind_filter {
143                    if symbol.kind != kind {
144                        continue;
145                    }
146                }
147
148                results.push(SearchResult {
149                    symbol: symbol.clone(),
150                    relevance: 0.7,
151                    context: None,
152                });
153            }
154        }
155
156        // Limit results
157        results.truncate(options.max_results);
158
159        results
160    }
161
162    /// Get all symbols
163    pub fn all_symbols(&self) -> Vec<SearchResult> {
164        self.index
165            .all_symbols()
166            .into_iter()
167            .map(|symbol| SearchResult {
168                symbol: symbol.clone(),
169                relevance: 1.0,
170                context: None,
171            })
172            .collect()
173    }
174
175    /// Get statistics about the index
176    pub fn get_statistics(&self) -> SearchStatistics {
177        let all_symbols = self.index.all_symbols();
178        let mut kind_counts: HashMap<String, usize> = HashMap::new();
179
180        for symbol in &all_symbols {
181            let kind_str = format!("{:?}", symbol.kind);
182            *kind_counts.entry(kind_str).or_insert(0) += 1;
183        }
184
185        SearchStatistics {
186            total_symbols: self.index.symbol_count(),
187            total_references: self.index.reference_count(),
188            kind_distribution: kind_counts,
189        }
190    }
191}
192
193/// Statistics about the search index
194#[derive(Debug, Clone)]
195pub struct SearchStatistics {
196    /// Total number of symbols
197    pub total_symbols: usize,
198    /// Total number of references
199    pub total_references: usize,
200    /// Distribution of symbols by kind
201    pub kind_distribution: HashMap<String, usize>,
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::models::Symbol;
208    use std::path::PathBuf;
209
210    fn create_test_symbol(id: &str, name: &str, kind: SymbolKind) -> Symbol {
211        Symbol {
212            id: id.to_string(),
213            name: name.to_string(),
214            kind,
215            file: PathBuf::from("test.rs"),
216            line: 1,
217            column: 1,
218            references: Vec::new(),
219        }
220    }
221
222    fn create_test_index() -> SemanticIndex {
223        let mut index = SemanticIndex::new();
224        index.add_symbol(create_test_symbol(
225            "sym1",
226            "my_function",
227            SymbolKind::Function,
228        ));
229        index.add_symbol(create_test_symbol("sym2", "MyClass", SymbolKind::Class));
230        index.add_symbol(create_test_symbol(
231            "sym3",
232            "my_constant",
233            SymbolKind::Constant,
234        ));
235        index
236    }
237
238    #[test]
239    fn test_search_by_name() {
240        let index = create_test_index();
241        let engine = SearchEngine::new(index);
242        let options = SearchOptions::default();
243
244        let results = engine.search_by_name("my_", &options);
245        assert!(!results.is_empty());
246    }
247
248    #[test]
249    fn test_search_by_kind() {
250        let index = create_test_index();
251        let engine = SearchEngine::new(index);
252        let options = SearchOptions::default();
253
254        let results = engine.search_by_kind(SymbolKind::Function, &options);
255        assert_eq!(results.len(), 1);
256        assert_eq!(results[0].symbol.name, "my_function");
257    }
258
259    #[test]
260    fn test_full_text_search() {
261        let index = create_test_index();
262        let engine = SearchEngine::new(index);
263        let options = SearchOptions::default();
264
265        let results = engine.full_text_search("function", &options);
266        assert!(!results.is_empty());
267    }
268
269    #[test]
270    fn test_full_text_search_case_insensitive() {
271        let index = create_test_index();
272        let engine = SearchEngine::new(index);
273        let options = SearchOptions {
274            case_sensitive: false,
275            ..Default::default()
276        };
277
278        let results = engine.full_text_search("MY_FUNCTION", &options);
279        assert!(!results.is_empty());
280    }
281
282    #[test]
283    fn test_full_text_search_case_sensitive() {
284        let index = create_test_index();
285        let engine = SearchEngine::new(index);
286        let options = SearchOptions {
287            case_sensitive: true,
288            ..Default::default()
289        };
290
291        let results = engine.full_text_search("MY_FUNCTION", &options);
292        assert!(results.is_empty());
293    }
294
295    #[test]
296    fn test_regex_search() {
297        let index = create_test_index();
298        let engine = SearchEngine::new(index);
299        let options = SearchOptions::default();
300
301        let results = engine.regex_search("my_.*", &options);
302        assert!(!results.is_empty());
303    }
304
305    #[test]
306    fn test_regex_search_invalid_pattern() {
307        let index = create_test_index();
308        let engine = SearchEngine::new(index);
309        let options = SearchOptions::default();
310
311        let results = engine.regex_search("[invalid(", &options);
312        assert!(results.is_empty());
313    }
314
315    #[test]
316    fn test_search_with_kind_filter() {
317        let index = create_test_index();
318        let engine = SearchEngine::new(index);
319        let options = SearchOptions {
320            kind_filter: Some(SymbolKind::Function),
321            ..Default::default()
322        };
323
324        let results = engine.full_text_search("my", &options);
325        assert_eq!(results.len(), 1);
326        assert_eq!(results[0].symbol.kind, SymbolKind::Function);
327    }
328
329    #[test]
330    fn test_search_max_results() {
331        let index = create_test_index();
332        let engine = SearchEngine::new(index);
333        let options = SearchOptions {
334            max_results: 1,
335            ..Default::default()
336        };
337
338        let results = engine.full_text_search("my", &options);
339        assert_eq!(results.len(), 1);
340    }
341
342    #[test]
343    fn test_all_symbols() {
344        let index = create_test_index();
345        let engine = SearchEngine::new(index);
346
347        let results = engine.all_symbols();
348        assert_eq!(results.len(), 3);
349    }
350
351    #[test]
352    fn test_get_statistics() {
353        let index = create_test_index();
354        let engine = SearchEngine::new(index);
355
356        let stats = engine.get_statistics();
357        assert_eq!(stats.total_symbols, 3);
358        assert!(stats.kind_distribution.contains_key("Function"));
359        assert!(stats.kind_distribution.contains_key("Class"));
360        assert!(stats.kind_distribution.contains_key("Constant"));
361    }
362}