1use anyhow::{anyhow, Result};
14use scirs2_core::ndarray_ext::Array1;
15use scirs2_core::parallel_ops::{IntoParallelRefIterator, ParallelIterator};
16use scirs2_core::simd::simd_dot_f32;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::Arc;
19use std::time::Instant;
20
21fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
29 if a.len() != b.len() || a.is_empty() {
30 return f32::INFINITY;
31 }
32 let a_arr = Array1::from_vec(a.to_vec());
33 let b_arr = Array1::from_vec(b.to_vec());
34
35 let dot = simd_dot_f32(&a_arr.view(), &b_arr.view());
36 let norm_a = simd_dot_f32(&a_arr.view(), &a_arr.view()).sqrt();
37 let norm_b = simd_dot_f32(&b_arr.view(), &b_arr.view()).sqrt();
38
39 if norm_a == 0.0 || norm_b == 0.0 {
40 f32::INFINITY
41 } else {
42 let sim = dot / (norm_a * norm_b);
43 1.0 - sim.clamp(-1.0, 1.0)
44 }
45}
46
47#[derive(Debug, Clone)]
53struct IndexEntry {
54 id: String,
55 data: Vec<f32>,
56}
57
58#[derive(Debug)]
64pub struct SimdVectorSearch {
65 entries: Vec<IndexEntry>,
66 parallel_threshold: usize,
68}
69
70impl SimdVectorSearch {
71 pub fn new(parallel_threshold: usize) -> Self {
73 Self {
74 entries: Vec::new(),
75 parallel_threshold,
76 }
77 }
78
79 pub fn default_threshold() -> Self {
81 Self::new(256)
82 }
83
84 pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
87 if vector.is_empty() {
88 return Err(anyhow!("vector must not be empty"));
89 }
90 if let Some(entry) = self.entries.iter_mut().find(|e| e.id == id) {
91 entry.data = vector;
92 } else {
93 self.entries.push(IndexEntry { id, data: vector });
94 }
95 Ok(())
96 }
97
98 pub fn len(&self) -> usize {
100 self.entries.len()
101 }
102
103 pub fn is_empty(&self) -> bool {
105 self.entries.is_empty()
106 }
107
108 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
113 if query.is_empty() {
114 return Err(anyhow!("query vector must not be empty"));
115 }
116 if self.entries.is_empty() {
117 return Ok(Vec::new());
118 }
119
120 let mut scored: Vec<(usize, f32)> = if self.entries.len() >= self.parallel_threshold {
121 let indexed: Vec<(usize, &IndexEntry)> = self.entries.iter().enumerate().collect();
123 let mut v: Vec<(usize, f32)> = indexed
124 .par_iter()
125 .map(|&(idx, entry)| (idx, cosine_distance_simd(query, &entry.data)))
126 .collect();
127 v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
128 v
129 } else {
130 let mut v: Vec<(usize, f32)> = self
132 .entries
133 .iter()
134 .enumerate()
135 .map(|(idx, entry)| (idx, cosine_distance_simd(query, &entry.data)))
136 .collect();
137 v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
138 v
139 };
140
141 scored.truncate(k);
142 let results = scored
143 .into_iter()
144 .map(|(idx, dist)| (self.entries[idx].id.clone(), dist))
145 .collect();
146 Ok(results)
147 }
148
149 pub fn all_distances(&self, query: &[f32]) -> Result<Vec<(String, f32)>> {
152 if query.is_empty() {
153 return Err(anyhow!("query vector must not be empty"));
154 }
155 let results = self
156 .entries
157 .iter()
158 .map(|e| (e.id.clone(), cosine_distance_simd(query, &e.data)))
159 .collect();
160 Ok(results)
161 }
162}
163
164impl Default for SimdVectorSearch {
165 fn default() -> Self {
166 Self::default_threshold()
167 }
168}
169
170#[derive(Debug)]
179pub struct BatchSearchEngine {
180 index: Arc<SimdVectorSearch>,
181}
182
183impl BatchSearchEngine {
184 pub fn new(index: SimdVectorSearch) -> Self {
186 Self {
187 index: Arc::new(index),
188 }
189 }
190
191 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(String, f32)>>> {
194 if queries.is_empty() {
195 return Ok(Vec::new());
196 }
197
198 let index = Arc::clone(&self.index);
199
200 let results: Vec<Vec<(String, f32)>> = queries
201 .par_iter()
202 .map(|q| index.search(q, k).unwrap_or_default())
203 .collect();
204
205 Ok(results)
206 }
207
208 pub fn timed_search(
210 &self,
211 query: &[f32],
212 k: usize,
213 metrics: &SearchMetrics,
214 ) -> Result<Vec<(String, f32)>> {
215 let start = Instant::now();
216 let result = self.index.search(query, k)?;
217 let elapsed_us = start.elapsed().as_micros() as u64;
218 metrics.record_query(elapsed_us);
219 Ok(result)
220 }
221
222 pub fn index_size(&self) -> usize {
224 self.index.len()
225 }
226}
227
228#[derive(Debug)]
240pub struct SearchMetrics {
241 total_queries: AtomicU64,
242 total_latency_us: AtomicU64,
243 min_latency_us: AtomicU64,
244 max_latency_us: AtomicU64,
245 reservoir: parking_lot::Mutex<Vec<u64>>,
247 reservoir_cap: usize,
248}
249
250impl SearchMetrics {
251 const DEFAULT_RESERVOIR_CAP: usize = 4096;
252
253 pub fn new() -> Self {
255 Self {
256 total_queries: AtomicU64::new(0),
257 total_latency_us: AtomicU64::new(0),
258 min_latency_us: AtomicU64::new(u64::MAX),
259 max_latency_us: AtomicU64::new(0),
260 reservoir: parking_lot::Mutex::new(Vec::with_capacity(Self::DEFAULT_RESERVOIR_CAP)),
261 reservoir_cap: Self::DEFAULT_RESERVOIR_CAP,
262 }
263 }
264
265 pub fn record_query(&self, latency_us: u64) {
267 self.total_queries.fetch_add(1, Ordering::Relaxed);
268 self.total_latency_us
269 .fetch_add(latency_us, Ordering::Relaxed);
270
271 let mut current_min = self.min_latency_us.load(Ordering::Relaxed);
273 while latency_us < current_min {
274 match self.min_latency_us.compare_exchange_weak(
275 current_min,
276 latency_us,
277 Ordering::Relaxed,
278 Ordering::Relaxed,
279 ) {
280 Ok(_) => break,
281 Err(updated) => current_min = updated,
282 }
283 }
284
285 let mut current_max = self.max_latency_us.load(Ordering::Relaxed);
287 while latency_us > current_max {
288 match self.max_latency_us.compare_exchange_weak(
289 current_max,
290 latency_us,
291 Ordering::Relaxed,
292 Ordering::Relaxed,
293 ) {
294 Ok(_) => break,
295 Err(updated) => current_max = updated,
296 }
297 }
298
299 let mut res = self.reservoir.lock();
301 if res.len() < self.reservoir_cap {
302 res.push(latency_us);
303 }
304 }
305
306 pub fn total_queries(&self) -> u64 {
308 self.total_queries.load(Ordering::Relaxed)
309 }
310
311 pub fn mean_latency_us(&self) -> Option<f64> {
314 let n = self.total_queries();
315 if n == 0 {
316 return None;
317 }
318 Some(self.total_latency_us.load(Ordering::Relaxed) as f64 / n as f64)
319 }
320
321 pub fn min_latency_us(&self) -> Option<u64> {
323 let v = self.min_latency_us.load(Ordering::Relaxed);
324 if v == u64::MAX {
325 None
326 } else {
327 Some(v)
328 }
329 }
330
331 pub fn max_latency_us(&self) -> Option<u64> {
333 let v = self.max_latency_us.load(Ordering::Relaxed);
334 if v == 0 && self.total_queries() == 0 {
335 None
336 } else {
337 Some(v)
338 }
339 }
340
341 pub fn throughput_qps(&self) -> Option<f64> {
344 let mean = self.mean_latency_us()?;
345 if mean == 0.0 {
346 return None;
347 }
348 Some(1_000_000.0 / mean)
349 }
350
351 pub fn percentile_us(&self, p: f64) -> Option<u64> {
354 let mut res = self.reservoir.lock();
355 if res.is_empty() {
356 return None;
357 }
358 res.sort_unstable();
359 let idx = ((p / 100.0) * (res.len() - 1) as f64).round() as usize;
360 Some(res[idx.min(res.len() - 1)])
361 }
362
363 pub fn reset(&self) {
365 self.total_queries.store(0, Ordering::Relaxed);
366 self.total_latency_us.store(0, Ordering::Relaxed);
367 self.min_latency_us.store(u64::MAX, Ordering::Relaxed);
368 self.max_latency_us.store(0, Ordering::Relaxed);
369 self.reservoir.lock().clear();
370 }
371}
372
373impl Default for SearchMetrics {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn unit_vec(dim: usize, hot_dim: usize) -> Vec<f32> {
392 let mut v = vec![0.0f32; dim];
393 v[hot_dim % dim] = 1.0;
394 v
395 }
396
397 fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
398 let mut state = seed;
399 (0..dim)
400 .map(|_| {
401 state = state
402 .wrapping_mul(6364136223846793005)
403 .wrapping_add(1442695040888963407);
404 ((state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
405 })
406 .collect()
407 }
408
409 fn build_index(n: usize, dim: usize) -> SimdVectorSearch {
410 let mut idx = SimdVectorSearch::new(16);
411 for i in 0..n {
412 let v = random_vec(dim, i as u64 + 1);
413 idx.insert(format!("v{}", i), v).unwrap();
414 }
415 idx
416 }
417
418 #[test]
423 fn test_simd_search_basic_knn() {
424 let mut idx = SimdVectorSearch::new(4);
425 idx.insert("a".into(), vec![1.0, 0.0, 0.0]).unwrap();
426 idx.insert("b".into(), vec![0.0, 1.0, 0.0]).unwrap();
427 idx.insert("c".into(), vec![0.9, 0.1, 0.0]).unwrap();
428
429 let query = vec![1.0, 0.0, 0.0];
430 let results = idx.search(&query, 2).unwrap();
431 assert_eq!(results.len(), 2);
432 assert_eq!(results[0].0, "a");
434 assert!(results[0].1 < 1e-5);
435 }
436
437 #[test]
438 fn test_simd_search_empty_index() {
439 let idx = SimdVectorSearch::new(16);
440 let results = idx.search(&[1.0, 0.0], 5).unwrap();
441 assert!(results.is_empty());
442 }
443
444 #[test]
445 fn test_simd_search_single_entry() {
446 let mut idx = SimdVectorSearch::new(4);
447 idx.insert("only".into(), vec![0.6, 0.8]).unwrap();
448 let results = idx.search(&[0.6, 0.8], 10).unwrap();
449 assert_eq!(results.len(), 1);
450 assert!(results[0].1 < 1e-5);
451 }
452
453 #[test]
454 fn test_simd_search_k_larger_than_index() {
455 let idx = build_index(5, 4);
456 let query = random_vec(4, 999);
457 let results = idx.search(&query, 100).unwrap();
458 assert_eq!(results.len(), 5, "should return at most index size");
459 }
460
461 #[test]
462 fn test_simd_search_results_sorted_ascending() {
463 let idx = build_index(50, 8);
464 let query = random_vec(8, 42);
465 let results = idx.search(&query, 20).unwrap();
466 for w in results.windows(2) {
467 assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", w);
468 }
469 }
470
471 #[test]
472 fn test_simd_search_parallel_threshold_switch() {
473 let mut idx = SimdVectorSearch::new(4);
475 for i in 0..8_usize {
476 idx.insert(format!("p{}", i), unit_vec(4, i)).unwrap();
477 }
478 let query = unit_vec(4, 0);
479 let results = idx.search(&query, 3).unwrap();
480 assert_eq!(results.len(), 3);
481 assert_eq!(results[0].0, "p0");
482 }
483
484 #[test]
485 fn test_simd_search_update_existing_id() {
486 let mut idx = SimdVectorSearch::new(4);
487 idx.insert("x".into(), vec![1.0, 0.0]).unwrap();
488 idx.insert("x".into(), vec![0.0, 1.0]).unwrap(); assert_eq!(idx.len(), 1);
491 let results = idx.search(&[0.0, 1.0], 1).unwrap();
492 assert_eq!(results[0].0, "x");
493 assert!(results[0].1 < 1e-5);
494 }
495
496 #[test]
497 fn test_simd_all_distances_length() {
498 let idx = build_index(10, 4);
499 let query = random_vec(4, 7);
500 let all = idx.all_distances(&query).unwrap();
501 assert_eq!(all.len(), 10);
502 }
503
504 #[test]
505 fn test_simd_orthogonal_max_distance() {
506 let mut idx = SimdVectorSearch::new(4);
507 idx.insert("y".into(), vec![0.0, 1.0]).unwrap();
508
509 let query = vec![1.0, 0.0];
510 let results = idx.search(&query, 1).unwrap();
511 assert!((results[0].1 - 1.0).abs() < 1e-4);
512 }
513
514 #[test]
519 fn test_batch_search_basic() {
520 let engine = BatchSearchEngine::new(build_index(20, 4));
521 let queries: Vec<Vec<f32>> = (0..5).map(|i| random_vec(4, i as u64)).collect();
522 let results = engine.batch_search(&queries, 3).unwrap();
523 assert_eq!(results.len(), 5);
524 for r in &results {
525 assert!(r.len() <= 3);
526 }
527 }
528
529 #[test]
530 fn test_batch_search_empty_queries() {
531 let engine = BatchSearchEngine::new(build_index(10, 4));
532 let results = engine.batch_search(&[], 5).unwrap();
533 assert!(results.is_empty());
534 }
535
536 #[test]
537 fn test_batch_search_order_preserved() {
538 let mut idx = SimdVectorSearch::new(4);
539 idx.insert("origin".into(), vec![0.0, 0.0, 0.0, 0.0])
540 .unwrap();
541 idx.insert("x_axis".into(), vec![1.0, 0.0, 0.0, 0.0])
542 .unwrap();
543 idx.insert("y_axis".into(), vec![0.0, 1.0, 0.0, 0.0])
544 .unwrap();
545
546 let engine = BatchSearchEngine::new(idx);
547 let queries = vec![
548 vec![1.0_f32, 0.0, 0.0, 0.0], vec![0.0_f32, 1.0, 0.0, 0.0], ];
551 let results = engine.batch_search(&queries, 1).unwrap();
552 assert_eq!(results[0][0].0, "x_axis");
553 assert_eq!(results[1][0].0, "y_axis");
554 }
555
556 #[test]
557 fn test_batch_search_large_concurrent() {
558 let engine = BatchSearchEngine::new(build_index(200, 16));
559 let queries: Vec<Vec<f32>> = (0..50).map(|i| random_vec(16, i as u64 + 100)).collect();
560 let results = engine.batch_search(&queries, 5).unwrap();
561 assert_eq!(results.len(), 50);
562 for r in &results {
563 assert!(!r.is_empty());
564 }
565 }
566
567 #[test]
568 fn test_batch_timed_search_records_metrics() {
569 let engine = BatchSearchEngine::new(build_index(30, 8));
570 let metrics = SearchMetrics::new();
571 let query = random_vec(8, 77);
572
573 let results = engine.timed_search(&query, 3, &metrics).unwrap();
574 assert!(!results.is_empty());
575 assert_eq!(metrics.total_queries(), 1);
576 assert!(metrics.mean_latency_us().is_some());
577 }
578
579 #[test]
584 fn test_metrics_basic_recording() {
585 let m = SearchMetrics::new();
586 m.record_query(100);
587 m.record_query(200);
588 m.record_query(300);
589
590 assert_eq!(m.total_queries(), 3);
591 let mean = m.mean_latency_us().unwrap();
592 assert!((mean - 200.0).abs() < 0.01);
593 }
594
595 #[test]
596 fn test_metrics_min_max() {
597 let m = SearchMetrics::new();
598 m.record_query(50);
599 m.record_query(150);
600 m.record_query(300);
601
602 assert_eq!(m.min_latency_us().unwrap(), 50);
603 assert_eq!(m.max_latency_us().unwrap(), 300);
604 }
605
606 #[test]
607 fn test_metrics_percentile_p50() {
608 let m = SearchMetrics::new();
609 for lat in [10_u64, 20, 30, 40, 50] {
610 m.record_query(lat);
611 }
612 let p50 = m.percentile_us(50.0).unwrap();
614 assert_eq!(p50, 30);
615 }
616
617 #[test]
618 fn test_metrics_reset() {
619 let m = SearchMetrics::new();
620 m.record_query(100);
621 m.reset();
622 assert_eq!(m.total_queries(), 0);
623 assert!(m.mean_latency_us().is_none());
624 }
625
626 #[test]
627 fn test_metrics_throughput_qps() {
628 let m = SearchMetrics::new();
629 m.record_query(1_000); let qps = m.throughput_qps().unwrap();
631 assert!((qps - 1_000.0).abs() < 0.01);
632 }
633
634 #[test]
635 fn test_metrics_no_queries_returns_none() {
636 let m = SearchMetrics::new();
637 assert!(m.mean_latency_us().is_none());
638 assert!(m.min_latency_us().is_none());
639 assert!(m.throughput_qps().is_none());
640 assert!(m.percentile_us(50.0).is_none());
641 }
642}