1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16use super::embedding_quant::{self, QuantizedVector};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct EmbeddingIndex {
20 pub version: u32,
21 pub dimensions: usize,
22 #[serde(default)]
25 pub model_id: Option<String>,
26 pub entries: Vec<EmbeddingEntry>,
27 pub file_hashes: HashMap<String, String>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct EmbeddingEntry {
32 pub file_path: String,
33 pub symbol_name: String,
34 pub start_line: usize,
35 pub end_line: usize,
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
39 pub embedding: Vec<f32>,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub quant: Option<QuantizedVector>,
43 pub content_hash: String,
44}
45
46impl EmbeddingEntry {
47 fn embedding_f32(&self) -> Vec<f32> {
50 match &self.quant {
51 Some(q) => q.dequantize(),
52 None => self.embedding.clone(),
53 }
54 }
55}
56
57const CURRENT_VERSION: u32 = 3;
59
60impl EmbeddingIndex {
61 pub fn new(dimensions: usize) -> Self {
62 Self {
63 version: CURRENT_VERSION,
64 dimensions,
65 model_id: None,
66 entries: Vec::new(),
67 file_hashes: HashMap::new(),
68 }
69 }
70
71 pub fn new_with_model(dimensions: usize, model_id: &str) -> Self {
73 Self {
74 version: CURRENT_VERSION,
75 dimensions,
76 model_id: Some(model_id.to_string()),
77 entries: Vec::new(),
78 file_hashes: HashMap::new(),
79 }
80 }
81
82 pub fn model_mismatch<'a>(&'a self, current_model: &'a str) -> Option<(&'a str, &'a str)> {
85 match &self.model_id {
86 Some(stored) if stored != current_model => Some((stored, current_model)),
87 _ => None,
88 }
89 }
90
91 pub fn dimension_mismatch(&self, engine_dimensions: usize) -> bool {
93 self.dimensions != engine_dimensions && !self.entries.is_empty()
94 }
95
96 pub fn memory_usage_bytes(&self) -> usize {
98 let entries_size: usize = self
99 .entries
100 .iter()
101 .map(|e| {
102 e.file_path.len()
103 + e.symbol_name.len()
104 + e.content_hash.len()
105 + e.quant
106 .as_ref()
107 .map_or(e.embedding.len() * 4, |q| q.code.len() + 4)
108 + 48
109 })
110 .sum();
111 let hashes_size: usize = self
112 .file_hashes
113 .iter()
114 .map(|(k, v)| k.len() + v.len() + 32)
115 .sum();
116 entries_size + hashes_size
117 }
118
119 pub fn unload(&mut self) {
121 let usage = self.memory_usage_bytes();
122 self.entries = Vec::new();
123 self.file_hashes = HashMap::new();
124 tracing::info!(
125 "[embeddings] unloaded index, freed ~{:.1}MB",
126 usage as f64 / 1_048_576.0
127 );
128 }
129
130 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
132 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
133 }
134
135 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
137 let current_hashes = compute_file_hashes(chunks);
138
139 let mut needs_update = Vec::new();
140 for (file, hash) in ¤t_hashes {
141 match self.file_hashes.get(file) {
142 Some(old_hash) if old_hash == hash => {}
143 _ => needs_update.push(file.clone()),
144 }
145 }
146
147 for file in self.file_hashes.keys() {
148 if !current_hashes.contains_key(file) {
149 needs_update.push(file.clone());
150 }
151 }
152
153 needs_update
154 }
155
156 pub fn update(
159 &mut self,
160 chunks: &[CodeChunk],
161 new_embeddings: &[(usize, Vec<f32>)],
162 changed_files: &[String],
163 ) {
164 self.entries
165 .retain(|e| !changed_files.contains(&e.file_path));
166
167 for file in changed_files {
168 self.file_hashes.remove(file);
169 }
170
171 let current_hashes = compute_file_hashes(chunks);
172 for file in changed_files {
173 if let Some(hash) = current_hashes.get(file) {
174 self.file_hashes.insert(file.clone(), hash.clone());
175 }
176 }
177
178 for &(chunk_idx, ref embedding) in new_embeddings {
179 if let Some(chunk) = chunks.get(chunk_idx) {
180 let content_hash = hash_content(&chunk.content);
181 self.entries.push(EmbeddingEntry {
182 file_path: chunk.file_path.clone(),
183 symbol_name: chunk.symbol_name.clone(),
184 start_line: chunk.start_line,
185 end_line: chunk.end_line,
186 embedding: Vec::new(),
187 quant: Some(embedding_quant::quantize(embedding)),
188 content_hash,
189 });
190 }
191 }
192 }
193
194 fn migrate_legacy_entries(&mut self) -> bool {
197 let mut changed = false;
198 for e in &mut self.entries {
199 if e.quant.is_none() && !e.embedding.is_empty() {
200 e.quant = Some(embedding_quant::quantize(&e.embedding));
201 e.embedding = Vec::new();
202 changed = true;
203 }
204 }
205 changed
206 }
207
208 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
211 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
212 HashMap::with_capacity(self.entries.len());
213 for e in &self.entries {
214 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
215 }
216
217 let mut result = Vec::with_capacity(chunks.len());
218 for chunk in chunks {
219 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
220 result.push(entry.embedding_f32());
221 }
222 Some(result)
223 }
224
225 pub fn coverage(&self, total_chunks: usize) -> f64 {
226 if total_chunks == 0 {
227 return 0.0;
228 }
229 self.entries.len() as f64 / total_chunks as f64
230 }
231
232 pub fn save(&self, root: &Path) -> std::io::Result<()> {
233 let dir = index_dir(root);
234 std::fs::create_dir_all(&dir)?;
235 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
236 std::fs::write(dir.join("embeddings.json"), data)?;
237 Ok(())
238 }
239
240 pub fn load(root: &Path) -> Option<Self> {
241 let dir = index_dir(root);
242 let path = dir.join("embeddings.json");
243 let data = std::fs::read_to_string(&path)
244 .or_else(|_| {
245 let legacy_dir = legacy_embedding_dir(root);
246 if legacy_dir == dir {
247 return Err(std::io::Error::new(
248 std::io::ErrorKind::NotFound,
249 "same path",
250 ));
251 }
252 let legacy_path = legacy_dir.join("embeddings.json");
253 let content = std::fs::read_to_string(&legacy_path)?;
254 let _ = std::fs::create_dir_all(&dir);
255 let _ = std::fs::copy(&legacy_path, &path);
256 Ok(content)
257 })
258 .ok()?;
259 let mut idx: Self = serde_json::from_str(&data).ok()?;
260 match idx.version {
261 CURRENT_VERSION => Some(idx),
262 1 | 2 => {
263 tracing::info!(
264 "[embeddings] migrating index v{} → v{CURRENT_VERSION} (int8 quantization)",
265 idx.version
266 );
267 idx.version = CURRENT_VERSION;
268 let quantized = idx.migrate_legacy_entries();
269 if quantized {
271 let _ = idx.save(root);
272 }
273 Some(idx)
274 }
275 _ => None,
276 }
277 }
278}
279
280fn index_dir(root: &Path) -> PathBuf {
281 crate::core::index_namespace::vectors_dir(root)
282}
283
284fn legacy_embedding_dir(root: &Path) -> PathBuf {
285 let mut hasher = Md5::new();
286 hasher.update(root.to_string_lossy().as_bytes());
287 let hash = format!("{:x}", hasher.finalize());
288 crate::core::data_dir::lean_ctx_data_dir()
289 .unwrap_or_else(|_| PathBuf::from("."))
290 .join("vectors")
291 .join(hash)
292}
293
294fn hash_content(content: &str) -> String {
295 let mut hasher = Md5::new();
296 hasher.update(content.as_bytes());
297 format!("{:x}", hasher.finalize())
298}
299
300fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
301 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
302 for chunk in chunks {
303 by_file
304 .entry(chunk.file_path.as_str())
305 .or_default()
306 .push(chunk);
307 }
308
309 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
310 for (file, mut file_chunks) in by_file {
311 file_chunks.sort_by(|a, b| {
312 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
313 b.start_line,
314 b.end_line,
315 b.symbol_name.as_str(),
316 ))
317 });
318
319 let mut hasher = Md5::new();
320 hasher.update(file.as_bytes());
321 for c in file_chunks {
322 hasher.update(c.start_line.to_le_bytes());
323 hasher.update(c.end_line.to_le_bytes());
324 hasher.update(c.symbol_name.as_bytes());
325 hasher.update([kind_tag(&c.kind)]);
326 hasher.update(c.content.as_bytes());
327 }
328 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
329 }
330 out
331}
332
333fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
334 use super::bm25_index::ChunkKind;
335 match kind {
336 ChunkKind::Function => 1,
337 ChunkKind::Struct => 2,
338 ChunkKind::Impl => 3,
339 ChunkKind::Module => 4,
340 ChunkKind::Class => 5,
341 ChunkKind::Method => 6,
342 ChunkKind::Other => 7,
343 ChunkKind::Issue => 8,
344 ChunkKind::PullRequest => 9,
345 ChunkKind::WikiPage => 10,
346 ChunkKind::DbSchema => 11,
347 ChunkKind::ApiEndpoint => 12,
348 ChunkKind::Ticket => 13,
349 ChunkKind::ExternalOther => 14,
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::core::bm25_index::{ChunkKind, CodeChunk};
357
358 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
359 CodeChunk {
360 file_path: file.to_string(),
361 symbol_name: name.to_string(),
362 kind: ChunkKind::Function,
363 start_line: start,
364 end_line: end,
365 content: content.to_string(),
366 tokens: vec![name.to_string()],
367 token_count: 1,
368 }
369 }
370
371 fn dummy_embedding(dim: usize) -> Vec<f32> {
372 vec![0.1; dim]
373 }
374
375 #[test]
376 fn new_index_is_empty() {
377 let idx = EmbeddingIndex::new(384);
378 assert!(idx.entries.is_empty());
379 assert!(idx.file_hashes.is_empty());
380 assert_eq!(idx.dimensions, 384);
381 }
382
383 #[test]
384 fn files_needing_update_all_new() {
385 let idx = EmbeddingIndex::new(384);
386 let chunks = vec![
387 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
388 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
389 ];
390 let needs = idx.files_needing_update(&chunks);
391 assert_eq!(needs.len(), 2);
392 }
393
394 #[test]
395 fn files_needing_update_unchanged() {
396 let mut idx = EmbeddingIndex::new(384);
397 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
398
399 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
400
401 let needs = idx.files_needing_update(&chunks);
402 assert!(needs.is_empty(), "unchanged file should not need update");
403 }
404
405 #[test]
406 fn files_needing_update_changed_content() {
407 let mut idx = EmbeddingIndex::new(384);
408 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
409 idx.update(
410 &chunks_v1,
411 &[(0, dummy_embedding(384))],
412 &["a.rs".to_string()],
413 );
414
415 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
416 let needs = idx.files_needing_update(&chunks_v2);
417 assert!(
418 needs.contains(&"a.rs".to_string()),
419 "changed file should need update"
420 );
421 }
422
423 #[test]
424 fn files_needing_update_detects_change_in_later_chunk() {
425 let mut idx = EmbeddingIndex::new(3);
426 let chunks_v1 = vec![
427 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
428 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
429 ];
430 idx.update(
431 &chunks_v1,
432 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
433 &["a.rs".to_string()],
434 );
435
436 let chunks_v2 = vec![
437 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
438 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
439 ];
440 let needs = idx.files_needing_update(&chunks_v2);
441 assert!(
442 needs.contains(&"a.rs".to_string()),
443 "changing a later chunk should trigger re-embedding"
444 );
445 }
446
447 #[test]
448 fn files_needing_update_deleted_file() {
449 let mut idx = EmbeddingIndex::new(384);
450 let chunks = vec![
451 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
452 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
453 ];
454 idx.update(
455 &chunks,
456 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
457 &["a.rs".to_string(), "b.rs".to_string()],
458 );
459
460 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
461 let needs = idx.files_needing_update(&chunks_after);
462 assert!(
463 needs.contains(&"b.rs".to_string()),
464 "deleted file should trigger update"
465 );
466 }
467
468 #[test]
469 fn update_preserves_unchanged() {
470 let mut idx = EmbeddingIndex::new(384);
471 let chunks = vec![
472 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
473 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
474 ];
475 idx.update(
476 &chunks,
477 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
478 &["a.rs".to_string(), "b.rs".to_string()],
479 );
480 assert_eq!(idx.entries.len(), 2);
481
482 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
483 assert_eq!(idx.entries.len(), 2);
484
485 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
486 assert!(
487 (b_entry.embedding_f32()[0] - 0.1).abs() < 1e-6,
488 "b.rs embedding should be preserved"
489 );
490 }
491
492 #[test]
493 fn get_aligned_embeddings() {
494 let mut idx = EmbeddingIndex::new(2);
495 let chunks = vec![
496 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
497 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
498 ];
499 idx.update(
500 &chunks,
501 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
502 &["a.rs".to_string(), "b.rs".to_string()],
503 );
504
505 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
506 assert_eq!(aligned.len(), 2);
507 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
508 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
509 }
510
511 #[test]
512 fn get_aligned_embeddings_missing() {
513 let idx = EmbeddingIndex::new(384);
514 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
515 assert!(idx.get_aligned_embeddings(&chunks).is_none());
516 }
517
518 #[test]
519 fn coverage_calculation() {
520 let mut idx = EmbeddingIndex::new(384);
521 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
522
523 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
524 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
525 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
526 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
527 }
528
529 #[test]
530 fn save_and_load_roundtrip() {
531 let _lock = crate::core::data_dir::test_env_lock();
532 let data_dir = tempfile::tempdir().unwrap();
533 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
534
535 let project_dir = tempfile::tempdir().unwrap();
536
537 let mut idx = EmbeddingIndex::new(3);
538 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
539 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
540 idx.save(project_dir.path()).unwrap();
541
542 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
543 assert_eq!(loaded.dimensions, 3);
544 assert_eq!(loaded.entries.len(), 1);
545 let recon = loaded.entries[0].embedding_f32();
547 assert!((recon[0] - 1.0).abs() < 0.02);
548 assert!((recon[1] - 2.0).abs() < 0.02);
549 assert!((recon[2] - 3.0).abs() < 0.02);
550
551 std::env::remove_var("LEAN_CTX_DATA_DIR");
552 }
553
554 #[test]
555 fn new_with_model_sets_model_id() {
556 let idx = EmbeddingIndex::new_with_model(768, "jina-code-v2");
557 assert_eq!(idx.model_id, Some("jina-code-v2".to_string()));
558 assert_eq!(idx.dimensions, 768);
559 }
560
561 #[test]
562 fn model_mismatch_detection() {
563 let idx = EmbeddingIndex::new_with_model(768, "all-MiniLM-L6-v2");
564 assert!(idx.model_mismatch("all-MiniLM-L6-v2").is_none());
565 assert!(idx.model_mismatch("jina-code-v2").is_some());
566
567 let (stored, current) = idx.model_mismatch("jina-code-v2").unwrap();
568 assert_eq!(stored, "all-MiniLM-L6-v2");
569 assert_eq!(current, "jina-code-v2");
570 }
571
572 #[test]
573 fn model_mismatch_none_when_no_model_id() {
574 let idx = EmbeddingIndex::new(384);
575 assert!(idx.model_mismatch("anything").is_none());
576 }
577
578 #[test]
579 fn dimension_mismatch_detection() {
580 let mut idx = EmbeddingIndex::new(384);
581 assert!(!idx.dimension_mismatch(384));
582 assert!(!idx.dimension_mismatch(768)); let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
585 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
586 assert!(!idx.dimension_mismatch(384));
587 assert!(idx.dimension_mismatch(768));
588 }
589
590 #[test]
591 fn v1_index_migration() {
592 let _lock = crate::core::data_dir::test_env_lock();
593 let data_dir = tempfile::tempdir().unwrap();
594 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
595 let project_dir = tempfile::tempdir().unwrap();
596
597 let v1_json = serde_json::json!({
598 "version": 1,
599 "dimensions": 384,
600 "entries": [],
601 "file_hashes": {}
602 });
603
604 let dir = crate::core::index_namespace::vectors_dir(project_dir.path());
605 std::fs::create_dir_all(&dir).unwrap();
606 std::fs::write(dir.join("embeddings.json"), v1_json.to_string()).unwrap();
607
608 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
609 assert_eq!(loaded.version, CURRENT_VERSION);
610 assert_eq!(loaded.dimensions, 384);
611 assert!(loaded.model_id.is_none());
612
613 std::env::remove_var("LEAN_CTX_DATA_DIR");
614 }
615
616 #[test]
617 fn v2_index_quantizes_on_migration() {
618 let _lock = crate::core::data_dir::test_env_lock();
619 let data_dir = tempfile::tempdir().unwrap();
620 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
621 let project_dir = tempfile::tempdir().unwrap();
622
623 let v2_json = serde_json::json!({
625 "version": 2,
626 "dimensions": 3,
627 "model_id": "all-MiniLM-L6-v2",
628 "entries": [{
629 "file_path": "a.rs",
630 "symbol_name": "fn_a",
631 "start_line": 1,
632 "end_line": 3,
633 "embedding": [1.0, 2.0, 3.0],
634 "content_hash": "abc"
635 }],
636 "file_hashes": {}
637 });
638
639 let dir = crate::core::index_namespace::vectors_dir(project_dir.path());
640 std::fs::create_dir_all(&dir).unwrap();
641 std::fs::write(dir.join("embeddings.json"), v2_json.to_string()).unwrap();
642
643 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
644 assert_eq!(loaded.version, CURRENT_VERSION);
645 let entry = &loaded.entries[0];
647 assert!(
648 entry.embedding.is_empty(),
649 "f32 field cleared after migration"
650 );
651 assert!(entry.quant.is_some(), "entry is now quantized");
652 let recon = entry.embedding_f32();
653 assert!((recon[2] - 3.0).abs() < 0.02);
654
655 let reloaded = EmbeddingIndex::load(project_dir.path()).unwrap();
657 assert_eq!(reloaded.version, CURRENT_VERSION);
658 assert!(reloaded.entries[0].quant.is_some());
659
660 std::env::remove_var("LEAN_CTX_DATA_DIR");
661 }
662}