1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use md5::{Digest, Md5};
13use serde::{Deserialize, Serialize};
14
15use super::vector_index::CodeChunk;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingIndex {
19 pub version: u32,
20 pub dimensions: usize,
21 pub entries: Vec<EmbeddingEntry>,
22 pub file_hashes: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct EmbeddingEntry {
27 pub file_path: String,
28 pub symbol_name: String,
29 pub start_line: usize,
30 pub end_line: usize,
31 pub embedding: Vec<f32>,
32 pub content_hash: String,
33}
34
35const CURRENT_VERSION: u32 = 1;
36
37impl EmbeddingIndex {
38 pub fn new(dimensions: usize) -> Self {
39 Self {
40 version: CURRENT_VERSION,
41 dimensions,
42 entries: Vec::new(),
43 file_hashes: HashMap::new(),
44 }
45 }
46
47 pub fn load_or_new(root: &Path, dimensions: usize) -> Self {
49 Self::load(root).unwrap_or_else(|| Self::new(dimensions))
50 }
51
52 pub fn files_needing_update(&self, chunks: &[CodeChunk]) -> Vec<String> {
54 let mut current_hashes: HashMap<String, String> = HashMap::new();
55 for chunk in chunks {
56 current_hashes
57 .entry(chunk.file_path.clone())
58 .or_insert_with(|| hash_content(&chunk.content));
59 }
60
61 let mut needs_update = Vec::new();
62 for (file, hash) in ¤t_hashes {
63 match self.file_hashes.get(file) {
64 Some(old_hash) if old_hash == hash => {}
65 _ => needs_update.push(file.clone()),
66 }
67 }
68
69 for file in self.file_hashes.keys() {
70 if !current_hashes.contains_key(file) {
71 needs_update.push(file.clone());
72 }
73 }
74
75 needs_update
76 }
77
78 pub fn update(
81 &mut self,
82 chunks: &[CodeChunk],
83 new_embeddings: &[(usize, Vec<f32>)],
84 changed_files: &[String],
85 ) {
86 self.entries
87 .retain(|e| !changed_files.contains(&e.file_path));
88
89 for file in changed_files {
90 self.file_hashes.remove(file);
91 }
92
93 for &(chunk_idx, ref embedding) in new_embeddings {
94 if let Some(chunk) = chunks.get(chunk_idx) {
95 let content_hash = hash_content(&chunk.content);
96 self.file_hashes
97 .insert(chunk.file_path.clone(), content_hash.clone());
98 self.entries.push(EmbeddingEntry {
99 file_path: chunk.file_path.clone(),
100 symbol_name: chunk.symbol_name.clone(),
101 start_line: chunk.start_line,
102 end_line: chunk.end_line,
103 embedding: embedding.clone(),
104 content_hash,
105 });
106 }
107 }
108 }
109
110 pub fn get_aligned_embeddings(&self, chunks: &[CodeChunk]) -> Option<Vec<Vec<f32>>> {
113 let mut result = Vec::with_capacity(chunks.len());
114
115 for chunk in chunks {
116 let entry = self.entries.iter().find(|e| {
117 e.file_path == chunk.file_path
118 && e.start_line == chunk.start_line
119 && e.end_line == chunk.end_line
120 })?;
121 result.push(entry.embedding.clone());
122 }
123
124 Some(result)
125 }
126
127 pub fn coverage(&self, total_chunks: usize) -> f64 {
128 if total_chunks == 0 {
129 return 0.0;
130 }
131 self.entries.len() as f64 / total_chunks as f64
132 }
133
134 pub fn save(&self, root: &Path) -> std::io::Result<()> {
135 let dir = index_dir(root);
136 std::fs::create_dir_all(&dir)?;
137 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
138 std::fs::write(dir.join("embeddings.json"), data)?;
139 Ok(())
140 }
141
142 pub fn load(root: &Path) -> Option<Self> {
143 let path = index_dir(root).join("embeddings.json");
144 let data = std::fs::read_to_string(path).ok()?;
145 let idx: Self = serde_json::from_str(&data).ok()?;
146 if idx.version != CURRENT_VERSION {
147 return None;
148 }
149 Some(idx)
150 }
151}
152
153fn index_dir(root: &Path) -> PathBuf {
154 let mut hasher = Md5::new();
155 hasher.update(root.to_string_lossy().as_bytes());
156 let hash = format!("{:x}", hasher.finalize());
157 dirs::home_dir()
158 .unwrap_or_else(|| PathBuf::from("."))
159 .join(".lean-ctx")
160 .join("vectors")
161 .join(hash)
162}
163
164fn hash_content(content: &str) -> String {
165 let mut hasher = Md5::new();
166 hasher.update(content.as_bytes());
167 format!("{:x}", hasher.finalize())
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::core::vector_index::{ChunkKind, CodeChunk};
174
175 fn make_chunk(file: &str, name: &str, content: &str, start: usize, end: usize) -> CodeChunk {
176 CodeChunk {
177 file_path: file.to_string(),
178 symbol_name: name.to_string(),
179 kind: ChunkKind::Function,
180 start_line: start,
181 end_line: end,
182 content: content.to_string(),
183 tokens: vec![name.to_string()],
184 token_count: 1,
185 }
186 }
187
188 fn dummy_embedding(dim: usize) -> Vec<f32> {
189 vec![0.1; dim]
190 }
191
192 #[test]
193 fn new_index_is_empty() {
194 let idx = EmbeddingIndex::new(384);
195 assert!(idx.entries.is_empty());
196 assert!(idx.file_hashes.is_empty());
197 assert_eq!(idx.dimensions, 384);
198 }
199
200 #[test]
201 fn files_needing_update_all_new() {
202 let idx = EmbeddingIndex::new(384);
203 let chunks = vec![
204 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
205 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
206 ];
207 let needs = idx.files_needing_update(&chunks);
208 assert_eq!(needs.len(), 2);
209 }
210
211 #[test]
212 fn files_needing_update_unchanged() {
213 let mut idx = EmbeddingIndex::new(384);
214 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
215
216 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
217
218 let needs = idx.files_needing_update(&chunks);
219 assert!(needs.is_empty(), "unchanged file should not need update");
220 }
221
222 #[test]
223 fn files_needing_update_changed_content() {
224 let mut idx = EmbeddingIndex::new(384);
225 let chunks_v1 = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
226 idx.update(
227 &chunks_v1,
228 &[(0, dummy_embedding(384))],
229 &["a.rs".to_string()],
230 );
231
232 let chunks_v2 = vec![make_chunk("a.rs", "fn_a", "fn a() { modified }", 1, 3)];
233 let needs = idx.files_needing_update(&chunks_v2);
234 assert!(
235 needs.contains(&"a.rs".to_string()),
236 "changed file should need update"
237 );
238 }
239
240 #[test]
241 fn files_needing_update_deleted_file() {
242 let mut idx = EmbeddingIndex::new(384);
243 let chunks = vec![
244 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
245 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
246 ];
247 idx.update(
248 &chunks,
249 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
250 &["a.rs".to_string(), "b.rs".to_string()],
251 );
252
253 let chunks_after = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
254 let needs = idx.files_needing_update(&chunks_after);
255 assert!(
256 needs.contains(&"b.rs".to_string()),
257 "deleted file should trigger update"
258 );
259 }
260
261 #[test]
262 fn update_preserves_unchanged() {
263 let mut idx = EmbeddingIndex::new(384);
264 let chunks = vec![
265 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
266 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
267 ];
268 idx.update(
269 &chunks,
270 &[(0, dummy_embedding(384)), (1, dummy_embedding(384))],
271 &["a.rs".to_string(), "b.rs".to_string()],
272 );
273 assert_eq!(idx.entries.len(), 2);
274
275 idx.update(
276 &chunks,
277 &[(0, vec![0.5; 384])],
278 &["a.rs".to_string()],
279 );
280 assert_eq!(idx.entries.len(), 2);
281
282 let b_entry = idx.entries.iter().find(|e| e.file_path == "b.rs").unwrap();
283 assert!((b_entry.embedding[0] - 0.1).abs() < 1e-6, "b.rs embedding should be preserved");
284 }
285
286 #[test]
287 fn get_aligned_embeddings() {
288 let mut idx = EmbeddingIndex::new(2);
289 let chunks = vec![
290 make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3),
291 make_chunk("b.rs", "fn_b", "fn b() {}", 1, 3),
292 ];
293 idx.update(
294 &chunks,
295 &[(0, vec![1.0, 0.0]), (1, vec![0.0, 1.0])],
296 &["a.rs".to_string(), "b.rs".to_string()],
297 );
298
299 let aligned = idx.get_aligned_embeddings(&chunks).unwrap();
300 assert_eq!(aligned.len(), 2);
301 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
302 assert!((aligned[1][1] - 1.0).abs() < 1e-6);
303 }
304
305 #[test]
306 fn get_aligned_embeddings_missing() {
307 let idx = EmbeddingIndex::new(384);
308 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
309 assert!(idx.get_aligned_embeddings(&chunks).is_none());
310 }
311
312 #[test]
313 fn coverage_calculation() {
314 let mut idx = EmbeddingIndex::new(384);
315 assert!((idx.coverage(10) - 0.0).abs() < 1e-6);
316
317 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
318 idx.update(&chunks, &[(0, dummy_embedding(384))], &["a.rs".to_string()]);
319 assert!((idx.coverage(2) - 0.5).abs() < 1e-6);
320 assert!((idx.coverage(1) - 1.0).abs() < 1e-6);
321 }
322
323 #[test]
324 fn save_and_load_roundtrip() {
325 let dir = std::env::temp_dir().join("lean_ctx_embed_idx_test");
326 let _ = std::fs::remove_dir_all(&dir);
327 std::fs::create_dir_all(&dir).unwrap();
328
329 let mut idx = EmbeddingIndex::new(3);
330 let chunks = vec![make_chunk("a.rs", "fn_a", "fn a() {}", 1, 3)];
331 idx.update(&chunks, &[(0, vec![1.0, 2.0, 3.0])], &["a.rs".to_string()]);
332 idx.save(&dir).unwrap();
333
334 let loaded = EmbeddingIndex::load(&dir).unwrap();
335 assert_eq!(loaded.dimensions, 3);
336 assert_eq!(loaded.entries.len(), 1);
337 assert!((loaded.entries[0].embedding[0] - 1.0).abs() < 1e-6);
338
339 let _ = std::fs::remove_dir_all(&dir);
340 }
341}