1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::bm25_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 pub entries: Vec<EmbeddingEntry>,
22 pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27 pub file_path: String,
28 pub symbol_name: String,
29 pub start_line: usize,
30 pub end_line: usize,
31 pub embedding: Vec<f32>,
32 pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38 pub fn new(dimensions: usize) -> Self {
39 Self {
40 version: CURRENT_VERSION,
41 dimensions,
42 entries: Vec::new(),
43 file_hashes: HashMap::new(),
44 }
45 }
46
47 pub fn memory_usage_bytes(&self) -> usize {
49 let entries_size: usize = self
50 .entries
51 .iter()
52 .map(|e| {
53 e.file_path.len()
54 + e.symbol_name.len()
55 + e.content_hash.len()
56 + e.embedding.len() * 4
57 + 48
58 })
59 .sum();
60 let hashes_size: usize = self
61 .file_hashes
62 .iter()
63 .map(|(k, v)| k.len() + v.len() + 32)
64 .sum();
65 entries_size + hashes_size
66 }
67
68 pub fn unload(&mut self) {
70 let usage = self.memory_usage_bytes();
71 self.entries = Vec::new();
72 self.file_hashes = HashMap::new();
73 tracing::info!(
74 "[embeddings] unloaded index, freed ~{:.1}MB",
75 usage as f64 / 1_048_576.0
76 );
77 }
78
79 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
81 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
82 }
83
84 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
86 let current_hashes = compute_file_hashes(chunks);
87
88 let mut needs_update = Vec::new();
89 for (file, hash) in ¤t_hashes {
90 match self.file_hashes.get(file) {
91 Some(old_hash) if old_hash == hash => {}
92 _ => needs_update.push(file.clone()),
93 }
94 }
95
96 for file in self.file_hashes.keys() {
97 if !current_hashes.contains_key(file) {
98 needs_update.push(file.clone());
99 }
100 }
101
102 needs_update
103 }
104
105 pub fn update(
108 &mut self,
109 chunks: &[CodeChunk],
110 new_embeddings: &[(usize, Vec<f32>)],
111 changed_files: &[String],
112 ) {
113 self.entries
114 .retain(|e| !changed_files.contains(&e.file_path));
115
116 for file in changed_files {
117 self.file_hashes.remove(file);
118 }
119
120 let current_hashes = compute_file_hashes(chunks);
121 for file in changed_files {
122 if let Some(hash) = current_hashes.get(file) {
123 self.file_hashes.insert(file.clone(), hash.clone());
124 }
125 }
126
127 for &(chunk_idx, ref embedding) in new_embeddings {
128 if let Some(chunk) = chunks.get(chunk_idx) {
129 let content_hash = hash_content(&chunk.content);
130 self.entries.push(EmbeddingEntry {
131 file_path: chunk.file_path.clone(),
132 symbol_name: chunk.symbol_name.clone(),
133 start_line: chunk.start_line,
134 end_line: chunk.end_line,
135 embedding: embedding.clone(),
136 content_hash,
137 });
138 }
139 }
140 }
141
142 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
145 let mut map: HashMap<(&str, usize, usize), &EmbeddingEntry> =
146 HashMap::with_capacity(self.entries.len());
147 for e in &self.entries {
148 map.insert((e.file_path.as_str(), e.start_line, e.end_line), e);
149 }
150
151 let mut result = Vec::with_capacity(chunks.len());
152 for chunk in chunks {
153 let entry = map.get(&(chunk.file_path.as_str(), chunk.start_line, chunk.end_line))?;
154 result.push(entry.embedding.clone());
155 }
156 Some(result)
157 }
158
159 pub fn coverage(&self, total_chunks: usize) -> f64 {
160 if total_chunks == 0 {
161 return 0.0;
162 }
163 self.entries.len() as f64 / total_chunks as f64
164 }
165
166 pub fn save(&self, root: &Path) -> std::io::Result<()> {
167 let dir = index_dir(root);
168 std::fs::create_dir_all(&dir)?;
169 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
170 std::fs::write(dir.join("embeddings.json"), data)?;
171 Ok(())
172 }
173
174 pub fn load(root: &Path) -> Option<Self> {
175 let dir = index_dir(root);
176 let path = dir.join("embeddings.json");
177 let data = std::fs::read_to_string(&path)
178 .or_else(|_| {
179 let legacy_dir = legacy_embedding_dir(root);
180 if legacy_dir == dir {
181 return Err(std::io::Error::new(
182 std::io::ErrorKind::NotFound,
183 "same path",
184 ));
185 }
186 let legacy_path = legacy_dir.join("embeddings.json");
187 let content = std::fs::read_to_string(&legacy_path)?;
188 let _ = std::fs::create_dir_all(&dir);
189 let _ = std::fs::copy(&legacy_path, &path);
190 Ok(content)
191 })
192 .ok()?;
193 let idx: Self = serde_json::from_str(&data).ok()?;
194 if idx.version != CURRENT_VERSION {
195 return None;
196 }
197 Some(idx)
198 }
199}
200
201fn index_dir(root: &Path) -> PathBuf {
202 crate::core::index_namespace::vectors_dir(root)
203}
204
205fn legacy_embedding_dir(root: &Path) -> PathBuf {
206 let mut hasher = Md5::new();
207 hasher.update(root.to_string_lossy().as_bytes());
208 let hash = format!("{:x}", hasher.finalize());
209 crate::core::data_dir::lean_ctx_data_dir()
210 .unwrap_or_else(|_| PathBuf::from("."))
211 .join("vectors")
212 .join(hash)
213}
214
215fn hash_content(content: &str) -> String {
216 let mut hasher = Md5::new();
217 hasher.update(content.as_bytes());
218 format!("{:x}", hasher.finalize())
219}
220
221fn compute_file_hashes(chunks: &[CodeChunk]) -> HashMap<String, String> {
222 let mut by_file: HashMap<&str, Vec<&CodeChunk>> = HashMap::new();
223 for chunk in chunks {
224 by_file
225 .entry(chunk.file_path.as_str())
226 .or_default()
227 .push(chunk);
228 }
229
230 let mut out: HashMap<String, String> = HashMap::with_capacity(by_file.len());
231 for (file, mut file_chunks) in by_file {
232 file_chunks.sort_by(|a, b| {
233 (a.start_line, a.end_line, a.symbol_name.as_str()).cmp(&(
234 b.start_line,
235 b.end_line,
236 b.symbol_name.as_str(),
237 ))
238 });
239
240 let mut hasher = Md5::new();
241 hasher.update(file.as_bytes());
242 for c in file_chunks {
243 hasher.update(c.start_line.to_le_bytes());
244 hasher.update(c.end_line.to_le_bytes());
245 hasher.update(c.symbol_name.as_bytes());
246 hasher.update([kind_tag(&c.kind)]);
247 hasher.update(c.content.as_bytes());
248 }
249 out.insert(file.to_string(), format!("{:x}", hasher.finalize()));
250 }
251 out
252}
253
254fn kind_tag(kind: &super::bm25_index::ChunkKind) -> u8 {
255 use super::bm25_index::ChunkKind;
256 match kind {
257 ChunkKind::Function => 1,
258 ChunkKind::Struct => 2,
259 ChunkKind::Impl => 3,
260 ChunkKind::Module => 4,
261 ChunkKind::Class => 5,
262 ChunkKind::Method => 6,
263 ChunkKind::Other => 7,
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::core::bm25_index::{ChunkKind, CodeChunk};
271
272 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
273 CodeChunk {
274 file_path: file.to_string(),
275 symbol_name: name.to_string(),
276 kind: ChunkKind::Function,
277 start_line: start,
278 end_line: end,
279 content: content.to_string(),
280 tokens: vec![name.to_string()],
281 token_count: 1,
282 }
283 }
284
285 fn dummy_embedding(dim: usize) -> Vec<f32> {
286 vec![0.1; dim]
287 }
288
289 #[test]
290 fn new_index_is_empty() {
291 let idx = EmbeddingIndex::new(384);
292 assert!(idx.entries.is_empty());
293 assert!(idx.file_hashes.is_empty());
294 assert_eq!(idx.dimensions, 384);
295 }
296
297 #[test]
298 fn files_needing_update_all_new() {
299 let idx = EmbeddingIndex::new(384);
300 let chunks = vec![
301 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
302 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
303 ];
304 let needs = idx.files_needing_update(&chunks);
305 assert_eq!(needs.len(), 2);
306 }
307
308 #[test]
309 fn files_needing_update_unchanged() {
310 let mut idx = EmbeddingIndex::new(384);
311 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
312
313 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
314
315 let needs = idx.files_needing_update(&chunks);
316 assert!(needs.is_empty(), "unchanged file should not need update");
317 }
318
319 #[test]
320 fn files_needing_update_changed_content() {
321 let mut idx = EmbeddingIndex::new(384);
322 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
323 idx.update(
324 &chunks_v1,
325 &[(0, dummy_embedding(384))],
326 &["a.rs".to_string()],
327 );
328
329 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
330 let needs = idx.files_needing_update(&chunks_v2);
331 assert!(
332 needs.contains(&"a.rs".to_string()),
333 "changed file should need update"
334 );
335 }
336
337 #[test]
338 fn files_needing_update_detects_change_in_later_chunk() {
339 let mut idx = EmbeddingIndex::new(3);
340 let chunks_v1 = vec![
341 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
342 make_chunk("a.rs", "fn_b", "fn b() {}", 10, 12),
343 ];
344 idx.update(
345 &chunks_v1,
346 &[(0, vec![0.1, 0.1, 0.1]), (1, vec![0.2, 0.2, 0.2])],
347 &["a.rs".to_string()],
348 );
349
350 let chunks_v2 = vec![
351 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
352 make_chunk("a.rs", "fn_b", "fn b() { changed }", 10, 12),
353 ];
354 let needs = idx.files_needing_update(&chunks_v2);
355 assert!(
356 needs.contains(&"a.rs".to_string()),
357 "changing a later chunk should trigger re-embedding"
358 );
359 }
360
361 #[test]
362 fn files_needing_update_deleted_file() {
363 let mut idx = EmbeddingIndex::new(384);
364 let chunks = vec![
365 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
366 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
367 ];
368 idx.update(
369 &chunks,
370 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
371 &["a.rs".to_string(), "b.rs".to_string()],
372 );
373
374 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
375 let needs = idx.files_needing_update(&chunks_after);
376 assert!(
377 needs.contains(&"b.rs".to_string()),
378 "deleted file should trigger update"
379 );
380 }
381
382 #[test]
383 fn update_preserves_unchanged() {
384 let mut idx = EmbeddingIndex::new(384);
385 let chunks = vec![
386 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
387 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
388 ];
389 idx.update(
390 &chunks,
391 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
392 &["a.rs".to_string(), "b.rs".to_string()],
393 );
394 assert_eq!(idx.entries.len(), 2);
395
396 idx.update(&chunks, &[(0, vec![0.5; 384])], &["a.rs".to_string()]);
397 assert_eq!(idx.entries.len(), 2);
398
399 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
400 assert!(
401 (b_entry.embedding[0] - 0.1).abs() < 1e-6,
402 "b.rs embedding should be preserved"
403 );
404 }
405
406 #[test]
407 fn get_aligned_embeddings() {
408 let mut idx = EmbeddingIndex::new(2);
409 let chunks = vec![
410 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
411 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
412 ];
413 idx.update(
414 &chunks,
415 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
416 &["a.rs".to_string(), "b.rs".to_string()],
417 );
418
419 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
420 assert_eq!(aligned.len(), 2);
421 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
422 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
423 }
424
425 #[test]
426 fn get_aligned_embeddings_missing() {
427 let idx = EmbeddingIndex::new(384);
428 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
429 assert!(idx.get_aligned_embeddings(&chunks).is_none());
430 }
431
432 #[test]
433 fn coverage_calculation() {
434 let mut idx = EmbeddingIndex::new(384);
435 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
436
437 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
438 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
439 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
440 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
441 }
442
443 #[test]
444 fn save_and_load_roundtrip() {
445 let _lock = crate::core::data_dir::test_env_lock();
446 let data_dir = tempfile::tempdir().unwrap();
447 std::env::set_var("LEAN_CTX_DATA_DIR", data_dir.path());
448
449 let project_dir = tempfile::tempdir().unwrap();
450
451 let mut idx = EmbeddingIndex::new(3);
452 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
453 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
454 idx.save(project_dir.path()).unwrap();
455
456 let loaded = EmbeddingIndex::load(project_dir.path()).unwrap();
457 assert_eq!(loaded.dimensions, 3);
458 assert_eq!(loaded.entries.len(), 1);
459 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
460
461 std::env::remove_var("LEAN_CTX_DATA_DIR");
462 }
463}