Skip to main content

lean_ctx/core/
embedding_index.rs

1//! Persistent, incremental embedding index.
2//!
3//! Stores pre-computed chunk embeddings alongside file content hashes.
4//! On re-index, only files whose hash has changed get re-embedded,
5//! avoiding expensive model inference for unchanged code.
6//!
7//! Storage format: `~/.lean-ctx/vectors/<project_hash>/embeddings.json`
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::vector_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19    pub version: u32,
20    pub dimensions: usize,
21    pub entries: Vec<EmbeddingEntry>,
22    pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27    pub file_path: String,
28    pub symbol_name: String,
29    pub start_line: usize,
30    pub end_line: usize,
31    pub embedding: Vec<f32>,
32    pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38    pub fn new(dimensions: usize) -> Self {
39        Self {
40            version: CURRENT_VERSION,
41            dimensions,
42            entries: Vec::new(),
43            file_hashes: HashMap::new(),
44        }
45    }
46
47    /// Load a previously saved index, or create a new empty one.
48    pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49        Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50    }
51
52    /// Determine which files need re-embedding based on content hashes.
53    pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54        let current_hashes = compute_file_hashes(chunks);
55
56        let mut needs_update = Vec::new();
57        for (file, hash) in &current_hashes {
58            match self.file_hashes.get(file) {
59                Some(old_hash) if old_hash == hash => {}
60                _ => needs_update.push(file.clone()),
61            }
62        }
63
64        for file in self.file_hashes.keys() {
65            if !current_hashes.contains_key(file) {
66                needs_update.push(file.clone());
67            }
68        }
69
70        needs_update
71    }
72
73    /// Update the index with new embeddings for changed files.
74    /// Preserves existing embeddings for unchanged files.
75    pub fn update(
76        &mut self,
77        chunks: &[CodeChunk],
78        new_embeddings: &[(usize, Vec<f32>)],
79        changed_files: &[String],
80    ) {
81        self.entries
82            .retain(|e| !changed_files.contains(&e.file_path));
83
84        for file in changed_files {
85            self.file_hashes.remove(file);
86        }
87
88        let current_hashes = compute_file_hashes(chunks);
89        for file in changed_files {
90            if let Some(hash) = current_hashes.get(file) {
91                self.file_hashes.insert(file.clone(), hash.clone());
92            }
93        }
94
95        for &(chunk_idx, ref embedding) in new_embeddings {
96            if let Some(chunk) = chunks.get(chunk_idx) {
97                let content_hash = hash_content(&chunk.content);
98                self.entries.push(EmbeddingEntry {
99                    file_path: chunk.file_path.clone(),
100                    symbol_name: chunk.symbol_name.clone(),
101                    start_line: chunk.start_line,
102                    end_line: chunk.end_line,
103                    embedding: embedding.clone(),
104                    content_hash,
105                });
106            }
107        }
108    }
109
110    /// Get all embeddings in chunk order (aligned with BM25Index.chunks).
111    /// Returns None if index doesn't cover all chunks.
112    pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113        let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
114            HashMap::with_capacity(self.entries.len());
115        for e in &self.entries {
116            map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
117        }
118
119        let mut result = Vec::with_capacity(chunks.len());
120        for chunk in chunks {
121            let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
122            result.push(entry.embedding.clone());
123        }
124        Some(result)
125    }
126
127    pub fn coverage(&self, total_chunks: usize) -> f64 {
128        if total_chunks == 0 {
129            return 0.0;
130        }
131        self.entries.len() as f64 / total_chunks as f64
132    }
133
134    pub fn save(&self, root: &Path) -> std::io::Result<()> {
135        let dir = index_dir(root);
136        std::fs::create_dir_all(&dir)?;
137        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138        std::fs::write(dir.join("embeddings.json"), data)?;
139        Ok(())
140    }
141
142    pub fn load(root: &Path) -> Option<Self> {
143        let path = index_dir(root).join("embeddings.json");
144        let data = std::fs::read_to_string(path).ok()?;
145        let idx: Self = serde_json::from_str(&data).ok()?;
146        if idx.version != CURRENT_VERSION {
147            return None;
148        }
149        Some(idx)
150    }
151}
152
153fn index_dir(root: &Path) -> PathBuf {
154    let mut hasher = Md5::new();
155    hasher.update(root.to_string_lossy().as_bytes());
156    let hash = format!("{:x}", hasher.finalize());
157    crate::core::data_dir::lean_ctx_data_dir()
158        .unwrap_or_else(|_| PathBuf::from("."))
159        .join("vectors")
160        .join(hash)
161}
162
163fn hash_content(content: &str) -> String {
164    let mut hasher = Md5::new();
165    hasher.update(content.as_bytes());
166    format!("{:x}", hasher.finalize())
167}
168
169fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
170    let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
171    for chunk in chunks {
172        by_file
173            .entry(chunk.file_path.as_str())
174            .or_default()
175            .push(chunk);
176    }
177
178    let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
179    for (file, mut file_chunks) in by_file {
180        file_chunks.sort_by(|a, b| {
181            (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
182                b.start_line,
183                b.end_line,
184                b.symbol_name.as_str(),
185            ))
186        });
187
188        let mut hasher = Md5::new();
189        hasher.update(file.as_bytes());
190        for c in file_chunks {
191            hasher.update(c.start_line.to_le_bytes());
192            hasher.update(c.end_line.to_le_bytes());
193            hasher.update(c.symbol_name.as_bytes());
194            hasher.update([kind_tag(&c.kind)]);
195            hasher.update(c.content.as_bytes());
196        }
197        out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
198    }
199    out
200}
201
202fn kind_tag(kind: &super::vector_index::ChunkKind) -> u8 {
203    use super::vector_index::ChunkKind;
204    match kind {
205        ChunkKind::Function => 1,
206        ChunkKind::Struct => 2,
207        ChunkKind::Impl => 3,
208        ChunkKind::Module => 4,
209        ChunkKind::Class => 5,
210        ChunkKind::Method => 6,
211        ChunkKind::Other => 7,
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::core::vector_index::{ChunkKind, CodeChunk};
219
220    fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
221        CodeChunk {
222            file_path: file.to_string(),
223            symbol_name: name.to_string(),
224            kind: ChunkKind::Function,
225            start_line: start,
226            end_line: end,
227            content: content.to_string(),
228            tokens: vec![name.to_string()],
229            token_count: 1,
230        }
231    }
232
233    fn dummy_embedding(dim: usize) -> Vec<f32> {
234        vec![0.1; dim]
235    }
236
237    #[test]
238    fn new_index_is_empty() {
239        let idx = EmbeddingIndex::new(384);
240        assert!(idx.entries.is_empty());
241        assert!(idx.file_hashes.is_empty());
242        assert_eq!(idx.dimensions, 384);
243    }
244
245    #[test]
246    fn files_needing_update_all_new() {
247        let idx = EmbeddingIndex::new(384);
248        let chunks = vec![
249            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
250            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
251        ];
252        let needs = idx.files_needing_update(&chunks);
253        assert_eq!(needs.len(), 2);
254    }
255
256    #[test]
257    fn files_needing_update_unchanged() {
258        let mut idx = EmbeddingIndex::new(384);
259        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
260
261        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
262
263        let needs = idx.files_needing_update(&chunks);
264        assert!(needs.is_empty(), "unchanged file should not need update");
265    }
266
267    #[test]
268    fn files_needing_update_changed_content() {
269        let mut idx = EmbeddingIndex::new(384);
270        let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
271        idx.update(
272            &chunks_v1,
273            &[(0, dummy_embedding(384))],
274            &["a.rs".to_string()],
275        );
276
277        let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
278        let needs = idx.files_needing_update(&chunks_v2);
279        assert!(
280            needs.contains(&"a.rs".to_string()),
281            "changed file should need update"
282        );
283    }
284
285    #[test]
286    fn files_needing_update_detects_change_in_later_chunk() {
287        let mut idx = EmbeddingIndex::new(3);
288        let chunks_v1 = vec![
289            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
290            make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
291        ];
292        idx.update(
293            &chunks_v1,
294            &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
295            &["a.rs".to_string()],
296        );
297
298        let chunks_v2 = vec![
299            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
300            make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
301        ];
302        let needs = idx.files_needing_update(&chunks_v2);
303        assert!(
304            needs.contains(&"a.rs".to_string()),
305            "changing a later chunk should trigger re-embedding"
306        );
307    }
308
309    #[test]
310    fn files_needing_update_deleted_file() {
311        let mut idx = EmbeddingIndex::new(384);
312        let chunks = vec![
313            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
314            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
315        ];
316        idx.update(
317            &chunks,
318            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
319            &["a.rs".to_string(), "b.rs".to_string()],
320        );
321
322        let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
323        let needs = idx.files_needing_update(&chunks_after);
324        assert!(
325            needs.contains(&"b.rs".to_string()),
326            "deleted file should trigger update"
327        );
328    }
329
330    #[test]
331    fn update_preserves_unchanged() {
332        let mut idx = EmbeddingIndex::new(384);
333        let chunks = vec![
334            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
335            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
336        ];
337        idx.update(
338            &chunks,
339            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
340            &["a.rs".to_string(), "b.rs".to_string()],
341        );
342        assert_eq!(idx.entries.len(), 2);
343
344        idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
345        assert_eq!(idx.entries.len(), 2);
346
347        let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
348        assert!(
349            (b_entry.embedding[0] - 0.1).abs() < 1e-6,
350            "b.rs embedding should be preserved"
351        );
352    }
353
354    #[test]
355    fn get_aligned_embeddings() {
356        let mut idx = EmbeddingIndex::new(2);
357        let chunks = vec![
358            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
359            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
360        ];
361        idx.update(
362            &chunks,
363            &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
364            &["a.rs".to_string(), "b.rs".to_string()],
365        );
366
367        let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
368        assert_eq!(aligned.len(), 2);
369        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
370        assert!((aligned[1][1] - 1.0).abs() < 1e-6);
371    }
372
373    #[test]
374    fn get_aligned_embeddings_missing() {
375        let idx = EmbeddingIndex::new(384);
376        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
377        assert!(idx.get_aligned_embeddings(&chunks).is_none());
378    }
379
380    #[test]
381    fn coverage_calculation() {
382        let mut idx = EmbeddingIndex::new(384);
383        assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
384
385        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
386        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
387        assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
388        assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
389    }
390
391    #[test]
392    fn save_and_load_roundtrip() {
393        let _lock = crate::core::data_dir::test_env_lock();
394        let data_dir = tempfile::tempdir().unwrap();
395        std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
396
397        let project_dir = tempfile::tempdir().unwrap();
398
399        let mut idx = EmbeddingIndex::new(3);
400        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
401        idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
402        idx.save(project_dir.path()).unwrap();
403
404        let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
405        assert_eq!(loaded.dimensions, 3);
406        assert_eq!(loaded.entries.len(), 1);
407        assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
408
409        std::env::remove_var("LEAN_CTX_DATA_DIR");
410    }
411}