Skip to main content

lean_ctx/core/
embedding_index.rs

1//! Persistent, incremental embedding index.
2//!
3//! Stores pre-computed chunk embeddings alongside file content hashes.
4//! On re-index, only files whose hash has changed get re-embedded,
5//! avoiding expensive model inference for unchanged code.
6//!
7//! Storage format: `~/.lean-ctx/vectors/<project_hash>/embeddings.json`
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19    pub version: u32,
20    pub dimensions: usize,
21    pub entries: Vec<EmbeddingEntry>,
22    pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27    pub file_path: String,
28    pub symbol_name: String,
29    pub start_line: usize,
30    pub end_line: usize,
31    pub embedding: Vec<f32>,
32    pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38    pub fn new(dimensions: usize) -> Self {
39        Self {
40            version: CURRENT_VERSION,
41            dimensions,
42            entries: Vec::new(),
43            file_hashes: HashMap::new(),
44        }
45    }
46
47    /// Approximate heap memory used by this index in bytes.
48    pub fn memory_usage_bytes(&self) -> usize {
49        let entries_size: usize = self
50            .entries
51            .iter()
52            .map(|e| {
53                e.file_path.len()
54                    + e.symbol_name.len()
55                    + e.content_hash.len()
56                    + e.embedding.len() * 4
57                    + 48
58            })
59            .sum();
60        let hashes_size: usize = self
61            .file_hashes
62            .iter()
63            .map(|(k, v)| k.len() + v.len() + 32)
64            .sum();
65        entries_size + hashes_size
66    }
67
68    /// Drops all in-memory data to free heap. Index can be re-loaded from disk.
69    pub fn unload(&mut self) {
70        let usage = self.memory_usage_bytes();
71        self.entries = Vec::new();
72        self.file_hashes = HashMap::new();
73        tracing::info!(
74            "[embeddings] unloaded index, freed ~{:.1}MB",
75            usage as f64 / 1_048_576.0
76        );
77    }
78
79    /// Load a previously saved index, or create a new empty one.
80    pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
81        Self::load(root).unwrap_or_else(|| Self::new(dimensions))
82    }
83
84    /// Determine which files need re-embedding based on content hashes.
85    pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
86        let current_hashes = compute_file_hashes(chunks);
87
88        let mut needs_update = Vec::new();
89        for (file, hash) in &current_hashes {
90            match self.file_hashes.get(file) {
91                Some(old_hash) if old_hash == hash => {}
92                _ => needs_update.push(file.clone()),
93            }
94        }
95
96        for file in self.file_hashes.keys() {
97            if !current_hashes.contains_key(file) {
98                needs_update.push(file.clone());
99            }
100        }
101
102        needs_update
103    }
104
105    /// Update the index with new embeddings for changed files.
106    /// Preserves existing embeddings for unchanged files.
107    pub fn update(
108        &mut self,
109        chunks: &[CodeChunk],
110        new_embeddings: &[(usize, Vec<f32>)],
111        changed_files: &[String],
112    ) {
113        self.entries
114            .retain(|e| !changed_files.contains(&e.file_path));
115
116        for file in changed_files {
117            self.file_hashes.remove(file);
118        }
119
120        let current_hashes = compute_file_hashes(chunks);
121        for file in changed_files {
122            if let Some(hash) = current_hashes.get(file) {
123                self.file_hashes.insert(file.clone(), hash.clone());
124            }
125        }
126
127        for &(chunk_idx, ref embedding) in new_embeddings {
128            if let Some(chunk) = chunks.get(chunk_idx) {
129                let content_hash = hash_content(&chunk.content);
130                self.entries.push(EmbeddingEntry {
131                    file_path: chunk.file_path.clone(),
132                    symbol_name: chunk.symbol_name.clone(),
133                    start_line: chunk.start_line,
134                    end_line: chunk.end_line,
135                    embedding: embedding.clone(),
136                    content_hash,
137                });
138            }
139        }
140    }
141
142    /// Get all embeddings in chunk order (aligned with BM25Index.chunks).
143    /// Returns None if index doesn't cover all chunks.
144    pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
145        let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
146            HashMap::with_capacity(self.entries.len());
147        for e in &self.entries {
148            map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
149        }
150
151        let mut result = Vec::with_capacity(chunks.len());
152        for chunk in chunks {
153            let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
154            result.push(entry.embedding.clone());
155        }
156        Some(result)
157    }
158
159    pub fn coverage(&self, total_chunks: usize) -> f64 {
160        if total_chunks == 0 {
161            return 0.0;
162        }
163        self.entries.len() as f64 / total_chunks as f64
164    }
165
166    pub fn save(&self, root: &Path) -> std::io::Result<()> {
167        let dir = index_dir(root);
168        std::fs::create_dir_all(&dir)?;
169        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
170        std::fs::write(dir.join("embeddings.json"), data)?;
171        Ok(())
172    }
173
174    pub fn load(root: &Path) -> Option<Self> {
175        let dir = index_dir(root);
176        let path = dir.join("embeddings.json");
177        let data = std::fs::read_to_string(&path)
178            .or_else(|_| {
179                let legacy_dir = legacy_embedding_dir(root);
180                if legacy_dir == dir {
181                    return Err(std::io::Error::new(
182                        std::io::ErrorKind::NotFound,
183                        "same path",
184                    ));
185                }
186                let legacy_path = legacy_dir.join("embeddings.json");
187                let content = std::fs::read_to_string(&legacy_path)?;
188                let _ = std::fs::create_dir_all(&dir);
189                let _ = std::fs::copy(&legacy_path, &path);
190                Ok(content)
191            })
192            .ok()?;
193        let idx: Self = serde_json::from_str(&data).ok()?;
194        if idx.version != CURRENT_VERSION {
195            return None;
196        }
197        Some(idx)
198    }
199}
200
201fn index_dir(root: &Path) -> PathBuf {
202    crate::core::index_namespace::vectors_dir(root)
203}
204
205fn legacy_embedding_dir(root: &Path) -> PathBuf {
206    let mut hasher = Md5::new();
207    hasher.update(root.to_string_lossy().as_bytes());
208    let hash = format!("{:x}", hasher.finalize());
209    crate::core::data_dir::lean_ctx_data_dir()
210        .unwrap_or_else(|_| PathBuf::from("."))
211        .join("vectors")
212        .join(hash)
213}
214
215fn hash_content(content: &str) -> String {
216    let mut hasher = Md5::new();
217    hasher.update(content.as_bytes());
218    format!("{:x}", hasher.finalize())
219}
220
221fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
222    let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
223    for chunk in chunks {
224        by_file
225            .entry(chunk.file_path.as_str())
226            .or_default()
227            .push(chunk);
228    }
229
230    let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
231    for (file, mut file_chunks) in by_file {
232        file_chunks.sort_by(|a, b| {
233            (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
234                b.start_line,
235                b.end_line,
236                b.symbol_name.as_str(),
237            ))
238        });
239
240        let mut hasher = Md5::new();
241        hasher.update(file.as_bytes());
242        for c in file_chunks {
243            hasher.update(c.start_line.to_le_bytes());
244            hasher.update(c.end_line.to_le_bytes());
245            hasher.update(c.symbol_name.as_bytes());
246            hasher.update([kind_tag(&c.kind)]);
247            hasher.update(c.content.as_bytes());
248        }
249        out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
250    }
251    out
252}
253
254fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
255    use super::bm25_index::ChunkKind;
256    match kind {
257        ChunkKind::Function => 1,
258        ChunkKind::Struct => 2,
259        ChunkKind::Impl => 3,
260        ChunkKind::Module => 4,
261        ChunkKind::Class => 5,
262        ChunkKind::Method => 6,
263        ChunkKind::Other => 7,
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::core::bm25_index::{ChunkKind, CodeChunk};
271
272    fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
273        CodeChunk {
274            file_path: file.to_string(),
275            symbol_name: name.to_string(),
276            kind: ChunkKind::Function,
277            start_line: start,
278            end_line: end,
279            content: content.to_string(),
280            tokens: vec![name.to_string()],
281            token_count: 1,
282        }
283    }
284
285    fn dummy_embedding(dim: usize) -> Vec<f32> {
286        vec![0.1; dim]
287    }
288
289    #[test]
290    fn new_index_is_empty() {
291        let idx = EmbeddingIndex::new(384);
292        assert!(idx.entries.is_empty());
293        assert!(idx.file_hashes.is_empty());
294        assert_eq!(idx.dimensions, 384);
295    }
296
297    #[test]
298    fn files_needing_update_all_new() {
299        let idx = EmbeddingIndex::new(384);
300        let chunks = vec![
301            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
302            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
303        ];
304        let needs = idx.files_needing_update(&chunks);
305        assert_eq!(needs.len(), 2);
306    }
307
308    #[test]
309    fn files_needing_update_unchanged() {
310        let mut idx = EmbeddingIndex::new(384);
311        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
312
313        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
314
315        let needs = idx.files_needing_update(&chunks);
316        assert!(needs.is_empty(), "unchanged file should not need update");
317    }
318
319    #[test]
320    fn files_needing_update_changed_content() {
321        let mut idx = EmbeddingIndex::new(384);
322        let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
323        idx.update(
324            &chunks_v1,
325            &[(0, dummy_embedding(384))],
326            &["a.rs".to_string()],
327        );
328
329        let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
330        let needs = idx.files_needing_update(&chunks_v2);
331        assert!(
332            needs.contains(&"a.rs".to_string()),
333            "changed file should need update"
334        );
335    }
336
337    #[test]
338    fn files_needing_update_detects_change_in_later_chunk() {
339        let mut idx = EmbeddingIndex::new(3);
340        let chunks_v1 = vec![
341            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
342            make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
343        ];
344        idx.update(
345            &chunks_v1,
346            &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
347            &["a.rs".to_string()],
348        );
349
350        let chunks_v2 = vec![
351            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
352            make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
353        ];
354        let needs = idx.files_needing_update(&chunks_v2);
355        assert!(
356            needs.contains(&"a.rs".to_string()),
357            "changing a later chunk should trigger re-embedding"
358        );
359    }
360
361    #[test]
362    fn files_needing_update_deleted_file() {
363        let mut idx = EmbeddingIndex::new(384);
364        let chunks = vec![
365            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
366            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
367        ];
368        idx.update(
369            &chunks,
370            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
371            &["a.rs".to_string(), "b.rs".to_string()],
372        );
373
374        let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
375        let needs = idx.files_needing_update(&chunks_after);
376        assert!(
377            needs.contains(&"b.rs".to_string()),
378            "deleted file should trigger update"
379        );
380    }
381
382    #[test]
383    fn update_preserves_unchanged() {
384        let mut idx = EmbeddingIndex::new(384);
385        let chunks = vec![
386            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
387            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
388        ];
389        idx.update(
390            &chunks,
391            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
392            &["a.rs".to_string(), "b.rs".to_string()],
393        );
394        assert_eq!(idx.entries.len(), 2);
395
396        idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
397        assert_eq!(idx.entries.len(), 2);
398
399        let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
400        assert!(
401            (b_entry.embedding[0] - 0.1).abs() < 1e-6,
402            "b.rs embedding should be preserved"
403        );
404    }
405
406    #[test]
407    fn get_aligned_embeddings() {
408        let mut idx = EmbeddingIndex::new(2);
409        let chunks = vec![
410            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
411            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
412        ];
413        idx.update(
414            &chunks,
415            &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
416            &["a.rs".to_string(), "b.rs".to_string()],
417        );
418
419        let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
420        assert_eq!(aligned.len(), 2);
421        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
422        assert!((aligned[1][1] - 1.0).abs() < 1e-6);
423    }
424
425    #[test]
426    fn get_aligned_embeddings_missing() {
427        let idx = EmbeddingIndex::new(384);
428        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
429        assert!(idx.get_aligned_embeddings(&chunks).is_none());
430    }
431
432    #[test]
433    fn coverage_calculation() {
434        let mut idx = EmbeddingIndex::new(384);
435        assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
436
437        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
438        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
439        assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
440        assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
441    }
442
443    #[test]
444    fn save_and_load_roundtrip() {
445        let _lock = crate::core::data_dir::test_env_lock();
446        let data_dir = tempfile::tempdir().unwrap();
447        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
448
449        let project_dir = tempfile::tempdir().unwrap();
450
451        let mut idx = EmbeddingIndex::new(3);
452        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
453        idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
454        idx.save(project_dir.path()).unwrap();
455
456        let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
457        assert_eq!(loaded.dimensions, 3);
458        assert_eq!(loaded.entries.len(), 1);
459        assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
460
461        std::env::remove_var("LEAN_CTX_DATA_DIR");
462    }
463}