Skip to main content

lean_ctx/core/
embedding_index.rs

1//! Persistent, incremental embedding index.
2//!
3//! Stores pre-computed chunk embeddings alongside file content hashes.
4//! On re-index, only files whose hash has changed get re-embedded,
5//! avoiding expensive model inference for unchanged code.
6//!
7//! Storage format: `~/.lean-ctx/vectors/<project_hash>/embeddings.json`
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19    pub version: u32,
20    pub dimensions: usize,
21    pub entries: Vec<EmbeddingEntry>,
22    pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27    pub file_path: String,
28    pub symbol_name: String,
29    pub start_line: usize,
30    pub end_line: usize,
31    pub embedding: Vec<f32>,
32    pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38    pub fn new(dimensions: usize) -> Self {
39        Self {
40            version: CURRENT_VERSION,
41            dimensions,
42            entries: Vec::new(),
43            file_hashes: HashMap::new(),
44        }
45    }
46
47    /// Approximate heap memory used by this index in bytes.
48    pub fn memory_usage_bytes(&self) -> usize {
49        let entries_size: usize = self
50            .entries
51            .iter()
52            .map(|e| {
53                e.file_path.len()
54                    + e.symbol_name.len()
55                    + e.content_hash.len()
56                    + e.embedding.len() * 4
57                    + 48
58            })
59            .sum();
60        let hashes_size: usize = self
61            .file_hashes
62            .iter()
63            .map(|(k, v)| k.len() + v.len() + 32)
64            .sum();
65        entries_size + hashes_size
66    }
67
68    /// Drops all in-memory data to free heap. Index can be re-loaded from disk.
69    pub fn unload(&mut self) {
70        let usage = self.memory_usage_bytes();
71        self.entries = Vec::new();
72        self.file_hashes = HashMap::new();
73        tracing::info!(
74            "[embeddings] unloaded index, freed ~{:.1}MB",
75            usage as f64 / 1_048_576.0
76        );
77    }
78
79    /// Load a previously saved index, or create a new empty one.
80    pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
81        Self::load(root).unwrap_or_else(|| Self::new(dimensions))
82    }
83
84    /// Determine which files need re-embedding based on content hashes.
85    pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
86        let current_hashes = compute_file_hashes(chunks);
87
88        let mut needs_update = Vec::new();
89        for (file, hash) in &current_hashes {
90            match self.file_hashes.get(file) {
91                Some(old_hash) if old_hash == hash => {}
92                _ => needs_update.push(file.clone()),
93            }
94        }
95
96        for file in self.file_hashes.keys() {
97            if !current_hashes.contains_key(file) {
98                needs_update.push(file.clone());
99            }
100        }
101
102        needs_update
103    }
104
105    /// Update the index with new embeddings for changed files.
106    /// Preserves existing embeddings for unchanged files.
107    pub fn update(
108        &mut self,
109        chunks: &[CodeChunk],
110        new_embeddings: &[(usize, Vec<f32>)],
111        changed_files: &[String],
112    ) {
113        self.entries
114            .retain(|e| !changed_files.contains(&e.file_path));
115
116        for file in changed_files {
117            self.file_hashes.remove(file);
118        }
119
120        let current_hashes = compute_file_hashes(chunks);
121        for file in changed_files {
122            if let Some(hash) = current_hashes.get(file) {
123                self.file_hashes.insert(file.clone(), hash.clone());
124            }
125        }
126
127        for &(chunk_idx, ref embedding) in new_embeddings {
128            if let Some(chunk) = chunks.get(chunk_idx) {
129                let content_hash = hash_content(&chunk.content);
130                self.entries.push(EmbeddingEntry {
131                    file_path: chunk.file_path.clone(),
132                    symbol_name: chunk.symbol_name.clone(),
133                    start_line: chunk.start_line,
134                    end_line: chunk.end_line,
135                    embedding: embedding.clone(),
136                    content_hash,
137                });
138            }
139        }
140    }
141
142    /// Get all embeddings in chunk order (aligned with BM25Index.chunks).
143    /// Returns None if index doesn't cover all chunks.
144    pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
145        let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
146            HashMap::with_capacity(self.entries.len());
147        for e in &self.entries {
148            map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
149        }
150
151        let mut result = Vec::with_capacity(chunks.len());
152        for chunk in chunks {
153            let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
154            result.push(entry.embedding.clone());
155        }
156        Some(result)
157    }
158
159    pub fn coverage(&self, total_chunks: usize) -> f64 {
160        if total_chunks == 0 {
161            return 0.0;
162        }
163        self.entries.len() as f64 / total_chunks as f64
164    }
165
166    pub fn save(&self, root: &Path) -> std::io::Result<()> {
167        let dir = index_dir(root);
168        std::fs::create_dir_all(&dir)?;
169        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
170        std::fs::write(dir.join("embeddings.json"), data)?;
171        Ok(())
172    }
173
174    pub fn load(root: &Path) -> Option<Self> {
175        let dir = index_dir(root);
176        let path = dir.join("embeddings.json");
177        let data = std::fs::read_to_string(&path)
178            .or_else(|_| {
179                let legacy_dir = legacy_embedding_dir(root);
180                if legacy_dir == dir {
181                    return Err(std::io::Error::new(
182                        std::io::ErrorKind::NotFound,
183                        "same path",
184                    ));
185                }
186                let legacy_path = legacy_dir.join("embeddings.json");
187                let content = std::fs::read_to_string(&legacy_path)?;
188                let _ = std::fs::create_dir_all(&dir);
189                let _ = std::fs::copy(&legacy_path, &path);
190                Ok(content)
191            })
192            .ok()?;
193        let idx: Self = serde_json::from_str(&data).ok()?;
194        if idx.version != CURRENT_VERSION {
195            return None;
196        }
197        Some(idx)
198    }
199}
200
201fn index_dir(root: &Path) -> PathBuf {
202    crate::core::index_namespace::vectors_dir(root)
203}
204
205fn legacy_embedding_dir(root: &Path) -> PathBuf {
206    let mut hasher = Md5::new();
207    hasher.update(root.to_string_lossy().as_bytes());
208    let hash = format!("{:x}", hasher.finalize());
209    crate::core::data_dir::lean_ctx_data_dir()
210        .unwrap_or_else(|_| PathBuf::from("."))
211        .join("vectors")
212        .join(hash)
213}
214
215fn hash_content(content: &str) -> String {
216    let mut hasher = Md5::new();
217    hasher.update(content.as_bytes());
218    format!("{:x}", hasher.finalize())
219}
220
221fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
222    let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
223    for chunk in chunks {
224        by_file
225            .entry(chunk.file_path.as_str())
226            .or_default()
227            .push(chunk);
228    }
229
230    let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
231    for (file, mut file_chunks) in by_file {
232        file_chunks.sort_by(|a, b| {
233            (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
234                b.start_line,
235                b.end_line,
236                b.symbol_name.as_str(),
237            ))
238        });
239
240        let mut hasher = Md5::new();
241        hasher.update(file.as_bytes());
242        for c in file_chunks {
243            hasher.update(c.start_line.to_le_bytes());
244            hasher.update(c.end_line.to_le_bytes());
245            hasher.update(c.symbol_name.as_bytes());
246            hasher.update([kind_tag(&c.kind)]);
247            hasher.update(c.content.as_bytes());
248        }
249        out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
250    }
251    out
252}
253
254fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
255    use super::bm25_index::ChunkKind;
256    match kind {
257        ChunkKind::Function => 1,
258        ChunkKind::Struct => 2,
259        ChunkKind::Impl => 3,
260        ChunkKind::Module => 4,
261        ChunkKind::Class => 5,
262        ChunkKind::Method => 6,
263        ChunkKind::Other => 7,
264        ChunkKind::Issue => 8,
265        ChunkKind::PullRequest => 9,
266        ChunkKind::WikiPage => 10,
267        ChunkKind::DbSchema => 11,
268        ChunkKind::ApiEndpoint => 12,
269        ChunkKind::Ticket => 13,
270        ChunkKind::ExternalOther => 14,
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::core::bm25_index::{ChunkKind, CodeChunk};
278
279    fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
280        CodeChunk {
281            file_path: file.to_string(),
282            symbol_name: name.to_string(),
283            kind: ChunkKind::Function,
284            start_line: start,
285            end_line: end,
286            content: content.to_string(),
287            tokens: vec![name.to_string()],
288            token_count: 1,
289        }
290    }
291
292    fn dummy_embedding(dim: usize) -> Vec<f32> {
293        vec![0.1; dim]
294    }
295
296    #[test]
297    fn new_index_is_empty() {
298        let idx = EmbeddingIndex::new(384);
299        assert!(idx.entries.is_empty());
300        assert!(idx.file_hashes.is_empty());
301        assert_eq!(idx.dimensions, 384);
302    }
303
304    #[test]
305    fn files_needing_update_all_new() {
306        let idx = EmbeddingIndex::new(384);
307        let chunks = vec![
308            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
309            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
310        ];
311        let needs = idx.files_needing_update(&chunks);
312        assert_eq!(needs.len(), 2);
313    }
314
315    #[test]
316    fn files_needing_update_unchanged() {
317        let mut idx = EmbeddingIndex::new(384);
318        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
319
320        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
321
322        let needs = idx.files_needing_update(&chunks);
323        assert!(needs.is_empty(), "unchanged file should not need update");
324    }
325
326    #[test]
327    fn files_needing_update_changed_content() {
328        let mut idx = EmbeddingIndex::new(384);
329        let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
330        idx.update(
331            &chunks_v1,
332            &[(0, dummy_embedding(384))],
333            &["a.rs".to_string()],
334        );
335
336        let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
337        let needs = idx.files_needing_update(&chunks_v2);
338        assert!(
339            needs.contains(&"a.rs".to_string()),
340            "changed file should need update"
341        );
342    }
343
344    #[test]
345    fn files_needing_update_detects_change_in_later_chunk() {
346        let mut idx = EmbeddingIndex::new(3);
347        let chunks_v1 = vec![
348            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
349            make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
350        ];
351        idx.update(
352            &chunks_v1,
353            &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
354            &["a.rs".to_string()],
355        );
356
357        let chunks_v2 = vec![
358            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
359            make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
360        ];
361        let needs = idx.files_needing_update(&chunks_v2);
362        assert!(
363            needs.contains(&"a.rs".to_string()),
364            "changing a later chunk should trigger re-embedding"
365        );
366    }
367
368    #[test]
369    fn files_needing_update_deleted_file() {
370        let mut idx = EmbeddingIndex::new(384);
371        let chunks = vec![
372            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
373            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
374        ];
375        idx.update(
376            &chunks,
377            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
378            &["a.rs".to_string(), "b.rs".to_string()],
379        );
380
381        let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
382        let needs = idx.files_needing_update(&chunks_after);
383        assert!(
384            needs.contains(&"b.rs".to_string()),
385            "deleted file should trigger update"
386        );
387    }
388
389    #[test]
390    fn update_preserves_unchanged() {
391        let mut idx = EmbeddingIndex::new(384);
392        let chunks = vec![
393            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
394            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
395        ];
396        idx.update(
397            &chunks,
398            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
399            &["a.rs".to_string(), "b.rs".to_string()],
400        );
401        assert_eq!(idx.entries.len(), 2);
402
403        idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
404        assert_eq!(idx.entries.len(), 2);
405
406        let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
407        assert!(
408            (b_entry.embedding[0] - 0.1).abs() < 1e-6,
409            "b.rs embedding should be preserved"
410        );
411    }
412
413    #[test]
414    fn get_aligned_embeddings() {
415        let mut idx = EmbeddingIndex::new(2);
416        let chunks = vec![
417            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
418            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
419        ];
420        idx.update(
421            &chunks,
422            &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
423            &["a.rs".to_string(), "b.rs".to_string()],
424        );
425
426        let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
427        assert_eq!(aligned.len(), 2);
428        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
429        assert!((aligned[1][1] - 1.0).abs() < 1e-6);
430    }
431
432    #[test]
433    fn get_aligned_embeddings_missing() {
434        let idx = EmbeddingIndex::new(384);
435        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
436        assert!(idx.get_aligned_embeddings(&chunks).is_none());
437    }
438
439    #[test]
440    fn coverage_calculation() {
441        let mut idx = EmbeddingIndex::new(384);
442        assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
443
444        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
445        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
446        assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
447        assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
448    }
449
450    #[test]
451    fn save_and_load_roundtrip() {
452        let _lock = crate::core::data_dir::test_env_lock();
453        let data_dir = tempfile::tempdir().unwrap();
454        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
455
456        let project_dir = tempfile::tempdir().unwrap();
457
458        let mut idx = EmbeddingIndex::new(3);
459        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
460        idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
461        idx.save(project_dir.path()).unwrap();
462
463        let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
464        assert_eq!(loaded.dimensions, 3);
465        assert_eq!(loaded.entries.len(), 1);
466        assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
467
468        std::env::remove_var("LEAN_CTX_DATA_DIR");
469    }
470}