1#![allow(clippy::type_complexity)]
2use crate::gpu_acceleration::GPUAccelerator;
3use ndarray::{Array1, Array2};
4use rayon::prelude::*;
5use std::cmp::Ordering;
6use std::collections::BinaryHeap;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10#[cfg(target_arch = "x86_64")]
11use std::arch::x86_64::*;
12
13#[derive(Debug, Clone)]
15pub struct PositionEntry {
16 pub vector: Array1<f32>,
17 pub evaluation: f32,
18 pub norm_squared: f32,
19}
20
21#[derive(Debug)]
23pub struct SearchResultRef<'a> {
24 pub similarity: f32,
25 pub evaluation: f32,
26 pub vector: &'a Array1<f32>,
27}
28
29#[derive(Debug, Clone)]
31pub struct SearchResult {
32 pub similarity: f32,
33 pub evaluation: f32,
34 pub vector: Array1<f32>,
35}
36
37impl PartialEq for SearchResult {
38 fn eq(&self, other: &Self) -> bool {
39 self.similarity == other.similarity
40 }
41}
42
43impl Eq for SearchResult {}
44
45impl PartialOrd for SearchResult {
46 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
47 Some(self.cmp(other))
48 }
49}
50
51impl Ord for SearchResult {
52 fn cmp(&self, other: &Self) -> Ordering {
53 other
55 .similarity
56 .partial_cmp(&self.similarity)
57 .unwrap_or(Ordering::Equal)
58 }
59}
60
61#[derive(Clone)]
63pub struct SimilaritySearch {
64 positions: Vec<PositionEntry>,
66 vector_size: usize,
68}
69
70impl SimilaritySearch {
71 pub fn new(vector_size: usize) -> Self {
73 Self {
74 positions: Vec::new(),
75 vector_size,
76 }
77 }
78
79 pub fn add_position(&mut self, vector: Array1<f32>, evaluation: f32) {
81 assert_eq!(vector.len(), self.vector_size, "Vector size mismatch");
82
83 let norm_squared =
84 self.simd_dot_product(vector.as_slice().unwrap(), vector.as_slice().unwrap());
85
86 self.positions.push(PositionEntry {
87 vector,
88 evaluation,
89 norm_squared,
90 });
91 }
92
93 pub fn search_ref(&self, query: &Array1<f32>, k: usize) -> Vec<(&Array1<f32>, f32, f32)> {
95 if self.positions.len() > 100 {
100 self.parallel_search_ref(query, k)
101 } else {
102 self.sequential_search_ref(query, k)
103 }
104 }
105
106 pub fn search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
108 let gpu_accelerator = GPUAccelerator::global();
109
110 if gpu_accelerator.is_gpu_enabled() && self.positions.len() > 500 {
112 match self.gpu_accelerated_search(query, k) {
113 Ok(results) => return results,
114 Err(e) => {
115 println!("GPU search failed ({e}), falling back to CPU");
116 }
117 }
118 }
119
120 if self.positions.len() > 100 {
122 self.parallel_search(query, k)
123 } else {
124 self.sequential_search(query, k)
125 }
126 }
127
128 pub fn gpu_accelerated_search(
130 &self,
131 query: &Array1<f32>,
132 k: usize,
133 ) -> Result<Vec<(Array1<f32>, f32, f32)>, Box<dyn std::error::Error>> {
134 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
135
136 if self.positions.is_empty() {
137 return Ok(Vec::new());
138 }
139
140 let gpu_accelerator = GPUAccelerator::global();
141
142 let mut vectors_data = Vec::with_capacity(self.positions.len() * self.vector_size);
144 for entry in &self.positions {
145 vectors_data.extend_from_slice(entry.vector.as_slice().unwrap());
146 }
147
148 let vectors_matrix =
149 Array2::from_shape_vec((self.positions.len(), self.vector_size), vectors_data)?;
150
151 let similarities = gpu_accelerator.cosine_similarity_batch(query, &vectors_matrix)?;
153
154 let mut indexed_similarities: Vec<(usize, f32)> = similarities
156 .iter()
157 .enumerate()
158 .map(|(i, &sim)| (i, sim))
159 .collect();
160
161 indexed_similarities
163 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164
165 let mut results = Vec::new();
167 for (idx, similarity) in indexed_similarities.into_iter().take(k) {
168 let entry = &self.positions[idx];
169 results.push((entry.vector.clone(), entry.evaluation, similarity));
170 }
171
172 Ok(results)
173 }
174
175 pub fn sequential_search_ref(
177 &self,
178 query: &Array1<f32>,
179 k: usize,
180 ) -> Vec<(&Array1<f32>, f32, f32)> {
181 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
182
183 if self.positions.is_empty() {
184 return Vec::new();
185 }
186
187 let query_norm_squared = query.dot(query);
188
189 let mut indexed_similarities: Vec<(usize, f32)> = self
191 .positions
192 .iter()
193 .enumerate()
194 .map(|(idx, entry)| {
195 let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
196 (idx, similarity)
197 })
198 .collect();
199
200 indexed_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
202
203 indexed_similarities
205 .into_iter()
206 .take(k)
207 .map(|(idx, similarity)| {
208 let entry = &self.positions[idx];
209 (&entry.vector, entry.evaluation, similarity)
210 })
211 .collect()
212 }
213
214 pub fn sequential_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
216 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
217
218 if self.positions.is_empty() {
219 return Vec::new();
220 }
221
222 let query_norm_squared = query.dot(query);
223
224 let mut heap = BinaryHeap::new();
226
227 for entry in &self.positions {
228 let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
229
230 let result = SearchResult {
231 similarity,
232 evaluation: entry.evaluation,
233 vector: entry.vector.clone(),
234 };
235
236 if heap.len() < k {
237 heap.push(result);
238 } else if similarity > heap.peek().unwrap().similarity {
239 heap.pop();
240 heap.push(result);
241 }
242 }
243
244 let mut results = Vec::new();
246 while let Some(result) = heap.pop() {
247 results.push((result.vector, result.evaluation, result.similarity));
248 }
249
250 results.reverse();
252 results
253 }
254
255 pub fn parallel_search_ref(
257 &self,
258 query: &Array1<f32>,
259 k: usize,
260 ) -> Vec<(&Array1<f32>, f32, f32)> {
261 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
262
263 if self.positions.is_empty() {
264 return Vec::new();
265 }
266
267 let query_norm_squared = query.dot(query);
268
269 let mut indexed_similarities: Vec<(usize, f32)> = self
271 .positions
272 .par_iter()
273 .enumerate()
274 .map(|(idx, entry)| {
275 let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
276 (idx, similarity)
277 })
278 .collect();
279
280 indexed_similarities
282 .par_sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
283 indexed_similarities.truncate(k);
284
285 indexed_similarities
287 .into_iter()
288 .map(|(idx, similarity)| {
289 let entry = &self.positions[idx];
290 (&entry.vector, entry.evaluation, similarity)
291 })
292 .collect()
293 }
294
295 pub fn parallel_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
297 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
298
299 if self.positions.is_empty() {
300 return Vec::new();
301 }
302
303 let query_norm_squared = query.dot(query);
304
305 let mut results: Vec<_> = self
307 .positions
308 .par_iter()
309 .map(|entry| {
310 let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
311 (entry.vector.clone(), entry.evaluation, similarity)
312 })
313 .collect();
314
315 results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
317 results.truncate(k);
318
319 results
320 }
321
322 pub fn brute_force_search(
324 &self,
325 query: &Array1<f32>,
326 k: usize,
327 ) -> Vec<(Array1<f32>, f32, f32)> {
328 let mut results: Vec<_> = if self.positions.len() > 100 {
329 self.positions
331 .par_iter()
332 .map(|entry| {
333 let similarity = self.cosine_similarity(query, &entry.vector);
334 (entry.vector.clone(), entry.evaluation, similarity)
335 })
336 .collect()
337 } else {
338 self.positions
340 .iter()
341 .map(|entry| {
342 let similarity = self.cosine_similarity(query, &entry.vector);
343 (entry.vector.clone(), entry.evaluation, similarity)
344 })
345 .collect()
346 };
347
348 if results.len() > 1000 {
350 results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
351 } else {
352 results.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
353 }
354
355 results.truncate(k);
357 results
358 }
359
360 fn cosine_similarity_fast(
362 &self,
363 query: &Array1<f32>,
364 query_norm_squared: f32,
365 entry: &PositionEntry,
366 ) -> f32 {
367 let dot_product =
368 self.simd_dot_product(query.as_slice().unwrap(), entry.vector.as_slice().unwrap());
369
370 if query_norm_squared == 0.0 || entry.norm_squared == 0.0 {
371 0.0
372 } else {
373 dot_product / (query_norm_squared.sqrt() * entry.norm_squared.sqrt())
374 }
375 }
376
377 #[inline]
379 fn simd_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
380 #[cfg(target_arch = "x86_64")]
381 {
382 if is_x86_feature_detected!("avx2") {
383 return unsafe { self.avx2_dot_product(a, b) };
384 } else if is_x86_feature_detected!("sse4.1") {
385 return unsafe { self.sse_dot_product(a, b) };
386 }
387 }
388
389 #[cfg(target_arch = "aarch64")]
390 {
391 if std::arch::is_aarch64_feature_detected!("neon") {
392 return unsafe { self.neon_dot_product(a, b) };
393 }
394 }
395
396 self.scalar_dot_product(a, b)
398 }
399
400 #[cfg(target_arch = "x86_64")]
401 #[target_feature(enable = "avx2")]
402 unsafe fn avx2_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
403 let len = a.len().min(b.len());
404 let mut sum = _mm256_setzero_ps();
405 let mut i = 0;
406
407 while i + 8 <= len {
409 let va = _mm256_loadu_ps(a.as_ptr().add(i));
410 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
411 let vmul = _mm256_mul_ps(va, vb);
412 sum = _mm256_add_ps(sum, vmul);
413 i += 8;
414 }
415
416 let mut result = [0.0f32; 8];
418 _mm256_storeu_ps(result.as_mut_ptr(), sum);
419 let mut final_sum = result.iter().sum::<f32>();
420
421 while i < len {
423 final_sum += a[i] * b[i];
424 i += 1;
425 }
426
427 final_sum
428 }
429
430 #[cfg(target_arch = "x86_64")]
431 #[target_feature(enable = "sse4.1")]
432 unsafe fn sse_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
433 let len = a.len().min(b.len());
434 let mut sum = _mm_setzero_ps();
435 let mut i = 0;
436
437 while i + 4 <= len {
439 let va = _mm_loadu_ps(a.as_ptr().add(i));
440 let vb = _mm_loadu_ps(b.as_ptr().add(i));
441 let vmul = _mm_mul_ps(va, vb);
442 sum = _mm_add_ps(sum, vmul);
443 i += 4;
444 }
445
446 let mut result = [0.0f32; 4];
448 _mm_storeu_ps(result.as_mut_ptr(), sum);
449 let mut final_sum = result.iter().sum::<f32>();
450
451 while i < len {
453 final_sum += a[i] * b[i];
454 i += 1;
455 }
456
457 final_sum
458 }
459
460 #[cfg(target_arch = "aarch64")]
461 #[target_feature(enable = "neon")]
462 unsafe fn neon_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
463 let len = a.len().min(b.len());
464 let mut sum = vdupq_n_f32(0.0);
465 let mut i = 0;
466
467 while i + 4 <= len {
469 let va = vld1q_f32(a.as_ptr().add(i));
470 let vb = vld1q_f32(b.as_ptr().add(i));
471 let vmul = vmulq_f32(va, vb);
472 sum = vaddq_f32(sum, vmul);
473 i += 4;
474 }
475
476 let mut result = [0.0f32; 4];
478 vst1q_f32(result.as_mut_ptr(), sum);
479 let mut final_sum = result.iter().sum::<f32>();
480
481 while i < len {
483 final_sum += a[i] * b[i];
484 i += 1;
485 }
486
487 final_sum
488 }
489
490 #[inline]
491 fn scalar_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
492 let len = a.len().min(b.len());
493 let mut sum = 0.0f32;
494
495 let mut i = 0;
497 while i + 4 <= len {
498 sum += a[i] * b[i] + a[i + 1] * b[i + 1] + a[i + 2] * b[i + 2] + a[i + 3] * b[i + 3];
499 i += 4;
500 }
501
502 while i < len {
504 sum += a[i] * b[i];
505 i += 1;
506 }
507
508 sum
509 }
510
511 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
513 let dot_product = a.dot(b);
514 let norm_a = a.dot(a).sqrt();
515 let norm_b = b.dot(b).sqrt();
516
517 if norm_a == 0.0 || norm_b == 0.0 {
518 0.0
519 } else {
520 dot_product / (norm_a * norm_b)
521 }
522 }
523
524 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
526 (a - b).mapv(|x| x * x).sum().sqrt()
527 }
528
529 pub fn search_by_distance(
531 &self,
532 query: &Array1<f32>,
533 k: usize,
534 ) -> Vec<(Array1<f32>, f32, f32)> {
535 let mut results: Vec<_> = self
536 .positions
537 .iter()
538 .map(|entry| {
539 let distance = self.euclidean_distance(query, &entry.vector);
540 (entry.vector.clone(), entry.evaluation, distance)
541 })
542 .collect();
543
544 results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
546
547 results.truncate(k);
549 results
550 }
551
552 pub fn size(&self) -> usize {
554 self.positions.len()
555 }
556
557 pub fn is_empty(&self) -> bool {
559 self.positions.is_empty()
560 }
561
562 pub fn clear(&mut self) {
564 self.positions.clear();
565 }
566
567 pub fn statistics(&self) -> SimilaritySearchStats {
569 if self.positions.is_empty() {
570 return SimilaritySearchStats {
571 count: 0,
572 avg_evaluation: 0.0,
573 min_evaluation: 0.0,
574 max_evaluation: 0.0,
575 };
576 }
577
578 let evaluations: Vec<f32> = self.positions.iter().map(|p| p.evaluation).collect();
579 let sum: f32 = evaluations.iter().sum();
580 let avg = sum / evaluations.len() as f32;
581 let min = evaluations.iter().fold(f32::INFINITY, |a, &b| a.min(b));
582 let max = evaluations.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
583
584 SimilaritySearchStats {
585 count: self.positions.len(),
586 avg_evaluation: avg,
587 min_evaluation: min,
588 max_evaluation: max,
589 }
590 }
591
592 pub fn get_all_positions(&self) -> Vec<(Array1<f32>, f32)> {
594 self.positions
595 .iter()
596 .map(|entry| (entry.vector.clone(), entry.evaluation))
597 .collect()
598 }
599
600 pub fn get_position_ref(&self, index: usize) -> Option<(&Array1<f32>, f32)> {
602 self.positions
603 .get(index)
604 .map(|entry| (&entry.vector, entry.evaluation))
605 }
606
607 pub fn iter_positions(&self) -> impl Iterator<Item = (&Array1<f32>, f32)> {
609 self.positions
610 .iter()
611 .map(|entry| (&entry.vector, entry.evaluation))
612 }
613}
614
615#[derive(Debug, Clone)]
617pub struct SimilaritySearchStats {
618 pub count: usize,
619 pub avg_evaluation: f32,
620 pub min_evaluation: f32,
621 pub max_evaluation: f32,
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use ndarray::Array1;
628
629 #[test]
630 fn test_similarity_search_creation() {
631 let search = SimilaritySearch::new(100);
632 assert_eq!(search.size(), 0);
633 assert!(search.is_empty());
634 }
635
636 #[test]
637 fn test_add_and_search() {
638 let mut search = SimilaritySearch::new(3);
639
640 let vec1 = Array1::from(vec![1.0, 0.0, 0.0]);
642 let vec2 = Array1::from(vec![0.0, 1.0, 0.0]);
643 let vec3 = Array1::from(vec![0.0, 0.0, 1.0]);
644
645 search.add_position(vec1.clone(), 1.0);
646 search.add_position(vec2, 0.5);
647 search.add_position(vec3, 0.0);
648
649 assert_eq!(search.size(), 3);
650
651 let results = search.search(&vec1, 2);
653 assert_eq!(results.len(), 2);
654
655 assert!((results[0].2 - 1.0).abs() < 1e-6);
657 assert!((results[0].1 - 1.0).abs() < 1e-6);
658 }
659
660 #[test]
661 fn test_cosine_similarity() {
662 let search = SimilaritySearch::new(2);
663
664 let vec1 = Array1::from(vec![1.0, 0.0]);
665 let vec2 = Array1::from(vec![1.0, 0.0]);
666 let vec3 = Array1::from(vec![0.0, 1.0]);
667
668 assert!((search.cosine_similarity(&vec1, &vec2) - 1.0).abs() < 1e-6);
670
671 assert!((search.cosine_similarity(&vec1, &vec3) - 0.0).abs() < 1e-6);
673 }
674
675 #[test]
676 fn test_statistics() {
677 let mut search = SimilaritySearch::new(2);
678
679 let vec = Array1::from(vec![1.0, 0.0]);
680 search.add_position(vec.clone(), 1.0);
681 search.add_position(vec.clone(), 2.0);
682 search.add_position(vec, 3.0);
683
684 let stats = search.statistics();
685 assert_eq!(stats.count, 3);
686 assert!((stats.avg_evaluation - 2.0).abs() < 1e-6);
687 assert!((stats.min_evaluation - 1.0).abs() < 1e-6);
688 assert!((stats.max_evaluation - 3.0).abs() < 1e-6);
689 }
690}