1use anyhow::{anyhow, Result};
50use rayon::prelude::*;
51use scirs2_core::ndarray_ext::Array1;
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use tracing::{debug, info};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum DistanceMetric {
59 Cosine,
61 Euclidean,
63 DotProduct,
65 Manhattan,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchConfig {
72 pub metric: DistanceMetric,
74 pub use_approximate: bool,
76 pub hnsw_m: usize,
78 pub hnsw_ef_construction: usize,
80 pub hnsw_ef_search: usize,
82 pub parallel: bool,
84 pub normalize: bool,
86}
87
88impl Default for SearchConfig {
89 fn default() -> Self {
90 Self {
91 metric: DistanceMetric::Cosine,
92 use_approximate: false,
93 hnsw_m: 16,
94 hnsw_ef_construction: 200,
95 hnsw_ef_search: 50,
96 parallel: true,
97 normalize: true,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SearchResult {
105 pub entity_id: String,
107 pub score: f32,
109 pub distance: f32,
111 pub rank: usize,
113}
114
115pub struct VectorSearchIndex {
117 config: SearchConfig,
118 embeddings: HashMap<String, Array1<f32>>,
119 entity_ids: Vec<String>,
120 embedding_matrix: Option<Vec<Vec<f32>>>,
121 dimensions: usize,
122 is_built: bool,
123}
124
125impl VectorSearchIndex {
126 pub fn new(config: SearchConfig) -> Self {
128 info!(
129 "Initialized vector search index: metric={:?}, approximate={}",
130 config.metric, config.use_approximate
131 );
132
133 Self {
134 config,
135 embeddings: HashMap::new(),
136 entity_ids: Vec::new(),
137 embedding_matrix: None,
138 dimensions: 0,
139 is_built: false,
140 }
141 }
142
143 pub fn build(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
145 if embeddings.is_empty() {
146 return Err(anyhow!("Cannot build index from empty embeddings"));
147 }
148
149 info!(
150 "Building vector search index for {} entities",
151 embeddings.len()
152 );
153
154 self.embeddings = embeddings.clone();
156 self.entity_ids = embeddings.keys().cloned().collect();
157 self.dimensions = embeddings.values().next().unwrap().len();
158
159 let mut matrix = Vec::new();
161 for entity_id in &self.entity_ids {
162 let mut emb = self.embeddings[entity_id].to_vec();
163
164 if self.config.normalize {
166 self.normalize_vector(&mut emb);
167 }
168
169 matrix.push(emb);
170 }
171 self.embedding_matrix = Some(matrix);
172
173 self.is_built = true;
174
175 info!("Vector search index built successfully");
176 Ok(())
177 }
178
179 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
181 if !self.is_built {
182 return Err(anyhow!("Index not built. Call build() first"));
183 }
184
185 if query.len() != self.dimensions {
186 return Err(anyhow!(
187 "Query dimension {} doesn't match index dimension {}",
188 query.len(),
189 self.dimensions
190 ));
191 }
192
193 let mut normalized_query = query.to_vec();
195 if self.config.normalize {
196 self.normalize_vector(&mut normalized_query);
197 }
198
199 debug!("Searching for {} nearest neighbors", k);
200
201 if self.config.use_approximate && self.embeddings.len() > 1000 {
202 self.approximate_search(&normalized_query, k)
203 } else {
204 self.exact_search(&normalized_query, k)
205 }
206 }
207
208 fn exact_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
210 let matrix = self.embedding_matrix.as_ref().unwrap();
211
212 let scores: Vec<(usize, f32)> = if self.config.parallel {
214 (0..self.entity_ids.len())
215 .into_par_iter()
216 .map(|i| {
217 let score = self.compute_similarity(query, &matrix[i]);
218 (i, score)
219 })
220 .collect()
221 } else {
222 (0..self.entity_ids.len())
223 .map(|i| {
224 let score = self.compute_similarity(query, &matrix[i]);
225 (i, score)
226 })
227 .collect()
228 };
229
230 let mut sorted_scores = scores;
232 sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
233
234 let results: Vec<SearchResult> = sorted_scores
236 .iter()
237 .take(k.min(self.entity_ids.len()))
238 .enumerate()
239 .map(|(rank, &(idx, score))| SearchResult {
240 entity_id: self.entity_ids[idx].clone(),
241 score,
242 distance: self.score_to_distance(score),
243 rank: rank + 1,
244 })
245 .collect();
246
247 debug!("Found {} results", results.len());
248 Ok(results)
249 }
250
251 fn approximate_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
253 debug!("Using exact search (HNSW not yet fully implemented)");
256 self.exact_search(query, k)
257 }
258
259 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
261 if !self.is_built {
262 return Err(anyhow!("Index not built. Call build() first"));
263 }
264
265 info!("Batch searching for {} queries", queries.len());
266
267 let results: Vec<Vec<SearchResult>> = if self.config.parallel {
268 queries
269 .par_iter()
270 .map(|query| self.search(query, k).unwrap_or_default())
271 .collect()
272 } else {
273 queries
274 .iter()
275 .map(|query| self.search(query, k).unwrap_or_default())
276 .collect()
277 };
278
279 Ok(results)
280 }
281
282 fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
284 match self.config.metric {
285 DistanceMetric::Cosine => {
286 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
288 }
289 DistanceMetric::Euclidean => {
290 let dist: f32 = a
292 .iter()
293 .zip(b.iter())
294 .map(|(x, y)| (x - y).powi(2))
295 .sum::<f32>()
296 .sqrt();
297 -dist
298 }
299 DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
300 DistanceMetric::Manhattan => {
301 let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
303 -dist
304 }
305 }
306 }
307
308 fn score_to_distance(&self, score: f32) -> f32 {
310 match self.config.metric {
311 DistanceMetric::Cosine => 1.0 - score, DistanceMetric::Euclidean | DistanceMetric::Manhattan => -score, DistanceMetric::DotProduct => -score,
314 }
315 }
316
317 fn normalize_vector(&self, vec: &mut [f32]) {
319 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
320 if norm > 1e-10 {
321 for x in vec.iter_mut() {
322 *x /= norm;
323 }
324 }
325 }
326
327 pub fn get_stats(&self) -> IndexStats {
329 IndexStats {
330 num_entities: self.entity_ids.len(),
331 dimensions: self.dimensions,
332 is_built: self.is_built,
333 metric: self.config.metric,
334 use_approximate: self.config.use_approximate,
335 }
336 }
337
338 pub fn radius_search(&self, query: &[f32], radius: f32) -> Result<Vec<SearchResult>> {
340 if !self.is_built {
341 return Err(anyhow!("Index not built. Call build() first"));
342 }
343
344 let all_results = self.search(query, self.entity_ids.len())?;
345
346 Ok(all_results
347 .into_iter()
348 .filter(|r| r.distance <= radius)
349 .collect())
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct IndexStats {
356 pub num_entities: usize,
358 pub dimensions: usize,
360 pub is_built: bool,
362 pub metric: DistanceMetric,
364 pub use_approximate: bool,
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use scirs2_core::ndarray_ext::array;
372
373 fn create_test_embeddings() -> HashMap<String, Array1<f32>> {
374 let mut embeddings = HashMap::new();
375
376 embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
378 embeddings.insert("entity2".to_string(), array![0.9, 0.1, 0.0]);
379 embeddings.insert("entity3".to_string(), array![0.0, 1.0, 0.0]);
380 embeddings.insert("entity4".to_string(), array![0.0, 0.0, 1.0]);
381 embeddings.insert("entity5".to_string(), array![0.7, 0.7, 0.0]);
382
383 embeddings
384 }
385
386 #[test]
387 fn test_index_creation() {
388 let config = SearchConfig::default();
389 let index = VectorSearchIndex::new(config);
390
391 assert!(!index.is_built);
392 assert_eq!(index.dimensions, 0);
393 }
394
395 #[test]
396 fn test_index_building() {
397 let embeddings = create_test_embeddings();
398 let mut index = VectorSearchIndex::new(SearchConfig::default());
399
400 let result = index.build(&embeddings);
401 assert!(result.is_ok());
402 assert!(index.is_built);
403 assert_eq!(index.dimensions, 3);
404 assert_eq!(index.entity_ids.len(), 5);
405 }
406
407 #[test]
408 fn test_exact_search() {
409 let embeddings = create_test_embeddings();
410 let mut index = VectorSearchIndex::new(SearchConfig::default());
411 index.build(&embeddings).unwrap();
412
413 let query = vec![1.0, 0.0, 0.0];
415 let results = index.search(&query, 3).unwrap();
416
417 assert_eq!(results.len(), 3);
418 assert_eq!(results[0].entity_id, "entity1");
420 assert!(results[0].score > 0.8);
421 }
422
423 #[test]
424 fn test_cosine_similarity() {
425 let config = SearchConfig {
426 metric: DistanceMetric::Cosine,
427 ..Default::default()
428 };
429
430 let embeddings = create_test_embeddings();
431 let mut index = VectorSearchIndex::new(config);
432 index.build(&embeddings).unwrap();
433
434 let query = vec![1.0, 1.0, 0.0];
435 let results = index.search(&query, 2).unwrap();
436
437 assert_eq!(results.len(), 2);
438 assert_eq!(results[0].entity_id, "entity5");
440 }
441
442 #[test]
443 fn test_batch_search() {
444 let embeddings = create_test_embeddings();
445 let mut index = VectorSearchIndex::new(SearchConfig::default());
446 index.build(&embeddings).unwrap();
447
448 let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
449
450 let results = index.batch_search(&queries, 2).unwrap();
451
452 assert_eq!(results.len(), 2);
453 assert_eq!(results[0].len(), 2);
454 assert_eq!(results[1].len(), 2);
455 }
456
457 #[test]
458 fn test_radius_search() {
459 let embeddings = create_test_embeddings();
460 let mut index = VectorSearchIndex::new(SearchConfig::default());
461 index.build(&embeddings).unwrap();
462
463 let query = vec![1.0, 0.0, 0.0];
464 let results = index.radius_search(&query, 0.3).unwrap();
465
466 assert!(!results.is_empty());
468 for result in results {
469 assert!(result.distance <= 0.3);
470 }
471 }
472
473 #[test]
474 fn test_different_metrics() {
475 let embeddings = create_test_embeddings();
476
477 for metric in &[
478 DistanceMetric::Cosine,
479 DistanceMetric::Euclidean,
480 DistanceMetric::DotProduct,
481 DistanceMetric::Manhattan,
482 ] {
483 let config = SearchConfig {
484 metric: *metric,
485 ..Default::default()
486 };
487
488 let mut index = VectorSearchIndex::new(config);
489 index.build(&embeddings).unwrap();
490
491 let query = vec![1.0, 0.0, 0.0];
492 let results = index.search(&query, 3).unwrap();
493
494 assert_eq!(results.len(), 3);
495 }
496 }
497
498 #[test]
499 fn test_index_stats() {
500 let embeddings = create_test_embeddings();
501 let mut index = VectorSearchIndex::new(SearchConfig::default());
502 index.build(&embeddings).unwrap();
503
504 let stats = index.get_stats();
505 assert_eq!(stats.num_entities, 5);
506 assert_eq!(stats.dimensions, 3);
507 assert!(stats.is_built);
508 }
509}