1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::vector_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 pub entries: Vec<EmbeddingEntry>,
22 pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27 pub file_path: String,
28 pub symbol_name: String,
29 pub start_line: usize,
30 pub end_line: usize,
31 pub embedding: Vec<f32>,
32 pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38 pub fn new(dimensions: usize) -> Self {
39 Self {
40 version: CURRENT_VERSION,
41 dimensions,
42 entries: Vec::new(),
43 file_hashes: HashMap::new(),
44 }
45 }
46
47 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50 }
51
52 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54 let current_hashes = compute_file_hashes(chunks);
55
56 let mut needs_update = Vec::new();
57 for (file, hash) in ¤t_hashes {
58 match self.file_hashes.get(file) {
59 Some(old_hash) if old_hash == hash => {}
60 _ => needs_update.push(file.clone()),
61 }
62 }
63
64 for file in self.file_hashes.keys() {
65 if !current_hashes.contains_key(file) {
66 needs_update.push(file.clone());
67 }
68 }
69
70 needs_update
71 }
72
73 pub fn update(
76 &mut self,
77 chunks: &[CodeChunk],
78 new_embeddings: &[(usize, Vec<f32>)],
79 changed_files: &[String],
80 ) {
81 self.entries
82 .retain(|e| !changed_files.contains(&e.file_path));
83
84 for file in changed_files {
85 self.file_hashes.remove(file);
86 }
87
88 let current_hashes = compute_file_hashes(chunks);
89 for file in changed_files {
90 if let Some(hash) = current_hashes.get(file) {
91 self.file_hashes.insert(file.clone(), hash.clone());
92 }
93 }
94
95 for &(chunk_idx, ref embedding) in new_embeddings {
96 if let Some(chunk) = chunks.get(chunk_idx) {
97 let content_hash = hash_content(&chunk.content);
98 self.entries.push(EmbeddingEntry {
99 file_path: chunk.file_path.clone(),
100 symbol_name: chunk.symbol_name.clone(),
101 start_line: chunk.start_line,
102 end_line: chunk.end_line,
103 embedding: embedding.clone(),
104 content_hash,
105 });
106 }
107 }
108 }
109
110 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
114 HashMap::with_capacity(self.entries.len());
115 for e in &self.entries {
116 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
117 }
118
119 let mut result = Vec::with_capacity(chunks.len());
120 for chunk in chunks {
121 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
122 result.push(entry.embedding.clone());
123 }
124 Some(result)
125 }
126
127 pub fn coverage(&self, total_chunks: usize) -> f64 {
128 if total_chunks == 0 {
129 return 0.0;
130 }
131 self.entries.len() as f64 / total_chunks as f64
132 }
133
134 pub fn save(&self, root: &Path) -> std::io::Result<()> {
135 let dir = index_dir(root);
136 std::fs::create_dir_all(&dir)?;
137 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138 std::fs::write(dir.join("embeddings.json"), data)?;
139 Ok(())
140 }
141
142 pub fn load(root: &Path) -> Option<Self> {
143 let path = index_dir(root).join("embeddings.json");
144 let data = std::fs::read_to_string(path).ok()?;
145 let idx: Self = serde_json::from_str(&data).ok()?;
146 if idx.version != CURRENT_VERSION {
147 return None;
148 }
149 Some(idx)
150 }
151}
152
153fn index_dir(root: &Path) -> PathBuf {
154 let mut hasher = Md5::new();
155 hasher.update(root.to_string_lossy().as_bytes());
156 let hash = format!("{:x}", hasher.finalize());
157 crate::core::data_dir::lean_ctx_data_dir()
158 .unwrap_or_else(|_| PathBuf::from("."))
159 .join("vectors")
160 .join(hash)
161}
162
163fn hash_content(content: &str) -> String {
164 let mut hasher = Md5::new();
165 hasher.update(content.as_bytes());
166 format!("{:x}", hasher.finalize())
167}
168
169fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
170 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
171 for chunk in chunks {
172 by_file
173 .entry(chunk.file_path.as_str())
174 .or_default()
175 .push(chunk);
176 }
177
178 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
179 for (file, mut file_chunks) in by_file {
180 file_chunks.sort_by(|a, b| {
181 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
182 b.start_line,
183 b.end_line,
184 b.symbol_name.as_str(),
185 ))
186 });
187
188 let mut hasher = Md5::new();
189 hasher.update(file.as_bytes());
190 for c in file_chunks {
191 hasher.update(c.start_line.to_le_bytes());
192 hasher.update(c.end_line.to_le_bytes());
193 hasher.update(c.symbol_name.as_bytes());
194 hasher.update([kind_tag(&c.kind)]);
195 hasher.update(c.content.as_bytes());
196 }
197 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
198 }
199 out
200}
201
202fn kind_tag(kind: &super::vector_index::ChunkKind) -> u8 {
203 use super::vector_index::ChunkKind;
204 match kind {
205 ChunkKind::Function => 1,
206 ChunkKind::Struct => 2,
207 ChunkKind::Impl => 3,
208 ChunkKind::Module => 4,
209 ChunkKind::Class => 5,
210 ChunkKind::Method => 6,
211 ChunkKind::Other => 7,
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::core::vector_index::{ChunkKind, CodeChunk};
219
220 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
221 CodeChunk {
222 file_path: file.to_string(),
223 symbol_name: name.to_string(),
224 kind: ChunkKind::Function,
225 start_line: start,
226 end_line: end,
227 content: content.to_string(),
228 tokens: vec![name.to_string()],
229 token_count: 1,
230 }
231 }
232
233 fn dummy_embedding(dim: usize) -> Vec<f32> {
234 vec![0.1; dim]
235 }
236
237 #[test]
238 fn new_index_is_empty() {
239 let idx = EmbeddingIndex::new(384);
240 assert!(idx.entries.is_empty());
241 assert!(idx.file_hashes.is_empty());
242 assert_eq!(idx.dimensions, 384);
243 }
244
245 #[test]
246 fn files_needing_update_all_new() {
247 let idx = EmbeddingIndex::new(384);
248 let chunks = vec![
249 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
250 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
251 ];
252 let needs = idx.files_needing_update(&chunks);
253 assert_eq!(needs.len(), 2);
254 }
255
256 #[test]
257 fn files_needing_update_unchanged() {
258 let mut idx = EmbeddingIndex::new(384);
259 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
260
261 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
262
263 let needs = idx.files_needing_update(&chunks);
264 assert!(needs.is_empty(), "unchanged file should not need update");
265 }
266
267 #[test]
268 fn files_needing_update_changed_content() {
269 let mut idx = EmbeddingIndex::new(384);
270 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
271 idx.update(
272 &chunks_v1,
273 &[(0, dummy_embedding(384))],
274 &["a.rs".to_string()],
275 );
276
277 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
278 let needs = idx.files_needing_update(&chunks_v2);
279 assert!(
280 needs.contains(&"a.rs".to_string()),
281 "changed file should need update"
282 );
283 }
284
285 #[test]
286 fn files_needing_update_detects_change_in_later_chunk() {
287 let mut idx = EmbeddingIndex::new(3);
288 let chunks_v1 = vec![
289 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
290 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
291 ];
292 idx.update(
293 &chunks_v1,
294 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
295 &["a.rs".to_string()],
296 );
297
298 let chunks_v2 = vec![
299 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
300 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
301 ];
302 let needs = idx.files_needing_update(&chunks_v2);
303 assert!(
304 needs.contains(&"a.rs".to_string()),
305 "changing a later chunk should trigger re-embedding"
306 );
307 }
308
309 #[test]
310 fn files_needing_update_deleted_file() {
311 let mut idx = EmbeddingIndex::new(384);
312 let chunks = vec![
313 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
314 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
315 ];
316 idx.update(
317 &chunks,
318 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
319 &["a.rs".to_string(), "b.rs".to_string()],
320 );
321
322 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
323 let needs = idx.files_needing_update(&chunks_after);
324 assert!(
325 needs.contains(&"b.rs".to_string()),
326 "deleted file should trigger update"
327 );
328 }
329
330 #[test]
331 fn update_preserves_unchanged() {
332 let mut idx = EmbeddingIndex::new(384);
333 let chunks = vec![
334 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
335 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
336 ];
337 idx.update(
338 &chunks,
339 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
340 &["a.rs".to_string(), "b.rs".to_string()],
341 );
342 assert_eq!(idx.entries.len(), 2);
343
344 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
345 assert_eq!(idx.entries.len(), 2);
346
347 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
348 assert!(
349 (b_entry.embedding[0] - 0.1).abs() < 1e-6,
350 "b.rs embedding should be preserved"
351 );
352 }
353
354 #[test]
355 fn get_aligned_embeddings() {
356 let mut idx = EmbeddingIndex::new(2);
357 let chunks = vec![
358 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
359 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
360 ];
361 idx.update(
362 &chunks,
363 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
364 &["a.rs".to_string(), "b.rs".to_string()],
365 );
366
367 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
368 assert_eq!(aligned.len(), 2);
369 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
370 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
371 }
372
373 #[test]
374 fn get_aligned_embeddings_missing() {
375 let idx = EmbeddingIndex::new(384);
376 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
377 assert!(idx.get_aligned_embeddings(&chunks).is_none());
378 }
379
380 #[test]
381 fn coverage_calculation() {
382 let mut idx = EmbeddingIndex::new(384);
383 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
384
385 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
386 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
387 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
388 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
389 }
390
391 #[test]
392 fn save_and_load_roundtrip() {
393 let _lock = crate::core::data_dir::test_env_lock();
394 let data_dir = tempfile::tempdir().unwrap();
395 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
396
397 let project_dir = tempfile::tempdir().unwrap();
398
399 let mut idx = EmbeddingIndex::new(3);
400 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
401 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
402 idx.save(project_dir.path()).unwrap();
403
404 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
405 assert_eq!(loaded.dimensions, 3);
406 assert_eq!(loaded.entries.len(), 1);
407 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
408
409 std::env::remove_var("LEAN_CTX_DATA_DIR");
410 }
411}