1use crate::embed::Embedder;
2use crate::parser::{DefKind, ParsedFile};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6
7pub type SymbolId = usize;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Symbol {
11 pub id: SymbolId,
12 pub name: String,
13 pub kind: DefKind,
14 pub file: PathBuf,
15 pub line: usize,
16 pub end_line: usize,
17 pub doc: Option<String>,
18 pub embedding: Option<Vec<f32>>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CallEdge {
23 pub caller_name: String,
24 pub caller_file: PathBuf,
25 pub caller_line: usize,
26 pub callee_name: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ImportEdge {
31 pub file: PathBuf,
32 pub symbol_name: String,
33 pub resolved_to: Option<SymbolId>,
34 pub resolved_file: Option<PathBuf>,
35 pub resolved_line: Option<usize>,
36 pub resolved_kind: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct CodeIndex {
41 pub symbols: Vec<Symbol>,
42 pub calls: Vec<CallEdge>,
43 pub imports: Vec<ImportEdge>,
44 pub files: Vec<PathBuf>,
45 pub root: PathBuf,
46
47 by_name: HashMap<String, Vec<SymbolId>>,
49 by_file: HashMap<PathBuf, Vec<SymbolId>>,
51}
52
53impl CodeIndex {
54 pub fn build(
55 parsed: Vec<ParsedFile>,
56 root: &Path,
57 embedder: Option<&dyn Embedder>,
58 ) -> Self {
59 let root = root.to_path_buf();
60 let mut idx = CodeIndex {
61 symbols: Vec::new(),
62 calls: Vec::new(),
63 imports: Vec::new(),
64 files: Vec::new(),
65 by_name: HashMap::new(),
66 by_file: HashMap::new(),
67 root,
68 };
69
70 for pf in &parsed {
71 idx.add_file(pf);
72 }
73
74 if let Some(embedder) = embedder {
75 idx.compute_embeddings(embedder);
76 }
77
78 idx.resolve_caller_names();
79 idx.resolve_imports();
80 idx
81 }
82
83 fn compute_embeddings(&mut self, embedder: &dyn Embedder) {
84 let texts: Vec<String> = self
85 .symbols
86 .iter()
87 .map(|s| {
88 let mut t = format!("{}: {:?}", s.name, s.kind);
89 if let Some(ref doc) = s.doc {
90 t.push('\n');
91 t.push_str(doc);
92 }
93 t
94 })
95 .collect();
96 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
97 if text_refs.is_empty() {
98 return;
99 }
100 match embedder.embed(&text_refs) {
101 Ok(embeddings) => {
102 for (sym, emb) in self.symbols.iter_mut().zip(embeddings) {
103 sym.embedding = Some(emb);
104 }
105 }
106 Err(e) => {
107 eprintln!("warn: embedding computation failed: {:#}", e);
108 }
109 }
110 }
111
112 fn add_file(&mut self, pf: &ParsedFile) {
113 if !self.files.contains(&pf.path) {
114 self.files.push(pf.path.clone());
115 }
116
117 for def in &pf.definitions {
118 let id = self.symbols.len();
119 self.symbols.push(Symbol {
120 id,
121 name: def.name.clone(),
122 kind: def.kind,
123 file: pf.path.clone(),
124 line: def.start_line,
125 end_line: def.end_line,
126 doc: def.doc.clone(),
127 embedding: None,
128 });
129 self.by_name
130 .entry(def.name.clone())
131 .or_default()
132 .push(id);
133 self.by_file
134 .entry(pf.path.clone())
135 .or_default()
136 .push(id);
137 }
138
139 for rf in &pf.references {
140 self.calls.push(CallEdge {
141 caller_name: String::new(),
142 caller_file: pf.path.clone(),
143 caller_line: rf.line,
144 callee_name: rf.name.clone(),
145 });
146 }
147
148 for imp in &pf.imports {
149 self.imports.push(ImportEdge {
150 file: pf.path.clone(),
151 symbol_name: imp.name.clone(),
152 resolved_to: None,
153 resolved_file: None,
154 resolved_line: None,
155 resolved_kind: None,
156 });
157 }
158 }
159
160 fn resolve_caller_names(&mut self) {
161 for call in &mut self.calls {
162 let Some(sym_ids) = self.by_file.get(&call.caller_file) else {
163 continue;
164 };
165 for &sym_id in sym_ids {
166 let Some(sym) = self.symbols.get(sym_id) else {
167 continue;
168 };
169 if sym.line <= call.caller_line && call.caller_line <= sym.end_line {
170 call.caller_name = sym.name.clone();
171 break;
172 }
173 }
174 }
175 }
176
177 fn resolve_imports(&mut self) {
178 for imp in &mut self.imports {
179 let Some(sym_ids) = self.by_name.get(&imp.symbol_name) else {
180 continue;
181 };
182 let resolved = sym_ids
184 .iter()
185 .filter_map(|id| self.symbols.get(*id))
186 .find(|s| s.file != imp.file)
187 .or_else(|| {
188 sym_ids
189 .iter()
190 .filter_map(|id| self.symbols.get(*id))
191 .next()
192 });
193 if let Some(sym) = resolved {
194 imp.resolved_to = Some(sym.id);
195 imp.resolved_file = Some(sym.file.clone());
196 imp.resolved_line = Some(sym.line);
197 imp.resolved_kind = Some(format!("{:?}", sym.kind).to_lowercase());
198 }
199 }
200 }
201
202 pub fn save(&self, path: &Path) -> anyhow::Result<()> {
203 if let Some(parent) = path.parent() {
204 std::fs::create_dir_all(parent)?;
205 }
206 let bytes = bincode::serialize(self)?;
207 std::fs::write(path, bytes)?;
208 Ok(())
209 }
210
211 pub fn load(path: &Path) -> anyhow::Result<Self> {
212 let bytes = std::fs::read(path)?;
213 let idx: CodeIndex = bincode::deserialize(&bytes)?;
214 Ok(idx)
215 }
216
217 pub fn find_symbols_by_name(&self, name: &str) -> Vec<&Symbol> {
218 self.by_name
219 .get(name)
220 .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
221 .unwrap_or_default()
222 }
223
224 pub fn find_symbols_by_pattern(&self, pattern: &str) -> Vec<&Symbol> {
225 let lower = pattern.to_lowercase();
226 self.symbols
227 .iter()
228 .filter(|s| s.name.to_lowercase().contains(&lower))
229 .collect()
230 }
231
232 pub fn find_calls_to(&self, name: &str) -> Vec<&CallEdge> {
233 self.calls
234 .iter()
235 .filter(|c| c.callee_name == name)
236 .collect()
237 }
238
239 pub fn find_calls_by(&self, name: &str) -> Vec<&CallEdge> {
240 self.calls
241 .iter()
242 .filter(|c| c.caller_name == name)
243 .collect()
244 }
245
246 pub fn find_implementations(&self, name: &str) -> Vec<&Symbol> {
247 self.symbols
248 .iter()
249 .filter(|s| s.kind == DefKind::Impl && s.name == name)
250 .collect()
251 }
252
253 pub fn find_symbols_in_file(&self, file: &Path) -> Vec<&Symbol> {
254 self.by_file
255 .get(file)
256 .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
257 .unwrap_or_default()
258 }
259
260 pub fn relative_path(&self, path: &Path) -> String {
261 path.strip_prefix(&self.root)
262 .unwrap_or(path)
263 .to_string_lossy()
264 .to_string()
265 }
266
267 pub fn find_imports_in_file(&self, file: &Path) -> Vec<&ImportEdge> {
268 self.imports
269 .iter()
270 .filter(|i| i.file == file)
271 .collect()
272 }
273
274 pub fn find_importers_of(&self, name: &str) -> Vec<&ImportEdge> {
275 self.imports
276 .iter()
277 .filter(|i| {
278 i.resolved_to
279 .and_then(|id| self.symbols.get(id))
280 .is_some_and(|s| s.name == name)
281 })
282 .collect()
283 }
284
285 pub fn semantic_search(
286 &self,
287 query_embed: &[f32],
288 k: usize,
289 ) -> Vec<(f64, &Symbol)> {
290 let mut scores: Vec<(f64, &Symbol)> = self
291 .symbols
292 .iter()
293 .filter_map(|s| s.embedding.as_ref().map(|e| (cosine_similarity(query_embed, e), s)))
294 .collect();
295 scores.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
296 scores.truncate(k);
297 scores
298 }
299}
300
301fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
302 let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
303 let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
304 let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
305 if na == 0.0 || nb == 0.0 {
306 0.0
307 } else {
308 dot / (na * nb)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::parser::{DefKind, ParsedDef, ParsedFile, ParsedImport, ParsedRef, RefKind};
316
317 fn make_file(
318 path: &str,
319 defs: Vec<(&str, DefKind, usize, usize)>,
320 refs: Vec<(&str, usize)>,
321 imports: Vec<&str>,
322 ) -> ParsedFile {
323 ParsedFile {
324 path: PathBuf::from(path),
325 language: crate::parser::LanguageId::Rust,
326 definitions: defs
327 .into_iter()
328 .map(|(name, kind, start_line, end_line)| ParsedDef {
329 name: name.to_string(),
330 kind,
331 start_line,
332 end_line,
333 doc: None,
334 })
335 .collect(),
336 references: refs
337 .into_iter()
338 .map(|(name, line)| ParsedRef {
339 name: name.to_string(),
340 kind: RefKind::Call,
341 line,
342 })
343 .collect(),
344 imports: imports
345 .into_iter()
346 .map(|name| ParsedImport {
347 name: name.to_string(),
348 })
349 .collect(),
350 }
351 }
352
353 #[test]
354 fn test_build_empty_index() {
355 let index = CodeIndex::build(vec![], Path::new("/root"), None);
356 assert_eq!(index.symbols.len(), 0);
357 assert_eq!(index.calls.len(), 0);
358 assert_eq!(index.imports.len(), 0);
359 assert_eq!(index.files.len(), 0);
360 }
361
362 #[test]
363 fn test_build_index_with_symbols() {
364 let files = vec![make_file(
365 "src/main.rs",
366 vec![("main", DefKind::Function, 1, 5)],
367 vec![],
368 vec![],
369 )];
370 let index = CodeIndex::build(files, Path::new("/root"), None);
371 assert_eq!(index.symbols.len(), 1);
372 assert_eq!(index.symbols[0].name, "main");
373 assert_eq!(index.symbols[0].kind, DefKind::Function);
374 assert_eq!(index.symbols[0].line, 1);
375 assert_eq!(index.symbols[0].end_line, 5);
376 }
377
378 #[test]
379 fn test_find_symbols_by_name() {
380 let files = vec![make_file(
381 "src/lib.rs",
382 vec![("foo", DefKind::Function, 1, 3), ("bar", DefKind::Function, 5, 7)],
383 vec![],
384 vec![],
385 )];
386 let index = CodeIndex::build(files, Path::new("/root"), None);
387 let found = index.find_symbols_by_name("foo");
388 assert_eq!(found.len(), 1);
389 assert_eq!(found[0].name, "foo");
390 }
391
392 #[test]
393 fn test_find_symbols_by_pattern() {
394 let files = vec![make_file(
395 "src/lib.rs",
396 vec![
397 ("calculate_revenue", DefKind::Function, 1, 3),
398 ("calculate_expenses", DefKind::Function, 5, 7),
399 ("print_report", DefKind::Function, 9, 11),
400 ],
401 vec![],
402 vec![],
403 )];
404 let index = CodeIndex::build(files, Path::new("/root"), None);
405 let found = index.find_symbols_by_pattern("calculate");
406 assert_eq!(found.len(), 2);
407 }
408
409 #[test]
410 fn test_calls_are_recorded() {
411 let files = vec![make_file(
412 "src/main.rs",
413 vec![("run", DefKind::Function, 1, 10)],
414 vec![("helper", 3),("other", 5)],
415 vec![],
416 )];
417 let index = CodeIndex::build(files, Path::new("/root"), None);
418 assert_eq!(index.calls.len(), 2);
419 }
420
421 #[test]
422 fn test_imports_are_recorded() {
423 let files = vec![make_file(
424 "src/main.rs",
425 vec![],
426 vec![],
427 vec!["HashMap", "Vec"],
428 )];
429 let index = CodeIndex::build(files, Path::new("/root"), None);
430 assert_eq!(index.imports.len(), 2);
431 assert_eq!(index.imports[0].symbol_name, "HashMap");
432 assert!(index.imports[0].resolved_to.is_none());
434 }
435
436 #[test]
437 fn test_import_resolution() {
438 let files = vec![
439 make_file(
440 "src/lib.rs",
441 vec![("HashMap", DefKind::Struct, 10, 30)],
442 vec![],
443 vec![],
444 ),
445 make_file(
446 "src/main.rs",
447 vec![("main", DefKind::Function, 1, 5)],
448 vec![],
449 vec!["HashMap"],
450 ),
451 ];
452 let index = CodeIndex::build(files, Path::new("/root"), None);
453 let imports = index.find_imports_in_file(Path::new("src/main.rs"));
454 assert_eq!(imports.len(), 1);
455 let imp = imports[0];
456 assert!(imp.resolved_to.is_some());
457 assert_eq!(imp.resolved_file.as_deref(), Some(Path::new("src/lib.rs")));
458 assert_eq!(imp.resolved_line, Some(10));
459 assert_eq!(imp.resolved_kind.as_deref(), Some("struct"));
460 }
461
462 #[test]
463 fn test_save_and_load_roundtrip() -> anyhow::Result<()> {
464 let files = vec![make_file(
465 "src/main.rs",
466 vec![("main", DefKind::Function, 1, 10)],
467 vec![("helper", 5)],
468 vec!["std::fs"],
469 )];
470 let index = CodeIndex::build(files, Path::new("/root"), None);
471
472 let tmp = std::env::temp_dir().join("sift_test_index.bin");
473 index.save(&tmp)?;
474 let loaded = CodeIndex::load(&tmp)?;
475 std::fs::remove_file(&tmp)?;
476
477 assert_eq!(loaded.symbols.len(), 1);
478 assert_eq!(loaded.symbols[0].name, "main");
479 assert_eq!(loaded.calls.len(), 1);
480 assert_eq!(loaded.imports.len(), 1);
481 assert_eq!(loaded.imports[0].resolved_to, None);
483 Ok(())
484 }
485
486 #[test]
487 fn test_multiple_files_index() {
488 let files = vec![
489 make_file(
490 "src/main.rs",
491 vec![("main", DefKind::Function, 1, 10)],
492 vec![("helper", 3)],
493 vec![],
494 ),
495 make_file(
496 "src/helper.rs",
497 vec![("helper", DefKind::Function, 1, 5)],
498 vec![],
499 vec![],
500 ),
501 ];
502 let index = CodeIndex::build(files, Path::new("/root"), None);
503 assert_eq!(index.symbols.len(), 2);
504 assert_eq!(index.files.len(), 2);
505 }
506
507 #[test]
508 fn test_find_implementations() {
509 let files = vec![make_file(
510 "src/main.rs",
511 vec![
512 ("Iterator", DefKind::Trait, 1, 3),
513 ("Iterator", DefKind::Impl, 5, 20),
514 ],
515 vec![],
516 vec![],
517 )];
518 let index = CodeIndex::build(files, Path::new("/root"), None);
519 let impls = index.find_implementations("Iterator");
520 assert_eq!(impls.len(), 1);
521 assert_eq!(impls[0].kind, DefKind::Impl);
522 }
523
524 #[test]
525 fn test_relative_path() {
526 let files = vec![make_file(
527 "/root/src/main.rs",
528 vec![],
529 vec![],
530 vec![],
531 )];
532 let index = CodeIndex::build(files, Path::new("/root"), None);
533 assert_eq!(index.relative_path(Path::new("/root/src/main.rs")), "src/main.rs");
534 }
535}