Skip to main content

sift/
index.rs

1use crate::embed::Embedder;
2use crate::parser::{DefKind, ParsedFile};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6
7pub type SymbolId = usize;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Symbol {
11    pub id: SymbolId,
12    pub name: String,
13    pub kind: DefKind,
14    pub file: PathBuf,
15    pub line: usize,
16    pub end_line: usize,
17    pub doc: Option<String>,
18    pub embedding: Option<Vec<f32>>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CallEdge {
23    pub caller_name: String,
24    pub caller_file: PathBuf,
25    pub caller_line: usize,
26    pub callee_name: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ImportEdge {
31    pub file: PathBuf,
32    pub symbol_name: String,
33    pub resolved_to: Option<SymbolId>,
34    pub resolved_file: Option<PathBuf>,
35    pub resolved_line: Option<usize>,
36    pub resolved_kind: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct CodeIndex {
41    pub symbols: Vec<Symbol>,
42    pub calls: Vec<CallEdge>,
43    pub imports: Vec<ImportEdge>,
44    pub files: Vec<PathBuf>,
45    pub root: PathBuf,
46
47    // name -> symbol IDs (for fast lookup)
48    by_name: HashMap<String, Vec<SymbolId>>,
49    // file -> symbol IDs
50    by_file: HashMap<PathBuf, Vec<SymbolId>>,
51}
52
53impl CodeIndex {
54    pub fn build(
55        parsed: Vec<ParsedFile>,
56        root: &Path,
57        embedder: Option<&dyn Embedder>,
58    ) -> Self {
59        let root = root.to_path_buf();
60        let mut idx = CodeIndex {
61            symbols: Vec::new(),
62            calls: Vec::new(),
63            imports: Vec::new(),
64            files: Vec::new(),
65            by_name: HashMap::new(),
66            by_file: HashMap::new(),
67            root,
68        };
69
70        for pf in &parsed {
71            idx.add_file(pf);
72        }
73
74        if let Some(embedder) = embedder {
75            idx.compute_embeddings(embedder);
76        }
77
78        idx.resolve_caller_names();
79        idx.resolve_imports();
80        idx
81    }
82
83    fn compute_embeddings(&mut self, embedder: &dyn Embedder) {
84        let texts: Vec<String> = self
85            .symbols
86            .iter()
87            .map(|s| {
88                let mut t = format!("{}: {:?}", s.name, s.kind);
89                if let Some(ref doc) = s.doc {
90                    t.push('\n');
91                    t.push_str(doc);
92                }
93                t
94            })
95            .collect();
96        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
97        if text_refs.is_empty() {
98            return;
99        }
100        match embedder.embed(&text_refs) {
101            Ok(embeddings) => {
102                for (sym, emb) in self.symbols.iter_mut().zip(embeddings) {
103                    sym.embedding = Some(emb);
104                }
105            }
106            Err(e) => {
107                eprintln!("warn: embedding computation failed: {:#}", e);
108            }
109        }
110    }
111
112    fn add_file(&mut self, pf: &ParsedFile) {
113        if !self.files.contains(&pf.path) {
114            self.files.push(pf.path.clone());
115        }
116
117        for def in &pf.definitions {
118            let id = self.symbols.len();
119            self.symbols.push(Symbol {
120                id,
121                name: def.name.clone(),
122                kind: def.kind,
123                file: pf.path.clone(),
124                line: def.start_line,
125                end_line: def.end_line,
126                doc: def.doc.clone(),
127                embedding: None,
128            });
129            self.by_name
130                .entry(def.name.clone())
131                .or_default()
132                .push(id);
133            self.by_file
134                .entry(pf.path.clone())
135                .or_default()
136                .push(id);
137        }
138
139        for rf in &pf.references {
140            self.calls.push(CallEdge {
141                caller_name: String::new(),
142                caller_file: pf.path.clone(),
143                caller_line: rf.line,
144                callee_name: rf.name.clone(),
145            });
146        }
147
148        for imp in &pf.imports {
149            self.imports.push(ImportEdge {
150                file: pf.path.clone(),
151                symbol_name: imp.name.clone(),
152                resolved_to: None,
153                resolved_file: None,
154                resolved_line: None,
155                resolved_kind: None,
156            });
157        }
158    }
159
160    fn resolve_caller_names(&mut self) {
161        for call in &mut self.calls {
162            let Some(sym_ids) = self.by_file.get(&call.caller_file) else {
163                continue;
164            };
165            for &sym_id in sym_ids {
166                let Some(sym) = self.symbols.get(sym_id) else {
167                    continue;
168                };
169                if sym.line <= call.caller_line && call.caller_line <= sym.end_line {
170                    call.caller_name = sym.name.clone();
171                    break;
172                }
173            }
174        }
175    }
176
177    fn resolve_imports(&mut self) {
178        for imp in &mut self.imports {
179            let Some(sym_ids) = self.by_name.get(&imp.symbol_name) else {
180                continue;
181            };
182            // Prefer a definition from a different file than the import
183            let resolved = sym_ids
184                .iter()
185                .filter_map(|id| self.symbols.get(*id))
186                .find(|s| s.file != imp.file)
187                .or_else(|| {
188                    sym_ids
189                        .iter()
190                        .filter_map(|id| self.symbols.get(*id))
191                        .next()
192                });
193            if let Some(sym) = resolved {
194                imp.resolved_to = Some(sym.id);
195                imp.resolved_file = Some(sym.file.clone());
196                imp.resolved_line = Some(sym.line);
197                imp.resolved_kind = Some(format!("{:?}", sym.kind).to_lowercase());
198            }
199        }
200    }
201
202    pub fn save(&self, path: &Path) -> anyhow::Result<()> {
203        if let Some(parent) = path.parent() {
204            std::fs::create_dir_all(parent)?;
205        }
206        let bytes = bincode::serialize(self)?;
207        std::fs::write(path, bytes)?;
208        Ok(())
209    }
210
211    pub fn load(path: &Path) -> anyhow::Result<Self> {
212        let bytes = std::fs::read(path)?;
213        let idx: CodeIndex = bincode::deserialize(&bytes)?;
214        Ok(idx)
215    }
216
217    pub fn find_symbols_by_name(&self, name: &str) -> Vec<&Symbol> {
218        self.by_name
219            .get(name)
220            .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
221            .unwrap_or_default()
222    }
223
224    pub fn find_symbols_by_pattern(&self, pattern: &str) -> Vec<&Symbol> {
225        let lower = pattern.to_lowercase();
226        self.symbols
227            .iter()
228            .filter(|s| s.name.to_lowercase().contains(&lower))
229            .collect()
230    }
231
232    pub fn find_calls_to(&self, name: &str) -> Vec<&CallEdge> {
233        self.calls
234            .iter()
235            .filter(|c| c.callee_name == name)
236            .collect()
237    }
238
239    pub fn find_calls_by(&self, name: &str) -> Vec<&CallEdge> {
240        self.calls
241            .iter()
242            .filter(|c| c.caller_name == name)
243            .collect()
244    }
245
246    pub fn find_implementations(&self, name: &str) -> Vec<&Symbol> {
247        self.symbols
248            .iter()
249            .filter(|s| s.kind == DefKind::Impl && s.name == name)
250            .collect()
251    }
252
253    pub fn find_symbols_in_file(&self, file: &Path) -> Vec<&Symbol> {
254        self.by_file
255            .get(file)
256            .map(|ids| ids.iter().filter_map(|id| self.symbols.get(*id)).collect())
257            .unwrap_or_default()
258    }
259
260    pub fn relative_path(&self, path: &Path) -> String {
261        path.strip_prefix(&self.root)
262            .unwrap_or(path)
263            .to_string_lossy()
264            .to_string()
265    }
266
267    pub fn find_imports_in_file(&self, file: &Path) -> Vec<&ImportEdge> {
268        self.imports
269            .iter()
270            .filter(|i| i.file == file)
271            .collect()
272    }
273
274    pub fn find_importers_of(&self, name: &str) -> Vec<&ImportEdge> {
275        self.imports
276            .iter()
277            .filter(|i| {
278                i.resolved_to
279                    .and_then(|id| self.symbols.get(id))
280                    .is_some_and(|s| s.name == name)
281            })
282            .collect()
283    }
284
285    pub fn semantic_search(
286        &self,
287        query_embed: &[f32],
288        k: usize,
289    ) -> Vec<(f64, &Symbol)> {
290        let mut scores: Vec<(f64, &Symbol)> = self
291            .symbols
292            .iter()
293            .filter_map(|s| s.embedding.as_ref().map(|e| (cosine_similarity(query_embed, e), s)))
294            .collect();
295        scores.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
296        scores.truncate(k);
297        scores
298    }
299}
300
301fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
302    let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
303    let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
304    let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
305    if na == 0.0 || nb == 0.0 {
306        0.0
307    } else {
308        dot / (na * nb)
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::parser::{DefKind, ParsedDef, ParsedFile, ParsedImport, ParsedRef, RefKind};
316
317    fn make_file(
318        path: &str,
319        defs: Vec<(&str, DefKind, usize, usize)>,
320        refs: Vec<(&str, usize)>,
321        imports: Vec<&str>,
322    ) -> ParsedFile {
323        ParsedFile {
324            path: PathBuf::from(path),
325            language: crate::parser::LanguageId::Rust,
326            definitions: defs
327                .into_iter()
328                .map(|(name, kind, start_line, end_line)| ParsedDef {
329                    name: name.to_string(),
330                    kind,
331                    start_line,
332                    end_line,
333                    doc: None,
334                })
335                .collect(),
336            references: refs
337                .into_iter()
338                .map(|(name, line)| ParsedRef {
339                    name: name.to_string(),
340                    kind: RefKind::Call,
341                    line,
342                })
343                .collect(),
344            imports: imports
345                .into_iter()
346                .map(|name| ParsedImport {
347                    name: name.to_string(),
348                })
349                .collect(),
350        }
351    }
352
353    #[test]
354    fn test_build_empty_index() {
355        let index = CodeIndex::build(vec![], Path::new("/root"), None);
356        assert_eq!(index.symbols.len(), 0);
357        assert_eq!(index.calls.len(), 0);
358        assert_eq!(index.imports.len(), 0);
359        assert_eq!(index.files.len(), 0);
360    }
361
362    #[test]
363    fn test_build_index_with_symbols() {
364        let files = vec![make_file(
365            "src/main.rs",
366            vec![("main", DefKind::Function, 1, 5)],
367            vec![],
368            vec![],
369        )];
370        let index = CodeIndex::build(files, Path::new("/root"), None);
371        assert_eq!(index.symbols.len(), 1);
372        assert_eq!(index.symbols[0].name, "main");
373        assert_eq!(index.symbols[0].kind, DefKind::Function);
374        assert_eq!(index.symbols[0].line, 1);
375        assert_eq!(index.symbols[0].end_line, 5);
376    }
377
378    #[test]
379    fn test_find_symbols_by_name() {
380        let files = vec![make_file(
381            "src/lib.rs",
382            vec![("foo", DefKind::Function, 1, 3), ("bar", DefKind::Function, 5, 7)],
383            vec![],
384            vec![],
385        )];
386        let index = CodeIndex::build(files, Path::new("/root"), None);
387        let found = index.find_symbols_by_name("foo");
388        assert_eq!(found.len(), 1);
389        assert_eq!(found[0].name, "foo");
390    }
391
392    #[test]
393    fn test_find_symbols_by_pattern() {
394        let files = vec![make_file(
395            "src/lib.rs",
396            vec![
397                ("calculate_revenue", DefKind::Function, 1, 3),
398                ("calculate_expenses", DefKind::Function, 5, 7),
399                ("print_report", DefKind::Function, 9, 11),
400            ],
401            vec![],
402            vec![],
403        )];
404        let index = CodeIndex::build(files, Path::new("/root"), None);
405        let found = index.find_symbols_by_pattern("calculate");
406        assert_eq!(found.len(), 2);
407    }
408
409    #[test]
410    fn test_calls_are_recorded() {
411        let files = vec![make_file(
412            "src/main.rs",
413            vec![("run", DefKind::Function, 1, 10)],
414            vec![("helper", 3),("other", 5)],
415            vec![],
416        )];
417        let index = CodeIndex::build(files, Path::new("/root"), None);
418        assert_eq!(index.calls.len(), 2);
419    }
420
421    #[test]
422    fn test_imports_are_recorded() {
423        let files = vec![make_file(
424            "src/main.rs",
425            vec![],
426            vec![],
427            vec!["HashMap", "Vec"],
428        )];
429        let index = CodeIndex::build(files, Path::new("/root"), None);
430        assert_eq!(index.imports.len(), 2);
431        assert_eq!(index.imports[0].symbol_name, "HashMap");
432        // No resolution since no symbols with those names exist
433        assert!(index.imports[0].resolved_to.is_none());
434    }
435
436    #[test]
437    fn test_import_resolution() {
438        let files = vec![
439            make_file(
440                "src/lib.rs",
441                vec![("HashMap", DefKind::Struct, 10, 30)],
442                vec![],
443                vec![],
444            ),
445            make_file(
446                "src/main.rs",
447                vec![("main", DefKind::Function, 1, 5)],
448                vec![],
449                vec!["HashMap"],
450            ),
451        ];
452        let index = CodeIndex::build(files, Path::new("/root"), None);
453        let imports = index.find_imports_in_file(Path::new("src/main.rs"));
454        assert_eq!(imports.len(), 1);
455        let imp = imports[0];
456        assert!(imp.resolved_to.is_some());
457        assert_eq!(imp.resolved_file.as_deref(), Some(Path::new("src/lib.rs")));
458        assert_eq!(imp.resolved_line, Some(10));
459        assert_eq!(imp.resolved_kind.as_deref(), Some("struct"));
460    }
461
462    #[test]
463    fn test_save_and_load_roundtrip() -> anyhow::Result<()> {
464        let files = vec![make_file(
465            "src/main.rs",
466            vec![("main", DefKind::Function, 1, 10)],
467            vec![("helper", 5)],
468            vec!["std::fs"],
469        )];
470        let index = CodeIndex::build(files, Path::new("/root"), None);
471
472        let tmp = std::env::temp_dir().join("sift_test_index.bin");
473        index.save(&tmp)?;
474        let loaded = CodeIndex::load(&tmp)?;
475        std::fs::remove_file(&tmp)?;
476
477        assert_eq!(loaded.symbols.len(), 1);
478        assert_eq!(loaded.symbols[0].name, "main");
479        assert_eq!(loaded.calls.len(), 1);
480        assert_eq!(loaded.imports.len(), 1);
481        // Resolution persists (no match for "std::fs" so resolved_to is None)
482        assert_eq!(loaded.imports[0].resolved_to, None);
483        Ok(())
484    }
485
486    #[test]
487    fn test_multiple_files_index() {
488        let files = vec![
489            make_file(
490                "src/main.rs",
491                vec![("main", DefKind::Function, 1, 10)],
492                vec![("helper", 3)],
493                vec![],
494            ),
495            make_file(
496                "src/helper.rs",
497                vec![("helper", DefKind::Function, 1, 5)],
498                vec![],
499                vec![],
500            ),
501        ];
502        let index = CodeIndex::build(files, Path::new("/root"), None);
503        assert_eq!(index.symbols.len(), 2);
504        assert_eq!(index.files.len(), 2);
505    }
506
507    #[test]
508    fn test_find_implementations() {
509        let files = vec![make_file(
510            "src/main.rs",
511            vec![
512                ("Iterator", DefKind::Trait, 1, 3),
513                ("Iterator", DefKind::Impl, 5, 20),
514            ],
515            vec![],
516            vec![],
517        )];
518        let index = CodeIndex::build(files, Path::new("/root"), None);
519        let impls = index.find_implementations("Iterator");
520        assert_eq!(impls.len(), 1);
521        assert_eq!(impls[0].kind, DefKind::Impl);
522    }
523
524    #[test]
525    fn test_relative_path() {
526        let files = vec![make_file(
527            "/root/src/main.rs",
528            vec![],
529            vec![],
530            vec![],
531        )];
532        let index = CodeIndex::build(files, Path::new("/root"), None);
533        assert_eq!(index.relative_path(Path::new("/root/src/main.rs")), "src/main.rs");
534    }
535}