1use anyhow::{anyhow, Result};
50use rayon::prelude::*;
51use scirs2_core::ndarray_ext::Array1;
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use tracing::{debug, info};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum DistanceMetric {
59 Cosine,
61 Euclidean,
63 DotProduct,
65 Manhattan,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchConfig {
72 pub metric: DistanceMetric,
74 pub use_approximate: bool,
76 pub hnsw_m: usize,
78 pub hnsw_ef_construction: usize,
80 pub hnsw_ef_search: usize,
82 pub parallel: bool,
84 pub normalize: bool,
86}
87
88impl Default for SearchConfig {
89 fn default() -> Self {
90 Self {
91 metric: DistanceMetric::Cosine,
92 use_approximate: false,
93 hnsw_m: 16,
94 hnsw_ef_construction: 200,
95 hnsw_ef_search: 50,
96 parallel: true,
97 normalize: true,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SearchResult {
105 pub entity_id: String,
107 pub score: f32,
109 pub distance: f32,
111 pub rank: usize,
113}
114
115pub struct VectorSearchIndex {
117 config: SearchConfig,
118 embeddings: HashMap<String, Array1<f32>>,
119 entity_ids: Vec<String>,
120 embedding_matrix: Option<Vec<Vec<f32>>>,
121 dimensions: usize,
122 is_built: bool,
123}
124
125impl VectorSearchIndex {
126 pub fn new(config: SearchConfig) -> Self {
128 info!(
129 "Initialized vector search index: metric={:?}, approximate={}",
130 config.metric, config.use_approximate
131 );
132
133 Self {
134 config,
135 embeddings: HashMap::new(),
136 entity_ids: Vec::new(),
137 embedding_matrix: None,
138 dimensions: 0,
139 is_built: false,
140 }
141 }
142
143 pub fn build(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
145 if embeddings.is_empty() {
146 return Err(anyhow!("Cannot build index from empty embeddings"));
147 }
148
149 info!(
150 "Building vector search index for {} entities",
151 embeddings.len()
152 );
153
154 self.embeddings = embeddings.clone();
156 self.entity_ids = embeddings.keys().cloned().collect();
157 self.dimensions = embeddings
158 .values()
159 .next()
160 .expect("embeddings should not be empty")
161 .len();
162
163 let mut matrix = Vec::new();
165 for entity_id in &self.entity_ids {
166 let mut emb = self.embeddings[entity_id].to_vec();
167
168 if self.config.normalize {
170 self.normalize_vector(&mut emb);
171 }
172
173 matrix.push(emb);
174 }
175 self.embedding_matrix = Some(matrix);
176
177 self.is_built = true;
178
179 info!("Vector search index built successfully");
180 Ok(())
181 }
182
183 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
185 if !self.is_built {
186 return Err(anyhow!("Index not built. Call build() first"));
187 }
188
189 if query.len() != self.dimensions {
190 return Err(anyhow!(
191 "Query dimension {} doesn't match index dimension {}",
192 query.len(),
193 self.dimensions
194 ));
195 }
196
197 let mut normalized_query = query.to_vec();
199 if self.config.normalize {
200 self.normalize_vector(&mut normalized_query);
201 }
202
203 debug!("Searching for {} nearest neighbors", k);
204
205 if self.config.use_approximate && self.embeddings.len() > 1000 {
206 self.approximate_search(&normalized_query, k)
207 } else {
208 self.exact_search(&normalized_query, k)
209 }
210 }
211
212 fn exact_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
214 let matrix = self
215 .embedding_matrix
216 .as_ref()
217 .expect("embedding matrix should be built before search");
218
219 let scores: Vec<(usize, f32)> = if self.config.parallel {
221 (0..self.entity_ids.len())
222 .into_par_iter()
223 .map(|i| {
224 let score = self.compute_similarity(query, &matrix[i]);
225 (i, score)
226 })
227 .collect()
228 } else {
229 (0..self.entity_ids.len())
230 .map(|i| {
231 let score = self.compute_similarity(query, &matrix[i]);
232 (i, score)
233 })
234 .collect()
235 };
236
237 let mut sorted_scores = scores;
239 sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
240
241 let results: Vec<SearchResult> = sorted_scores
243 .iter()
244 .take(k.min(self.entity_ids.len()))
245 .enumerate()
246 .map(|(rank, &(idx, score))| SearchResult {
247 entity_id: self.entity_ids[idx].clone(),
248 score,
249 distance: self.score_to_distance(score),
250 rank: rank + 1,
251 })
252 .collect();
253
254 debug!("Found {} results", results.len());
255 Ok(results)
256 }
257
258 fn approximate_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
260 debug!("Using exact search (HNSW not yet fully implemented)");
263 self.exact_search(query, k)
264 }
265
266 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
268 if !self.is_built {
269 return Err(anyhow!("Index not built. Call build() first"));
270 }
271
272 info!("Batch searching for {} queries", queries.len());
273
274 let results: Vec<Vec<SearchResult>> = if self.config.parallel {
275 queries
276 .par_iter()
277 .map(|query| self.search(query, k).unwrap_or_default())
278 .collect()
279 } else {
280 queries
281 .iter()
282 .map(|query| self.search(query, k).unwrap_or_default())
283 .collect()
284 };
285
286 Ok(results)
287 }
288
289 fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
291 match self.config.metric {
292 DistanceMetric::Cosine => {
293 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
295 }
296 DistanceMetric::Euclidean => {
297 let dist: f32 = a
299 .iter()
300 .zip(b.iter())
301 .map(|(x, y)| (x - y).powi(2))
302 .sum::<f32>()
303 .sqrt();
304 -dist
305 }
306 DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
307 DistanceMetric::Manhattan => {
308 let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
310 -dist
311 }
312 }
313 }
314
315 fn score_to_distance(&self, score: f32) -> f32 {
317 match self.config.metric {
318 DistanceMetric::Cosine => 1.0 - score, DistanceMetric::Euclidean | DistanceMetric::Manhattan => -score, DistanceMetric::DotProduct => -score,
321 }
322 }
323
324 fn normalize_vector(&self, vec: &mut [f32]) {
326 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
327 if norm > 1e-10 {
328 for x in vec.iter_mut() {
329 *x /= norm;
330 }
331 }
332 }
333
334 pub fn get_stats(&self) -> IndexStats {
336 IndexStats {
337 num_entities: self.entity_ids.len(),
338 dimensions: self.dimensions,
339 is_built: self.is_built,
340 metric: self.config.metric,
341 use_approximate: self.config.use_approximate,
342 }
343 }
344
345 pub fn radius_search(&self, query: &[f32], radius: f32) -> Result<Vec<SearchResult>> {
347 if !self.is_built {
348 return Err(anyhow!("Index not built. Call build() first"));
349 }
350
351 let all_results = self.search(query, self.entity_ids.len())?;
352
353 Ok(all_results
354 .into_iter()
355 .filter(|r| r.distance <= radius)
356 .collect())
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct IndexStats {
363 pub num_entities: usize,
365 pub dimensions: usize,
367 pub is_built: bool,
369 pub metric: DistanceMetric,
371 pub use_approximate: bool,
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use scirs2_core::ndarray_ext::array;
379
380 fn create_test_embeddings() -> HashMap<String, Array1<f32>> {
381 let mut embeddings = HashMap::new();
382
383 embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
385 embeddings.insert("entity2".to_string(), array![0.9, 0.1, 0.0]);
386 embeddings.insert("entity3".to_string(), array![0.0, 1.0, 0.0]);
387 embeddings.insert("entity4".to_string(), array![0.0, 0.0, 1.0]);
388 embeddings.insert("entity5".to_string(), array![0.7, 0.7, 0.0]);
389
390 embeddings
391 }
392
393 #[test]
394 fn test_index_creation() {
395 let config = SearchConfig::default();
396 let index = VectorSearchIndex::new(config);
397
398 assert!(!index.is_built);
399 assert_eq!(index.dimensions, 0);
400 }
401
402 #[test]
403 fn test_index_building() {
404 let embeddings = create_test_embeddings();
405 let mut index = VectorSearchIndex::new(SearchConfig::default());
406
407 let result = index.build(&embeddings);
408 assert!(result.is_ok());
409 assert!(index.is_built);
410 assert_eq!(index.dimensions, 3);
411 assert_eq!(index.entity_ids.len(), 5);
412 }
413
414 #[test]
415 fn test_exact_search() {
416 let embeddings = create_test_embeddings();
417 let mut index = VectorSearchIndex::new(SearchConfig::default());
418 index.build(&embeddings).unwrap();
419
420 let query = vec![1.0, 0.0, 0.0];
422 let results = index.search(&query, 3).unwrap();
423
424 assert_eq!(results.len(), 3);
425 assert_eq!(results[0].entity_id, "entity1");
427 assert!(results[0].score > 0.8);
428 }
429
430 #[test]
431 fn test_cosine_similarity() {
432 let config = SearchConfig {
433 metric: DistanceMetric::Cosine,
434 ..Default::default()
435 };
436
437 let embeddings = create_test_embeddings();
438 let mut index = VectorSearchIndex::new(config);
439 index.build(&embeddings).unwrap();
440
441 let query = vec![1.0, 1.0, 0.0];
442 let results = index.search(&query, 2).unwrap();
443
444 assert_eq!(results.len(), 2);
445 assert_eq!(results[0].entity_id, "entity5");
447 }
448
449 #[test]
450 fn test_batch_search() {
451 let embeddings = create_test_embeddings();
452 let mut index = VectorSearchIndex::new(SearchConfig::default());
453 index.build(&embeddings).unwrap();
454
455 let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
456
457 let results = index.batch_search(&queries, 2).unwrap();
458
459 assert_eq!(results.len(), 2);
460 assert_eq!(results[0].len(), 2);
461 assert_eq!(results[1].len(), 2);
462 }
463
464 #[test]
465 fn test_radius_search() {
466 let embeddings = create_test_embeddings();
467 let mut index = VectorSearchIndex::new(SearchConfig::default());
468 index.build(&embeddings).unwrap();
469
470 let query = vec![1.0, 0.0, 0.0];
471 let results = index.radius_search(&query, 0.3).unwrap();
472
473 assert!(!results.is_empty());
475 for result in results {
476 assert!(result.distance <= 0.3);
477 }
478 }
479
480 #[test]
481 fn test_different_metrics() {
482 let embeddings = create_test_embeddings();
483
484 for metric in &[
485 DistanceMetric::Cosine,
486 DistanceMetric::Euclidean,
487 DistanceMetric::DotProduct,
488 DistanceMetric::Manhattan,
489 ] {
490 let config = SearchConfig {
491 metric: *metric,
492 ..Default::default()
493 };
494
495 let mut index = VectorSearchIndex::new(config);
496 index.build(&embeddings).unwrap();
497
498 let query = vec![1.0, 0.0, 0.0];
499 let results = index.search(&query, 3).unwrap();
500
501 assert_eq!(results.len(), 3);
502 }
503 }
504
505 #[test]
506 fn test_index_stats() {
507 let embeddings = create_test_embeddings();
508 let mut index = VectorSearchIndex::new(SearchConfig::default());
509 index.build(&embeddings).unwrap();
510
511 let stats = index.get_stats();
512 assert_eq!(stats.num_entities, 5);
513 assert_eq!(stats.dimensions, 3);
514 assert!(stats.is_built);
515 }
516}