Skip to main content

lean_ctx/core/
embedding_index.rs

1//! Persistent, incremental embedding index.
2//!
3//! Stores pre-computed chunk embeddings alongside file content hashes.
4//! On re-index, only files whose hash has changed get re-embedded,
5//! avoiding expensive model inference for unchanged code.
6//!
7//! Storage format: `~/.lean-ctx/vectors/<project_hash>/embeddings.json`
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19    pub version: u32,
20    pub dimensions: usize,
21    pub entries: Vec<EmbeddingEntry>,
22    pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27    pub file_path: String,
28    pub symbol_name: String,
29    pub start_line: usize,
30    pub end_line: usize,
31    pub embedding: Vec<f32>,
32    pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38    pub fn new(dimensions: usize) -> Self {
39        Self {
40            version: CURRENT_VERSION,
41            dimensions,
42            entries: Vec::new(),
43            file_hashes: HashMap::new(),
44        }
45    }
46
47    /// Load a previously saved index, or create a new empty one.
48    pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49        Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50    }
51
52    /// Determine which files need re-embedding based on content hashes.
53    pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54        let current_hashes = compute_file_hashes(chunks);
55
56        let mut needs_update = Vec::new();
57        for (file, hash) in &current_hashes {
58            match self.file_hashes.get(file) {
59                Some(old_hash) if old_hash == hash => {}
60                _ => needs_update.push(file.clone()),
61            }
62        }
63
64        for file in self.file_hashes.keys() {
65            if !current_hashes.contains_key(file) {
66                needs_update.push(file.clone());
67            }
68        }
69
70        needs_update
71    }
72
73    /// Update the index with new embeddings for changed files.
74    /// Preserves existing embeddings for unchanged files.
75    pub fn update(
76        &mut self,
77        chunks: &[CodeChunk],
78        new_embeddings: &[(usize, Vec<f32>)],
79        changed_files: &[String],
80    ) {
81        self.entries
82            .retain(|e| !changed_files.contains(&e.file_path));
83
84        for file in changed_files {
85            self.file_hashes.remove(file);
86        }
87
88        let current_hashes = compute_file_hashes(chunks);
89        for file in changed_files {
90            if let Some(hash) = current_hashes.get(file) {
91                self.file_hashes.insert(file.clone(), hash.clone());
92            }
93        }
94
95        for &(chunk_idx, ref embedding) in new_embeddings {
96            if let Some(chunk) = chunks.get(chunk_idx) {
97                let content_hash = hash_content(&chunk.content);
98                self.entries.push(EmbeddingEntry {
99                    file_path: chunk.file_path.clone(),
100                    symbol_name: chunk.symbol_name.clone(),
101                    start_line: chunk.start_line,
102                    end_line: chunk.end_line,
103                    embedding: embedding.clone(),
104                    content_hash,
105                });
106            }
107        }
108    }
109
110    /// Get all embeddings in chunk order (aligned with BM25Index.chunks).
111    /// Returns None if index doesn't cover all chunks.
112    pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113        let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
114            HashMap::with_capacity(self.entries.len());
115        for e in &self.entries {
116            map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
117        }
118
119        let mut result = Vec::with_capacity(chunks.len());
120        for chunk in chunks {
121            let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
122            result.push(entry.embedding.clone());
123        }
124        Some(result)
125    }
126
127    pub fn coverage(&self, total_chunks: usize) -> f64 {
128        if total_chunks == 0 {
129            return 0.0;
130        }
131        self.entries.len() as f64 / total_chunks as f64
132    }
133
134    pub fn save(&self, root: &Path) -> std::io::Result<()> {
135        let dir = index_dir(root);
136        std::fs::create_dir_all(&dir)?;
137        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138        std::fs::write(dir.join("embeddings.json"), data)?;
139        Ok(())
140    }
141
142    pub fn load(root: &Path) -> Option<Self> {
143        let dir = index_dir(root);
144        let path = dir.join("embeddings.json");
145        let data = std::fs::read_to_string(&path)
146            .or_else(|_| {
147                let legacy_dir = legacy_embedding_dir(root);
148                if legacy_dir == dir {
149                    return Err(std::io::Error::new(
150                        std::io::ErrorKind::NotFound,
151                        "same path",
152                    ));
153                }
154                let legacy_path = legacy_dir.join("embeddings.json");
155                let content = std::fs::read_to_string(&legacy_path)?;
156                let _ = std::fs::create_dir_all(&dir);
157                let _ = std::fs::copy(&legacy_path, &path);
158                Ok(content)
159            })
160            .ok()?;
161        let idx: Self = serde_json::from_str(&data).ok()?;
162        if idx.version != CURRENT_VERSION {
163            return None;
164        }
165        Some(idx)
166    }
167}
168
169fn index_dir(root: &Path) -> PathBuf {
170    crate::core::index_namespace::vectors_dir(root)
171}
172
173fn legacy_embedding_dir(root: &Path) -> PathBuf {
174    let mut hasher = Md5::new();
175    hasher.update(root.to_string_lossy().as_bytes());
176    let hash = format!("{:x}", hasher.finalize());
177    crate::core::data_dir::lean_ctx_data_dir()
178        .unwrap_or_else(|_| PathBuf::from("."))
179        .join("vectors")
180        .join(hash)
181}
182
183fn hash_content(content: &str) -> String {
184    let mut hasher = Md5::new();
185    hasher.update(content.as_bytes());
186    format!("{:x}", hasher.finalize())
187}
188
189fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
190    let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
191    for chunk in chunks {
192        by_file
193            .entry(chunk.file_path.as_str())
194            .or_default()
195            .push(chunk);
196    }
197
198    let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
199    for (file, mut file_chunks) in by_file {
200        file_chunks.sort_by(|a, b| {
201            (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
202                b.start_line,
203                b.end_line,
204                b.symbol_name.as_str(),
205            ))
206        });
207
208        let mut hasher = Md5::new();
209        hasher.update(file.as_bytes());
210        for c in file_chunks {
211            hasher.update(c.start_line.to_le_bytes());
212            hasher.update(c.end_line.to_le_bytes());
213            hasher.update(c.symbol_name.as_bytes());
214            hasher.update([kind_tag(&c.kind)]);
215            hasher.update(c.content.as_bytes());
216        }
217        out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
218    }
219    out
220}
221
222fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
223    use super::bm25_index::ChunkKind;
224    match kind {
225        ChunkKind::Function => 1,
226        ChunkKind::Struct => 2,
227        ChunkKind::Impl => 3,
228        ChunkKind::Module => 4,
229        ChunkKind::Class => 5,
230        ChunkKind::Method => 6,
231        ChunkKind::Other => 7,
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::core::bm25_index::{ChunkKind, CodeChunk};
239
240    fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
241        CodeChunk {
242            file_path: file.to_string(),
243            symbol_name: name.to_string(),
244            kind: ChunkKind::Function,
245            start_line: start,
246            end_line: end,
247            content: content.to_string(),
248            tokens: vec![name.to_string()],
249            token_count: 1,
250        }
251    }
252
253    fn dummy_embedding(dim: usize) -> Vec<f32> {
254        vec![0.1; dim]
255    }
256
257    #[test]
258    fn new_index_is_empty() {
259        let idx = EmbeddingIndex::new(384);
260        assert!(idx.entries.is_empty());
261        assert!(idx.file_hashes.is_empty());
262        assert_eq!(idx.dimensions, 384);
263    }
264
265    #[test]
266    fn files_needing_update_all_new() {
267        let idx = EmbeddingIndex::new(384);
268        let chunks = vec![
269            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
270            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
271        ];
272        let needs = idx.files_needing_update(&chunks);
273        assert_eq!(needs.len(), 2);
274    }
275
276    #[test]
277    fn files_needing_update_unchanged() {
278        let mut idx = EmbeddingIndex::new(384);
279        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
280
281        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
282
283        let needs = idx.files_needing_update(&chunks);
284        assert!(needs.is_empty(), "unchanged file should not need update");
285    }
286
287    #[test]
288    fn files_needing_update_changed_content() {
289        let mut idx = EmbeddingIndex::new(384);
290        let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
291        idx.update(
292            &chunks_v1,
293            &[(0, dummy_embedding(384))],
294            &["a.rs".to_string()],
295        );
296
297        let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
298        let needs = idx.files_needing_update(&chunks_v2);
299        assert!(
300            needs.contains(&"a.rs".to_string()),
301            "changed file should need update"
302        );
303    }
304
305    #[test]
306    fn files_needing_update_detects_change_in_later_chunk() {
307        let mut idx = EmbeddingIndex::new(3);
308        let chunks_v1 = vec![
309            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
310            make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
311        ];
312        idx.update(
313            &chunks_v1,
314            &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
315            &["a.rs".to_string()],
316        );
317
318        let chunks_v2 = vec![
319            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
320            make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
321        ];
322        let needs = idx.files_needing_update(&chunks_v2);
323        assert!(
324            needs.contains(&"a.rs".to_string()),
325            "changing a later chunk should trigger re-embedding"
326        );
327    }
328
329    #[test]
330    fn files_needing_update_deleted_file() {
331        let mut idx = EmbeddingIndex::new(384);
332        let chunks = vec![
333            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
334            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
335        ];
336        idx.update(
337            &chunks,
338            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
339            &["a.rs".to_string(), "b.rs".to_string()],
340        );
341
342        let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
343        let needs = idx.files_needing_update(&chunks_after);
344        assert!(
345            needs.contains(&"b.rs".to_string()),
346            "deleted file should trigger update"
347        );
348    }
349
350    #[test]
351    fn update_preserves_unchanged() {
352        let mut idx = EmbeddingIndex::new(384);
353        let chunks = vec![
354            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
355            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
356        ];
357        idx.update(
358            &chunks,
359            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
360            &["a.rs".to_string(), "b.rs".to_string()],
361        );
362        assert_eq!(idx.entries.len(), 2);
363
364        idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
365        assert_eq!(idx.entries.len(), 2);
366
367        let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
368        assert!(
369            (b_entry.embedding[0] - 0.1).abs() < 1e-6,
370            "b.rs embedding should be preserved"
371        );
372    }
373
374    #[test]
375    fn get_aligned_embeddings() {
376        let mut idx = EmbeddingIndex::new(2);
377        let chunks = vec![
378            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
379            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
380        ];
381        idx.update(
382            &chunks,
383            &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
384            &["a.rs".to_string(), "b.rs".to_string()],
385        );
386
387        let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
388        assert_eq!(aligned.len(), 2);
389        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
390        assert!((aligned[1][1] - 1.0).abs() < 1e-6);
391    }
392
393    #[test]
394    fn get_aligned_embeddings_missing() {
395        let idx = EmbeddingIndex::new(384);
396        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
397        assert!(idx.get_aligned_embeddings(&chunks).is_none());
398    }
399
400    #[test]
401    fn coverage_calculation() {
402        let mut idx = EmbeddingIndex::new(384);
403        assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
404
405        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
406        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
407        assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
408        assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
409    }
410
411    #[test]
412    fn save_and_load_roundtrip() {
413        let _lock = crate::core::data_dir::test_env_lock();
414        let data_dir = tempfile::tempdir().unwrap();
415        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
416
417        let project_dir = tempfile::tempdir().unwrap();
418
419        let mut idx = EmbeddingIndex::new(3);
420        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
421        idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
422        idx.save(project_dir.path()).unwrap();
423
424        let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
425        assert_eq!(loaded.dimensions, 3);
426        assert_eq!(loaded.entries.len(), 1);
427        assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
428
429        std::env::remove_var("LEAN_CTX_DATA_DIR");
430    }
431}