1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 pub entries: Vec<EmbeddingEntry>,
22 pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27 pub file_path: String,
28 pub symbol_name: String,
29 pub start_line: usize,
30 pub end_line: usize,
31 pub embedding: Vec<f32>,
32 pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38 pub fn new(dimensions: usize) -> Self {
39 Self {
40 version: CURRENT_VERSION,
41 dimensions,
42 entries: Vec::new(),
43 file_hashes: HashMap::new(),
44 }
45 }
46
47 pub fn memory_usage_bytes(&self) -> usize {
49 let entries_size: usize = self
50 .entries
51 .iter()
52 .map(|e| {
53 e.file_path.len()
54 + e.symbol_name.len()
55 + e.content_hash.len()
56 + e.embedding.len() * 4
57 + 48
58 })
59 .sum();
60 let hashes_size: usize = self
61 .file_hashes
62 .iter()
63 .map(|(k, v)| k.len() + v.len() + 32)
64 .sum();
65 entries_size + hashes_size
66 }
67
68 pub fn unload(&mut self) {
70 let usage = self.memory_usage_bytes();
71 self.entries = Vec::new();
72 self.file_hashes = HashMap::new();
73 tracing::info!(
74 "[embeddings] unloaded index, freed ~{:.1}MB",
75 usage as f64 / 1_048_576.0
76 );
77 }
78
79 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
81 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
82 }
83
84 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
86 let current_hashes = compute_file_hashes(chunks);
87
88 let mut needs_update = Vec::new();
89 for (file, hash) in ¤t_hashes {
90 match self.file_hashes.get(file) {
91 Some(old_hash) if old_hash == hash => {}
92 _ => needs_update.push(file.clone()),
93 }
94 }
95
96 for file in self.file_hashes.keys() {
97 if !current_hashes.contains_key(file) {
98 needs_update.push(file.clone());
99 }
100 }
101
102 needs_update
103 }
104
105 pub fn update(
108 &mut self,
109 chunks: &[CodeChunk],
110 new_embeddings: &[(usize, Vec<f32>)],
111 changed_files: &[String],
112 ) {
113 self.entries
114 .retain(|e| !changed_files.contains(&e.file_path));
115
116 for file in changed_files {
117 self.file_hashes.remove(file);
118 }
119
120 let current_hashes = compute_file_hashes(chunks);
121 for file in changed_files {
122 if let Some(hash) = current_hashes.get(file) {
123 self.file_hashes.insert(file.clone(), hash.clone());
124 }
125 }
126
127 for &(chunk_idx, ref embedding) in new_embeddings {
128 if let Some(chunk) = chunks.get(chunk_idx) {
129 let content_hash = hash_content(&chunk.content);
130 self.entries.push(EmbeddingEntry {
131 file_path: chunk.file_path.clone(),
132 symbol_name: chunk.symbol_name.clone(),
133 start_line: chunk.start_line,
134 end_line: chunk.end_line,
135 embedding: embedding.clone(),
136 content_hash,
137 });
138 }
139 }
140 }
141
142 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
145 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
146 HashMap::with_capacity(self.entries.len());
147 for e in &self.entries {
148 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
149 }
150
151 let mut result = Vec::with_capacity(chunks.len());
152 for chunk in chunks {
153 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
154 result.push(entry.embedding.clone());
155 }
156 Some(result)
157 }
158
159 pub fn coverage(&self, total_chunks: usize) -> f64 {
160 if total_chunks == 0 {
161 return 0.0;
162 }
163 self.entries.len() as f64 / total_chunks as f64
164 }
165
166 pub fn save(&self, root: &Path) -> std::io::Result<()> {
167 let dir = index_dir(root);
168 std::fs::create_dir_all(&dir)?;
169 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
170 std::fs::write(dir.join("embeddings.json"), data)?;
171 Ok(())
172 }
173
174 pub fn load(root: &Path) -> Option<Self> {
175 let dir = index_dir(root);
176 let path = dir.join("embeddings.json");
177 let data = std::fs::read_to_string(&path)
178 .or_else(|_| {
179 let legacy_dir = legacy_embedding_dir(root);
180 if legacy_dir == dir {
181 return Err(std::io::Error::new(
182 std::io::ErrorKind::NotFound,
183 "same path",
184 ));
185 }
186 let legacy_path = legacy_dir.join("embeddings.json");
187 let content = std::fs::read_to_string(&legacy_path)?;
188 let _ = std::fs::create_dir_all(&dir);
189 let _ = std::fs::copy(&legacy_path, &path);
190 Ok(content)
191 })
192 .ok()?;
193 let idx: Self = serde_json::from_str(&data).ok()?;
194 if idx.version != CURRENT_VERSION {
195 return None;
196 }
197 Some(idx)
198 }
199}
200
201fn index_dir(root: &Path) -> PathBuf {
202 crate::core::index_namespace::vectors_dir(root)
203}
204
205fn legacy_embedding_dir(root: &Path) -> PathBuf {
206 let mut hasher = Md5::new();
207 hasher.update(root.to_string_lossy().as_bytes());
208 let hash = format!("{:x}", hasher.finalize());
209 crate::core::data_dir::lean_ctx_data_dir()
210 .unwrap_or_else(|_| PathBuf::from("."))
211 .join("vectors")
212 .join(hash)
213}
214
215fn hash_content(content: &str) -> String {
216 let mut hasher = Md5::new();
217 hasher.update(content.as_bytes());
218 format!("{:x}", hasher.finalize())
219}
220
221fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
222 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
223 for chunk in chunks {
224 by_file
225 .entry(chunk.file_path.as_str())
226 .or_default()
227 .push(chunk);
228 }
229
230 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
231 for (file, mut file_chunks) in by_file {
232 file_chunks.sort_by(|a, b| {
233 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
234 b.start_line,
235 b.end_line,
236 b.symbol_name.as_str(),
237 ))
238 });
239
240 let mut hasher = Md5::new();
241 hasher.update(file.as_bytes());
242 for c in file_chunks {
243 hasher.update(c.start_line.to_le_bytes());
244 hasher.update(c.end_line.to_le_bytes());
245 hasher.update(c.symbol_name.as_bytes());
246 hasher.update([kind_tag(&c.kind)]);
247 hasher.update(c.content.as_bytes());
248 }
249 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
250 }
251 out
252}
253
254fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
255 use super::bm25_index::ChunkKind;
256 match kind {
257 ChunkKind::Function => 1,
258 ChunkKind::Struct => 2,
259 ChunkKind::Impl => 3,
260 ChunkKind::Module => 4,
261 ChunkKind::Class => 5,
262 ChunkKind::Method => 6,
263 ChunkKind::Other => 7,
264 ChunkKind::Issue => 8,
265 ChunkKind::PullRequest => 9,
266 ChunkKind::WikiPage => 10,
267 ChunkKind::DbSchema => 11,
268 ChunkKind::ApiEndpoint => 12,
269 ChunkKind::Ticket => 13,
270 ChunkKind::ExternalOther => 14,
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::core::bm25_index::{ChunkKind, CodeChunk};
278
279 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
280 CodeChunk {
281 file_path: file.to_string(),
282 symbol_name: name.to_string(),
283 kind: ChunkKind::Function,
284 start_line: start,
285 end_line: end,
286 content: content.to_string(),
287 tokens: vec![name.to_string()],
288 token_count: 1,
289 }
290 }
291
292 fn dummy_embedding(dim: usize) -> Vec<f32> {
293 vec![0.1; dim]
294 }
295
296 #[test]
297 fn new_index_is_empty() {
298 let idx = EmbeddingIndex::new(384);
299 assert!(idx.entries.is_empty());
300 assert!(idx.file_hashes.is_empty());
301 assert_eq!(idx.dimensions, 384);
302 }
303
304 #[test]
305 fn files_needing_update_all_new() {
306 let idx = EmbeddingIndex::new(384);
307 let chunks = vec![
308 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
309 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
310 ];
311 let needs = idx.files_needing_update(&chunks);
312 assert_eq!(needs.len(), 2);
313 }
314
315 #[test]
316 fn files_needing_update_unchanged() {
317 let mut idx = EmbeddingIndex::new(384);
318 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
319
320 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
321
322 let needs = idx.files_needing_update(&chunks);
323 assert!(needs.is_empty(), "unchanged file should not need update");
324 }
325
326 #[test]
327 fn files_needing_update_changed_content() {
328 let mut idx = EmbeddingIndex::new(384);
329 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
330 idx.update(
331 &chunks_v1,
332 &[(0, dummy_embedding(384))],
333 &["a.rs".to_string()],
334 );
335
336 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
337 let needs = idx.files_needing_update(&chunks_v2);
338 assert!(
339 needs.contains(&"a.rs".to_string()),
340 "changed file should need update"
341 );
342 }
343
344 #[test]
345 fn files_needing_update_detects_change_in_later_chunk() {
346 let mut idx = EmbeddingIndex::new(3);
347 let chunks_v1 = vec![
348 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
349 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
350 ];
351 idx.update(
352 &chunks_v1,
353 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
354 &["a.rs".to_string()],
355 );
356
357 let chunks_v2 = vec![
358 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
359 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
360 ];
361 let needs = idx.files_needing_update(&chunks_v2);
362 assert!(
363 needs.contains(&"a.rs".to_string()),
364 "changing a later chunk should trigger re-embedding"
365 );
366 }
367
368 #[test]
369 fn files_needing_update_deleted_file() {
370 let mut idx = EmbeddingIndex::new(384);
371 let chunks = vec![
372 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
373 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
374 ];
375 idx.update(
376 &chunks,
377 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
378 &["a.rs".to_string(), "b.rs".to_string()],
379 );
380
381 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
382 let needs = idx.files_needing_update(&chunks_after);
383 assert!(
384 needs.contains(&"b.rs".to_string()),
385 "deleted file should trigger update"
386 );
387 }
388
389 #[test]
390 fn update_preserves_unchanged() {
391 let mut idx = EmbeddingIndex::new(384);
392 let chunks = vec![
393 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
394 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
395 ];
396 idx.update(
397 &chunks,
398 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
399 &["a.rs".to_string(), "b.rs".to_string()],
400 );
401 assert_eq!(idx.entries.len(), 2);
402
403 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
404 assert_eq!(idx.entries.len(), 2);
405
406 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
407 assert!(
408 (b_entry.embedding[0] - 0.1).abs() < 1e-6,
409 "b.rs embedding should be preserved"
410 );
411 }
412
413 #[test]
414 fn get_aligned_embeddings() {
415 let mut idx = EmbeddingIndex::new(2);
416 let chunks = vec![
417 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
418 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
419 ];
420 idx.update(
421 &chunks,
422 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
423 &["a.rs".to_string(), "b.rs".to_string()],
424 );
425
426 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
427 assert_eq!(aligned.len(), 2);
428 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
429 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
430 }
431
432 #[test]
433 fn get_aligned_embeddings_missing() {
434 let idx = EmbeddingIndex::new(384);
435 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
436 assert!(idx.get_aligned_embeddings(&chunks).is_none());
437 }
438
439 #[test]
440 fn coverage_calculation() {
441 let mut idx = EmbeddingIndex::new(384);
442 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
443
444 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
445 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
446 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
447 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
448 }
449
450 #[test]
451 fn save_and_load_roundtrip() {
452 let _lock = crate::core::data_dir::test_env_lock();
453 let data_dir = tempfile::tempdir().unwrap();
454 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
455
456 let project_dir = tempfile::tempdir().unwrap();
457
458 let mut idx = EmbeddingIndex::new(3);
459 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
460 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
461 idx.save(project_dir.path()).unwrap();
462
463 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
464 assert_eq!(loaded.dimensions, 3);
465 assert_eq!(loaded.entries.len(), 1);
466 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
467
468 std::env::remove_var("LEAN_CTX_DATA_DIR");
469 }
470}