ricecoder_research/
semantic_index.rs

1//! Semantic index for fast symbol lookup and search
2
3use crate::models::{SearchResult, Symbol, SymbolKind, SymbolReference};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Semantic index for code symbols
8#[derive(Debug, Clone)]
9pub struct SemanticIndex {
10    /// Map from symbol ID to symbol
11    symbols_by_id: HashMap<String, Symbol>,
12    /// Map from symbol name to symbol IDs
13    symbols_by_name: HashMap<String, Vec<String>>,
14    /// Map from file path to symbol IDs in that file
15    symbols_by_file: HashMap<PathBuf, Vec<String>>,
16    /// Map from symbol ID to all references to that symbol
17    references_by_symbol: HashMap<String, Vec<SymbolReference>>,
18}
19
20impl SemanticIndex {
21    /// Create a new semantic index
22    pub fn new() -> Self {
23        SemanticIndex {
24            symbols_by_id: HashMap::new(),
25            symbols_by_name: HashMap::new(),
26            symbols_by_file: HashMap::new(),
27            references_by_symbol: HashMap::new(),
28        }
29    }
30
31    /// Add a symbol to the index
32    pub fn add_symbol(&mut self, symbol: Symbol) {
33        let symbol_id = symbol.id.clone();
34        let symbol_name = symbol.name.clone();
35        let symbol_file = symbol.file.clone();
36
37        // Add to symbols_by_id
38        self.symbols_by_id.insert(symbol_id.clone(), symbol);
39
40        // Add to symbols_by_name
41        self.symbols_by_name
42            .entry(symbol_name)
43            .or_default()
44            .push(symbol_id.clone());
45
46        // Add to symbols_by_file
47        self.symbols_by_file
48            .entry(symbol_file)
49            .or_default()
50            .push(symbol_id);
51    }
52
53    /// Add a reference to the index
54    pub fn add_reference(&mut self, reference: SymbolReference) {
55        self.references_by_symbol
56            .entry(reference.symbol_id.clone())
57            .or_default()
58            .push(reference);
59    }
60
61    /// Get a symbol by ID
62    pub fn get_symbol(&self, symbol_id: &str) -> Option<&Symbol> {
63        self.symbols_by_id.get(symbol_id)
64    }
65
66    /// Get all symbols with a given name
67    pub fn get_symbols_by_name(&self, name: &str) -> Vec<&Symbol> {
68        self.symbols_by_name
69            .get(name)
70            .map(|ids| {
71                ids.iter()
72                    .filter_map(|id| self.symbols_by_id.get(id))
73                    .collect()
74            })
75            .unwrap_or_default()
76    }
77
78    /// Get all symbols in a file
79    pub fn get_symbols_in_file(&self, file: &PathBuf) -> Vec<&Symbol> {
80        self.symbols_by_file
81            .get(file)
82            .map(|ids| {
83                ids.iter()
84                    .filter_map(|id| self.symbols_by_id.get(id))
85                    .collect()
86            })
87            .unwrap_or_default()
88    }
89
90    /// Get all references to a symbol
91    pub fn get_references_to_symbol(&self, symbol_id: &str) -> Vec<&SymbolReference> {
92        self.references_by_symbol
93            .get(symbol_id)
94            .map(|refs| refs.iter().collect())
95            .unwrap_or_default()
96    }
97
98    /// Search for symbols by name (substring match)
99    pub fn search_by_name(&self, query: &str) -> Vec<SearchResult> {
100        let mut results = Vec::new();
101
102        for (name, symbol_ids) in &self.symbols_by_name {
103            if name.contains(query) {
104                for symbol_id in symbol_ids {
105                    if let Some(symbol) = self.symbols_by_id.get(symbol_id) {
106                        // Calculate relevance based on match quality
107                        let relevance = if name == query {
108                            1.0
109                        } else if name.starts_with(query) {
110                            0.8
111                        } else {
112                            0.5
113                        };
114
115                        results.push(SearchResult {
116                            symbol: symbol.clone(),
117                            relevance,
118                            context: None,
119                        });
120                    }
121                }
122            }
123        }
124
125        // Sort by relevance (descending)
126        results.sort_by(|a, b| {
127            b.relevance
128                .partial_cmp(&a.relevance)
129                .unwrap_or(std::cmp::Ordering::Equal)
130        });
131
132        results
133    }
134
135    /// Search for symbols by kind
136    pub fn search_by_kind(&self, kind: SymbolKind) -> Vec<&Symbol> {
137        self.symbols_by_id
138            .values()
139            .filter(|symbol| symbol.kind == kind)
140            .collect()
141    }
142
143    /// Get all symbols in the index
144    pub fn all_symbols(&self) -> Vec<&Symbol> {
145        self.symbols_by_id.values().collect()
146    }
147
148    /// Get the total number of symbols
149    pub fn symbol_count(&self) -> usize {
150        self.symbols_by_id.len()
151    }
152
153    /// Get the total number of references
154    pub fn reference_count(&self) -> usize {
155        self.references_by_symbol
156            .values()
157            .map(|refs| refs.len())
158            .sum()
159    }
160
161    /// Clear the index
162    pub fn clear(&mut self) {
163        self.symbols_by_id.clear();
164        self.symbols_by_name.clear();
165        self.symbols_by_file.clear();
166        self.references_by_symbol.clear();
167    }
168}
169
170impl Default for SemanticIndex {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    fn create_test_symbol(id: &str, name: &str, kind: SymbolKind) -> Symbol {
181        Symbol {
182            id: id.to_string(),
183            name: name.to_string(),
184            kind,
185            file: PathBuf::from("test.rs"),
186            line: 1,
187            column: 1,
188            references: Vec::new(),
189        }
190    }
191
192    #[test]
193    fn test_add_and_get_symbol() {
194        let mut index = SemanticIndex::new();
195        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
196
197        index.add_symbol(symbol.clone());
198
199        assert_eq!(index.get_symbol("sym1"), Some(&symbol));
200    }
201
202    #[test]
203    fn test_get_symbols_by_name() {
204        let mut index = SemanticIndex::new();
205        let symbol1 = create_test_symbol("sym1", "my_function", SymbolKind::Function);
206        let symbol2 = create_test_symbol("sym2", "my_function", SymbolKind::Function);
207
208        index.add_symbol(symbol1.clone());
209        index.add_symbol(symbol2.clone());
210
211        let results = index.get_symbols_by_name("my_function");
212        assert_eq!(results.len(), 2);
213    }
214
215    #[test]
216    fn test_get_symbols_in_file() {
217        let mut index = SemanticIndex::new();
218        let symbol1 = create_test_symbol("sym1", "func1", SymbolKind::Function);
219        let mut symbol2 = create_test_symbol("sym2", "func2", SymbolKind::Function);
220        symbol2.file = PathBuf::from("other.rs");
221
222        index.add_symbol(symbol1.clone());
223        index.add_symbol(symbol2);
224
225        let results = index.get_symbols_in_file(&PathBuf::from("test.rs"));
226        assert_eq!(results.len(), 1);
227        assert_eq!(results[0].name, "func1");
228    }
229
230    #[test]
231    fn test_search_by_name_exact_match() {
232        let mut index = SemanticIndex::new();
233        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
234        index.add_symbol(symbol);
235
236        let results = index.search_by_name("my_function");
237        assert_eq!(results.len(), 1);
238        assert_eq!(results[0].relevance, 1.0);
239    }
240
241    #[test]
242    fn test_search_by_name_prefix_match() {
243        let mut index = SemanticIndex::new();
244        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
245        index.add_symbol(symbol);
246
247        let results = index.search_by_name("my_");
248        assert_eq!(results.len(), 1);
249        assert_eq!(results[0].relevance, 0.8);
250    }
251
252    #[test]
253    fn test_search_by_name_substring_match() {
254        let mut index = SemanticIndex::new();
255        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
256        index.add_symbol(symbol);
257
258        let results = index.search_by_name("function");
259        assert_eq!(results.len(), 1);
260        assert_eq!(results[0].relevance, 0.5);
261    }
262
263    #[test]
264    fn test_search_by_kind() {
265        let mut index = SemanticIndex::new();
266        let func = create_test_symbol("sym1", "my_function", SymbolKind::Function);
267        let class = create_test_symbol("sym2", "MyClass", SymbolKind::Class);
268
269        index.add_symbol(func);
270        index.add_symbol(class);
271
272        let results = index.search_by_kind(SymbolKind::Function);
273        assert_eq!(results.len(), 1);
274        assert_eq!(results[0].kind, SymbolKind::Function);
275    }
276
277    #[test]
278    fn test_add_reference() {
279        let mut index = SemanticIndex::new();
280        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
281        index.add_symbol(symbol);
282
283        let reference = SymbolReference {
284            symbol_id: "sym1".to_string(),
285            file: PathBuf::from("test.rs"),
286            line: 5,
287            kind: crate::models::ReferenceKind::Usage,
288        };
289
290        index.add_reference(reference);
291
292        let refs = index.get_references_to_symbol("sym1");
293        assert_eq!(refs.len(), 1);
294        assert_eq!(refs[0].line, 5);
295    }
296
297    #[test]
298    fn test_symbol_count() {
299        let mut index = SemanticIndex::new();
300        let symbol1 = create_test_symbol("sym1", "func1", SymbolKind::Function);
301        let symbol2 = create_test_symbol("sym2", "func2", SymbolKind::Function);
302
303        index.add_symbol(symbol1);
304        index.add_symbol(symbol2);
305
306        assert_eq!(index.symbol_count(), 2);
307    }
308
309    #[test]
310    fn test_reference_count() {
311        let mut index = SemanticIndex::new();
312        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
313        index.add_symbol(symbol);
314
315        let ref1 = SymbolReference {
316            symbol_id: "sym1".to_string(),
317            file: PathBuf::from("test.rs"),
318            line: 5,
319            kind: crate::models::ReferenceKind::Usage,
320        };
321
322        let ref2 = SymbolReference {
323            symbol_id: "sym1".to_string(),
324            file: PathBuf::from("test.rs"),
325            line: 10,
326            kind: crate::models::ReferenceKind::Usage,
327        };
328
329        index.add_reference(ref1);
330        index.add_reference(ref2);
331
332        assert_eq!(index.reference_count(), 2);
333    }
334
335    #[test]
336    fn test_clear() {
337        let mut index = SemanticIndex::new();
338        let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
339        index.add_symbol(symbol);
340
341        assert_eq!(index.symbol_count(), 1);
342
343        index.clear();
344
345        assert_eq!(index.symbol_count(), 0);
346    }
347
348    #[test]
349    fn test_all_symbols() {
350        let mut index = SemanticIndex::new();
351        let symbol1 = create_test_symbol("sym1", "func1", SymbolKind::Function);
352        let symbol2 = create_test_symbol("sym2", "func2", SymbolKind::Function);
353
354        index.add_symbol(symbol1);
355        index.add_symbol(symbol2);
356
357        let all = index.all_symbols();
358        assert_eq!(all.len(), 2);
359    }
360}