Skip to main content

sift/
index.rs

1use crate::embed::Embedder;
2use crate::parser::{DefKind, ParsedFile};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6
7pub type SymbolId = usize;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Symbol {
11    pub id: SymbolId,
12    pub name: String,
13    pub kind: DefKind,
14    pub file: PathBuf,
15    pub line: usize,
16    pub end_line: usize,
17    pub doc: Option<String>,
18    pub embedding: Option<Vec<f32>>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CallEdge {
23    pub caller_name: String,
24    pub caller_file: PathBuf,
25    pub caller_line: usize,
26    pub callee_name: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ImportEdge {
31    pub file: PathBuf,
32    pub symbol_name: String,
33    pub resolved_to: Option<SymbolId>,
34    pub resolved_file: Option<PathBuf>,
35    pub resolved_line: Option<usize>,
36    pub resolved_kind: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct CodeIndex {
41    pub symbols: Vec<Symbol>,
42    pub calls: Vec<CallEdge>,
43    pub imports: Vec<ImportEdge>,
44    pub files: Vec<PathBuf>,
45    pub root: PathBuf,
46
47    // name -> symbol IDs (for fast lookup)
48    by_name: HashMap<String, Vec<SymbolId>>,
49    // file -> symbol IDs
50    by_file: HashMap<PathBuf, Vec<SymbolId>>,
51}
52
53impl CodeIndex {
54    pub fn build(
55        parsed: Vec<ParsedFile>,
56        root: &Path,
57        embedder: Option<&dyn Embedder>,
58    ) -> Self {
59        let root = root.to_path_buf();
60        let mut idx = CodeIndex {
61            symbols: Vec::new(),
62            calls: Vec::new(),
63            imports: Vec::new(),
64            files: Vec::new(),
65            by_name: HashMap::new(),
66            by_file: HashMap::new(),
67            root,
68        };
69
70        for pf in &parsed {
71            idx.add_file(pf);
72        }
73
74        if let Some(embedder) = embedder {
75            idx.compute_embeddings(embedder);
76        }
77
78        idx.resolve_caller_names();
79        idx.resolve_imports();
80        idx
81    }
82
83    fn compute_embeddings(&mut self, embedder: &dyn Embedder) {
84        let texts: Vec<String> = self
85            .symbols
86            .iter()
87            .map(|s| {
88                let mut t = format!("{}: {:?}", s.name, s.kind);
89                if let Some(ref doc) = s.doc {
90                    t.push('\n');
91                    t.push_str(doc);
92                }
93                t
94            })
95            .collect();
96        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
97        if text_refs.is_empty() {
98            return;
99        }
100        if let Ok(embeddings) = embedder.embed(&text_refs) {
101            for (sym, emb) in self.symbols.iter_mut().zip(embeddings) {
102                sym.embedding = Some(emb);
103            }
104        }
105    }
106
107    fn add_file(&mut self, pf: &ParsedFile) {
108        if !self.files.contains(&pf.path) {
109            self.files.push(pf.path.clone());
110        }
111
112        for def in &pf.definitions {
113            let id = self.symbols.len();
114            self.symbols.push(Symbol {
115                id,
116                name: def.name.clone(),
117                kind: def.kind,
118                file: pf.path.clone(),
119                line: def.start_line,
120                end_line: def.end_line,
121                doc: def.doc.clone(),
122                embedding: None,
123            });
124            self.by_name
125                .entry(def.name.clone())
126                .or_default()
127                .push(id);
128            self.by_file
129                .entry(pf.path.clone())
130                .or_default()
131                .push(id);
132        }
133
134        for rf in &pf.references {
135            self.calls.push(CallEdge {
136                caller_name: String::new(),
137                caller_file: pf.path.clone(),
138                caller_line: rf.line,
139                callee_name: rf.name.clone(),
140            });
141        }
142
143        for imp in &pf.imports {
144            self.imports.push(ImportEdge {
145                file: pf.path.clone(),
146                symbol_name: imp.name.clone(),
147                resolved_to: None,
148                resolved_file: None,
149                resolved_line: None,
150                resolved_kind: None,
151            });
152        }
153    }
154
155    fn resolve_caller_names(&mut self) {
156        for call in &mut self.calls {
157            let Some(sym_ids) = self.by_file.get(&call.caller_file) else {
158                continue;
159            };
160            for &sym_id in sym_ids {
161                let Some(sym) = self.symbols.get(sym_id) else {
162                    continue;
163                };
164                if sym.line <= call.caller_line && call.caller_line <= sym.end_line {
165                    call.caller_name = sym.name.clone();
166                    break;
167                }
168            }
169        }
170    }
171
172    fn resolve_imports(&mut self) {
173        for imp in &mut self.imports {
174            let Some(sym_ids) = self.by_name.get(&imp.symbol_name) else {
175                continue;
176            };
177            // Prefer a definition from a different file than the import
178            let resolved = sym_ids
179                .iter()
180                .filter_map(|id| self.symbols.get(*id))
181                .find(|s| s.file != imp.file)
182                .or_else(|| {
183                    sym_ids
184                        .iter()
185                        .filter_map(|id| self.symbols.get(*id))
186                        .next()
187                });
188            if let Some(sym) = resolved {
189                imp.resolved_to = Some(sym.id);
190                imp.resolved_file = Some(sym.file.clone());
191                imp.resolved_line = Some(sym.line);
192                imp.resolved_kind = Some(format!("{:?}", sym.kind).to_lowercase());
193            }
194        }
195    }
196
197    pub fn save(&self, path: &Path) -> anyhow::Result<()> {
198        if let Some(parent) = path.parent() {
199            std::fs::create_dir_all(parent)?;
200        }
201        let bytes = bincode::serialize(self)?;
202        std::fs::write(path, bytes)?;
203        Ok(())
204    }
205
206    pub fn load(path: &Path) -> anyhow::Result<Self> {
207        let bytes = std::fs::read(path)?;
208        let idx: CodeIndex = bincode::deserialize(&bytes)?;
209        Ok(idx)
210    }
211
212    pub fn find_symbols_by_name(&self, name: &str) -> Vec<&Symbol> {
213        self.by_name
214            .get(name)
215            .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
216            .unwrap_or_default()
217    }
218
219    pub fn find_symbols_by_pattern(&self, pattern: &str) -> Vec<&Symbol> {
220        let lower = pattern.to_lowercase();
221        self.symbols
222            .iter()
223            .filter(|s| s.name.to_lowercase().contains(&lower))
224            .collect()
225    }
226
227    pub fn find_calls_to(&self, name: &str) -> Vec<&CallEdge> {
228        self.calls
229            .iter()
230            .filter(|c| c.callee_name == name)
231            .collect()
232    }
233
234    pub fn find_calls_by(&self, name: &str) -> Vec<&CallEdge> {
235        self.calls
236            .iter()
237            .filter(|c| c.caller_name == name)
238            .collect()
239    }
240
241    pub fn find_implementations(&self, name: &str) -> Vec<&Symbol> {
242        self.symbols
243            .iter()
244            .filter(|s| s.kind == DefKind::Impl && s.name == name)
245            .collect()
246    }
247
248    pub fn find_symbols_in_file(&self, file: &Path) -> Vec<&Symbol> {
249        self.by_file
250            .get(file)
251            .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
252            .unwrap_or_default()
253    }
254
255    pub fn relative_path(&self, path: &Path) -> String {
256        path.strip_prefix(&self.root)
257            .unwrap_or(path)
258            .to_string_lossy()
259            .to_string()
260    }
261
262    pub fn find_imports_in_file(&self, file: &Path) -> Vec<&ImportEdge> {
263        self.imports
264            .iter()
265            .filter(|i| i.file == file)
266            .collect()
267    }
268
269    pub fn find_importers_of(&self, name: &str) -> Vec<&ImportEdge> {
270        self.imports
271            .iter()
272            .filter(|i| {
273                i.resolved_to
274                    .and_then(|id| self.symbols.get(id))
275                    .is_some_and(|s| s.name == name)
276            })
277            .collect()
278    }
279
280    pub fn semantic_search(
281        &self,
282        query_embed: &[f32],
283        k: usize,
284    ) -> Vec<(f64, &Symbol)> {
285        let mut scores: Vec<(f64, &Symbol)> = self
286            .symbols
287            .iter()
288            .filter_map(|s| s.embedding.as_ref().map(|e| (cosine_similarity(query_embed, e), s)))
289            .collect();
290        scores.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
291        scores.truncate(k);
292        scores
293    }
294}
295
296fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
297    let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
298    let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
299    let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
300    if na == 0.0 || nb == 0.0 {
301        0.0
302    } else {
303        dot / (na * nb)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::parser::{DefKind, ParsedDef, ParsedFile, ParsedImport, ParsedRef, RefKind};
311
312    fn make_file(
313        path: &str,
314        defs: Vec<(&str, DefKind, usize, usize)>,
315        refs: Vec<(&str, usize)>,
316        imports: Vec<&str>,
317    ) -> ParsedFile {
318        ParsedFile {
319            path: PathBuf::from(path),
320            language: crate::parser::LanguageId::Rust,
321            definitions: defs
322                .into_iter()
323                .map(|(name, kind, start_line, end_line)| ParsedDef {
324                    name: name.to_string(),
325                    kind,
326                    start_line,
327                    end_line,
328                    doc: None,
329                })
330                .collect(),
331            references: refs
332                .into_iter()
333                .map(|(name, line)| ParsedRef {
334                    name: name.to_string(),
335                    kind: RefKind::Call,
336                    line,
337                })
338                .collect(),
339            imports: imports
340                .into_iter()
341                .map(|name| ParsedImport {
342                    name: name.to_string(),
343                })
344                .collect(),
345        }
346    }
347
348    #[test]
349    fn test_build_empty_index() {
350        let index = CodeIndex::build(vec![], Path::new("/root"), None);
351        assert_eq!(index.symbols.len(), 0);
352        assert_eq!(index.calls.len(), 0);
353        assert_eq!(index.imports.len(), 0);
354        assert_eq!(index.files.len(), 0);
355    }
356
357    #[test]
358    fn test_build_index_with_symbols() {
359        let files = vec![make_file(
360            "src/main.rs",
361            vec![("main", DefKind::Function, 1, 5)],
362            vec![],
363            vec![],
364        )];
365        let index = CodeIndex::build(files, Path::new("/root"), None);
366        assert_eq!(index.symbols.len(), 1);
367        assert_eq!(index.symbols[0].name, "main");
368        assert_eq!(index.symbols[0].kind, DefKind::Function);
369        assert_eq!(index.symbols[0].line, 1);
370        assert_eq!(index.symbols[0].end_line, 5);
371    }
372
373    #[test]
374    fn test_find_symbols_by_name() {
375        let files = vec![make_file(
376            "src/lib.rs",
377            vec![("foo", DefKind::Function, 1, 3), ("bar", DefKind::Function, 5, 7)],
378            vec![],
379            vec![],
380        )];
381        let index = CodeIndex::build(files, Path::new("/root"), None);
382        let found = index.find_symbols_by_name("foo");
383        assert_eq!(found.len(), 1);
384        assert_eq!(found[0].name, "foo");
385    }
386
387    #[test]
388    fn test_find_symbols_by_pattern() {
389        let files = vec![make_file(
390            "src/lib.rs",
391            vec![
392                ("calculate_revenue", DefKind::Function, 1, 3),
393                ("calculate_expenses", DefKind::Function, 5, 7),
394                ("print_report", DefKind::Function, 9, 11),
395            ],
396            vec![],
397            vec![],
398        )];
399        let index = CodeIndex::build(files, Path::new("/root"), None);
400        let found = index.find_symbols_by_pattern("calculate");
401        assert_eq!(found.len(), 2);
402    }
403
404    #[test]
405    fn test_calls_are_recorded() {
406        let files = vec![make_file(
407            "src/main.rs",
408            vec![("run", DefKind::Function, 1, 10)],
409            vec![("helper", 3),("other", 5)],
410            vec![],
411        )];
412        let index = CodeIndex::build(files, Path::new("/root"), None);
413        assert_eq!(index.calls.len(), 2);
414    }
415
416    #[test]
417    fn test_imports_are_recorded() {
418        let files = vec![make_file(
419            "src/main.rs",
420            vec![],
421            vec![],
422            vec!["HashMap", "Vec"],
423        )];
424        let index = CodeIndex::build(files, Path::new("/root"), None);
425        assert_eq!(index.imports.len(), 2);
426        assert_eq!(index.imports[0].symbol_name, "HashMap");
427        // No resolution since no symbols with those names exist
428        assert!(index.imports[0].resolved_to.is_none());
429    }
430
431    #[test]
432    fn test_import_resolution() {
433        let files = vec![
434            make_file(
435                "src/lib.rs",
436                vec![("HashMap", DefKind::Struct, 10, 30)],
437                vec![],
438                vec![],
439            ),
440            make_file(
441                "src/main.rs",
442                vec![("main", DefKind::Function, 1, 5)],
443                vec![],
444                vec!["HashMap"],
445            ),
446        ];
447        let index = CodeIndex::build(files, Path::new("/root"), None);
448        let imports = index.find_imports_in_file(Path::new("src/main.rs"));
449        assert_eq!(imports.len(), 1);
450        let imp = imports[0];
451        assert!(imp.resolved_to.is_some());
452        assert_eq!(imp.resolved_file.as_deref(), Some(Path::new("src/lib.rs")));
453        assert_eq!(imp.resolved_line, Some(10));
454        assert_eq!(imp.resolved_kind.as_deref(), Some("struct"));
455    }
456
457    #[test]
458    fn test_save_and_load_roundtrip() -> anyhow::Result<()> {
459        let files = vec![make_file(
460            "src/main.rs",
461            vec![("main", DefKind::Function, 1, 10)],
462            vec![("helper", 5)],
463            vec!["std::fs"],
464        )];
465        let index = CodeIndex::build(files, Path::new("/root"), None);
466
467        let tmp = std::env::temp_dir().join("sift_test_index.bin");
468        index.save(&tmp)?;
469        let loaded = CodeIndex::load(&tmp)?;
470        std::fs::remove_file(&tmp)?;
471
472        assert_eq!(loaded.symbols.len(), 1);
473        assert_eq!(loaded.symbols[0].name, "main");
474        assert_eq!(loaded.calls.len(), 1);
475        assert_eq!(loaded.imports.len(), 1);
476        // Resolution persists (no match for "std::fs" so resolved_to is None)
477        assert_eq!(loaded.imports[0].resolved_to, None);
478        Ok(())
479    }
480
481    #[test]
482    fn test_multiple_files_index() {
483        let files = vec![
484            make_file(
485                "src/main.rs",
486                vec![("main", DefKind::Function, 1, 10)],
487                vec![("helper", 3)],
488                vec![],
489            ),
490            make_file(
491                "src/helper.rs",
492                vec![("helper", DefKind::Function, 1, 5)],
493                vec![],
494                vec![],
495            ),
496        ];
497        let index = CodeIndex::build(files, Path::new("/root"), None);
498        assert_eq!(index.symbols.len(), 2);
499        assert_eq!(index.files.len(), 2);
500    }
501
502    #[test]
503    fn test_find_implementations() {
504        let files = vec![make_file(
505            "src/main.rs",
506            vec![
507                ("Iterator", DefKind::Trait, 1, 3),
508                ("Iterator", DefKind::Impl, 5, 20),
509            ],
510            vec![],
511            vec![],
512        )];
513        let index = CodeIndex::build(files, Path::new("/root"), None);
514        let impls = index.find_implementations("Iterator");
515        assert_eq!(impls.len(), 1);
516        assert_eq!(impls[0].kind, DefKind::Impl);
517    }
518
519    #[test]
520    fn test_relative_path() {
521        let files = vec![make_file(
522            "/root/src/main.rs",
523            vec![],
524            vec![],
525            vec![],
526        )];
527        let index = CodeIndex::build(files, Path::new("/root"), None);
528        assert_eq!(index.relative_path(Path::new("/root/src/main.rs")), "src/main.rs");
529    }
530}