1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 pub entries: Vec<EmbeddingEntry>,
22 pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27 pub file_path: String,
28 pub symbol_name: String,
29 pub start_line: usize,
30 pub end_line: usize,
31 pub embedding: Vec<f32>,
32 pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38 pub fn new(dimensions: usize) -> Self {
39 Self {
40 version: CURRENT_VERSION,
41 dimensions,
42 entries: Vec::new(),
43 file_hashes: HashMap::new(),
44 }
45 }
46
47 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50 }
51
52 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54 let current_hashes = compute_file_hashes(chunks);
55
56 let mut needs_update = Vec::new();
57 for (file, hash) in ¤t_hashes {
58 match self.file_hashes.get(file) {
59 Some(old_hash) if old_hash == hash => {}
60 _ => needs_update.push(file.clone()),
61 }
62 }
63
64 for file in self.file_hashes.keys() {
65 if !current_hashes.contains_key(file) {
66 needs_update.push(file.clone());
67 }
68 }
69
70 needs_update
71 }
72
73 pub fn update(
76 &mut self,
77 chunks: &[CodeChunk],
78 new_embeddings: &[(usize, Vec<f32>)],
79 changed_files: &[String],
80 ) {
81 self.entries
82 .retain(|e| !changed_files.contains(&e.file_path));
83
84 for file in changed_files {
85 self.file_hashes.remove(file);
86 }
87
88 let current_hashes = compute_file_hashes(chunks);
89 for file in changed_files {
90 if let Some(hash) = current_hashes.get(file) {
91 self.file_hashes.insert(file.clone(), hash.clone());
92 }
93 }
94
95 for &(chunk_idx, ref embedding) in new_embeddings {
96 if let Some(chunk) = chunks.get(chunk_idx) {
97 let content_hash = hash_content(&chunk.content);
98 self.entries.push(EmbeddingEntry {
99 file_path: chunk.file_path.clone(),
100 symbol_name: chunk.symbol_name.clone(),
101 start_line: chunk.start_line,
102 end_line: chunk.end_line,
103 embedding: embedding.clone(),
104 content_hash,
105 });
106 }
107 }
108 }
109
110 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
114 HashMap::with_capacity(self.entries.len());
115 for e in &self.entries {
116 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
117 }
118
119 let mut result = Vec::with_capacity(chunks.len());
120 for chunk in chunks {
121 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
122 result.push(entry.embedding.clone());
123 }
124 Some(result)
125 }
126
127 pub fn coverage(&self, total_chunks: usize) -> f64 {
128 if total_chunks == 0 {
129 return 0.0;
130 }
131 self.entries.len() as f64 / total_chunks as f64
132 }
133
134 pub fn save(&self, root: &Path) -> std::io::Result<()> {
135 let dir = index_dir(root);
136 std::fs::create_dir_all(&dir)?;
137 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138 std::fs::write(dir.join("embeddings.json"), data)?;
139 Ok(())
140 }
141
142 pub fn load(root: &Path) -> Option<Self> {
143 let dir = index_dir(root);
144 let path = dir.join("embeddings.json");
145 let data = std::fs::read_to_string(&path)
146 .or_else(|_| {
147 let legacy_dir = legacy_embedding_dir(root);
148 if legacy_dir == dir {
149 return Err(std::io::Error::new(
150 std::io::ErrorKind::NotFound,
151 "same path",
152 ));
153 }
154 let legacy_path = legacy_dir.join("embeddings.json");
155 let content = std::fs::read_to_string(&legacy_path)?;
156 let _ = std::fs::create_dir_all(&dir);
157 let _ = std::fs::copy(&legacy_path, &path);
158 Ok(content)
159 })
160 .ok()?;
161 let idx: Self = serde_json::from_str(&data).ok()?;
162 if idx.version != CURRENT_VERSION {
163 return None;
164 }
165 Some(idx)
166 }
167}
168
169fn index_dir(root: &Path) -> PathBuf {
170 crate::core::index_namespace::vectors_dir(root)
171}
172
173fn legacy_embedding_dir(root: &Path) -> PathBuf {
174 let mut hasher = Md5::new();
175 hasher.update(root.to_string_lossy().as_bytes());
176 let hash = format!("{:x}", hasher.finalize());
177 crate::core::data_dir::lean_ctx_data_dir()
178 .unwrap_or_else(|_| PathBuf::from("."))
179 .join("vectors")
180 .join(hash)
181}
182
183fn hash_content(content: &str) -> String {
184 let mut hasher = Md5::new();
185 hasher.update(content.as_bytes());
186 format!("{:x}", hasher.finalize())
187}
188
189fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
190 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
191 for chunk in chunks {
192 by_file
193 .entry(chunk.file_path.as_str())
194 .or_default()
195 .push(chunk);
196 }
197
198 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
199 for (file, mut file_chunks) in by_file {
200 file_chunks.sort_by(|a, b| {
201 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
202 b.start_line,
203 b.end_line,
204 b.symbol_name.as_str(),
205 ))
206 });
207
208 let mut hasher = Md5::new();
209 hasher.update(file.as_bytes());
210 for c in file_chunks {
211 hasher.update(c.start_line.to_le_bytes());
212 hasher.update(c.end_line.to_le_bytes());
213 hasher.update(c.symbol_name.as_bytes());
214 hasher.update([kind_tag(&c.kind)]);
215 hasher.update(c.content.as_bytes());
216 }
217 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
218 }
219 out
220}
221
222fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
223 use super::bm25_index::ChunkKind;
224 match kind {
225 ChunkKind::Function => 1,
226 ChunkKind::Struct => 2,
227 ChunkKind::Impl => 3,
228 ChunkKind::Module => 4,
229 ChunkKind::Class => 5,
230 ChunkKind::Method => 6,
231 ChunkKind::Other => 7,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::core::bm25_index::{ChunkKind, CodeChunk};
239
240 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
241 CodeChunk {
242 file_path: file.to_string(),
243 symbol_name: name.to_string(),
244 kind: ChunkKind::Function,
245 start_line: start,
246 end_line: end,
247 content: content.to_string(),
248 tokens: vec![name.to_string()],
249 token_count: 1,
250 }
251 }
252
253 fn dummy_embedding(dim: usize) -> Vec<f32> {
254 vec![0.1; dim]
255 }
256
257 #[test]
258 fn new_index_is_empty() {
259 let idx = EmbeddingIndex::new(384);
260 assert!(idx.entries.is_empty());
261 assert!(idx.file_hashes.is_empty());
262 assert_eq!(idx.dimensions, 384);
263 }
264
265 #[test]
266 fn files_needing_update_all_new() {
267 let idx = EmbeddingIndex::new(384);
268 let chunks = vec![
269 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
270 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
271 ];
272 let needs = idx.files_needing_update(&chunks);
273 assert_eq!(needs.len(), 2);
274 }
275
276 #[test]
277 fn files_needing_update_unchanged() {
278 let mut idx = EmbeddingIndex::new(384);
279 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
280
281 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
282
283 let needs = idx.files_needing_update(&chunks);
284 assert!(needs.is_empty(), "unchanged file should not need update");
285 }
286
287 #[test]
288 fn files_needing_update_changed_content() {
289 let mut idx = EmbeddingIndex::new(384);
290 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
291 idx.update(
292 &chunks_v1,
293 &[(0, dummy_embedding(384))],
294 &["a.rs".to_string()],
295 );
296
297 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
298 let needs = idx.files_needing_update(&chunks_v2);
299 assert!(
300 needs.contains(&"a.rs".to_string()),
301 "changed file should need update"
302 );
303 }
304
305 #[test]
306 fn files_needing_update_detects_change_in_later_chunk() {
307 let mut idx = EmbeddingIndex::new(3);
308 let chunks_v1 = vec![
309 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
310 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
311 ];
312 idx.update(
313 &chunks_v1,
314 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
315 &["a.rs".to_string()],
316 );
317
318 let chunks_v2 = vec![
319 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
320 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
321 ];
322 let needs = idx.files_needing_update(&chunks_v2);
323 assert!(
324 needs.contains(&"a.rs".to_string()),
325 "changing a later chunk should trigger re-embedding"
326 );
327 }
328
329 #[test]
330 fn files_needing_update_deleted_file() {
331 let mut idx = EmbeddingIndex::new(384);
332 let chunks = vec![
333 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
334 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
335 ];
336 idx.update(
337 &chunks,
338 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
339 &["a.rs".to_string(), "b.rs".to_string()],
340 );
341
342 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
343 let needs = idx.files_needing_update(&chunks_after);
344 assert!(
345 needs.contains(&"b.rs".to_string()),
346 "deleted file should trigger update"
347 );
348 }
349
350 #[test]
351 fn update_preserves_unchanged() {
352 let mut idx = EmbeddingIndex::new(384);
353 let chunks = vec![
354 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
355 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
356 ];
357 idx.update(
358 &chunks,
359 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
360 &["a.rs".to_string(), "b.rs".to_string()],
361 );
362 assert_eq!(idx.entries.len(), 2);
363
364 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
365 assert_eq!(idx.entries.len(), 2);
366
367 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
368 assert!(
369 (b_entry.embedding[0] - 0.1).abs() < 1e-6,
370 "b.rs embedding should be preserved"
371 );
372 }
373
374 #[test]
375 fn get_aligned_embeddings() {
376 let mut idx = EmbeddingIndex::new(2);
377 let chunks = vec![
378 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
379 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
380 ];
381 idx.update(
382 &chunks,
383 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
384 &["a.rs".to_string(), "b.rs".to_string()],
385 );
386
387 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
388 assert_eq!(aligned.len(), 2);
389 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
390 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
391 }
392
393 #[test]
394 fn get_aligned_embeddings_missing() {
395 let idx = EmbeddingIndex::new(384);
396 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
397 assert!(idx.get_aligned_embeddings(&chunks).is_none());
398 }
399
400 #[test]
401 fn coverage_calculation() {
402 let mut idx = EmbeddingIndex::new(384);
403 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
404
405 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
406 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
407 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
408 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
409 }
410
411 #[test]
412 fn save_and_load_roundtrip() {
413 let _lock = crate::core::data_dir::test_env_lock();
414 let data_dir = tempfile::tempdir().unwrap();
415 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
416
417 let project_dir = tempfile::tempdir().unwrap();
418
419 let mut idx = EmbeddingIndex::new(3);
420 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
421 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
422 idx.save(project_dir.path()).unwrap();
423
424 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
425 assert_eq!(loaded.dimensions, 3);
426 assert_eq!(loaded.entries.len(), 1);
427 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
428
429 std::env::remove_var("LEAN_CTX_DATA_DIR");
430 }
431}