1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 #[serde(default)]
24 pub model_id: Option<String>,
25 pub entries: Vec<EmbeddingEntry>,
26 pub file_hashes: HashMap<String, String>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct EmbeddingEntry {
31 pub file_path: String,
32 pub symbol_name: String,
33 pub start_line: usize,
34 pub end_line: usize,
35 pub embedding: Vec<f32>,
36 pub content_hash: String,
37}
38
39const CURRENT_VERSION: u32 = 2;
40
41impl EmbeddingIndex {
42 pub fn new(dimensions: usize) -> Self {
43 Self {
44 version: CURRENT_VERSION,
45 dimensions,
46 model_id: None,
47 entries: Vec::new(),
48 file_hashes: HashMap::new(),
49 }
50 }
51
52 pub fn new_with_model(dimensions: usize, model_id: &str) -> Self {
54 Self {
55 version: CURRENT_VERSION,
56 dimensions,
57 model_id: Some(model_id.to_string()),
58 entries: Vec::new(),
59 file_hashes: HashMap::new(),
60 }
61 }
62
63 pub fn model_mismatch<'a>(&'a self, current_model: &'a str) -> Option<(&'a str, &'a str)> {
66 match &self.model_id {
67 Some(stored) if stored != current_model => Some((stored, current_model)),
68 _ => None,
69 }
70 }
71
72 pub fn dimension_mismatch(&self, engine_dimensions: usize) -> bool {
74 self.dimensions != engine_dimensions && !self.entries.is_empty()
75 }
76
77 pub fn memory_usage_bytes(&self) -> usize {
79 let entries_size: usize = self
80 .entries
81 .iter()
82 .map(|e| {
83 e.file_path.len()
84 + e.symbol_name.len()
85 + e.content_hash.len()
86 + e.embedding.len() * 4
87 + 48
88 })
89 .sum();
90 let hashes_size: usize = self
91 .file_hashes
92 .iter()
93 .map(|(k, v)| k.len() + v.len() + 32)
94 .sum();
95 entries_size + hashes_size
96 }
97
98 pub fn unload(&mut self) {
100 let usage = self.memory_usage_bytes();
101 self.entries = Vec::new();
102 self.file_hashes = HashMap::new();
103 tracing::info!(
104 "[embeddings] unloaded index, freed ~{:.1}MB",
105 usage as f64 / 1_048_576.0
106 );
107 }
108
109 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
111 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
112 }
113
114 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
116 let current_hashes = compute_file_hashes(chunks);
117
118 let mut needs_update = Vec::new();
119 for (file, hash) in ¤t_hashes {
120 match self.file_hashes.get(file) {
121 Some(old_hash) if old_hash == hash => {}
122 _ => needs_update.push(file.clone()),
123 }
124 }
125
126 for file in self.file_hashes.keys() {
127 if !current_hashes.contains_key(file) {
128 needs_update.push(file.clone());
129 }
130 }
131
132 needs_update
133 }
134
135 pub fn update(
138 &mut self,
139 chunks: &[CodeChunk],
140 new_embeddings: &[(usize, Vec<f32>)],
141 changed_files: &[String],
142 ) {
143 self.entries
144 .retain(|e| !changed_files.contains(&e.file_path));
145
146 for file in changed_files {
147 self.file_hashes.remove(file);
148 }
149
150 let current_hashes = compute_file_hashes(chunks);
151 for file in changed_files {
152 if let Some(hash) = current_hashes.get(file) {
153 self.file_hashes.insert(file.clone(), hash.clone());
154 }
155 }
156
157 for &(chunk_idx, ref embedding) in new_embeddings {
158 if let Some(chunk) = chunks.get(chunk_idx) {
159 let content_hash = hash_content(&chunk.content);
160 self.entries.push(EmbeddingEntry {
161 file_path: chunk.file_path.clone(),
162 symbol_name: chunk.symbol_name.clone(),
163 start_line: chunk.start_line,
164 end_line: chunk.end_line,
165 embedding: embedding.clone(),
166 content_hash,
167 });
168 }
169 }
170 }
171
172 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
175 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
176 HashMap::with_capacity(self.entries.len());
177 for e in &self.entries {
178 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
179 }
180
181 let mut result = Vec::with_capacity(chunks.len());
182 for chunk in chunks {
183 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
184 result.push(entry.embedding.clone());
185 }
186 Some(result)
187 }
188
189 pub fn coverage(&self, total_chunks: usize) -> f64 {
190 if total_chunks == 0 {
191 return 0.0;
192 }
193 self.entries.len() as f64 / total_chunks as f64
194 }
195
196 pub fn save(&self, root: &Path) -> std::io::Result<()> {
197 let dir = index_dir(root);
198 std::fs::create_dir_all(&dir)?;
199 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
200 std::fs::write(dir.join("embeddings.json"), data)?;
201 Ok(())
202 }
203
204 pub fn load(root: &Path) -> Option<Self> {
205 let dir = index_dir(root);
206 let path = dir.join("embeddings.json");
207 let data = std::fs::read_to_string(&path)
208 .or_else(|_| {
209 let legacy_dir = legacy_embedding_dir(root);
210 if legacy_dir == dir {
211 return Err(std::io::Error::new(
212 std::io::ErrorKind::NotFound,
213 "same path",
214 ));
215 }
216 let legacy_path = legacy_dir.join("embeddings.json");
217 let content = std::fs::read_to_string(&legacy_path)?;
218 let _ = std::fs::create_dir_all(&dir);
219 let _ = std::fs::copy(&legacy_path, &path);
220 Ok(content)
221 })
222 .ok()?;
223 let idx: Self = serde_json::from_str(&data).ok()?;
224 match idx.version {
225 CURRENT_VERSION => Some(idx),
226 1 => {
227 tracing::info!(
228 "[embeddings] migrating index v1 → v{CURRENT_VERSION} (adding model_id field)"
229 );
230 Some(Self {
231 version: CURRENT_VERSION,
232 dimensions: idx.dimensions,
233 model_id: None,
234 entries: idx.entries,
235 file_hashes: idx.file_hashes,
236 })
237 }
238 _ => None,
239 }
240 }
241}
242
243fn index_dir(root: &Path) -> PathBuf {
244 crate::core::index_namespace::vectors_dir(root)
245}
246
247fn legacy_embedding_dir(root: &Path) -> PathBuf {
248 let mut hasher = Md5::new();
249 hasher.update(root.to_string_lossy().as_bytes());
250 let hash = format!("{:x}", hasher.finalize());
251 crate::core::data_dir::lean_ctx_data_dir()
252 .unwrap_or_else(|_| PathBuf::from("."))
253 .join("vectors")
254 .join(hash)
255}
256
257fn hash_content(content: &str) -> String {
258 let mut hasher = Md5::new();
259 hasher.update(content.as_bytes());
260 format!("{:x}", hasher.finalize())
261}
262
263fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
264 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
265 for chunk in chunks {
266 by_file
267 .entry(chunk.file_path.as_str())
268 .or_default()
269 .push(chunk);
270 }
271
272 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
273 for (file, mut file_chunks) in by_file {
274 file_chunks.sort_by(|a, b| {
275 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
276 b.start_line,
277 b.end_line,
278 b.symbol_name.as_str(),
279 ))
280 });
281
282 let mut hasher = Md5::new();
283 hasher.update(file.as_bytes());
284 for c in file_chunks {
285 hasher.update(c.start_line.to_le_bytes());
286 hasher.update(c.end_line.to_le_bytes());
287 hasher.update(c.symbol_name.as_bytes());
288 hasher.update([kind_tag(&c.kind)]);
289 hasher.update(c.content.as_bytes());
290 }
291 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
292 }
293 out
294}
295
296fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
297 use super::bm25_index::ChunkKind;
298 match kind {
299 ChunkKind::Function => 1,
300 ChunkKind::Struct => 2,
301 ChunkKind::Impl => 3,
302 ChunkKind::Module => 4,
303 ChunkKind::Class => 5,
304 ChunkKind::Method => 6,
305 ChunkKind::Other => 7,
306 ChunkKind::Issue => 8,
307 ChunkKind::PullRequest => 9,
308 ChunkKind::WikiPage => 10,
309 ChunkKind::DbSchema => 11,
310 ChunkKind::ApiEndpoint => 12,
311 ChunkKind::Ticket => 13,
312 ChunkKind::ExternalOther => 14,
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::core::bm25_index::{ChunkKind, CodeChunk};
320
321 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
322 CodeChunk {
323 file_path: file.to_string(),
324 symbol_name: name.to_string(),
325 kind: ChunkKind::Function,
326 start_line: start,
327 end_line: end,
328 content: content.to_string(),
329 tokens: vec![name.to_string()],
330 token_count: 1,
331 }
332 }
333
334 fn dummy_embedding(dim: usize) -> Vec<f32> {
335 vec![0.1; dim]
336 }
337
338 #[test]
339 fn new_index_is_empty() {
340 let idx = EmbeddingIndex::new(384);
341 assert!(idx.entries.is_empty());
342 assert!(idx.file_hashes.is_empty());
343 assert_eq!(idx.dimensions, 384);
344 }
345
346 #[test]
347 fn files_needing_update_all_new() {
348 let idx = EmbeddingIndex::new(384);
349 let chunks = vec![
350 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
351 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
352 ];
353 let needs = idx.files_needing_update(&chunks);
354 assert_eq!(needs.len(), 2);
355 }
356
357 #[test]
358 fn files_needing_update_unchanged() {
359 let mut idx = EmbeddingIndex::new(384);
360 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
361
362 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
363
364 let needs = idx.files_needing_update(&chunks);
365 assert!(needs.is_empty(), "unchanged file should not need update");
366 }
367
368 #[test]
369 fn files_needing_update_changed_content() {
370 let mut idx = EmbeddingIndex::new(384);
371 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
372 idx.update(
373 &chunks_v1,
374 &[(0, dummy_embedding(384))],
375 &["a.rs".to_string()],
376 );
377
378 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
379 let needs = idx.files_needing_update(&chunks_v2);
380 assert!(
381 needs.contains(&"a.rs".to_string()),
382 "changed file should need update"
383 );
384 }
385
386 #[test]
387 fn files_needing_update_detects_change_in_later_chunk() {
388 let mut idx = EmbeddingIndex::new(3);
389 let chunks_v1 = vec![
390 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
391 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
392 ];
393 idx.update(
394 &chunks_v1,
395 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
396 &["a.rs".to_string()],
397 );
398
399 let chunks_v2 = vec![
400 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
401 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
402 ];
403 let needs = idx.files_needing_update(&chunks_v2);
404 assert!(
405 needs.contains(&"a.rs".to_string()),
406 "changing a later chunk should trigger re-embedding"
407 );
408 }
409
410 #[test]
411 fn files_needing_update_deleted_file() {
412 let mut idx = EmbeddingIndex::new(384);
413 let chunks = vec![
414 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
415 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
416 ];
417 idx.update(
418 &chunks,
419 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
420 &["a.rs".to_string(), "b.rs".to_string()],
421 );
422
423 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
424 let needs = idx.files_needing_update(&chunks_after);
425 assert!(
426 needs.contains(&"b.rs".to_string()),
427 "deleted file should trigger update"
428 );
429 }
430
431 #[test]
432 fn update_preserves_unchanged() {
433 let mut idx = EmbeddingIndex::new(384);
434 let chunks = vec![
435 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
436 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
437 ];
438 idx.update(
439 &chunks,
440 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
441 &["a.rs".to_string(), "b.rs".to_string()],
442 );
443 assert_eq!(idx.entries.len(), 2);
444
445 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
446 assert_eq!(idx.entries.len(), 2);
447
448 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
449 assert!(
450 (b_entry.embedding[0] - 0.1).abs() < 1e-6,
451 "b.rs embedding should be preserved"
452 );
453 }
454
455 #[test]
456 fn get_aligned_embeddings() {
457 let mut idx = EmbeddingIndex::new(2);
458 let chunks = vec![
459 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
460 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
461 ];
462 idx.update(
463 &chunks,
464 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
465 &["a.rs".to_string(), "b.rs".to_string()],
466 );
467
468 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
469 assert_eq!(aligned.len(), 2);
470 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
471 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
472 }
473
474 #[test]
475 fn get_aligned_embeddings_missing() {
476 let idx = EmbeddingIndex::new(384);
477 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
478 assert!(idx.get_aligned_embeddings(&chunks).is_none());
479 }
480
481 #[test]
482 fn coverage_calculation() {
483 let mut idx = EmbeddingIndex::new(384);
484 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
485
486 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
487 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
488 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
489 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
490 }
491
492 #[test]
493 fn save_and_load_roundtrip() {
494 let _lock = crate::core::data_dir::test_env_lock();
495 let data_dir = tempfile::tempdir().unwrap();
496 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
497
498 let project_dir = tempfile::tempdir().unwrap();
499
500 let mut idx = EmbeddingIndex::new(3);
501 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
502 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
503 idx.save(project_dir.path()).unwrap();
504
505 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
506 assert_eq!(loaded.dimensions, 3);
507 assert_eq!(loaded.entries.len(), 1);
508 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
509
510 std::env::remove_var("LEAN_CTX_DATA_DIR");
511 }
512
513 #[test]
514 fn new_with_model_sets_model_id() {
515 let idx = EmbeddingIndex::new_with_model(768, "jina-code-v2");
516 assert_eq!(idx.model_id, Some("jina-code-v2".to_string()));
517 assert_eq!(idx.dimensions, 768);
518 }
519
520 #[test]
521 fn model_mismatch_detection() {
522 let idx = EmbeddingIndex::new_with_model(768, "all-MiniLM-L6-v2");
523 assert!(idx.model_mismatch("all-MiniLM-L6-v2").is_none());
524 assert!(idx.model_mismatch("jina-code-v2").is_some());
525
526 let (stored, current) = idx.model_mismatch("jina-code-v2").unwrap();
527 assert_eq!(stored, "all-MiniLM-L6-v2");
528 assert_eq!(current, "jina-code-v2");
529 }
530
531 #[test]
532 fn model_mismatch_none_when_no_model_id() {
533 let idx = EmbeddingIndex::new(384);
534 assert!(idx.model_mismatch("anything").is_none());
535 }
536
537 #[test]
538 fn dimension_mismatch_detection() {
539 let mut idx = EmbeddingIndex::new(384);
540 assert!(!idx.dimension_mismatch(384));
541 assert!(!idx.dimension_mismatch(768)); let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
544 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
545 assert!(!idx.dimension_mismatch(384));
546 assert!(idx.dimension_mismatch(768));
547 }
548
549 #[test]
550 fn v1_index_migration() {
551 let _lock = crate::core::data_dir::test_env_lock();
552 let data_dir = tempfile::tempdir().unwrap();
553 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
554 let project_dir = tempfile::tempdir().unwrap();
555
556 let v1_json = serde_json::json!({
557 "version": 1,
558 "dimensions": 384,
559 "entries": [],
560 "file_hashes": {}
561 });
562
563 let dir = crate::core::index_namespace::vectors_dir(project_dir.path());
564 std::fs::create_dir_all(&dir).unwrap();
565 std::fs::write(dir.join("embeddings.json"), v1_json.to_string()).unwrap();
566
567 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
568 assert_eq!(loaded.version, CURRENT_VERSION);
569 assert_eq!(loaded.dimensions, 384);
570 assert!(loaded.model_id.is_none());
571
572 std::env::remove_var("LEAN_CTX_DATA_DIR");
573 }
574}