1use crate::embed::Embedder;
2use crate::parser::{DefKind, ParsedFile};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6
7pub type SymbolId = usize;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Symbol {
11 pub id: SymbolId,
12 pub name: String,
13 pub kind: DefKind,
14 pub file: PathBuf,
15 pub line: usize,
16 pub end_line: usize,
17 pub doc: Option<String>,
18 pub embedding: Option<Vec<f32>>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CallEdge {
23 pub caller_name: String,
24 pub caller_file: PathBuf,
25 pub caller_line: usize,
26 pub callee_name: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ImportEdge {
31 pub file: PathBuf,
32 pub symbol_name: String,
33 pub resolved_to: Option<SymbolId>,
34 pub resolved_file: Option<PathBuf>,
35 pub resolved_line: Option<usize>,
36 pub resolved_kind: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct CodeIndex {
41 pub symbols: Vec<Symbol>,
42 pub calls: Vec<CallEdge>,
43 pub imports: Vec<ImportEdge>,
44 pub files: Vec<PathBuf>,
45 pub root: PathBuf,
46
47 by_name: HashMap<String, Vec<SymbolId>>,
49 by_file: HashMap<PathBuf, Vec<SymbolId>>,
51}
52
53impl CodeIndex {
54 pub fn build(
55 parsed: Vec<ParsedFile>,
56 root: &Path,
57 embedder: Option<&dyn Embedder>,
58 ) -> Self {
59 let root = root.to_path_buf();
60 let mut idx = CodeIndex {
61 symbols: Vec::new(),
62 calls: Vec::new(),
63 imports: Vec::new(),
64 files: Vec::new(),
65 by_name: HashMap::new(),
66 by_file: HashMap::new(),
67 root,
68 };
69
70 for pf in &parsed {
71 idx.add_file(pf);
72 }
73
74 if let Some(embedder) = embedder {
75 idx.compute_embeddings(embedder);
76 }
77
78 idx.resolve_caller_names();
79 idx.resolve_imports();
80 idx
81 }
82
83 fn compute_embeddings(&mut self, embedder: &dyn Embedder) {
84 let texts: Vec<String> = self
85 .symbols
86 .iter()
87 .map(|s| {
88 let mut t = format!("{}: {:?}", s.name, s.kind);
89 if let Some(ref doc) = s.doc {
90 t.push('\n');
91 t.push_str(doc);
92 }
93 t
94 })
95 .collect();
96 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
97 if text_refs.is_empty() {
98 return;
99 }
100 if let Ok(embeddings) = embedder.embed(&text_refs) {
101 for (sym, emb) in self.symbols.iter_mut().zip(embeddings) {
102 sym.embedding = Some(emb);
103 }
104 }
105 }
106
107 fn add_file(&mut self, pf: &ParsedFile) {
108 if !self.files.contains(&pf.path) {
109 self.files.push(pf.path.clone());
110 }
111
112 for def in &pf.definitions {
113 let id = self.symbols.len();
114 self.symbols.push(Symbol {
115 id,
116 name: def.name.clone(),
117 kind: def.kind,
118 file: pf.path.clone(),
119 line: def.start_line,
120 end_line: def.end_line,
121 doc: def.doc.clone(),
122 embedding: None,
123 });
124 self.by_name
125 .entry(def.name.clone())
126 .or_default()
127 .push(id);
128 self.by_file
129 .entry(pf.path.clone())
130 .or_default()
131 .push(id);
132 }
133
134 for rf in &pf.references {
135 self.calls.push(CallEdge {
136 caller_name: String::new(),
137 caller_file: pf.path.clone(),
138 caller_line: rf.line,
139 callee_name: rf.name.clone(),
140 });
141 }
142
143 for imp in &pf.imports {
144 self.imports.push(ImportEdge {
145 file: pf.path.clone(),
146 symbol_name: imp.name.clone(),
147 resolved_to: None,
148 resolved_file: None,
149 resolved_line: None,
150 resolved_kind: None,
151 });
152 }
153 }
154
155 fn resolve_caller_names(&mut self) {
156 for call in &mut self.calls {
157 let Some(sym_ids) = self.by_file.get(&call.caller_file) else {
158 continue;
159 };
160 for &sym_id in sym_ids {
161 let Some(sym) = self.symbols.get(sym_id) else {
162 continue;
163 };
164 if sym.line <= call.caller_line && call.caller_line <= sym.end_line {
165 call.caller_name = sym.name.clone();
166 break;
167 }
168 }
169 }
170 }
171
172 fn resolve_imports(&mut self) {
173 for imp in &mut self.imports {
174 let Some(sym_ids) = self.by_name.get(&imp.symbol_name) else {
175 continue;
176 };
177 let resolved = sym_ids
179 .iter()
180 .filter_map(|id| self.symbols.get(*id))
181 .find(|s| s.file != imp.file)
182 .or_else(|| {
183 sym_ids
184 .iter()
185 .filter_map(|id| self.symbols.get(*id))
186 .next()
187 });
188 if let Some(sym) = resolved {
189 imp.resolved_to = Some(sym.id);
190 imp.resolved_file = Some(sym.file.clone());
191 imp.resolved_line = Some(sym.line);
192 imp.resolved_kind = Some(format!("{:?}", sym.kind).to_lowercase());
193 }
194 }
195 }
196
197 pub fn save(&self, path: &Path) -> anyhow::Result<()> {
198 if let Some(parent) = path.parent() {
199 std::fs::create_dir_all(parent)?;
200 }
201 let bytes = bincode::serialize(self)?;
202 std::fs::write(path, bytes)?;
203 Ok(())
204 }
205
206 pub fn load(path: &Path) -> anyhow::Result<Self> {
207 let bytes = std::fs::read(path)?;
208 let idx: CodeIndex = bincode::deserialize(&bytes)?;
209 Ok(idx)
210 }
211
212 pub fn find_symbols_by_name(&self, name: &str) -> Vec<&Symbol> {
213 self.by_name
214 .get(name)
215 .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
216 .unwrap_or_default()
217 }
218
219 pub fn find_symbols_by_pattern(&self, pattern: &str) -> Vec<&Symbol> {
220 let lower = pattern.to_lowercase();
221 self.symbols
222 .iter()
223 .filter(|s| s.name.to_lowercase().contains(&lower))
224 .collect()
225 }
226
227 pub fn find_calls_to(&self, name: &str) -> Vec<&CallEdge> {
228 self.calls
229 .iter()
230 .filter(|c| c.callee_name == name)
231 .collect()
232 }
233
234 pub fn find_calls_by(&self, name: &str) -> Vec<&CallEdge> {
235 self.calls
236 .iter()
237 .filter(|c| c.caller_name == name)
238 .collect()
239 }
240
241 pub fn find_implementations(&self, name: &str) -> Vec<&Symbol> {
242 self.symbols
243 .iter()
244 .filter(|s| s.kind == DefKind::Impl && s.name == name)
245 .collect()
246 }
247
248 pub fn find_symbols_in_file(&self, file: &Path) -> Vec<&Symbol> {
249 self.by_file
250 .get(file)
251 .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
252 .unwrap_or_default()
253 }
254
255 pub fn relative_path(&self, path: &Path) -> String {
256 path.strip_prefix(&self.root)
257 .unwrap_or(path)
258 .to_string_lossy()
259 .to_string()
260 }
261
262 pub fn find_imports_in_file(&self, file: &Path) -> Vec<&ImportEdge> {
263 self.imports
264 .iter()
265 .filter(|i| i.file == file)
266 .collect()
267 }
268
269 pub fn find_importers_of(&self, name: &str) -> Vec<&ImportEdge> {
270 self.imports
271 .iter()
272 .filter(|i| {
273 i.resolved_to
274 .and_then(|id| self.symbols.get(id))
275 .is_some_and(|s| s.name == name)
276 })
277 .collect()
278 }
279
280 pub fn semantic_search(
281 &self,
282 query_embed: &[f32],
283 k: usize,
284 ) -> Vec<(f64, &Symbol)> {
285 let mut scores: Vec<(f64, &Symbol)> = self
286 .symbols
287 .iter()
288 .filter_map(|s| s.embedding.as_ref().map(|e| (cosine_similarity(query_embed, e), s)))
289 .collect();
290 scores.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
291 scores.truncate(k);
292 scores
293 }
294}
295
296fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
297 let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
298 let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
299 let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
300 if na == 0.0 || nb == 0.0 {
301 0.0
302 } else {
303 dot / (na * nb)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::parser::{DefKind, ParsedDef, ParsedFile, ParsedImport, ParsedRef, RefKind};
311
312 fn make_file(
313 path: &str,
314 defs: Vec<(&str, DefKind, usize, usize)>,
315 refs: Vec<(&str, usize)>,
316 imports: Vec<&str>,
317 ) -> ParsedFile {
318 ParsedFile {
319 path: PathBuf::from(path),
320 language: crate::parser::LanguageId::Rust,
321 definitions: defs
322 .into_iter()
323 .map(|(name, kind, start_line, end_line)| ParsedDef {
324 name: name.to_string(),
325 kind,
326 start_line,
327 end_line,
328 doc: None,
329 })
330 .collect(),
331 references: refs
332 .into_iter()
333 .map(|(name, line)| ParsedRef {
334 name: name.to_string(),
335 kind: RefKind::Call,
336 line,
337 })
338 .collect(),
339 imports: imports
340 .into_iter()
341 .map(|name| ParsedImport {
342 name: name.to_string(),
343 })
344 .collect(),
345 }
346 }
347
348 #[test]
349 fn test_build_empty_index() {
350 let index = CodeIndex::build(vec![], Path::new("/root"), None);
351 assert_eq!(index.symbols.len(), 0);
352 assert_eq!(index.calls.len(), 0);
353 assert_eq!(index.imports.len(), 0);
354 assert_eq!(index.files.len(), 0);
355 }
356
357 #[test]
358 fn test_build_index_with_symbols() {
359 let files = vec![make_file(
360 "src/main.rs",
361 vec![("main", DefKind::Function, 1, 5)],
362 vec![],
363 vec![],
364 )];
365 let index = CodeIndex::build(files, Path::new("/root"), None);
366 assert_eq!(index.symbols.len(), 1);
367 assert_eq!(index.symbols[0].name, "main");
368 assert_eq!(index.symbols[0].kind, DefKind::Function);
369 assert_eq!(index.symbols[0].line, 1);
370 assert_eq!(index.symbols[0].end_line, 5);
371 }
372
373 #[test]
374 fn test_find_symbols_by_name() {
375 let files = vec![make_file(
376 "src/lib.rs",
377 vec![("foo", DefKind::Function, 1, 3), ("bar", DefKind::Function, 5, 7)],
378 vec![],
379 vec![],
380 )];
381 let index = CodeIndex::build(files, Path::new("/root"), None);
382 let found = index.find_symbols_by_name("foo");
383 assert_eq!(found.len(), 1);
384 assert_eq!(found[0].name, "foo");
385 }
386
387 #[test]
388 fn test_find_symbols_by_pattern() {
389 let files = vec![make_file(
390 "src/lib.rs",
391 vec![
392 ("calculate_revenue", DefKind::Function, 1, 3),
393 ("calculate_expenses", DefKind::Function, 5, 7),
394 ("print_report", DefKind::Function, 9, 11),
395 ],
396 vec![],
397 vec![],
398 )];
399 let index = CodeIndex::build(files, Path::new("/root"), None);
400 let found = index.find_symbols_by_pattern("calculate");
401 assert_eq!(found.len(), 2);
402 }
403
404 #[test]
405 fn test_calls_are_recorded() {
406 let files = vec![make_file(
407 "src/main.rs",
408 vec![("run", DefKind::Function, 1, 10)],
409 vec![("helper", 3),("other", 5)],
410 vec![],
411 )];
412 let index = CodeIndex::build(files, Path::new("/root"), None);
413 assert_eq!(index.calls.len(), 2);
414 }
415
416 #[test]
417 fn test_imports_are_recorded() {
418 let files = vec![make_file(
419 "src/main.rs",
420 vec![],
421 vec![],
422 vec!["HashMap", "Vec"],
423 )];
424 let index = CodeIndex::build(files, Path::new("/root"), None);
425 assert_eq!(index.imports.len(), 2);
426 assert_eq!(index.imports[0].symbol_name, "HashMap");
427 assert!(index.imports[0].resolved_to.is_none());
429 }
430
431 #[test]
432 fn test_import_resolution() {
433 let files = vec![
434 make_file(
435 "src/lib.rs",
436 vec![("HashMap", DefKind::Struct, 10, 30)],
437 vec![],
438 vec![],
439 ),
440 make_file(
441 "src/main.rs",
442 vec![("main", DefKind::Function, 1, 5)],
443 vec![],
444 vec!["HashMap"],
445 ),
446 ];
447 let index = CodeIndex::build(files, Path::new("/root"), None);
448 let imports = index.find_imports_in_file(Path::new("src/main.rs"));
449 assert_eq!(imports.len(), 1);
450 let imp = imports[0];
451 assert!(imp.resolved_to.is_some());
452 assert_eq!(imp.resolved_file.as_deref(), Some(Path::new("src/lib.rs")));
453 assert_eq!(imp.resolved_line, Some(10));
454 assert_eq!(imp.resolved_kind.as_deref(), Some("struct"));
455 }
456
457 #[test]
458 fn test_save_and_load_roundtrip() -> anyhow::Result<()> {
459 let files = vec![make_file(
460 "src/main.rs",
461 vec![("main", DefKind::Function, 1, 10)],
462 vec![("helper", 5)],
463 vec!["std::fs"],
464 )];
465 let index = CodeIndex::build(files, Path::new("/root"), None);
466
467 let tmp = std::env::temp_dir().join("sift_test_index.bin");
468 index.save(&tmp)?;
469 let loaded = CodeIndex::load(&tmp)?;
470 std::fs::remove_file(&tmp)?;
471
472 assert_eq!(loaded.symbols.len(), 1);
473 assert_eq!(loaded.symbols[0].name, "main");
474 assert_eq!(loaded.calls.len(), 1);
475 assert_eq!(loaded.imports.len(), 1);
476 assert_eq!(loaded.imports[0].resolved_to, None);
478 Ok(())
479 }
480
481 #[test]
482 fn test_multiple_files_index() {
483 let files = vec![
484 make_file(
485 "src/main.rs",
486 vec![("main", DefKind::Function, 1, 10)],
487 vec![("helper", 3)],
488 vec![],
489 ),
490 make_file(
491 "src/helper.rs",
492 vec![("helper", DefKind::Function, 1, 5)],
493 vec![],
494 vec![],
495 ),
496 ];
497 let index = CodeIndex::build(files, Path::new("/root"), None);
498 assert_eq!(index.symbols.len(), 2);
499 assert_eq!(index.files.len(), 2);
500 }
501
502 #[test]
503 fn test_find_implementations() {
504 let files = vec![make_file(
505 "src/main.rs",
506 vec![
507 ("Iterator", DefKind::Trait, 1, 3),
508 ("Iterator", DefKind::Impl, 5, 20),
509 ],
510 vec![],
511 vec![],
512 )];
513 let index = CodeIndex::build(files, Path::new("/root"), None);
514 let impls = index.find_implementations("Iterator");
515 assert_eq!(impls.len(), 1);
516 assert_eq!(impls[0].kind, DefKind::Impl);
517 }
518
519 #[test]
520 fn test_relative_path() {
521 let files = vec![make_file(
522 "/root/src/main.rs",
523 vec![],
524 vec![],
525 vec![],
526 )];
527 let index = CodeIndex::build(files, Path::new("/root"), None);
528 assert_eq!(index.relative_path(Path::new("/root/src/main.rs")), "src/main.rs");
529 }
530}