1use anyhow::{anyhow, Result};
14use scirs2_core::ndarray_ext::Array1;
15use scirs2_core::parallel_ops::{IntoParallelRefIterator, ParallelIterator};
16use scirs2_core::simd::simd_dot_f32;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::Arc;
19use std::time::Instant;
20
21fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
29 if a.len() != b.len() || a.is_empty() {
30 return f32::INFINITY;
31 }
32 let a_arr = Array1::from_vec(a.to_vec());
33 let b_arr = Array1::from_vec(b.to_vec());
34
35 let dot = simd_dot_f32(&a_arr.view(), &b_arr.view());
36 let norm_a = simd_dot_f32(&a_arr.view(), &a_arr.view()).sqrt();
37 let norm_b = simd_dot_f32(&b_arr.view(), &b_arr.view()).sqrt();
38
39 if norm_a == 0.0 || norm_b == 0.0 {
40 f32::INFINITY
41 } else {
42 let sim = dot / (norm_a * norm_b);
43 1.0 - sim.clamp(-1.0, 1.0)
44 }
45}
46
47#[derive(Debug, Clone)]
53struct IndexEntry {
54 id: String,
55 data: Vec<f32>,
56}
57
58#[derive(Debug)]
64pub struct SimdVectorSearch {
65 entries: Vec<IndexEntry>,
66 parallel_threshold: usize,
68}
69
70impl SimdVectorSearch {
71 pub fn new(parallel_threshold: usize) -> Self {
73 Self {
74 entries: Vec::new(),
75 parallel_threshold,
76 }
77 }
78
79 pub fn default_threshold() -> Self {
81 Self::new(256)
82 }
83
84 pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
87 if vector.is_empty() {
88 return Err(anyhow!("vector must not be empty"));
89 }
90 if let Some(entry) = self.entries.iter_mut().find(|e| e.id == id) {
91 entry.data = vector;
92 } else {
93 self.entries.push(IndexEntry { id, data: vector });
94 }
95 Ok(())
96 }
97
98 pub fn len(&self) -> usize {
100 self.entries.len()
101 }
102
103 pub fn is_empty(&self) -> bool {
105 self.entries.is_empty()
106 }
107
108 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
113 if query.is_empty() {
114 return Err(anyhow!("query vector must not be empty"));
115 }
116 if self.entries.is_empty() {
117 return Ok(Vec::new());
118 }
119
120 let mut scored: Vec<(usize, f32)> = if self.entries.len() >= self.parallel_threshold {
121 let indexed: Vec<(usize, &IndexEntry)> = self.entries.iter().enumerate().collect();
123 let mut v: Vec<(usize, f32)> = indexed
124 .par_iter()
125 .map(|&(idx, entry)| (idx, cosine_distance_simd(query, &entry.data)))
126 .collect();
127 v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
128 v
129 } else {
130 let mut v: Vec<(usize, f32)> = self
132 .entries
133 .iter()
134 .enumerate()
135 .map(|(idx, entry)| (idx, cosine_distance_simd(query, &entry.data)))
136 .collect();
137 v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
138 v
139 };
140
141 scored.truncate(k);
142 let results = scored
143 .into_iter()
144 .map(|(idx, dist)| (self.entries[idx].id.clone(), dist))
145 .collect();
146 Ok(results)
147 }
148
149 pub fn all_distances(&self, query: &[f32]) -> Result<Vec<(String, f32)>> {
152 if query.is_empty() {
153 return Err(anyhow!("query vector must not be empty"));
154 }
155 let results = self
156 .entries
157 .iter()
158 .map(|e| (e.id.clone(), cosine_distance_simd(query, &e.data)))
159 .collect();
160 Ok(results)
161 }
162}
163
164impl Default for SimdVectorSearch {
165 fn default() -> Self {
166 Self::default_threshold()
167 }
168}
169
170#[derive(Debug)]
179pub struct BatchSearchEngine {
180 index: Arc<SimdVectorSearch>,
181}
182
183impl BatchSearchEngine {
184 pub fn new(index: SimdVectorSearch) -> Self {
186 Self {
187 index: Arc::new(index),
188 }
189 }
190
191 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(String, f32)>>> {
194 if queries.is_empty() {
195 return Ok(Vec::new());
196 }
197
198 let index = Arc::clone(&self.index);
199
200 let results: Vec<Vec<(String, f32)>> = queries
201 .par_iter()
202 .map(|q| index.search(q, k).unwrap_or_default())
203 .collect();
204
205 Ok(results)
206 }
207
208 pub fn timed_search(
210 &self,
211 query: &[f32],
212 k: usize,
213 metrics: &SearchMetrics,
214 ) -> Result<Vec<(String, f32)>> {
215 let start = Instant::now();
216 let result = self.index.search(query, k)?;
217 let elapsed_us = start.elapsed().as_micros() as u64;
218 metrics.record_query(elapsed_us);
219 Ok(result)
220 }
221
222 pub fn index_size(&self) -> usize {
224 self.index.len()
225 }
226}
227
228#[derive(Debug)]
240pub struct SearchMetrics {
241 total_queries: AtomicU64,
242 total_latency_us: AtomicU64,
243 min_latency_us: AtomicU64,
244 max_latency_us: AtomicU64,
245 reservoir: parking_lot::Mutex<Vec<u64>>,
247 reservoir_cap: usize,
248}
249
250impl SearchMetrics {
251 const DEFAULT_RESERVOIR_CAP: usize = 4096;
252
253 pub fn new() -> Self {
255 Self {
256 total_queries: AtomicU64::new(0),
257 total_latency_us: AtomicU64::new(0),
258 min_latency_us: AtomicU64::new(u64::MAX),
259 max_latency_us: AtomicU64::new(0),
260 reservoir: parking_lot::Mutex::new(Vec::with_capacity(Self::DEFAULT_RESERVOIR_CAP)),
261 reservoir_cap: Self::DEFAULT_RESERVOIR_CAP,
262 }
263 }
264
265 pub fn record_query(&self, latency_us: u64) {
267 self.total_queries.fetch_add(1, Ordering::Relaxed);
268 self.total_latency_us
269 .fetch_add(latency_us, Ordering::Relaxed);
270
271 let mut current_min = self.min_latency_us.load(Ordering::Relaxed);
273 while latency_us < current_min {
274 match self.min_latency_us.compare_exchange_weak(
275 current_min,
276 latency_us,
277 Ordering::Relaxed,
278 Ordering::Relaxed,
279 ) {
280 Ok(_) => break,
281 Err(updated) => current_min = updated,
282 }
283 }
284
285 let mut current_max = self.max_latency_us.load(Ordering::Relaxed);
287 while latency_us > current_max {
288 match self.max_latency_us.compare_exchange_weak(
289 current_max,
290 latency_us,
291 Ordering::Relaxed,
292 Ordering::Relaxed,
293 ) {
294 Ok(_) => break,
295 Err(updated) => current_max = updated,
296 }
297 }
298
299 let mut res = self.reservoir.lock();
301 if res.len() < self.reservoir_cap {
302 res.push(latency_us);
303 }
304 }
305
306 pub fn total_queries(&self) -> u64 {
308 self.total_queries.load(Ordering::Relaxed)
309 }
310
311 pub fn mean_latency_us(&self) -> Option<f64> {
314 let n = self.total_queries();
315 if n == 0 {
316 return None;
317 }
318 Some(self.total_latency_us.load(Ordering::Relaxed) as f64 / n as f64)
319 }
320
321 pub fn min_latency_us(&self) -> Option<u64> {
323 let v = self.min_latency_us.load(Ordering::Relaxed);
324 if v == u64::MAX {
325 None
326 } else {
327 Some(v)
328 }
329 }
330
331 pub fn max_latency_us(&self) -> Option<u64> {
333 let v = self.max_latency_us.load(Ordering::Relaxed);
334 if v == 0 && self.total_queries() == 0 {
335 None
336 } else {
337 Some(v)
338 }
339 }
340
341 pub fn throughput_qps(&self) -> Option<f64> {
344 let mean = self.mean_latency_us()?;
345 if mean == 0.0 {
346 return None;
347 }
348 Some(1_000_000.0 / mean)
349 }
350
351 pub fn percentile_us(&self, p: f64) -> Option<u64> {
354 let mut res = self.reservoir.lock();
355 if res.is_empty() {
356 return None;
357 }
358 res.sort_unstable();
359 let idx = ((p / 100.0) * (res.len() - 1) as f64).round() as usize;
360 Some(res[idx.min(res.len() - 1)])
361 }
362
363 pub fn reset(&self) {
365 self.total_queries.store(0, Ordering::Relaxed);
366 self.total_latency_us.store(0, Ordering::Relaxed);
367 self.min_latency_us.store(u64::MAX, Ordering::Relaxed);
368 self.max_latency_us.store(0, Ordering::Relaxed);
369 self.reservoir.lock().clear();
370 }
371}
372
373impl Default for SearchMetrics {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386 use anyhow::Result;
387
388 fn unit_vec(dim: usize, hot_dim: usize) -> Vec<f32> {
393 let mut v = vec![0.0f32; dim];
394 v[hot_dim % dim] = 1.0;
395 v
396 }
397
398 fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
399 let mut state = seed;
400 (0..dim)
401 .map(|_| {
402 state = state
403 .wrapping_mul(6364136223846793005)
404 .wrapping_add(1442695040888963407);
405 ((state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
406 })
407 .collect()
408 }
409
410 fn build_index(n: usize, dim: usize) -> SimdVectorSearch {
411 let mut idx = SimdVectorSearch::new(16);
412 for i in 0..n {
413 let v = random_vec(dim, i as u64 + 1);
414 idx.insert(format!("v{}", i), v)
415 .expect("insert should succeed");
416 }
417 idx
418 }
419
420 #[test]
425 fn test_simd_search_basic_knn() -> Result<()> {
426 let mut idx = SimdVectorSearch::new(4);
427 idx.insert("a".into(), vec![1.0, 0.0, 0.0])?;
428 idx.insert("b".into(), vec![0.0, 1.0, 0.0])?;
429 idx.insert("c".into(), vec![0.9, 0.1, 0.0])?;
430
431 let query = vec![1.0, 0.0, 0.0];
432 let results = idx.search(&query, 2)?;
433 assert_eq!(results.len(), 2);
434 assert_eq!(results[0].0, "a");
436 assert!(results[0].1 < 1e-5);
437 Ok(())
438 }
439
440 #[test]
441 fn test_simd_search_empty_index() -> Result<()> {
442 let idx = SimdVectorSearch::new(16);
443 let results = idx.search(&[1.0, 0.0], 5)?;
444 assert!(results.is_empty());
445 Ok(())
446 }
447
448 #[test]
449 fn test_simd_search_single_entry() -> Result<()> {
450 let mut idx = SimdVectorSearch::new(4);
451 idx.insert("only".into(), vec![0.6, 0.8])?;
452 let results = idx.search(&[0.6, 0.8], 10)?;
453 assert_eq!(results.len(), 1);
454 assert!(results[0].1 < 1e-5);
455 Ok(())
456 }
457
458 #[test]
459 fn test_simd_search_k_larger_than_index() -> Result<()> {
460 let idx = build_index(5, 4);
461 let query = random_vec(4, 999);
462 let results = idx.search(&query, 100)?;
463 assert_eq!(results.len(), 5, "should return at most index size");
464 Ok(())
465 }
466
467 #[test]
468 fn test_simd_search_results_sorted_ascending() -> Result<()> {
469 let idx = build_index(50, 8);
470 let query = random_vec(8, 42);
471 let results = idx.search(&query, 20)?;
472 for w in results.windows(2) {
473 assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", w);
474 }
475 Ok(())
476 }
477
478 #[test]
479 fn test_simd_search_parallel_threshold_switch() -> Result<()> {
480 let mut idx = SimdVectorSearch::new(4);
482 for i in 0..8_usize {
483 idx.insert(format!("p{}", i), unit_vec(4, i))?;
484 }
485 let query = unit_vec(4, 0);
486 let results = idx.search(&query, 3)?;
487 assert_eq!(results.len(), 3);
488 assert_eq!(results[0].0, "p0");
489 Ok(())
490 }
491
492 #[test]
493 fn test_simd_search_update_existing_id() -> Result<()> {
494 let mut idx = SimdVectorSearch::new(4);
495 idx.insert("x".into(), vec![1.0, 0.0])?;
496 idx.insert("x".into(), vec![0.0, 1.0])?; assert_eq!(idx.len(), 1);
499 let results = idx.search(&[0.0, 1.0], 1)?;
500 assert_eq!(results[0].0, "x");
501 assert!(results[0].1 < 1e-5);
502 Ok(())
503 }
504
505 #[test]
506 fn test_simd_all_distances_length() -> Result<()> {
507 let idx = build_index(10, 4);
508 let query = random_vec(4, 7);
509 let all = idx.all_distances(&query)?;
510 assert_eq!(all.len(), 10);
511 Ok(())
512 }
513
514 #[test]
515 fn test_simd_orthogonal_max_distance() -> Result<()> {
516 let mut idx = SimdVectorSearch::new(4);
517 idx.insert("y".into(), vec![0.0, 1.0])?;
518
519 let query = vec![1.0, 0.0];
520 let results = idx.search(&query, 1)?;
521 assert!((results[0].1 - 1.0).abs() < 1e-4);
522 Ok(())
523 }
524
525 #[test]
530 fn test_batch_search_basic() -> Result<()> {
531 let engine = BatchSearchEngine::new(build_index(20, 4));
532 let queries: Vec<Vec<f32>> = (0..5).map(|i| random_vec(4, i as u64)).collect();
533 let results = engine.batch_search(&queries, 3)?;
534 assert_eq!(results.len(), 5);
535 for r in &results {
536 assert!(r.len() <= 3);
537 }
538 Ok(())
539 }
540
541 #[test]
542 fn test_batch_search_empty_queries() -> Result<()> {
543 let engine = BatchSearchEngine::new(build_index(10, 4));
544 let results = engine.batch_search(&[], 5)?;
545 assert!(results.is_empty());
546 Ok(())
547 }
548
549 #[test]
550 fn test_batch_search_order_preserved() -> Result<()> {
551 let mut idx = SimdVectorSearch::new(4);
552 idx.insert("origin".into(), vec![0.0, 0.0, 0.0, 0.0])?;
553 idx.insert("x_axis".into(), vec![1.0, 0.0, 0.0, 0.0])?;
554 idx.insert("y_axis".into(), vec![0.0, 1.0, 0.0, 0.0])?;
555
556 let engine = BatchSearchEngine::new(idx);
557 let queries = vec![
558 vec![1.0_f32, 0.0, 0.0, 0.0], vec![0.0_f32, 1.0, 0.0, 0.0], ];
561 let results = engine.batch_search(&queries, 1)?;
562 assert_eq!(results[0][0].0, "x_axis");
563 assert_eq!(results[1][0].0, "y_axis");
564 Ok(())
565 }
566
567 #[test]
568 fn test_batch_search_large_concurrent() -> Result<()> {
569 let engine = BatchSearchEngine::new(build_index(200, 16));
570 let queries: Vec<Vec<f32>> = (0..50).map(|i| random_vec(16, i as u64 + 100)).collect();
571 let results = engine.batch_search(&queries, 5)?;
572 assert_eq!(results.len(), 50);
573 for r in &results {
574 assert!(!r.is_empty());
575 }
576 Ok(())
577 }
578
579 #[test]
580 fn test_batch_timed_search_records_metrics() -> Result<()> {
581 let engine = BatchSearchEngine::new(build_index(30, 8));
582 let metrics = SearchMetrics::new();
583 let query = random_vec(8, 77);
584
585 let results = engine.timed_search(&query, 3, &metrics)?;
586 assert!(!results.is_empty());
587 assert_eq!(metrics.total_queries(), 1);
588 assert!(metrics.mean_latency_us().is_some());
589 Ok(())
590 }
591
592 #[test]
597 fn test_metrics_basic_recording() -> Result<()> {
598 let m = SearchMetrics::new();
599 m.record_query(100);
600 m.record_query(200);
601 m.record_query(300);
602
603 assert_eq!(m.total_queries(), 3);
604 let mean = m.mean_latency_us().expect("mean latency should be Some");
605 assert!((mean - 200.0).abs() < 0.01);
606 Ok(())
607 }
608
609 #[test]
610 fn test_metrics_min_max() -> Result<()> {
611 let m = SearchMetrics::new();
612 m.record_query(50);
613 m.record_query(150);
614 m.record_query(300);
615
616 let __val = m.min_latency_us().expect("min latency should be Some");
617 assert_eq!(__val, 50);
618 let __val = m.max_latency_us().expect("max latency should be Some");
619 assert_eq!(__val, 300);
620 Ok(())
621 }
622
623 #[test]
624 fn test_metrics_percentile_p50() -> Result<()> {
625 let m = SearchMetrics::new();
626 for lat in [10_u64, 20, 30, 40, 50] {
627 m.record_query(lat);
628 }
629 let p50 = m.percentile_us(50.0).expect("p50 should be Some");
631 assert_eq!(p50, 30);
632 Ok(())
633 }
634
635 #[test]
636 fn test_metrics_reset() {
637 let m = SearchMetrics::new();
638 m.record_query(100);
639 m.reset();
640 assert_eq!(m.total_queries(), 0);
641 assert!(m.mean_latency_us().is_none());
642 }
643
644 #[test]
645 fn test_metrics_throughput_qps() -> Result<()> {
646 let m = SearchMetrics::new();
647 m.record_query(1_000); let qps = m.throughput_qps().expect("throughput_qps should be Some");
649 assert!((qps - 1_000.0).abs() < 0.01);
650 Ok(())
651 }
652
653 #[test]
654 fn test_metrics_no_queries_returns_none() {
655 let m = SearchMetrics::new();
656 assert!(m.mean_latency_us().is_none());
657 assert!(m.min_latency_us().is_none());
658 assert!(m.throughput_qps().is_none());
659 assert!(m.percentile_us(50.0).is_none());
660 }
661}