Skip to main content

oxirs_vec/gpu_search_enhanced/
mod.rs

1//! GPU-enhanced vector search primitives for OxiRS Vector Search.
2//!
3//! This module provides three high-level building blocks for
4//! performance-critical vector search workloads:
5//!
6//! - [`SimdVectorSearch`] – SIMD-accelerated flat (brute-force) search using
7//!   parallel dot products via `scirs2_core::simd`.
8//! - [`BatchSearchEngine`] – Concurrent batch search across multiple query
9//!   vectors using `scirs2_core::parallel_ops`.
10//! - [`SearchMetrics`] – Lightweight instrumentation for measuring throughput
11//!   and latency percentiles.
12
13use anyhow::{anyhow, Result};
14use scirs2_core::ndarray_ext::Array1;
15use scirs2_core::parallel_ops::{IntoParallelRefIterator, ParallelIterator};
16use scirs2_core::simd::simd_dot_f32;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::Arc;
19use std::time::Instant;
20
21// ---------------------------------------------------------------------------
22// Internal helper
23// ---------------------------------------------------------------------------
24
25/// Compute the cosine distance between two `f32` slices using SIMD dot
26/// products from `scirs2_core`.  Returns `f32::INFINITY` when either vector
27/// has zero magnitude.
28fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
29    if a.len() != b.len() || a.is_empty() {
30        return f32::INFINITY;
31    }
32    let a_arr = Array1::from_vec(a.to_vec());
33    let b_arr = Array1::from_vec(b.to_vec());
34
35    let dot = simd_dot_f32(&a_arr.view(), &b_arr.view());
36    let norm_a = simd_dot_f32(&a_arr.view(), &a_arr.view()).sqrt();
37    let norm_b = simd_dot_f32(&b_arr.view(), &b_arr.view()).sqrt();
38
39    if norm_a == 0.0 || norm_b == 0.0 {
40        f32::INFINITY
41    } else {
42        let sim = dot / (norm_a * norm_b);
43        1.0 - sim.clamp(-1.0, 1.0)
44    }
45}
46
47// ---------------------------------------------------------------------------
48// SimdVectorSearch
49// ---------------------------------------------------------------------------
50
51/// A stored entry in the [`SimdVectorSearch`] flat index.
52#[derive(Debug, Clone)]
53struct IndexEntry {
54    id: String,
55    data: Vec<f32>,
56}
57
58/// SIMD-accelerated flat (brute-force) vector search.
59///
60/// Stores all vectors in memory and computes cosine distances to a query
61/// using `scirs2_core`'s SIMD dot-product primitive.  Parallel execution
62/// is used for batches that exceed the configured threshold.
63#[derive(Debug)]
64pub struct SimdVectorSearch {
65    entries: Vec<IndexEntry>,
66    /// Use parallel computation when the candidate set exceeds this size.
67    parallel_threshold: usize,
68}
69
70impl SimdVectorSearch {
71    /// Create a new empty flat index with the given parallel threshold.
72    pub fn new(parallel_threshold: usize) -> Self {
73        Self {
74            entries: Vec::new(),
75            parallel_threshold,
76        }
77    }
78
79    /// Create a flat index with the default threshold (256 vectors).
80    pub fn default_threshold() -> Self {
81        Self::new(256)
82    }
83
84    /// Insert a vector under `id`.  If `id` already exists, the vector is
85    /// replaced.
86    pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
87        if vector.is_empty() {
88            return Err(anyhow!("vector must not be empty"));
89        }
90        if let Some(entry) = self.entries.iter_mut().find(|e| e.id == id) {
91            entry.data = vector;
92        } else {
93            self.entries.push(IndexEntry { id, data: vector });
94        }
95        Ok(())
96    }
97
98    /// Return the number of vectors in the index.
99    pub fn len(&self) -> usize {
100        self.entries.len()
101    }
102
103    /// Check whether the index is empty.
104    pub fn is_empty(&self) -> bool {
105        self.entries.is_empty()
106    }
107
108    /// Search for the `k` nearest neighbours of `query` by cosine distance.
109    ///
110    /// Returns a `Vec` of `(id, distance)` pairs sorted by ascending distance,
111    /// length ≤ `k`.
112    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
113        if query.is_empty() {
114            return Err(anyhow!("query vector must not be empty"));
115        }
116        if self.entries.is_empty() {
117            return Ok(Vec::new());
118        }
119
120        let mut scored: Vec<(usize, f32)> = if self.entries.len() >= self.parallel_threshold {
121            // Parallel path
122            let indexed: Vec<(usize, &IndexEntry)> = self.entries.iter().enumerate().collect();
123            let mut v: Vec<(usize, f32)> = indexed
124                .par_iter()
125                .map(|&(idx, entry)| (idx, cosine_distance_simd(query, &entry.data)))
126                .collect();
127            v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
128            v
129        } else {
130            // Sequential path for small sets
131            let mut v: Vec<(usize, f32)> = self
132                .entries
133                .iter()
134                .enumerate()
135                .map(|(idx, entry)| (idx, cosine_distance_simd(query, &entry.data)))
136                .collect();
137            v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
138            v
139        };
140
141        scored.truncate(k);
142        let results = scored
143            .into_iter()
144            .map(|(idx, dist)| (self.entries[idx].id.clone(), dist))
145            .collect();
146        Ok(results)
147    }
148
149    /// Compute raw cosine distances from `query` to all indexed vectors and
150    /// return them in insertion order (not sorted).
151    pub fn all_distances(&self, query: &[f32]) -> Result<Vec<(String, f32)>> {
152        if query.is_empty() {
153            return Err(anyhow!("query vector must not be empty"));
154        }
155        let results = self
156            .entries
157            .iter()
158            .map(|e| (e.id.clone(), cosine_distance_simd(query, &e.data)))
159            .collect();
160        Ok(results)
161    }
162}
163
164impl Default for SimdVectorSearch {
165    fn default() -> Self {
166        Self::default_threshold()
167    }
168}
169
170// ---------------------------------------------------------------------------
171// BatchSearchEngine
172// ---------------------------------------------------------------------------
173
174/// Concurrent batch search over multiple query vectors.
175///
176/// Wraps a [`SimdVectorSearch`] index and executes multiple queries in
177/// parallel using `scirs2_core::parallel_ops`.
178#[derive(Debug)]
179pub struct BatchSearchEngine {
180    index: Arc<SimdVectorSearch>,
181}
182
183impl BatchSearchEngine {
184    /// Wrap an existing [`SimdVectorSearch`] index.
185    pub fn new(index: SimdVectorSearch) -> Self {
186        Self {
187            index: Arc::new(index),
188        }
189    }
190
191    /// Execute `queries` in parallel, each returning the `k` nearest
192    /// neighbours.  The outer `Vec` preserves query ordering.
193    pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(String, f32)>>> {
194        if queries.is_empty() {
195            return Ok(Vec::new());
196        }
197
198        let index = Arc::clone(&self.index);
199
200        let results: Vec<Vec<(String, f32)>> = queries
201            .par_iter()
202            .map(|q| index.search(q, k).unwrap_or_default())
203            .collect();
204
205        Ok(results)
206    }
207
208    /// Search for a single query and record the latency into `metrics`.
209    pub fn timed_search(
210        &self,
211        query: &[f32],
212        k: usize,
213        metrics: &SearchMetrics,
214    ) -> Result<Vec<(String, f32)>> {
215        let start = Instant::now();
216        let result = self.index.search(query, k)?;
217        let elapsed_us = start.elapsed().as_micros() as u64;
218        metrics.record_query(elapsed_us);
219        Ok(result)
220    }
221
222    /// Return the number of vectors in the underlying index.
223    pub fn index_size(&self) -> usize {
224        self.index.len()
225    }
226}
227
228// ---------------------------------------------------------------------------
229// SearchMetrics
230// ---------------------------------------------------------------------------
231
232/// Lightweight, lock-free performance metrics for vector search operations.
233///
234/// Tracks:
235/// - total number of queries executed
236/// - cumulative latency in microseconds (for mean computation)
237/// - minimum and maximum observed latency
238/// - approximate p50, p90, p99 percentiles (computed from a sorted snapshot)
239#[derive(Debug)]
240pub struct SearchMetrics {
241    total_queries: AtomicU64,
242    total_latency_us: AtomicU64,
243    min_latency_us: AtomicU64,
244    max_latency_us: AtomicU64,
245    // Simple reservoir for percentile estimation (bounded to 4096 samples)
246    reservoir: parking_lot::Mutex<Vec<u64>>,
247    reservoir_cap: usize,
248}
249
250impl SearchMetrics {
251    const DEFAULT_RESERVOIR_CAP: usize = 4096;
252
253    /// Create a new metrics collector.
254    pub fn new() -> Self {
255        Self {
256            total_queries: AtomicU64::new(0),
257            total_latency_us: AtomicU64::new(0),
258            min_latency_us: AtomicU64::new(u64::MAX),
259            max_latency_us: AtomicU64::new(0),
260            reservoir: parking_lot::Mutex::new(Vec::with_capacity(Self::DEFAULT_RESERVOIR_CAP)),
261            reservoir_cap: Self::DEFAULT_RESERVOIR_CAP,
262        }
263    }
264
265    /// Record a single query with its observed latency in microseconds.
266    pub fn record_query(&self, latency_us: u64) {
267        self.total_queries.fetch_add(1, Ordering::Relaxed);
268        self.total_latency_us
269            .fetch_add(latency_us, Ordering::Relaxed);
270
271        // Update min
272        let mut current_min = self.min_latency_us.load(Ordering::Relaxed);
273        while latency_us < current_min {
274            match self.min_latency_us.compare_exchange_weak(
275                current_min,
276                latency_us,
277                Ordering::Relaxed,
278                Ordering::Relaxed,
279            ) {
280                Ok(_) => break,
281                Err(updated) => current_min = updated,
282            }
283        }
284
285        // Update max
286        let mut current_max = self.max_latency_us.load(Ordering::Relaxed);
287        while latency_us > current_max {
288            match self.max_latency_us.compare_exchange_weak(
289                current_max,
290                latency_us,
291                Ordering::Relaxed,
292                Ordering::Relaxed,
293            ) {
294                Ok(_) => break,
295                Err(updated) => current_max = updated,
296            }
297        }
298
299        // Push to reservoir (capped)
300        let mut res = self.reservoir.lock();
301        if res.len() < self.reservoir_cap {
302            res.push(latency_us);
303        }
304    }
305
306    /// Total number of queries recorded.
307    pub fn total_queries(&self) -> u64 {
308        self.total_queries.load(Ordering::Relaxed)
309    }
310
311    /// Mean latency in microseconds, or `None` if no queries have been
312    /// recorded yet.
313    pub fn mean_latency_us(&self) -> Option<f64> {
314        let n = self.total_queries();
315        if n == 0 {
316            return None;
317        }
318        Some(self.total_latency_us.load(Ordering::Relaxed) as f64 / n as f64)
319    }
320
321    /// Minimum observed latency in microseconds.
322    pub fn min_latency_us(&self) -> Option<u64> {
323        let v = self.min_latency_us.load(Ordering::Relaxed);
324        if v == u64::MAX {
325            None
326        } else {
327            Some(v)
328        }
329    }
330
331    /// Maximum observed latency in microseconds.
332    pub fn max_latency_us(&self) -> Option<u64> {
333        let v = self.max_latency_us.load(Ordering::Relaxed);
334        if v == 0 && self.total_queries() == 0 {
335            None
336        } else {
337            Some(v)
338        }
339    }
340
341    /// Approximate throughput in queries per second, computed from the mean
342    /// latency. Returns `None` if no queries recorded or mean is zero.
343    pub fn throughput_qps(&self) -> Option<f64> {
344        let mean = self.mean_latency_us()?;
345        if mean == 0.0 {
346            return None;
347        }
348        Some(1_000_000.0 / mean)
349    }
350
351    /// Compute the p-th percentile (0–100) latency from the current
352    /// reservoir sample.  Returns `None` if no samples are available.
353    pub fn percentile_us(&self, p: f64) -> Option<u64> {
354        let mut res = self.reservoir.lock();
355        if res.is_empty() {
356            return None;
357        }
358        res.sort_unstable();
359        let idx = ((p / 100.0) * (res.len() - 1) as f64).round() as usize;
360        Some(res[idx.min(res.len() - 1)])
361    }
362
363    /// Reset all counters.
364    pub fn reset(&self) {
365        self.total_queries.store(0, Ordering::Relaxed);
366        self.total_latency_us.store(0, Ordering::Relaxed);
367        self.min_latency_us.store(u64::MAX, Ordering::Relaxed);
368        self.max_latency_us.store(0, Ordering::Relaxed);
369        self.reservoir.lock().clear();
370    }
371}
372
373impl Default for SearchMetrics {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379// ---------------------------------------------------------------------------
380// Tests
381// ---------------------------------------------------------------------------
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    // ------------------------------------------------------------------
388    // Helpers
389    // ------------------------------------------------------------------
390
391    fn unit_vec(dim: usize, hot_dim: usize) -> Vec<f32> {
392        let mut v = vec![0.0f32; dim];
393        v[hot_dim % dim] = 1.0;
394        v
395    }
396
397    fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
398        let mut state = seed;
399        (0..dim)
400            .map(|_| {
401                state = state
402                    .wrapping_mul(6364136223846793005)
403                    .wrapping_add(1442695040888963407);
404                ((state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
405            })
406            .collect()
407    }
408
409    fn build_index(n: usize, dim: usize) -> SimdVectorSearch {
410        let mut idx = SimdVectorSearch::new(16);
411        for i in 0..n {
412            let v = random_vec(dim, i as u64 + 1);
413            idx.insert(format!("v{}", i), v).unwrap();
414        }
415        idx
416    }
417
418    // ------------------------------------------------------------------
419    // SimdVectorSearch tests
420    // ------------------------------------------------------------------
421
422    #[test]
423    fn test_simd_search_basic_knn() {
424        let mut idx = SimdVectorSearch::new(4);
425        idx.insert("a".into(), vec![1.0, 0.0, 0.0]).unwrap();
426        idx.insert("b".into(), vec![0.0, 1.0, 0.0]).unwrap();
427        idx.insert("c".into(), vec![0.9, 0.1, 0.0]).unwrap();
428
429        let query = vec![1.0, 0.0, 0.0];
430        let results = idx.search(&query, 2).unwrap();
431        assert_eq!(results.len(), 2);
432        // "a" (identical) must be first
433        assert_eq!(results[0].0, "a");
434        assert!(results[0].1 < 1e-5);
435    }
436
437    #[test]
438    fn test_simd_search_empty_index() {
439        let idx = SimdVectorSearch::new(16);
440        let results = idx.search(&[1.0, 0.0], 5).unwrap();
441        assert!(results.is_empty());
442    }
443
444    #[test]
445    fn test_simd_search_single_entry() {
446        let mut idx = SimdVectorSearch::new(4);
447        idx.insert("only".into(), vec![0.6, 0.8]).unwrap();
448        let results = idx.search(&[0.6, 0.8], 10).unwrap();
449        assert_eq!(results.len(), 1);
450        assert!(results[0].1 < 1e-5);
451    }
452
453    #[test]
454    fn test_simd_search_k_larger_than_index() {
455        let idx = build_index(5, 4);
456        let query = random_vec(4, 999);
457        let results = idx.search(&query, 100).unwrap();
458        assert_eq!(results.len(), 5, "should return at most index size");
459    }
460
461    #[test]
462    fn test_simd_search_results_sorted_ascending() {
463        let idx = build_index(50, 8);
464        let query = random_vec(8, 42);
465        let results = idx.search(&query, 20).unwrap();
466        for w in results.windows(2) {
467            assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", w);
468        }
469    }
470
471    #[test]
472    fn test_simd_search_parallel_threshold_switch() {
473        // Parallel threshold of 4; index has 8 entries → parallel path
474        let mut idx = SimdVectorSearch::new(4);
475        for i in 0..8_usize {
476            idx.insert(format!("p{}", i), unit_vec(4, i)).unwrap();
477        }
478        let query = unit_vec(4, 0);
479        let results = idx.search(&query, 3).unwrap();
480        assert_eq!(results.len(), 3);
481        assert_eq!(results[0].0, "p0");
482    }
483
484    #[test]
485    fn test_simd_search_update_existing_id() {
486        let mut idx = SimdVectorSearch::new(4);
487        idx.insert("x".into(), vec![1.0, 0.0]).unwrap();
488        idx.insert("x".into(), vec![0.0, 1.0]).unwrap(); // update
489
490        assert_eq!(idx.len(), 1);
491        let results = idx.search(&[0.0, 1.0], 1).unwrap();
492        assert_eq!(results[0].0, "x");
493        assert!(results[0].1 < 1e-5);
494    }
495
496    #[test]
497    fn test_simd_all_distances_length() {
498        let idx = build_index(10, 4);
499        let query = random_vec(4, 7);
500        let all = idx.all_distances(&query).unwrap();
501        assert_eq!(all.len(), 10);
502    }
503
504    #[test]
505    fn test_simd_orthogonal_max_distance() {
506        let mut idx = SimdVectorSearch::new(4);
507        idx.insert("y".into(), vec![0.0, 1.0]).unwrap();
508
509        let query = vec![1.0, 0.0];
510        let results = idx.search(&query, 1).unwrap();
511        assert!((results[0].1 - 1.0).abs() < 1e-4);
512    }
513
514    // ------------------------------------------------------------------
515    // BatchSearchEngine tests
516    // ------------------------------------------------------------------
517
518    #[test]
519    fn test_batch_search_basic() {
520        let engine = BatchSearchEngine::new(build_index(20, 4));
521        let queries: Vec<Vec<f32>> = (0..5).map(|i| random_vec(4, i as u64)).collect();
522        let results = engine.batch_search(&queries, 3).unwrap();
523        assert_eq!(results.len(), 5);
524        for r in &results {
525            assert!(r.len() <= 3);
526        }
527    }
528
529    #[test]
530    fn test_batch_search_empty_queries() {
531        let engine = BatchSearchEngine::new(build_index(10, 4));
532        let results = engine.batch_search(&[], 5).unwrap();
533        assert!(results.is_empty());
534    }
535
536    #[test]
537    fn test_batch_search_order_preserved() {
538        let mut idx = SimdVectorSearch::new(4);
539        idx.insert("origin".into(), vec![0.0, 0.0, 0.0, 0.0])
540            .unwrap();
541        idx.insert("x_axis".into(), vec![1.0, 0.0, 0.0, 0.0])
542            .unwrap();
543        idx.insert("y_axis".into(), vec![0.0, 1.0, 0.0, 0.0])
544            .unwrap();
545
546        let engine = BatchSearchEngine::new(idx);
547        let queries = vec![
548            vec![1.0_f32, 0.0, 0.0, 0.0], // closest to x_axis
549            vec![0.0_f32, 1.0, 0.0, 0.0], // closest to y_axis
550        ];
551        let results = engine.batch_search(&queries, 1).unwrap();
552        assert_eq!(results[0][0].0, "x_axis");
553        assert_eq!(results[1][0].0, "y_axis");
554    }
555
556    #[test]
557    fn test_batch_search_large_concurrent() {
558        let engine = BatchSearchEngine::new(build_index(200, 16));
559        let queries: Vec<Vec<f32>> = (0..50).map(|i| random_vec(16, i as u64 + 100)).collect();
560        let results = engine.batch_search(&queries, 5).unwrap();
561        assert_eq!(results.len(), 50);
562        for r in &results {
563            assert!(!r.is_empty());
564        }
565    }
566
567    #[test]
568    fn test_batch_timed_search_records_metrics() {
569        let engine = BatchSearchEngine::new(build_index(30, 8));
570        let metrics = SearchMetrics::new();
571        let query = random_vec(8, 77);
572
573        let results = engine.timed_search(&query, 3, &metrics).unwrap();
574        assert!(!results.is_empty());
575        assert_eq!(metrics.total_queries(), 1);
576        assert!(metrics.mean_latency_us().is_some());
577    }
578
579    // ------------------------------------------------------------------
580    // SearchMetrics tests
581    // ------------------------------------------------------------------
582
583    #[test]
584    fn test_metrics_basic_recording() {
585        let m = SearchMetrics::new();
586        m.record_query(100);
587        m.record_query(200);
588        m.record_query(300);
589
590        assert_eq!(m.total_queries(), 3);
591        let mean = m.mean_latency_us().unwrap();
592        assert!((mean - 200.0).abs() < 0.01);
593    }
594
595    #[test]
596    fn test_metrics_min_max() {
597        let m = SearchMetrics::new();
598        m.record_query(50);
599        m.record_query(150);
600        m.record_query(300);
601
602        assert_eq!(m.min_latency_us().unwrap(), 50);
603        assert_eq!(m.max_latency_us().unwrap(), 300);
604    }
605
606    #[test]
607    fn test_metrics_percentile_p50() {
608        let m = SearchMetrics::new();
609        for lat in [10_u64, 20, 30, 40, 50] {
610            m.record_query(lat);
611        }
612        // p50 of [10,20,30,40,50] is index 2 → 30
613        let p50 = m.percentile_us(50.0).unwrap();
614        assert_eq!(p50, 30);
615    }
616
617    #[test]
618    fn test_metrics_reset() {
619        let m = SearchMetrics::new();
620        m.record_query(100);
621        m.reset();
622        assert_eq!(m.total_queries(), 0);
623        assert!(m.mean_latency_us().is_none());
624    }
625
626    #[test]
627    fn test_metrics_throughput_qps() {
628        let m = SearchMetrics::new();
629        m.record_query(1_000); // 1 ms = 1 000 µs → 1 000 QPS
630        let qps = m.throughput_qps().unwrap();
631        assert!((qps - 1_000.0).abs() < 0.01);
632    }
633
634    #[test]
635    fn test_metrics_no_queries_returns_none() {
636        let m = SearchMetrics::new();
637        assert!(m.mean_latency_us().is_none());
638        assert!(m.min_latency_us().is_none());
639        assert!(m.throughput_qps().is_none());
640        assert!(m.percentile_us(50.0).is_none());
641    }
642}