1#![allow(clippy::type_complexity)]
2use crate::gpu_acceleration::GPUAccelerator;
3use crate::utils::simd::SimdVectorOps;
4use ndarray::{Array1, Array2};
5use rayon::prelude::*;
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap};
8use std::time::{Duration, Instant};
9#[derive(Debug, Clone)]
13pub struct PositionEntry {
14 pub vector: Array1<f32>,
15 pub evaluation: f32,
16 pub norm_squared: f32,
17}
18
19#[derive(Debug)]
21pub struct SearchResultRef<'a> {
22 pub similarity: f32,
23 pub evaluation: f32,
24 pub vector: &'a Array1<f32>,
25}
26
27#[derive(Debug, Clone)]
29pub struct SearchResult {
30 pub similarity: f32,
31 pub evaluation: f32,
32 pub vector: Array1<f32>,
33}
34
35impl PartialEq for SearchResult {
36 fn eq(&self, other: &Self) -> bool {
37 self.similarity == other.similarity
38 }
39}
40
41impl Eq for SearchResult {}
42
43impl PartialOrd for SearchResult {
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49impl Ord for SearchResult {
50 fn cmp(&self, other: &Self) -> Ordering {
51 other
53 .similarity
54 .partial_cmp(&self.similarity)
55 .unwrap_or(Ordering::Equal)
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct ClusterNode {
62 pub centroid: Array1<f32>,
64 pub position_indices: Vec<usize>,
66 pub children: Vec<ClusterNode>,
68 pub radius: f32,
70 pub size: usize,
72}
73
74#[derive(Debug, Clone)]
76pub struct SearchResultCache {
77 pub results: Vec<(Array1<f32>, f32, f32)>,
78 pub timestamp: Instant,
79}
80
81#[derive(Debug, Clone)]
83pub struct SimilarityCacheStats {
84 pub result_cache_size: usize,
85 pub similarity_cache_size: usize,
86 pub max_cache_size: usize,
87 pub cache_ttl_secs: u64,
88 pub cache_hits: u64,
89 pub cache_misses: u64,
90 pub hit_ratio: f32,
91}
92
93#[derive(Clone)]
95pub struct SimilaritySearch {
96 positions: Vec<PositionEntry>,
98 vector_size: usize,
100 cluster_tree: Option<ClusterNode>,
102 similarity_cache: HashMap<(usize, usize), (f32, Instant)>,
104 result_cache: HashMap<u64, SearchResultCache>,
106 max_cache_size: usize,
108 cache_ttl: Duration,
110 cache_hits: u64,
112 cache_misses: u64,
113}
114
115impl SimilaritySearch {
116 pub fn new(vector_size: usize) -> Self {
118 Self {
119 positions: Vec::new(),
120 vector_size,
121 cluster_tree: None,
122 similarity_cache: HashMap::with_capacity(10000),
123 result_cache: HashMap::with_capacity(1000),
124 max_cache_size: 10000,
125 cache_ttl: Duration::from_secs(300), cache_hits: 0,
127 cache_misses: 0,
128 }
129 }
130
131 pub fn with_cache_config(vector_size: usize, max_cache_size: usize, cache_ttl_secs: u64) -> Self {
133 Self {
134 positions: Vec::new(),
135 vector_size,
136 cluster_tree: None,
137 similarity_cache: HashMap::with_capacity(max_cache_size),
138 result_cache: HashMap::with_capacity(max_cache_size / 10),
139 max_cache_size,
140 cache_ttl: Duration::from_secs(cache_ttl_secs),
141 cache_hits: 0,
142 cache_misses: 0,
143 }
144 }
145
146 pub fn add_position(&mut self, vector: Array1<f32>, evaluation: f32) {
148 assert_eq!(vector.len(), self.vector_size, "Vector size mismatch");
149
150 let norm_squared = SimdVectorOps::squared_norm(&vector);
151
152 self.positions.push(PositionEntry {
153 vector,
154 evaluation,
155 norm_squared,
156 });
157
158 self.cluster_tree = None;
160
161 self.evict_expired_cache_entries();
163 if self.similarity_cache.len() > self.max_cache_size {
164 self.evict_oldest_cache_entries();
165 }
166 }
167
168 pub fn search_ref(&self, query: &Array1<f32>, k: usize) -> Vec<(&Array1<f32>, f32, f32)> {
170 if self.positions.len() > 1000 {
175 self.hierarchical_search_ref(query, k)
176 } else if self.positions.len() > 100 {
177 self.parallel_search_ref(query, k)
178 } else {
179 self.sequential_search_ref(query, k)
180 }
181 }
182
183 pub fn search(&mut self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
185 let query_hash = self.hash_query(query, k);
187
188 if let Some(cached_result) = self.get_cached_result(query_hash) {
190 return cached_result;
191 }
192
193 let results = self.search_uncached(query, k);
195
196 self.cache_search_result(query_hash, results.clone());
198
199 results
200 }
201
202 fn search_uncached(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
204 if self.positions.len() > 50 {
206 return self.search_optimized(query, k);
207 }
208
209 let gpu_accelerator = GPUAccelerator::global();
210
211 if gpu_accelerator.is_gpu_enabled() && self.positions.len() > 500 {
213 match self.gpu_accelerated_search(query, k) {
214 Ok(results) => return results,
215 Err(e) => {
216 println!("GPU search failed ({e}), falling back to CPU");
217 }
218 }
219 }
220
221 if self.positions.len() > 1000 {
223 self.hierarchical_search(query, k)
224 } else if self.positions.len() > 100 {
225 self.parallel_search(query, k)
226 } else {
227 self.sequential_search(query, k)
228 }
229 }
230
231 pub fn gpu_accelerated_search(
233 &self,
234 query: &Array1<f32>,
235 k: usize,
236 ) -> Result<Vec<(Array1<f32>, f32, f32)>, Box<dyn std::error::Error>> {
237 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
238
239 if self.positions.is_empty() {
240 return Ok(Vec::new());
241 }
242
243 let gpu_accelerator = GPUAccelerator::global();
244
245 let mut vectors_data = Vec::with_capacity(self.positions.len() * self.vector_size);
247 for entry in &self.positions {
248 vectors_data.extend_from_slice(entry.vector.as_slice().unwrap());
249 }
250
251 let vectors_matrix =
252 Array2::from_shape_vec((self.positions.len(), self.vector_size), vectors_data)?;
253
254 let similarities = gpu_accelerator.cosine_similarity_batch(query, &vectors_matrix)?;
256
257 let mut indexed_similarities: Vec<(usize, f32)> = similarities
259 .iter()
260 .enumerate()
261 .map(|(i, &sim)| (i, sim))
262 .collect();
263
264 indexed_similarities
266 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
267
268 let mut results = Vec::new();
270 for (idx, similarity) in indexed_similarities.into_iter().take(k) {
271 let entry = &self.positions[idx];
272 results.push((entry.vector.clone(), entry.evaluation, similarity));
273 }
274
275 Ok(results)
276 }
277
278 pub fn sequential_search_ref(
280 &self,
281 query: &Array1<f32>,
282 k: usize,
283 ) -> Vec<(&Array1<f32>, f32, f32)> {
284 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
285
286 if self.positions.is_empty() {
287 return Vec::new();
288 }
289
290 let query_norm_squared = SimdVectorOps::squared_norm(query);
291
292 let mut indexed_similarities: Vec<(usize, f32)> = self
294 .positions
295 .iter()
296 .enumerate()
297 .map(|(idx, entry)| {
298 let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
299 (idx, similarity)
300 })
301 .collect();
302
303 indexed_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
305
306 indexed_similarities
308 .into_iter()
309 .take(k)
310 .map(|(idx, similarity)| {
311 let entry = &self.positions[idx];
312 (&entry.vector, entry.evaluation, similarity)
313 })
314 .collect()
315 }
316
317 pub fn sequential_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
319 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
320
321 if self.positions.is_empty() {
322 return Vec::new();
323 }
324
325 let query_norm_squared = SimdVectorOps::squared_norm(query);
326
327 let mut heap = BinaryHeap::new();
329
330 for entry in &self.positions {
331 let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
332
333 let result = SearchResult {
334 similarity,
335 evaluation: entry.evaluation,
336 vector: entry.vector.clone(),
337 };
338
339 if heap.len() < k {
340 heap.push(result);
341 } else if similarity > heap.peek().unwrap().similarity {
342 heap.pop();
343 heap.push(result);
344 }
345 }
346
347 let mut results = Vec::new();
349 while let Some(result) = heap.pop() {
350 results.push((result.vector, result.evaluation, result.similarity));
351 }
352
353 results.reverse();
355 results
356 }
357
358 pub fn parallel_search_ref(
360 &self,
361 query: &Array1<f32>,
362 k: usize,
363 ) -> Vec<(&Array1<f32>, f32, f32)> {
364 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
365
366 if self.positions.is_empty() {
367 return Vec::new();
368 }
369
370 let query_norm_squared = SimdVectorOps::squared_norm(query);
371
372 let mut indexed_similarities: Vec<(usize, f32)> = self
374 .positions
375 .par_iter()
376 .enumerate()
377 .map(|(idx, entry)| {
378 let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
379 (idx, similarity)
380 })
381 .collect();
382
383 indexed_similarities
385 .par_sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
386 indexed_similarities.truncate(k);
387
388 indexed_similarities
390 .into_iter()
391 .map(|(idx, similarity)| {
392 let entry = &self.positions[idx];
393 (&entry.vector, entry.evaluation, similarity)
394 })
395 .collect()
396 }
397
398 pub fn parallel_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
400 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
401
402 if self.positions.is_empty() {
403 return Vec::new();
404 }
405
406 let query_norm_squared = SimdVectorOps::squared_norm(query);
407
408 let mut results: Vec<_> = self
410 .positions
411 .par_iter()
412 .map(|entry| {
413 let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
414 (entry.vector.clone(), entry.evaluation, similarity)
415 })
416 .collect();
417
418 results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
420 results.truncate(k);
421
422 results
423 }
424
425 pub fn brute_force_search(
427 &self,
428 query: &Array1<f32>,
429 k: usize,
430 ) -> Vec<(Array1<f32>, f32, f32)> {
431 let mut results: Vec<_> = if self.positions.len() > 100 {
432 self.positions
434 .par_iter()
435 .map(|entry| {
436 let similarity = self.cosine_similarity(query, &entry.vector);
437 (entry.vector.clone(), entry.evaluation, similarity)
438 })
439 .collect()
440 } else {
441 self.positions
443 .iter()
444 .map(|entry| {
445 let similarity = self.cosine_similarity(query, &entry.vector);
446 (entry.vector.clone(), entry.evaluation, similarity)
447 })
448 .collect()
449 };
450
451 if results.len() > 1000 {
453 results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
454 } else {
455 results.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
456 }
457
458 results.truncate(k);
460 results
461 }
462
463 fn cosine_similarity_fast(
465 &mut self,
466 query: &Array1<f32>,
467 query_norm_squared: f32,
468 entry_index: usize,
469 ) -> f32 {
470 let now = Instant::now();
472 let cache_key = (0, entry_index); if let Some((cached_similarity, cached_time)) = self.similarity_cache.get(&cache_key) {
475 if now.duration_since(*cached_time) < self.cache_ttl {
476 self.cache_hits += 1;
477 return *cached_similarity;
478 }
479 }
480
481 self.cache_misses += 1;
483 let entry = &self.positions[entry_index];
484
485 if query_norm_squared == 0.0 || entry.norm_squared == 0.0 {
487 return 0.0;
488 }
489
490 let dot_product = SimdVectorOps::dot_product(query, &entry.vector);
491
492 let query_norm_inv = 1.0 / query_norm_squared.sqrt();
494 let entry_norm_inv = 1.0 / entry.norm_squared.sqrt();
495
496 let similarity = dot_product * query_norm_inv * entry_norm_inv;
497
498 self.similarity_cache.insert(cache_key, (similarity, now));
500
501 similarity
502 }
503
504 fn cosine_similarity_fast_uncached(
506 &self,
507 query: &Array1<f32>,
508 query_norm_squared: f32,
509 entry: &PositionEntry,
510 ) -> f32 {
511 if query_norm_squared == 0.0 || entry.norm_squared == 0.0 {
513 return 0.0;
514 }
515
516 let dot_product = SimdVectorOps::dot_product(query, &entry.vector);
517
518 let query_norm_inv = 1.0 / query_norm_squared.sqrt();
520 let entry_norm_inv = 1.0 / entry.norm_squared.sqrt();
521
522 dot_product * query_norm_inv * entry_norm_inv
523 }
524
525 fn cosine_similarity_ultra_fast(
527 &self,
528 query: &Array1<f32>,
529 query_norm: f32,
530 entry: &PositionEntry,
531 entry_norm: f32,
532 ) -> f32 {
533 if query_norm == 0.0 || entry_norm == 0.0 {
534 return 0.0;
535 }
536
537 let dot_product = SimdVectorOps::dot_product(query, &entry.vector);
538 dot_product / (query_norm * entry_norm)
539 }
540
541 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
543 SimdVectorOps::cosine_similarity(a, b)
544 }
545
546 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
548 (a - b).mapv(|x| x * x).sum().sqrt()
549 }
550
551 pub fn search_by_distance(
553 &self,
554 query: &Array1<f32>,
555 k: usize,
556 ) -> Vec<(Array1<f32>, f32, f32)> {
557 let mut results: Vec<_> = self
558 .positions
559 .iter()
560 .map(|entry| {
561 let distance = self.euclidean_distance(query, &entry.vector);
562 (entry.vector.clone(), entry.evaluation, distance)
563 })
564 .collect();
565
566 results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
568
569 results.truncate(k);
571 results
572 }
573
574 pub fn size(&self) -> usize {
576 self.positions.len()
577 }
578
579 pub fn is_empty(&self) -> bool {
581 self.positions.is_empty()
582 }
583
584 pub fn clear(&mut self) {
586 self.positions.clear();
587 }
588
589 pub fn statistics(&self) -> SimilaritySearchStats {
591 if self.positions.is_empty() {
592 return SimilaritySearchStats {
593 count: 0,
594 avg_evaluation: 0.0,
595 min_evaluation: 0.0,
596 max_evaluation: 0.0,
597 };
598 }
599
600 let evaluations: Vec<f32> = self.positions.iter().map(|p| p.evaluation).collect();
601 let sum: f32 = evaluations.iter().sum();
602 let avg = sum / evaluations.len() as f32;
603 let min = evaluations.iter().fold(f32::INFINITY, |a, &b| a.min(b));
604 let max = evaluations.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
605
606 SimilaritySearchStats {
607 count: self.positions.len(),
608 avg_evaluation: avg,
609 min_evaluation: min,
610 max_evaluation: max,
611 }
612 }
613
614 pub fn get_all_positions(&self) -> Vec<(Array1<f32>, f32)> {
616 self.positions
617 .iter()
618 .map(|entry| (entry.vector.clone(), entry.evaluation))
619 .collect()
620 }
621
622 pub fn get_position_ref(&self, index: usize) -> Option<(&Array1<f32>, f32)> {
624 self.positions
625 .get(index)
626 .map(|entry| (&entry.vector, entry.evaluation))
627 }
628
629 pub fn iter_positions(&self) -> impl Iterator<Item = (&Array1<f32>, f32)> {
631 self.positions
632 .iter()
633 .map(|entry| (&entry.vector, entry.evaluation))
634 }
635
636 pub fn build_cluster_tree(&mut self) {
638 if self.positions.is_empty() {
639 self.cluster_tree = None;
640 return;
641 }
642
643 let indices: Vec<usize> = (0..self.positions.len()).collect();
644 self.cluster_tree = Some(self.build_cluster_recursive(indices, 0));
645 }
646
647 fn build_cluster_recursive(&self, indices: Vec<usize>, depth: usize) -> ClusterNode {
649 let max_depth = 10;
650 let min_cluster_size = 32;
651
652 if indices.len() <= min_cluster_size || depth >= max_depth {
653 let centroid = self.compute_centroid(&indices);
655 let radius = self.compute_cluster_radius(¢roid, &indices);
656
657 return ClusterNode {
658 centroid,
659 position_indices: indices.clone(),
660 children: Vec::new(),
661 radius,
662 size: indices.len(),
663 };
664 }
665
666 let k = if indices.len() > 200 { 4 } else { 2 };
668 let clusters = self.k_means_clustering(&indices, k);
669
670 let mut children = Vec::new();
671 let mut all_indices = Vec::new();
672
673 for cluster_indices in clusters {
674 if !cluster_indices.is_empty() {
675 let child = self.build_cluster_recursive(cluster_indices.clone(), depth + 1);
676 all_indices.extend(cluster_indices);
677 children.push(child);
678 }
679 }
680
681 let centroid = self.compute_centroid(&all_indices);
682 let radius = self.compute_cluster_radius(¢roid, &all_indices);
683
684 ClusterNode {
685 centroid,
686 position_indices: all_indices,
687 children,
688 radius,
689 size: indices.len(),
690 }
691 }
692
693 fn compute_centroid(&self, indices: &[usize]) -> Array1<f32> {
695 if indices.is_empty() {
696 return Array1::zeros(self.vector_size);
697 }
698
699 let mut centroid = Array1::zeros(self.vector_size);
700 for &idx in indices {
701 centroid = SimdVectorOps::add_vectors(¢roid, &self.positions[idx].vector);
702 }
703
704 SimdVectorOps::scale_vector(¢roid, 1.0 / indices.len() as f32)
705 }
706
707 fn compute_cluster_radius(&self, centroid: &Array1<f32>, indices: &[usize]) -> f32 {
709 indices
710 .iter()
711 .map(|&idx| 1.0 - self.cosine_similarity_cached(centroid, &self.positions[idx].vector))
712 .fold(0.0, f32::max)
713 }
714
715 fn k_means_clustering(&self, indices: &[usize], k: usize) -> Vec<Vec<usize>> {
717 if indices.len() <= k {
718 return indices.iter().map(|&i| vec![i]).collect();
719 }
720
721 let mut centroids = Vec::new();
723 let step = indices.len() / k;
724 for i in 0..k {
725 let idx = indices[i * step];
726 centroids.push(self.positions[idx].vector.clone());
727 }
728
729 const MAX_ITERATIONS: usize = 10;
730
731 for _ in 0..MAX_ITERATIONS {
732 let mut clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
734
735 for &idx in indices {
736 let mut best_cluster = 0;
737 let mut best_similarity = -1.0;
738
739 for (cluster_idx, centroid) in centroids.iter().enumerate() {
740 let similarity =
741 self.cosine_similarity_cached(centroid, &self.positions[idx].vector);
742 if similarity > best_similarity {
743 best_similarity = similarity;
744 best_cluster = cluster_idx;
745 }
746 }
747
748 clusters[best_cluster].push(idx);
749 }
750
751 let mut converged = true;
753 for (cluster_idx, cluster) in clusters.iter().enumerate() {
754 if !cluster.is_empty() {
755 let new_centroid = self.compute_centroid(cluster);
756 let similarity =
757 self.cosine_similarity_cached(¢roids[cluster_idx], &new_centroid);
758
759 if similarity < 0.99 {
760 converged = false;
761 }
762
763 centroids[cluster_idx] = new_centroid;
764 }
765 }
766
767 if converged {
768 break;
769 }
770 }
771
772 let mut final_clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
774 for &idx in indices {
775 let mut best_cluster = 0;
776 let mut best_similarity = -1.0;
777
778 for (cluster_idx, centroid) in centroids.iter().enumerate() {
779 let similarity =
780 self.cosine_similarity_cached(centroid, &self.positions[idx].vector);
781 if similarity > best_similarity {
782 best_similarity = similarity;
783 best_cluster = cluster_idx;
784 }
785 }
786
787 final_clusters[best_cluster].push(idx);
788 }
789
790 final_clusters
791 .into_iter()
792 .filter(|cluster| !cluster.is_empty())
793 .collect()
794 }
795
796 fn hierarchical_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
798 if self.cluster_tree.is_none() {
800 return self.parallel_search(query, k);
802 }
803
804 let cluster_tree = self.cluster_tree.as_ref().unwrap();
805 let mut candidates = Vec::new();
806
807 self.traverse_cluster_tree(query, cluster_tree, &mut candidates, k * 5);
809
810 let mut results: Vec<_> = candidates
812 .into_iter()
813 .map(|idx| {
814 let entry = &self.positions[idx];
815 let similarity = self.cosine_similarity_cached(query, &entry.vector);
816 (entry.vector.clone(), entry.evaluation, similarity)
817 })
818 .collect();
819
820 results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
822 results.truncate(k);
823
824 results
825 }
826
827 fn hierarchical_search_ref(
829 &self,
830 query: &Array1<f32>,
831 k: usize,
832 ) -> Vec<(&Array1<f32>, f32, f32)> {
833 if self.cluster_tree.is_none() {
835 return self.parallel_search_ref(query, k);
837 }
838
839 let cluster_tree = self.cluster_tree.as_ref().unwrap();
840 let mut candidates = Vec::new();
841
842 self.traverse_cluster_tree(query, cluster_tree, &mut candidates, k * 5);
844
845 let mut results: Vec<_> = candidates
847 .into_iter()
848 .map(|idx| {
849 let entry = &self.positions[idx];
850 let similarity = self.cosine_similarity_cached(query, &entry.vector);
851 (&entry.vector, entry.evaluation, similarity)
852 })
853 .collect();
854
855 results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
857 results.truncate(k);
858
859 results
860 }
861
862 fn traverse_cluster_tree(
864 &self,
865 query: &Array1<f32>,
866 node: &ClusterNode,
867 candidates: &mut Vec<usize>,
868 max_candidates: usize,
869 ) {
870 if candidates.len() >= max_candidates {
871 return;
872 }
873
874 let centroid_similarity = self.cosine_similarity_cached(query, &node.centroid);
876
877 let distance_threshold = 0.1; if centroid_similarity < distance_threshold {
880 return;
881 }
882
883 if node.children.is_empty() {
884 for &idx in &node.position_indices {
886 if candidates.len() < max_candidates {
887 candidates.push(idx);
888 }
889 }
890 } else {
891 let mut child_similarities: Vec<_> = node
893 .children
894 .iter()
895 .enumerate()
896 .map(|(i, child)| {
897 let similarity = self.cosine_similarity_cached(query, &child.centroid);
898 (i, similarity)
899 })
900 .collect();
901
902 child_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
903
904 for (child_idx, _) in child_similarities {
906 self.traverse_cluster_tree(
907 query,
908 &node.children[child_idx],
909 candidates,
910 max_candidates,
911 );
912 }
913 }
914 }
915
916 fn cosine_similarity_cached(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
918 let a_norm_sq = SimdVectorOps::squared_norm(a);
920 let b_norm_sq = SimdVectorOps::squared_norm(b);
921
922 if a_norm_sq == 0.0 || b_norm_sq == 0.0 {
923 return 0.0;
924 }
925
926 let dot_product = SimdVectorOps::dot_product(a, b);
927 dot_product / (a_norm_sq.sqrt() * b_norm_sq.sqrt())
928 }
929
930 pub fn rebuild_cluster_tree(&mut self) {
932 self.cluster_tree = None;
933 self.build_cluster_tree();
934 }
935
936 pub fn cluster_tree_stats(&self) -> Option<ClusterTreeStats> {
938 self.cluster_tree.as_ref().map(|tree| {
939 let mut stats = ClusterTreeStats {
940 total_nodes: 0,
941 leaf_nodes: 0,
942 max_depth: 0,
943 avg_cluster_size: 0.0,
944 max_cluster_size: 0,
945 };
946
947 self.collect_cluster_stats(tree, 0, &mut stats);
948
949 if stats.leaf_nodes > 0 {
950 stats.avg_cluster_size = self.positions.len() as f32 / stats.leaf_nodes as f32;
951 }
952
953 stats
954 })
955 }
956
957 fn collect_cluster_stats(
959 &self,
960 node: &ClusterNode,
961 depth: usize,
962 stats: &mut ClusterTreeStats,
963 ) {
964 stats.total_nodes += 1;
965 stats.max_depth = stats.max_depth.max(depth);
966 stats.max_cluster_size = stats.max_cluster_size.max(node.size);
967
968 if node.children.is_empty() {
969 stats.leaf_nodes += 1;
970 } else {
971 for child in &node.children {
972 self.collect_cluster_stats(child, depth + 1, stats);
973 }
974 }
975 }
976
977 fn hash_query(&self, query: &Array1<f32>, k: usize) -> u64 {
981 use std::collections::hash_map::DefaultHasher;
982 use std::hash::{Hash, Hasher};
983
984 let mut hasher = DefaultHasher::new();
985
986 for i in (0..query.len()).step_by(16) { ((query[i] * 1000.0) as i32).hash(&mut hasher);
989 }
990 k.hash(&mut hasher);
991 self.positions.len().hash(&mut hasher); hasher.finish()
994 }
995
996 fn get_cached_result(&mut self, query_hash: u64) -> Option<Vec<(Array1<f32>, f32, f32)>> {
998 let now = Instant::now();
999
1000 if let Some(cached_entry) = self.result_cache.get(&query_hash) {
1001 if now.duration_since(cached_entry.timestamp) < self.cache_ttl {
1002 self.cache_hits += 1;
1003 return Some(cached_entry.results.clone());
1004 } else {
1005 self.result_cache.remove(&query_hash);
1007 }
1008 }
1009
1010 self.cache_misses += 1;
1011 None
1012 }
1013
1014 fn cache_search_result(&mut self, query_hash: u64, results: Vec<(Array1<f32>, f32, f32)>) {
1016 let now = Instant::now();
1017
1018 self.result_cache.insert(query_hash, SearchResultCache {
1019 results,
1020 timestamp: now,
1021 });
1022
1023 if self.result_cache.len() > self.max_cache_size / 10 {
1025 self.evict_oldest_result_cache_entries();
1026 }
1027 }
1028
1029 fn evict_expired_cache_entries(&mut self) {
1031 let now = Instant::now();
1032
1033 self.similarity_cache.retain(|_, (_, cached_time)| {
1035 now.duration_since(*cached_time) < self.cache_ttl
1036 });
1037
1038 self.result_cache.retain(|_, cached_result| {
1040 now.duration_since(cached_result.timestamp) < self.cache_ttl
1041 });
1042 }
1043
1044 fn evict_oldest_cache_entries(&mut self) {
1046 let entries_to_remove = self.similarity_cache.len() / 4;
1048 if entries_to_remove > 0 {
1049 let mut entries: Vec<_> = self.similarity_cache.iter().map(|(k, v)| (*k, *v)).collect();
1050 entries.sort_by_key(|(_, (_, time))| *time);
1051
1052 for i in 0..entries_to_remove {
1053 if let Some((key, _)) = entries.get(i) {
1054 self.similarity_cache.remove(key);
1055 }
1056 }
1057 }
1058 }
1059
1060 fn evict_oldest_result_cache_entries(&mut self) {
1062 let entries_to_remove = self.result_cache.len() / 4;
1064 if entries_to_remove > 0 {
1065 let mut entries: Vec<_> = self.result_cache.iter().map(|(k, v)| (*k, v.timestamp)).collect();
1066 entries.sort_by_key(|(_, timestamp)| *timestamp);
1067
1068 for i in 0..entries_to_remove {
1069 if let Some((key, _)) = entries.get(i) {
1070 self.result_cache.remove(key);
1071 }
1072 }
1073 }
1074 }
1075
1076 pub fn get_cache_stats(&self) -> SimilarityCacheStats {
1078 let hit_ratio = if self.cache_hits + self.cache_misses > 0 {
1079 self.cache_hits as f32 / (self.cache_hits + self.cache_misses) as f32
1080 } else {
1081 0.0
1082 };
1083
1084 SimilarityCacheStats {
1085 result_cache_size: self.result_cache.len(),
1086 similarity_cache_size: self.similarity_cache.len(),
1087 max_cache_size: self.max_cache_size,
1088 cache_ttl_secs: self.cache_ttl.as_secs(),
1089 cache_hits: self.cache_hits,
1090 cache_misses: self.cache_misses,
1091 hit_ratio,
1092 }
1093 }
1094
1095 pub fn clear_caches(&mut self) {
1097 self.similarity_cache.clear();
1098 self.result_cache.clear();
1099 self.cache_hits = 0;
1100 self.cache_misses = 0;
1101 }
1102
1103 pub fn reset_cache_stats(&mut self) {
1105 self.cache_hits = 0;
1106 self.cache_misses = 0;
1107 }
1108}
1109
1110#[derive(Debug, Clone)]
1112pub struct SimilaritySearchStats {
1113 pub count: usize,
1114 pub avg_evaluation: f32,
1115 pub min_evaluation: f32,
1116 pub max_evaluation: f32,
1117}
1118
1119#[derive(Debug, Clone)]
1121pub struct ClusterTreeStats {
1122 pub total_nodes: usize,
1123 pub leaf_nodes: usize,
1124 pub max_depth: usize,
1125 pub avg_cluster_size: f32,
1126 pub max_cluster_size: usize,
1127}
1128
1129impl SimilaritySearch {
1130 pub fn search_optimized(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
1132 assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
1133
1134 if self.positions.is_empty() {
1135 return Vec::new();
1136 }
1137
1138 let query_norm_squared = SimdVectorOps::squared_norm(query);
1140 let query_norm = query_norm_squared.sqrt();
1141
1142 if k <= 10 && self.positions.len() > k * 10 {
1144 return self.search_with_bounded_heap(query, query_norm_squared, k);
1145 }
1146
1147 self.search_parallel_optimized(query, query_norm, k)
1149 }
1150
1151 fn search_with_bounded_heap(
1153 &self,
1154 query: &Array1<f32>,
1155 query_norm_squared: f32,
1156 k: usize,
1157 ) -> Vec<(Array1<f32>, f32, f32)> {
1158 let mut heap = BinaryHeap::with_capacity(k + 1);
1159 let mut min_similarity = f32::NEG_INFINITY;
1160
1161 for entry in &self.positions {
1162 if heap.len() == k && self.can_skip_entry(query, entry, min_similarity) {
1164 continue;
1165 }
1166
1167 let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
1168
1169 let result = SearchResult {
1170 similarity,
1171 evaluation: entry.evaluation,
1172 vector: entry.vector.clone(),
1173 };
1174
1175 if heap.len() < k {
1176 if heap.is_empty() || similarity < min_similarity {
1177 min_similarity = similarity;
1178 }
1179 heap.push(result);
1180 } else if similarity > min_similarity {
1181 heap.pop(); heap.push(result);
1183 min_similarity = heap.peek().map(|r| r.similarity).unwrap_or(f32::NEG_INFINITY);
1185 }
1186 }
1187
1188 self.heap_to_sorted_results(heap)
1190 }
1191
1192 fn search_parallel_optimized(
1194 &self,
1195 query: &Array1<f32>,
1196 query_norm: f32,
1197 k: usize,
1198 ) -> Vec<(Array1<f32>, f32, f32)> {
1199 let chunk_size = (self.positions.len() / rayon::current_num_threads()).max(1000);
1201
1202 let mut results: Vec<_> = self
1203 .positions
1204 .par_chunks(chunk_size)
1205 .flat_map(|chunk| {
1206 chunk.par_iter().map(|entry| {
1207 let entry_norm = entry.norm_squared.sqrt();
1208 let similarity = self.cosine_similarity_ultra_fast(query, query_norm, entry, entry_norm);
1209 (entry.vector.clone(), entry.evaluation, similarity)
1210 })
1211 })
1212 .collect();
1213
1214 if k * 10 < results.len() {
1216 results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1218 results.truncate(k);
1219 } else {
1220 results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1222 results.truncate(k);
1223 }
1224
1225 results
1226 }
1227
1228 fn can_skip_entry(&self, _query: &Array1<f32>, _entry: &PositionEntry, min_similarity: f32) -> bool {
1230 min_similarity > 0.95 }
1237
1238 fn heap_to_sorted_results(&self, mut heap: BinaryHeap<SearchResult>) -> Vec<(Array1<f32>, f32, f32)> {
1240 let mut results = Vec::with_capacity(heap.len());
1241 while let Some(result) = heap.pop() {
1242 results.push((result.vector, result.evaluation, result.similarity));
1243 }
1244 results.reverse(); results
1246 }
1247
1248 pub fn batch_search_optimized(
1250 &self,
1251 queries: &[Array1<f32>],
1252 k: usize,
1253 ) -> Vec<Vec<(Array1<f32>, f32, f32)>> {
1254 if queries.is_empty() || self.positions.is_empty() {
1255 return vec![Vec::new(); queries.len()];
1256 }
1257
1258 let query_norms: Vec<f32> = queries
1260 .par_iter()
1261 .map(|q| SimdVectorOps::squared_norm(q).sqrt())
1262 .collect();
1263
1264 queries
1266 .par_iter()
1267 .zip(query_norms.par_iter())
1268 .map(|(query, &query_norm)| {
1269 self.search_parallel_optimized(query, query_norm, k)
1270 })
1271 .collect()
1272 }
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277 use super::*;
1278 use ndarray::Array1;
1279
1280 #[test]
1281 fn test_similarity_search_creation() {
1282 let search = SimilaritySearch::new(100);
1283 assert_eq!(search.size(), 0);
1284 assert!(search.is_empty());
1285 }
1286
1287 #[test]
1288 fn test_add_and_search() {
1289 let mut search = SimilaritySearch::new(3);
1290
1291 let vec1 = Array1::from(vec![1.0, 0.0, 0.0]);
1293 let vec2 = Array1::from(vec![0.0, 1.0, 0.0]);
1294 let vec3 = Array1::from(vec![0.0, 0.0, 1.0]);
1295
1296 search.add_position(vec1.clone(), 1.0);
1297 search.add_position(vec2, 0.5);
1298 search.add_position(vec3, 0.0);
1299
1300 assert_eq!(search.size(), 3);
1301
1302 let results = search.search(&vec1, 2);
1304 assert_eq!(results.len(), 2);
1305
1306 assert!((results[0].2 - 1.0).abs() < 1e-6);
1308 assert!((results[0].1 - 1.0).abs() < 1e-6);
1309 }
1310
1311 #[test]
1312 fn test_cosine_similarity() {
1313 let search = SimilaritySearch::new(2);
1314
1315 let vec1 = Array1::from(vec![1.0, 0.0]);
1316 let vec2 = Array1::from(vec![1.0, 0.0]);
1317 let vec3 = Array1::from(vec![0.0, 1.0]);
1318
1319 assert!((search.cosine_similarity(&vec1, &vec2) - 1.0).abs() < 1e-6);
1321
1322 assert!((search.cosine_similarity(&vec1, &vec3) - 0.0).abs() < 1e-6);
1324 }
1325
1326 #[test]
1327 fn test_statistics() {
1328 let mut search = SimilaritySearch::new(2);
1329
1330 let vec = Array1::from(vec![1.0, 0.0]);
1331 search.add_position(vec.clone(), 1.0);
1332 search.add_position(vec.clone(), 2.0);
1333 search.add_position(vec, 3.0);
1334
1335 let stats = search.statistics();
1336 assert_eq!(stats.count, 3);
1337 assert!((stats.avg_evaluation - 2.0).abs() < 1e-6);
1338 assert!((stats.min_evaluation - 1.0).abs() < 1e-6);
1339 assert!((stats.max_evaluation - 3.0).abs() < 1e-6);
1340 }
1341}