1use crate::models::{SearchResult, Symbol, SymbolKind, SymbolReference};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7#[derive(Debug, Clone)]
9pub struct SemanticIndex {
10 symbols_by_id: HashMap<String, Symbol>,
12 symbols_by_name: HashMap<String, Vec<String>>,
14 symbols_by_file: HashMap<PathBuf, Vec<String>>,
16 references_by_symbol: HashMap<String, Vec<SymbolReference>>,
18}
19
20impl SemanticIndex {
21 pub fn new() -> Self {
23 SemanticIndex {
24 symbols_by_id: HashMap::new(),
25 symbols_by_name: HashMap::new(),
26 symbols_by_file: HashMap::new(),
27 references_by_symbol: HashMap::new(),
28 }
29 }
30
31 pub fn add_symbol(&mut self, symbol: Symbol) {
33 let symbol_id = symbol.id.clone();
34 let symbol_name = symbol.name.clone();
35 let symbol_file = symbol.file.clone();
36
37 self.symbols_by_id.insert(symbol_id.clone(), symbol);
39
40 self.symbols_by_name
42 .entry(symbol_name)
43 .or_default()
44 .push(symbol_id.clone());
45
46 self.symbols_by_file
48 .entry(symbol_file)
49 .or_default()
50 .push(symbol_id);
51 }
52
53 pub fn add_reference(&mut self, reference: SymbolReference) {
55 self.references_by_symbol
56 .entry(reference.symbol_id.clone())
57 .or_default()
58 .push(reference);
59 }
60
61 pub fn get_symbol(&self, symbol_id: &str) -> Option<&Symbol> {
63 self.symbols_by_id.get(symbol_id)
64 }
65
66 pub fn get_symbols_by_name(&self, name: &str) -> Vec<&Symbol> {
68 self.symbols_by_name
69 .get(name)
70 .map(|ids| {
71 ids.iter()
72 .filter_map(|id| self.symbols_by_id.get(id))
73 .collect()
74 })
75 .unwrap_or_default()
76 }
77
78 pub fn get_symbols_in_file(&self, file: &PathBuf) -> Vec<&Symbol> {
80 self.symbols_by_file
81 .get(file)
82 .map(|ids| {
83 ids.iter()
84 .filter_map(|id| self.symbols_by_id.get(id))
85 .collect()
86 })
87 .unwrap_or_default()
88 }
89
90 pub fn get_references_to_symbol(&self, symbol_id: &str) -> Vec<&SymbolReference> {
92 self.references_by_symbol
93 .get(symbol_id)
94 .map(|refs| refs.iter().collect())
95 .unwrap_or_default()
96 }
97
98 pub fn search_by_name(&self, query: &str) -> Vec<SearchResult> {
100 let mut results = Vec::new();
101
102 for (name, symbol_ids) in &self.symbols_by_name {
103 if name.contains(query) {
104 for symbol_id in symbol_ids {
105 if let Some(symbol) = self.symbols_by_id.get(symbol_id) {
106 let relevance = if name == query {
108 1.0
109 } else if name.starts_with(query) {
110 0.8
111 } else {
112 0.5
113 };
114
115 results.push(SearchResult {
116 symbol: symbol.clone(),
117 relevance,
118 context: None,
119 });
120 }
121 }
122 }
123 }
124
125 results.sort_by(|a, b| {
127 b.relevance
128 .partial_cmp(&a.relevance)
129 .unwrap_or(std::cmp::Ordering::Equal)
130 });
131
132 results
133 }
134
135 pub fn search_by_kind(&self, kind: SymbolKind) -> Vec<&Symbol> {
137 self.symbols_by_id
138 .values()
139 .filter(|symbol| symbol.kind == kind)
140 .collect()
141 }
142
143 pub fn all_symbols(&self) -> Vec<&Symbol> {
145 self.symbols_by_id.values().collect()
146 }
147
148 pub fn symbol_count(&self) -> usize {
150 self.symbols_by_id.len()
151 }
152
153 pub fn reference_count(&self) -> usize {
155 self.references_by_symbol
156 .values()
157 .map(|refs| refs.len())
158 .sum()
159 }
160
161 pub fn clear(&mut self) {
163 self.symbols_by_id.clear();
164 self.symbols_by_name.clear();
165 self.symbols_by_file.clear();
166 self.references_by_symbol.clear();
167 }
168}
169
170impl Default for SemanticIndex {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 fn create_test_symbol(id: &str, name: &str, kind: SymbolKind) -> Symbol {
181 Symbol {
182 id: id.to_string(),
183 name: name.to_string(),
184 kind,
185 file: PathBuf::from("test.rs"),
186 line: 1,
187 column: 1,
188 references: Vec::new(),
189 }
190 }
191
192 #[test]
193 fn test_add_and_get_symbol() {
194 let mut index = SemanticIndex::new();
195 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
196
197 index.add_symbol(symbol.clone());
198
199 assert_eq!(index.get_symbol("sym1"), Some(&symbol));
200 }
201
202 #[test]
203 fn test_get_symbols_by_name() {
204 let mut index = SemanticIndex::new();
205 let symbol1 = create_test_symbol("sym1", "my_function", SymbolKind::Function);
206 let symbol2 = create_test_symbol("sym2", "my_function", SymbolKind::Function);
207
208 index.add_symbol(symbol1.clone());
209 index.add_symbol(symbol2.clone());
210
211 let results = index.get_symbols_by_name("my_function");
212 assert_eq!(results.len(), 2);
213 }
214
215 #[test]
216 fn test_get_symbols_in_file() {
217 let mut index = SemanticIndex::new();
218 let symbol1 = create_test_symbol("sym1", "func1", SymbolKind::Function);
219 let mut symbol2 = create_test_symbol("sym2", "func2", SymbolKind::Function);
220 symbol2.file = PathBuf::from("other.rs");
221
222 index.add_symbol(symbol1.clone());
223 index.add_symbol(symbol2);
224
225 let results = index.get_symbols_in_file(&PathBuf::from("test.rs"));
226 assert_eq!(results.len(), 1);
227 assert_eq!(results[0].name, "func1");
228 }
229
230 #[test]
231 fn test_search_by_name_exact_match() {
232 let mut index = SemanticIndex::new();
233 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
234 index.add_symbol(symbol);
235
236 let results = index.search_by_name("my_function");
237 assert_eq!(results.len(), 1);
238 assert_eq!(results[0].relevance, 1.0);
239 }
240
241 #[test]
242 fn test_search_by_name_prefix_match() {
243 let mut index = SemanticIndex::new();
244 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
245 index.add_symbol(symbol);
246
247 let results = index.search_by_name("my_");
248 assert_eq!(results.len(), 1);
249 assert_eq!(results[0].relevance, 0.8);
250 }
251
252 #[test]
253 fn test_search_by_name_substring_match() {
254 let mut index = SemanticIndex::new();
255 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
256 index.add_symbol(symbol);
257
258 let results = index.search_by_name("function");
259 assert_eq!(results.len(), 1);
260 assert_eq!(results[0].relevance, 0.5);
261 }
262
263 #[test]
264 fn test_search_by_kind() {
265 let mut index = SemanticIndex::new();
266 let func = create_test_symbol("sym1", "my_function", SymbolKind::Function);
267 let class = create_test_symbol("sym2", "MyClass", SymbolKind::Class);
268
269 index.add_symbol(func);
270 index.add_symbol(class);
271
272 let results = index.search_by_kind(SymbolKind::Function);
273 assert_eq!(results.len(), 1);
274 assert_eq!(results[0].kind, SymbolKind::Function);
275 }
276
277 #[test]
278 fn test_add_reference() {
279 let mut index = SemanticIndex::new();
280 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
281 index.add_symbol(symbol);
282
283 let reference = SymbolReference {
284 symbol_id: "sym1".to_string(),
285 file: PathBuf::from("test.rs"),
286 line: 5,
287 kind: crate::models::ReferenceKind::Usage,
288 };
289
290 index.add_reference(reference);
291
292 let refs = index.get_references_to_symbol("sym1");
293 assert_eq!(refs.len(), 1);
294 assert_eq!(refs[0].line, 5);
295 }
296
297 #[test]
298 fn test_symbol_count() {
299 let mut index = SemanticIndex::new();
300 let symbol1 = create_test_symbol("sym1", "func1", SymbolKind::Function);
301 let symbol2 = create_test_symbol("sym2", "func2", SymbolKind::Function);
302
303 index.add_symbol(symbol1);
304 index.add_symbol(symbol2);
305
306 assert_eq!(index.symbol_count(), 2);
307 }
308
309 #[test]
310 fn test_reference_count() {
311 let mut index = SemanticIndex::new();
312 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
313 index.add_symbol(symbol);
314
315 let ref1 = SymbolReference {
316 symbol_id: "sym1".to_string(),
317 file: PathBuf::from("test.rs"),
318 line: 5,
319 kind: crate::models::ReferenceKind::Usage,
320 };
321
322 let ref2 = SymbolReference {
323 symbol_id: "sym1".to_string(),
324 file: PathBuf::from("test.rs"),
325 line: 10,
326 kind: crate::models::ReferenceKind::Usage,
327 };
328
329 index.add_reference(ref1);
330 index.add_reference(ref2);
331
332 assert_eq!(index.reference_count(), 2);
333 }
334
335 #[test]
336 fn test_clear() {
337 let mut index = SemanticIndex::new();
338 let symbol = create_test_symbol("sym1", "my_function", SymbolKind::Function);
339 index.add_symbol(symbol);
340
341 assert_eq!(index.symbol_count(), 1);
342
343 index.clear();
344
345 assert_eq!(index.symbol_count(), 0);
346 }
347
348 #[test]
349 fn test_all_symbols() {
350 let mut index = SemanticIndex::new();
351 let symbol1 = create_test_symbol("sym1", "func1", SymbolKind::Function);
352 let symbol2 = create_test_symbol("sym2", "func2", SymbolKind::Function);
353
354 index.add_symbol(symbol1);
355 index.add_symbol(symbol2);
356
357 let all = index.all_symbols();
358 assert_eq!(all.len(), 2);
359 }
360}