Skip to main content

lean_ctx/core/
embedding_index.rs

1//! Persistent, incremental embedding index.
2//!
3//! Stores pre-computed chunk embeddings alongside file content hashes.
4//! On re-index, only files whose hash has changed get re-embedded,
5//! avoiding expensive model inference for unchanged code.
6//!
7//! Storage format: `~/.lean-ctx/vectors/<project_hash>/embeddings.json`
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19    pub version: u32,
20    pub dimensions: usize,
21    /// Model identifier that generated these embeddings.
22    /// Used for mismatch detection when the user switches models.
23    #[serde(default)]
24    pub model_id: Option<String>,
25    pub entries: Vec<EmbeddingEntry>,
26    pub file_hashes: HashMap<String, String>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct EmbeddingEntry {
31    pub file_path: String,
32    pub symbol_name: String,
33    pub start_line: usize,
34    pub end_line: usize,
35    pub embedding: Vec<f32>,
36    pub content_hash: String,
37}
38
39const CURRENT_VERSION: u32 = 2;
40
41impl EmbeddingIndex {
42    pub fn new(dimensions: usize) -> Self {
43        Self {
44            version: CURRENT_VERSION,
45            dimensions,
46            model_id: None,
47            entries: Vec::new(),
48            file_hashes: HashMap::new(),
49        }
50    }
51
52    /// Create a new index tagged with a specific model identity.
53    pub fn new_with_model(dimensions: usize, model_id: &str) -> Self {
54        Self {
55            version: CURRENT_VERSION,
56            dimensions,
57            model_id: Some(model_id.to_string()),
58            entries: Vec::new(),
59            file_hashes: HashMap::new(),
60        }
61    }
62
63    /// Check if the index was built with a different model than currently selected.
64    /// Returns `Some((stored_model, current_model))` on mismatch, `None` if compatible.
65    pub fn model_mismatch<'a>(&'a self, current_model: &'a str) -> Option<(&'a str, &'a str)> {
66        match &self.model_id {
67            Some(stored) if stored != current_model => Some((stored, current_model)),
68            _ => None,
69        }
70    }
71
72    /// Check if index dimensions are incompatible with the current engine.
73    pub fn dimension_mismatch(&self, engine_dimensions: usize) -> bool {
74        self.dimensions != engine_dimensions && !self.entries.is_empty()
75    }
76
77    /// Approximate heap memory used by this index in bytes.
78    pub fn memory_usage_bytes(&self) -> usize {
79        let entries_size: usize = self
80            .entries
81            .iter()
82            .map(|e| {
83                e.file_path.len()
84                    + e.symbol_name.len()
85                    + e.content_hash.len()
86                    + e.embedding.len() * 4
87                    + 48
88            })
89            .sum();
90        let hashes_size: usize = self
91            .file_hashes
92            .iter()
93            .map(|(k, v)| k.len() + v.len() + 32)
94            .sum();
95        entries_size + hashes_size
96    }
97
98    /// Drops all in-memory data to free heap. Index can be re-loaded from disk.
99    pub fn unload(&mut self) {
100        let usage = self.memory_usage_bytes();
101        self.entries = Vec::new();
102        self.file_hashes = HashMap::new();
103        tracing::info!(
104            "[embeddings] unloaded index, freed ~{:.1}MB",
105            usage as f64 / 1_048_576.0
106        );
107    }
108
109    /// Load a previously saved index, or create a new empty one.
110    pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
111        Self::load(root).unwrap_or_else(|| Self::new(dimensions))
112    }
113
114    /// Determine which files need re-embedding based on content hashes.
115    pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
116        let current_hashes = compute_file_hashes(chunks);
117
118        let mut needs_update = Vec::new();
119        for (file, hash) in &current_hashes {
120            match self.file_hashes.get(file) {
121                Some(old_hash) if old_hash == hash => {}
122                _ => needs_update.push(file.clone()),
123            }
124        }
125
126        for file in self.file_hashes.keys() {
127            if !current_hashes.contains_key(file) {
128                needs_update.push(file.clone());
129            }
130        }
131
132        needs_update
133    }
134
135    /// Update the index with new embeddings for changed files.
136    /// Preserves existing embeddings for unchanged files.
137    pub fn update(
138        &mut self,
139        chunks: &[CodeChunk],
140        new_embeddings: &[(usize, Vec<f32>)],
141        changed_files: &[String],
142    ) {
143        self.entries
144            .retain(|e| !changed_files.contains(&e.file_path));
145
146        for file in changed_files {
147            self.file_hashes.remove(file);
148        }
149
150        let current_hashes = compute_file_hashes(chunks);
151        for file in changed_files {
152            if let Some(hash) = current_hashes.get(file) {
153                self.file_hashes.insert(file.clone(), hash.clone());
154            }
155        }
156
157        for &(chunk_idx, ref embedding) in new_embeddings {
158            if let Some(chunk) = chunks.get(chunk_idx) {
159                let content_hash = hash_content(&chunk.content);
160                self.entries.push(EmbeddingEntry {
161                    file_path: chunk.file_path.clone(),
162                    symbol_name: chunk.symbol_name.clone(),
163                    start_line: chunk.start_line,
164                    end_line: chunk.end_line,
165                    embedding: embedding.clone(),
166                    content_hash,
167                });
168            }
169        }
170    }
171
172    /// Get all embeddings in chunk order (aligned with BM25Index.chunks).
173    /// Returns None if index doesn't cover all chunks.
174    pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
175        let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
176            HashMap::with_capacity(self.entries.len());
177        for e in &self.entries {
178            map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
179        }
180
181        let mut result = Vec::with_capacity(chunks.len());
182        for chunk in chunks {
183            let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
184            result.push(entry.embedding.clone());
185        }
186        Some(result)
187    }
188
189    pub fn coverage(&self, total_chunks: usize) -> f64 {
190        if total_chunks == 0 {
191            return 0.0;
192        }
193        self.entries.len() as f64 / total_chunks as f64
194    }
195
196    pub fn save(&self, root: &Path) -> std::io::Result<()> {
197        let dir = index_dir(root);
198        std::fs::create_dir_all(&dir)?;
199        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
200        std::fs::write(dir.join("embeddings.json"), data)?;
201        Ok(())
202    }
203
204    pub fn load(root: &Path) -> Option<Self> {
205        let dir = index_dir(root);
206        let path = dir.join("embeddings.json");
207        let data = std::fs::read_to_string(&path)
208            .or_else(|_| {
209                let legacy_dir = legacy_embedding_dir(root);
210                if legacy_dir == dir {
211                    return Err(std::io::Error::new(
212                        std::io::ErrorKind::NotFound,
213                        "same path",
214                    ));
215                }
216                let legacy_path = legacy_dir.join("embeddings.json");
217                let content = std::fs::read_to_string(&legacy_path)?;
218                let _ = std::fs::create_dir_all(&dir);
219                let _ = std::fs::copy(&legacy_path, &path);
220                Ok(content)
221            })
222            .ok()?;
223        let idx: Self = serde_json::from_str(&data).ok()?;
224        match idx.version {
225            CURRENT_VERSION => Some(idx),
226            1 => {
227                tracing::info!(
228                    "[embeddings] migrating index v1 → v{CURRENT_VERSION} (adding model_id field)"
229                );
230                Some(Self {
231                    version: CURRENT_VERSION,
232                    dimensions: idx.dimensions,
233                    model_id: None,
234                    entries: idx.entries,
235                    file_hashes: idx.file_hashes,
236                })
237            }
238            _ => None,
239        }
240    }
241}
242
243fn index_dir(root: &Path) -> PathBuf {
244    crate::core::index_namespace::vectors_dir(root)
245}
246
247fn legacy_embedding_dir(root: &Path) -> PathBuf {
248    let mut hasher = Md5::new();
249    hasher.update(root.to_string_lossy().as_bytes());
250    let hash = format!("{:x}", hasher.finalize());
251    crate::core::data_dir::lean_ctx_data_dir()
252        .unwrap_or_else(|_| PathBuf::from("."))
253        .join("vectors")
254        .join(hash)
255}
256
257fn hash_content(content: &str) -> String {
258    let mut hasher = Md5::new();
259    hasher.update(content.as_bytes());
260    format!("{:x}", hasher.finalize())
261}
262
263fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
264    let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
265    for chunk in chunks {
266        by_file
267            .entry(chunk.file_path.as_str())
268            .or_default()
269            .push(chunk);
270    }
271
272    let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
273    for (file, mut file_chunks) in by_file {
274        file_chunks.sort_by(|a, b| {
275            (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
276                b.start_line,
277                b.end_line,
278                b.symbol_name.as_str(),
279            ))
280        });
281
282        let mut hasher = Md5::new();
283        hasher.update(file.as_bytes());
284        for c in file_chunks {
285            hasher.update(c.start_line.to_le_bytes());
286            hasher.update(c.end_line.to_le_bytes());
287            hasher.update(c.symbol_name.as_bytes());
288            hasher.update([kind_tag(&c.kind)]);
289            hasher.update(c.content.as_bytes());
290        }
291        out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
292    }
293    out
294}
295
296fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
297    use super::bm25_index::ChunkKind;
298    match kind {
299        ChunkKind::Function => 1,
300        ChunkKind::Struct => 2,
301        ChunkKind::Impl => 3,
302        ChunkKind::Module => 4,
303        ChunkKind::Class => 5,
304        ChunkKind::Method => 6,
305        ChunkKind::Other => 7,
306        ChunkKind::Issue => 8,
307        ChunkKind::PullRequest => 9,
308        ChunkKind::WikiPage => 10,
309        ChunkKind::DbSchema => 11,
310        ChunkKind::ApiEndpoint => 12,
311        ChunkKind::Ticket => 13,
312        ChunkKind::ExternalOther => 14,
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::core::bm25_index::{ChunkKind, CodeChunk};
320
321    fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
322        CodeChunk {
323            file_path: file.to_string(),
324            symbol_name: name.to_string(),
325            kind: ChunkKind::Function,
326            start_line: start,
327            end_line: end,
328            content: content.to_string(),
329            tokens: vec![name.to_string()],
330            token_count: 1,
331        }
332    }
333
334    fn dummy_embedding(dim: usize) -> Vec<f32> {
335        vec![0.1; dim]
336    }
337
338    #[test]
339    fn new_index_is_empty() {
340        let idx = EmbeddingIndex::new(384);
341        assert!(idx.entries.is_empty());
342        assert!(idx.file_hashes.is_empty());
343        assert_eq!(idx.dimensions, 384);
344    }
345
346    #[test]
347    fn files_needing_update_all_new() {
348        let idx = EmbeddingIndex::new(384);
349        let chunks = vec![
350            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
351            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
352        ];
353        let needs = idx.files_needing_update(&chunks);
354        assert_eq!(needs.len(), 2);
355    }
356
357    #[test]
358    fn files_needing_update_unchanged() {
359        let mut idx = EmbeddingIndex::new(384);
360        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
361
362        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
363
364        let needs = idx.files_needing_update(&chunks);
365        assert!(needs.is_empty(), "unchanged file should not need update");
366    }
367
368    #[test]
369    fn files_needing_update_changed_content() {
370        let mut idx = EmbeddingIndex::new(384);
371        let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
372        idx.update(
373            &chunks_v1,
374            &[(0, dummy_embedding(384))],
375            &["a.rs".to_string()],
376        );
377
378        let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
379        let needs = idx.files_needing_update(&chunks_v2);
380        assert!(
381            needs.contains(&"a.rs".to_string()),
382            "changed file should need update"
383        );
384    }
385
386    #[test]
387    fn files_needing_update_detects_change_in_later_chunk() {
388        let mut idx = EmbeddingIndex::new(3);
389        let chunks_v1 = vec![
390            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
391            make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
392        ];
393        idx.update(
394            &chunks_v1,
395            &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
396            &["a.rs".to_string()],
397        );
398
399        let chunks_v2 = vec![
400            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
401            make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
402        ];
403        let needs = idx.files_needing_update(&chunks_v2);
404        assert!(
405            needs.contains(&"a.rs".to_string()),
406            "changing a later chunk should trigger re-embedding"
407        );
408    }
409
410    #[test]
411    fn files_needing_update_deleted_file() {
412        let mut idx = EmbeddingIndex::new(384);
413        let chunks = vec![
414            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
415            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
416        ];
417        idx.update(
418            &chunks,
419            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
420            &["a.rs".to_string(), "b.rs".to_string()],
421        );
422
423        let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
424        let needs = idx.files_needing_update(&chunks_after);
425        assert!(
426            needs.contains(&"b.rs".to_string()),
427            "deleted file should trigger update"
428        );
429    }
430
431    #[test]
432    fn update_preserves_unchanged() {
433        let mut idx = EmbeddingIndex::new(384);
434        let chunks = vec![
435            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
436            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
437        ];
438        idx.update(
439            &chunks,
440            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
441            &["a.rs".to_string(), "b.rs".to_string()],
442        );
443        assert_eq!(idx.entries.len(), 2);
444
445        idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
446        assert_eq!(idx.entries.len(), 2);
447
448        let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
449        assert!(
450            (b_entry.embedding[0] - 0.1).abs() < 1e-6,
451            "b.rs embedding should be preserved"
452        );
453    }
454
455    #[test]
456    fn get_aligned_embeddings() {
457        let mut idx = EmbeddingIndex::new(2);
458        let chunks = vec![
459            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
460            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
461        ];
462        idx.update(
463            &chunks,
464            &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
465            &["a.rs".to_string(), "b.rs".to_string()],
466        );
467
468        let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
469        assert_eq!(aligned.len(), 2);
470        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
471        assert!((aligned[1][1] - 1.0).abs() < 1e-6);
472    }
473
474    #[test]
475    fn get_aligned_embeddings_missing() {
476        let idx = EmbeddingIndex::new(384);
477        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
478        assert!(idx.get_aligned_embeddings(&chunks).is_none());
479    }
480
481    #[test]
482    fn coverage_calculation() {
483        let mut idx = EmbeddingIndex::new(384);
484        assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
485
486        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
487        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
488        assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
489        assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
490    }
491
492    #[test]
493    fn save_and_load_roundtrip() {
494        let _lock = crate::core::data_dir::test_env_lock();
495        let data_dir = tempfile::tempdir().unwrap();
496        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
497
498        let project_dir = tempfile::tempdir().unwrap();
499
500        let mut idx = EmbeddingIndex::new(3);
501        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
502        idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
503        idx.save(project_dir.path()).unwrap();
504
505        let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
506        assert_eq!(loaded.dimensions, 3);
507        assert_eq!(loaded.entries.len(), 1);
508        assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
509
510        std::env::remove_var("LEAN_CTX_DATA_DIR");
511    }
512
513    #[test]
514    fn new_with_model_sets_model_id() {
515        let idx = EmbeddingIndex::new_with_model(768, "jina-code-v2");
516        assert_eq!(idx.model_id, Some("jina-code-v2".to_string()));
517        assert_eq!(idx.dimensions, 768);
518    }
519
520    #[test]
521    fn model_mismatch_detection() {
522        let idx = EmbeddingIndex::new_with_model(768, "all-MiniLM-L6-v2");
523        assert!(idx.model_mismatch("all-MiniLM-L6-v2").is_none());
524        assert!(idx.model_mismatch("jina-code-v2").is_some());
525
526        let (stored, current) = idx.model_mismatch("jina-code-v2").unwrap();
527        assert_eq!(stored, "all-MiniLM-L6-v2");
528        assert_eq!(current, "jina-code-v2");
529    }
530
531    #[test]
532    fn model_mismatch_none_when_no_model_id() {
533        let idx = EmbeddingIndex::new(384);
534        assert!(idx.model_mismatch("anything").is_none());
535    }
536
537    #[test]
538    fn dimension_mismatch_detection() {
539        let mut idx = EmbeddingIndex::new(384);
540        assert!(!idx.dimension_mismatch(384));
541        assert!(!idx.dimension_mismatch(768)); // no entries = no mismatch
542
543        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
544        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
545        assert!(!idx.dimension_mismatch(384));
546        assert!(idx.dimension_mismatch(768));
547    }
548
549    #[test]
550    fn v1_index_migration() {
551        let _lock = crate::core::data_dir::test_env_lock();
552        let data_dir = tempfile::tempdir().unwrap();
553        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
554        let project_dir = tempfile::tempdir().unwrap();
555
556        let v1_json = serde_json::json!({
557            "version": 1,
558            "dimensions": 384,
559            "entries": [],
560            "file_hashes": {}
561        });
562
563        let dir = crate::core::index_namespace::vectors_dir(project_dir.path());
564        std::fs::create_dir_all(&dir).unwrap();
565        std::fs::write(dir.join("embeddings.json"), v1_json.to_string()).unwrap();
566
567        let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
568        assert_eq!(loaded.version, CURRENT_VERSION);
569        assert_eq!(loaded.dimensions, 384);
570        assert!(loaded.model_id.is_none());
571
572        std::env::remove_var("LEAN_CTX_DATA_DIR");
573    }
574}