Skip to main content

lean_ctx/core/
embedding_index.rs

1//! Persistent, incremental embedding index.
2//!
3//! Stores pre-computed chunk embeddings alongside file content hashes.
4//! On re-index, only files whose hash has changed get re-embedded,
5//! avoiding expensive model inference for unchanged code.
6//!
7//! Storage format: `~/.lean-ctx/vectors/<project_hash>/embeddings.json`
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::vector_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19    pub version: u32,
20    pub dimensions: usize,
21    pub entries: Vec<EmbeddingEntry>,
22    pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27    pub file_path: String,
28    pub symbol_name: String,
29    pub start_line: usize,
30    pub end_line: usize,
31    pub embedding: Vec<f32>,
32    pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38    pub fn new(dimensions: usize) -> Self {
39        Self {
40            version: CURRENT_VERSION,
41            dimensions,
42            entries: Vec::new(),
43            file_hashes: HashMap::new(),
44        }
45    }
46
47    /// Load a previously saved index, or create a new empty one.
48    pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49        Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50    }
51
52    /// Determine which files need re-embedding based on content hashes.
53    pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54        let current_hashes = compute_file_hashes(chunks);
55
56        let mut needs_update = Vec::new();
57        for (file, hash) in &current_hashes {
58            match self.file_hashes.get(file) {
59                Some(old_hash) if old_hash == hash => {}
60                _ => needs_update.push(file.clone()),
61            }
62        }
63
64        for file in self.file_hashes.keys() {
65            if !current_hashes.contains_key(file) {
66                needs_update.push(file.clone());
67            }
68        }
69
70        needs_update
71    }
72
73    /// Update the index with new embeddings for changed files.
74    /// Preserves existing embeddings for unchanged files.
75    pub fn update(
76        &mut self,
77        chunks: &[CodeChunk],
78        new_embeddings: &[(usize, Vec<f32>)],
79        changed_files: &[String],
80    ) {
81        self.entries
82            .retain(|e| !changed_files.contains(&e.file_path));
83
84        for file in changed_files {
85            self.file_hashes.remove(file);
86        }
87
88        let current_hashes = compute_file_hashes(chunks);
89        for file in changed_files {
90            if let Some(hash) = current_hashes.get(file) {
91                self.file_hashes.insert(file.clone(), hash.clone());
92            }
93        }
94
95        for &(chunk_idx, ref embedding) in new_embeddings {
96            if let Some(chunk) = chunks.get(chunk_idx) {
97                let content_hash = hash_content(&chunk.content);
98                self.entries.push(EmbeddingEntry {
99                    file_path: chunk.file_path.clone(),
100                    symbol_name: chunk.symbol_name.clone(),
101                    start_line: chunk.start_line,
102                    end_line: chunk.end_line,
103                    embedding: embedding.clone(),
104                    content_hash,
105                });
106            }
107        }
108    }
109
110    /// Get all embeddings in chunk order (aligned with BM25Index.chunks).
111    /// Returns None if index doesn't cover all chunks.
112    pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113        let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
114            HashMap::with_capacity(self.entries.len());
115        for e in &self.entries {
116            map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
117        }
118
119        let mut result = Vec::with_capacity(chunks.len());
120        for chunk in chunks {
121            let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
122            result.push(entry.embedding.clone());
123        }
124        Some(result)
125    }
126
127    pub fn coverage(&self, total_chunks: usize) -> f64 {
128        if total_chunks == 0 {
129            return 0.0;
130        }
131        self.entries.len() as f64 / total_chunks as f64
132    }
133
134    pub fn save(&self, root: &Path) -> std::io::Result<()> {
135        let dir = index_dir(root);
136        std::fs::create_dir_all(&dir)?;
137        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138        std::fs::write(dir.join("embeddings.json"), data)?;
139        Ok(())
140    }
141
142    pub fn load(root: &Path) -> Option<Self> {
143        let path = index_dir(root).join("embeddings.json");
144        let data = std::fs::read_to_string(path).ok()?;
145        let idx: Self = serde_json::from_str(&data).ok()?;
146        if idx.version != CURRENT_VERSION {
147            return None;
148        }
149        Some(idx)
150    }
151}
152
153fn index_dir(root: &Path) -> PathBuf {
154    let mut hasher = Md5::new();
155    hasher.update(root.to_string_lossy().as_bytes());
156    let hash = format!("{:x}", hasher.finalize());
157    dirs::home_dir()
158        .unwrap_or_else(|| PathBuf::from("."))
159        .join(".lean-ctx")
160        .join("vectors")
161        .join(hash)
162}
163
164fn hash_content(content: &str) -> String {
165    let mut hasher = Md5::new();
166    hasher.update(content.as_bytes());
167    format!("{:x}", hasher.finalize())
168}
169
170fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
171    let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
172    for chunk in chunks {
173        by_file
174            .entry(chunk.file_path.as_str())
175            .or_default()
176            .push(chunk);
177    }
178
179    let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
180    for (file, mut file_chunks) in by_file {
181        file_chunks.sort_by(|a, b| {
182            (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
183                b.start_line,
184                b.end_line,
185                b.symbol_name.as_str(),
186            ))
187        });
188
189        let mut hasher = Md5::new();
190        hasher.update(file.as_bytes());
191        for c in file_chunks {
192            hasher.update(c.start_line.to_le_bytes());
193            hasher.update(c.end_line.to_le_bytes());
194            hasher.update(c.symbol_name.as_bytes());
195            hasher.update([kind_tag(&c.kind)]);
196            hasher.update(c.content.as_bytes());
197        }
198        out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
199    }
200    out
201}
202
203fn kind_tag(kind: &super::vector_index::ChunkKind) -> u8 {
204    use super::vector_index::ChunkKind;
205    match kind {
206        ChunkKind::Function => 1,
207        ChunkKind::Struct => 2,
208        ChunkKind::Impl => 3,
209        ChunkKind::Module => 4,
210        ChunkKind::Class => 5,
211        ChunkKind::Method => 6,
212        ChunkKind::Other => 7,
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::core::vector_index::{ChunkKind, CodeChunk};
220
221    fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
222        CodeChunk {
223            file_path: file.to_string(),
224            symbol_name: name.to_string(),
225            kind: ChunkKind::Function,
226            start_line: start,
227            end_line: end,
228            content: content.to_string(),
229            tokens: vec![name.to_string()],
230            token_count: 1,
231        }
232    }
233
234    fn dummy_embedding(dim: usize) -> Vec<f32> {
235        vec![0.1; dim]
236    }
237
238    #[test]
239    fn new_index_is_empty() {
240        let idx = EmbeddingIndex::new(384);
241        assert!(idx.entries.is_empty());
242        assert!(idx.file_hashes.is_empty());
243        assert_eq!(idx.dimensions, 384);
244    }
245
246    #[test]
247    fn files_needing_update_all_new() {
248        let idx = EmbeddingIndex::new(384);
249        let chunks = vec![
250            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
251            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
252        ];
253        let needs = idx.files_needing_update(&chunks);
254        assert_eq!(needs.len(), 2);
255    }
256
257    #[test]
258    fn files_needing_update_unchanged() {
259        let mut idx = EmbeddingIndex::new(384);
260        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
261
262        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
263
264        let needs = idx.files_needing_update(&chunks);
265        assert!(needs.is_empty(), "unchanged file should not need update");
266    }
267
268    #[test]
269    fn files_needing_update_changed_content() {
270        let mut idx = EmbeddingIndex::new(384);
271        let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
272        idx.update(
273            &chunks_v1,
274            &[(0, dummy_embedding(384))],
275            &["a.rs".to_string()],
276        );
277
278        let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
279        let needs = idx.files_needing_update(&chunks_v2);
280        assert!(
281            needs.contains(&"a.rs".to_string()),
282            "changed file should need update"
283        );
284    }
285
286    #[test]
287    fn files_needing_update_detects_change_in_later_chunk() {
288        let mut idx = EmbeddingIndex::new(3);
289        let chunks_v1 = vec![
290            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
291            make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
292        ];
293        idx.update(
294            &chunks_v1,
295            &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
296            &["a.rs".to_string()],
297        );
298
299        let chunks_v2 = vec![
300            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
301            make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
302        ];
303        let needs = idx.files_needing_update(&chunks_v2);
304        assert!(
305            needs.contains(&"a.rs".to_string()),
306            "changing a later chunk should trigger re-embedding"
307        );
308    }
309
310    #[test]
311    fn files_needing_update_deleted_file() {
312        let mut idx = EmbeddingIndex::new(384);
313        let chunks = vec![
314            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
315            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
316        ];
317        idx.update(
318            &chunks,
319            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
320            &["a.rs".to_string(), "b.rs".to_string()],
321        );
322
323        let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
324        let needs = idx.files_needing_update(&chunks_after);
325        assert!(
326            needs.contains(&"b.rs".to_string()),
327            "deleted file should trigger update"
328        );
329    }
330
331    #[test]
332    fn update_preserves_unchanged() {
333        let mut idx = EmbeddingIndex::new(384);
334        let chunks = vec![
335            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
336            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
337        ];
338        idx.update(
339            &chunks,
340            &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
341            &["a.rs".to_string(), "b.rs".to_string()],
342        );
343        assert_eq!(idx.entries.len(), 2);
344
345        idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
346        assert_eq!(idx.entries.len(), 2);
347
348        let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
349        assert!(
350            (b_entry.embedding[0] - 0.1).abs() < 1e-6,
351            "b.rs embedding should be preserved"
352        );
353    }
354
355    #[test]
356    fn get_aligned_embeddings() {
357        let mut idx = EmbeddingIndex::new(2);
358        let chunks = vec![
359            make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
360            make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
361        ];
362        idx.update(
363            &chunks,
364            &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
365            &["a.rs".to_string(), "b.rs".to_string()],
366        );
367
368        let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
369        assert_eq!(aligned.len(), 2);
370        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
371        assert!((aligned[1][1] - 1.0).abs() < 1e-6);
372    }
373
374    #[test]
375    fn get_aligned_embeddings_missing() {
376        let idx = EmbeddingIndex::new(384);
377        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
378        assert!(idx.get_aligned_embeddings(&chunks).is_none());
379    }
380
381    #[test]
382    fn coverage_calculation() {
383        let mut idx = EmbeddingIndex::new(384);
384        assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
385
386        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
387        idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
388        assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
389        assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
390    }
391
392    #[test]
393    fn save_and_load_roundtrip() {
394        let dir = std::env::temp_dir().join("lean_ctx_embed_idx_test");
395        let _ = std::fs::remove_dir_all(&dir);
396        std::fs::create_dir_all(&dir).unwrap();
397
398        let mut idx = EmbeddingIndex::new(3);
399        let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
400        idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
401        idx.save(&dir).unwrap();
402
403        let loaded = EmbeddingIndex::load(&dir).unwrap();
404        assert_eq!(loaded.dimensions, 3);
405        assert_eq!(loaded.entries.len(), 1);
406        assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
407
408        let _ = std::fs::remove_dir_all(&dir);
409    }
410}