1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::vector_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 pub entries: Vec<EmbeddingEntry>,
22 pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27 pub file_path: String,
28 pub symbol_name: String,
29 pub start_line: usize,
30 pub end_line: usize,
31 pub embedding: Vec<f32>,
32 pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38 pub fn new(dimensions: usize) -> Self {
39 Self {
40 version: CURRENT_VERSION,
41 dimensions,
42 entries: Vec::new(),
43 file_hashes: HashMap::new(),
44 }
45 }
46
47 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50 }
51
52 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54 let current_hashes = compute_file_hashes(chunks);
55
56 let mut needs_update = Vec::new();
57 for (file, hash) in ¤t_hashes {
58 match self.file_hashes.get(file) {
59 Some(old_hash) if old_hash == hash => {}
60 _ => needs_update.push(file.clone()),
61 }
62 }
63
64 for file in self.file_hashes.keys() {
65 if !current_hashes.contains_key(file) {
66 needs_update.push(file.clone());
67 }
68 }
69
70 needs_update
71 }
72
73 pub fn update(
76 &mut self,
77 chunks: &[CodeChunk],
78 new_embeddings: &[(usize, Vec<f32>)],
79 changed_files: &[String],
80 ) {
81 self.entries
82 .retain(|e| !changed_files.contains(&e.file_path));
83
84 for file in changed_files {
85 self.file_hashes.remove(file);
86 }
87
88 let current_hashes = compute_file_hashes(chunks);
89 for file in changed_files {
90 if let Some(hash) = current_hashes.get(file) {
91 self.file_hashes.insert(file.clone(), hash.clone());
92 }
93 }
94
95 for &(chunk_idx, ref embedding) in new_embeddings {
96 if let Some(chunk) = chunks.get(chunk_idx) {
97 let content_hash = hash_content(&chunk.content);
98 self.entries.push(EmbeddingEntry {
99 file_path: chunk.file_path.clone(),
100 symbol_name: chunk.symbol_name.clone(),
101 start_line: chunk.start_line,
102 end_line: chunk.end_line,
103 embedding: embedding.clone(),
104 content_hash,
105 });
106 }
107 }
108 }
109
110 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
114 HashMap::with_capacity(self.entries.len());
115 for e in &self.entries {
116 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
117 }
118
119 let mut result = Vec::with_capacity(chunks.len());
120 for chunk in chunks {
121 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
122 result.push(entry.embedding.clone());
123 }
124 Some(result)
125 }
126
127 pub fn coverage(&self, total_chunks: usize) -> f64 {
128 if total_chunks == 0 {
129 return 0.0;
130 }
131 self.entries.len() as f64 / total_chunks as f64
132 }
133
134 pub fn save(&self, root: &Path) -> std::io::Result<()> {
135 let dir = index_dir(root);
136 std::fs::create_dir_all(&dir)?;
137 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138 std::fs::write(dir.join("embeddings.json"), data)?;
139 Ok(())
140 }
141
142 pub fn load(root: &Path) -> Option<Self> {
143 let path = index_dir(root).join("embeddings.json");
144 let data = std::fs::read_to_string(path).ok()?;
145 let idx: Self = serde_json::from_str(&data).ok()?;
146 if idx.version != CURRENT_VERSION {
147 return None;
148 }
149 Some(idx)
150 }
151}
152
153fn index_dir(root: &Path) -> PathBuf {
154 let mut hasher = Md5::new();
155 hasher.update(root.to_string_lossy().as_bytes());
156 let hash = format!("{:x}", hasher.finalize());
157 dirs::home_dir()
158 .unwrap_or_else(|| PathBuf::from("."))
159 .join(".lean-ctx")
160 .join("vectors")
161 .join(hash)
162}
163
164fn hash_content(content: &str) -> String {
165 let mut hasher = Md5::new();
166 hasher.update(content.as_bytes());
167 format!("{:x}", hasher.finalize())
168}
169
170fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
171 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
172 for chunk in chunks {
173 by_file
174 .entry(chunk.file_path.as_str())
175 .or_default()
176 .push(chunk);
177 }
178
179 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
180 for (file, mut file_chunks) in by_file {
181 file_chunks.sort_by(|a, b| {
182 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
183 b.start_line,
184 b.end_line,
185 b.symbol_name.as_str(),
186 ))
187 });
188
189 let mut hasher = Md5::new();
190 hasher.update(file.as_bytes());
191 for c in file_chunks {
192 hasher.update(c.start_line.to_le_bytes());
193 hasher.update(c.end_line.to_le_bytes());
194 hasher.update(c.symbol_name.as_bytes());
195 hasher.update([kind_tag(&c.kind)]);
196 hasher.update(c.content.as_bytes());
197 }
198 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
199 }
200 out
201}
202
203fn kind_tag(kind: &super::vector_index::ChunkKind) -> u8 {
204 use super::vector_index::ChunkKind;
205 match kind {
206 ChunkKind::Function => 1,
207 ChunkKind::Struct => 2,
208 ChunkKind::Impl => 3,
209 ChunkKind::Module => 4,
210 ChunkKind::Class => 5,
211 ChunkKind::Method => 6,
212 ChunkKind::Other => 7,
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::core::vector_index::{ChunkKind, CodeChunk};
220
221 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
222 CodeChunk {
223 file_path: file.to_string(),
224 symbol_name: name.to_string(),
225 kind: ChunkKind::Function,
226 start_line: start,
227 end_line: end,
228 content: content.to_string(),
229 tokens: vec![name.to_string()],
230 token_count: 1,
231 }
232 }
233
234 fn dummy_embedding(dim: usize) -> Vec<f32> {
235 vec![0.1; dim]
236 }
237
238 #[test]
239 fn new_index_is_empty() {
240 let idx = EmbeddingIndex::new(384);
241 assert!(idx.entries.is_empty());
242 assert!(idx.file_hashes.is_empty());
243 assert_eq!(idx.dimensions, 384);
244 }
245
246 #[test]
247 fn files_needing_update_all_new() {
248 let idx = EmbeddingIndex::new(384);
249 let chunks = vec![
250 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
251 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
252 ];
253 let needs = idx.files_needing_update(&chunks);
254 assert_eq!(needs.len(), 2);
255 }
256
257 #[test]
258 fn files_needing_update_unchanged() {
259 let mut idx = EmbeddingIndex::new(384);
260 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
261
262 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
263
264 let needs = idx.files_needing_update(&chunks);
265 assert!(needs.is_empty(), "unchanged file should not need update");
266 }
267
268 #[test]
269 fn files_needing_update_changed_content() {
270 let mut idx = EmbeddingIndex::new(384);
271 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
272 idx.update(
273 &chunks_v1,
274 &[(0, dummy_embedding(384))],
275 &["a.rs".to_string()],
276 );
277
278 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
279 let needs = idx.files_needing_update(&chunks_v2);
280 assert!(
281 needs.contains(&"a.rs".to_string()),
282 "changed file should need update"
283 );
284 }
285
286 #[test]
287 fn files_needing_update_detects_change_in_later_chunk() {
288 let mut idx = EmbeddingIndex::new(3);
289 let chunks_v1 = vec![
290 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
291 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
292 ];
293 idx.update(
294 &chunks_v1,
295 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
296 &["a.rs".to_string()],
297 );
298
299 let chunks_v2 = vec![
300 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
301 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
302 ];
303 let needs = idx.files_needing_update(&chunks_v2);
304 assert!(
305 needs.contains(&"a.rs".to_string()),
306 "changing a later chunk should trigger re-embedding"
307 );
308 }
309
310 #[test]
311 fn files_needing_update_deleted_file() {
312 let mut idx = EmbeddingIndex::new(384);
313 let chunks = vec![
314 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
315 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
316 ];
317 idx.update(
318 &chunks,
319 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
320 &["a.rs".to_string(), "b.rs".to_string()],
321 );
322
323 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
324 let needs = idx.files_needing_update(&chunks_after);
325 assert!(
326 needs.contains(&"b.rs".to_string()),
327 "deleted file should trigger update"
328 );
329 }
330
331 #[test]
332 fn update_preserves_unchanged() {
333 let mut idx = EmbeddingIndex::new(384);
334 let chunks = vec![
335 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
336 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
337 ];
338 idx.update(
339 &chunks,
340 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
341 &["a.rs".to_string(), "b.rs".to_string()],
342 );
343 assert_eq!(idx.entries.len(), 2);
344
345 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
346 assert_eq!(idx.entries.len(), 2);
347
348 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
349 assert!(
350 (b_entry.embedding[0] - 0.1).abs() < 1e-6,
351 "b.rs embedding should be preserved"
352 );
353 }
354
355 #[test]
356 fn get_aligned_embeddings() {
357 let mut idx = EmbeddingIndex::new(2);
358 let chunks = vec![
359 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
360 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
361 ];
362 idx.update(
363 &chunks,
364 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
365 &["a.rs".to_string(), "b.rs".to_string()],
366 );
367
368 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
369 assert_eq!(aligned.len(), 2);
370 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
371 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
372 }
373
374 #[test]
375 fn get_aligned_embeddings_missing() {
376 let idx = EmbeddingIndex::new(384);
377 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
378 assert!(idx.get_aligned_embeddings(&chunks).is_none());
379 }
380
381 #[test]
382 fn coverage_calculation() {
383 let mut idx = EmbeddingIndex::new(384);
384 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
385
386 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
387 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
388 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
389 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
390 }
391
392 #[test]
393 fn save_and_load_roundtrip() {
394 let dir = std::env::temp_dir().join("lean_ctx_embed_idx_test");
395 let _ = std::fs::remove_dir_all(&dir);
396 std::fs::create_dir_all(&dir).unwrap();
397
398 let mut idx = EmbeddingIndex::new(3);
399 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
400 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
401 idx.save(&dir).unwrap();
402
403 let loaded = EmbeddingIndex::load(&dir).unwrap();
404 assert_eq!(loaded.dimensions, 3);
405 assert_eq!(loaded.entries.len(), 1);
406 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
407
408 let _ = std::fs::remove_dir_all(&dir);
409 }
410}