1use super::EmbeddingVector;
7use dashmap::DashMap;
8use std::hash::Hash;
9use std::hash::Hasher;
10use std::path::Path;
11use std::path::PathBuf;
12use std::sync::Arc;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum IndexError {
17 #[error("Index not found for key: {0:?}")]
18 IndexNotFound(IndexKey),
19
20 #[error("Dimension mismatch: expected {expected}, got {actual}")]
21 DimensionMismatch { expected: usize, actual: usize },
22
23 #[error("IO error: {0}")]
24 Io(#[from] std::io::Error),
25
26 #[error("Serialization error: {0}")]
27 Serialization(String),
28}
29
30#[derive(Debug, Clone, Eq, PartialEq, Hash)]
33pub struct IndexKey {
34 pub repo_root: PathBuf,
36 pub model_id: String,
38 pub dimensions: usize,
40}
41
42impl IndexKey {
43 pub fn new(repo: &Path, model: &str, dimensions: usize) -> Result<Self, IndexError> {
44 let repo_root = repo.canonicalize()?;
45 Ok(Self {
46 repo_root,
47 model_id: model.to_string(),
48 dimensions,
49 })
50 }
51
52 pub fn storage_hash(&self) -> String {
54 use std::collections::hash_map::DefaultHasher;
55 let mut hasher = DefaultHasher::new();
56 self.hash(&mut hasher);
57 format!("{:x}", hasher.finish())
58 }
59}
60
61pub struct VectorIndex {
63 dimensions: usize,
64 vectors: Vec<(Vec<f32>, ChunkMetadata)>,
65 }
67
68impl VectorIndex {
69 pub const fn new(dimensions: usize) -> Self {
70 Self {
71 dimensions,
72 vectors: Vec::new(),
73 }
74 }
75
76 pub fn insert(&mut self, vector: Vec<f32>, metadata: ChunkMetadata) -> Result<(), IndexError> {
77 if vector.len() != self.dimensions {
78 return Err(IndexError::DimensionMismatch {
79 expected: self.dimensions,
80 actual: vector.len(),
81 });
82 }
83 self.vectors.push((vector, metadata));
84 Ok(())
85 }
86
87 pub fn search(&self, query: &[f32], limit: usize) -> Result<Vec<SearchResult>, IndexError> {
88 if query.len() != self.dimensions {
89 return Err(IndexError::DimensionMismatch {
90 expected: self.dimensions,
91 actual: query.len(),
92 });
93 }
94
95 let mut results: Vec<_> = self
97 .vectors
98 .iter()
99 .map(|(vec, meta)| {
100 let similarity = cosine_similarity(query, vec);
101 SearchResult {
102 similarity,
103 metadata: meta.clone(),
104 }
105 })
106 .collect();
107
108 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
109 results.truncate(limit);
110
111 Ok(results)
112 }
113
114 pub async fn save_to_disk(&self, path: &Path) -> Result<(), IndexError> {
115 let data = bincode::encode_to_vec(&self.vectors, bincode::config::standard())
116 .map_err(|e| IndexError::Serialization(e.to_string()))?;
117 tokio::fs::write(path, data).await?;
118 Ok(())
119 }
120
121 pub async fn load_from_disk(path: &Path) -> Result<Self, IndexError> {
122 let data = tokio::fs::read(path).await?;
123 let vectors: Vec<(EmbeddingVector, ChunkMetadata)> =
124 bincode::decode_from_slice(&data, bincode::config::standard())
125 .map_err(|e| IndexError::Serialization(e.to_string()))?
126 .0; let dimensions = vectors.first().map(|(v, _)| v.len()).unwrap_or(1536);
130
131 Ok(Self {
132 dimensions,
133 vectors,
134 })
135 }
136}
137
138#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, bincode::Encode, bincode::Decode)]
140pub struct ChunkMetadata {
141 pub file_path: PathBuf,
142 pub start_line: usize,
143 pub end_line: usize,
144 pub content_hash: u64,
145 pub chunk_type: String,
146 pub symbols: Vec<String>,
147}
148
149#[derive(Debug, Clone)]
151pub struct SearchResult {
152 pub similarity: f32,
153 pub metadata: ChunkMetadata,
154}
155
156pub struct EmbeddingIndexManager {
158 indexes: DashMap<IndexKey, Arc<VectorIndex>>,
160 storage_dir: PathBuf,
162}
163
164impl EmbeddingIndexManager {
165 pub fn new(storage_dir: PathBuf) -> Self {
166 Self {
167 indexes: DashMap::new(),
168 storage_dir,
169 }
170 }
171
172 pub fn get_or_create_index(
174 &self,
175 repo: &Path,
176 model: &str,
177 dimensions: usize,
178 ) -> Result<Arc<VectorIndex>, IndexError> {
179 let key = IndexKey::new(repo, model, dimensions)?;
180
181 if let Some(index) = self.indexes.get(&key) {
183 return Ok(index.clone());
184 }
185
186 let storage_path = self.index_storage_path(&key);
188 if storage_path.exists() {
189 let runtime = tokio::runtime::Runtime::new()?;
190 let index = runtime.block_on(VectorIndex::load_from_disk(&storage_path))?;
191 let index = Arc::new(index);
192 self.indexes.insert(key, index.clone());
193 return Ok(index);
194 }
195
196 let index = Arc::new(VectorIndex::new(dimensions));
198 self.indexes.insert(key, index.clone());
199 Ok(index)
200 }
201
202 pub fn search(
204 &self,
205 repo: &Path,
206 model: &str,
207 dimensions: usize,
208 query_vector: &[f32],
209 limit: usize,
210 ) -> Result<Vec<SearchResult>, IndexError> {
211 let index = self.get_or_create_index(repo, model, dimensions)?;
212 index.search(query_vector, limit)
213 }
214
215 fn index_storage_path(&self, key: &IndexKey) -> PathBuf {
217 self.storage_dir
218 .join(key.storage_hash())
219 .join(&key.model_id)
220 .join(format!("dim_{}", key.dimensions))
221 .join("index.bincode")
222 }
223
224 pub async fn save_all(&self) -> Result<(), IndexError> {
226 for entry in self.indexes.iter() {
227 let key = entry.key();
228 let index = entry.value();
229 let path = self.index_storage_path(key);
230
231 if let Some(parent) = path.parent() {
233 tokio::fs::create_dir_all(parent).await?;
234 }
235
236 index.save_to_disk(&path).await?;
237 }
238 Ok(())
239 }
240
241 pub fn stats(&self) -> IndexManagerStats {
243 let mut stats = IndexManagerStats::default();
244
245 for entry in self.indexes.iter() {
246 stats.total_indexes += 1;
247
248 let model = &entry.key().model_id;
250 *stats.indexes_by_model.entry(model.clone()).or_insert(0) += 1;
251
252 let dims = entry.key().dimensions;
254 *stats.indexes_by_dimensions.entry(dims).or_insert(0) += 1;
255 }
256
257 stats
258 }
259}
260
261#[derive(Debug, Default)]
263pub struct IndexManagerStats {
264 pub total_indexes: usize,
265 pub indexes_by_model: std::collections::HashMap<String, usize>,
266 pub indexes_by_dimensions: std::collections::HashMap<usize, usize>,
267}
268
269fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
271 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
272 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
273 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
274
275 if magnitude_a * magnitude_b == 0.0 {
276 0.0
277 } else {
278 dot_product / (magnitude_a * magnitude_b)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use tempfile::tempdir;
286
287 #[test]
288 fn test_index_key_separation() {
289 let temp_dir = tempdir().unwrap();
291 let repo1_path = temp_dir.path().join("repo1");
292 let repo2_path = temp_dir.path().join("repo2");
293 std::fs::create_dir(&repo1_path).unwrap();
294 std::fs::create_dir(&repo2_path).unwrap();
295
296 let key1 = IndexKey::new(&repo1_path, "model1", 1536).unwrap();
297 let key2 = IndexKey::new(&repo1_path, "model2", 1536).unwrap();
298 let key3 = IndexKey::new(&repo2_path, "model1", 1536).unwrap();
299 let key4 = IndexKey::new(&repo1_path, "model1", 768).unwrap();
300
301 assert_ne!(key1, key2); assert_ne!(key1, key3); assert_ne!(key1, key4); }
306
307 #[test]
308 fn test_dimension_mismatch_protection() {
309 let mut index = VectorIndex::new(1536);
310
311 let vec_1536 = vec![0.1; 1536];
313 let metadata = ChunkMetadata {
314 file_path: PathBuf::from("test.rs"),
315 start_line: 1,
316 end_line: 10,
317 content_hash: 12345,
318 chunk_type: "function".to_string(),
319 symbols: vec!["test_fn".to_string()],
320 };
321 assert!(index.insert(vec_1536, metadata.clone()).is_ok());
322
323 let vec_768 = vec![0.1; 768];
325 assert!(matches!(
326 index.insert(vec_768, metadata),
327 Err(IndexError::DimensionMismatch { .. })
328 ));
329 }
330
331 #[test]
332 fn test_index_manager_separation() {
333 let dir = tempdir().unwrap();
334 let manager = EmbeddingIndexManager::new(dir.path().to_path_buf());
335
336 let temp_repos = tempdir().unwrap();
338 let repo1_path = temp_repos.path().join("repo1");
339 let repo2_path = temp_repos.path().join("repo2");
340 std::fs::create_dir(&repo1_path).unwrap();
341 std::fs::create_dir(&repo2_path).unwrap();
342
343 let index1 = manager
345 .get_or_create_index(&repo1_path, "openai:text-embedding-3-small", 1536)
346 .unwrap();
347
348 let index2 = manager
349 .get_or_create_index(&repo1_path, "gemini:embedding-001", 768)
350 .unwrap();
351
352 let index3 = manager
353 .get_or_create_index(&repo2_path, "openai:text-embedding-3-small", 1536)
354 .unwrap();
355
356 assert!(!Arc::ptr_eq(&index1, &index2));
358 assert!(!Arc::ptr_eq(&index1, &index3));
359 assert!(!Arc::ptr_eq(&index2, &index3));
360
361 let stats = manager.stats();
363 assert_eq!(stats.total_indexes, 3);
364 }
365}