chess_vector_engine/
similarity_search.rs

1#![allow(clippy::type_complexity)]
2use crate::gpu_acceleration::GPUAccelerator;
3use crate::utils::simd::SimdVectorOps;
4use ndarray::{Array1, Array2};
5use rayon::prelude::*;
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap};
8use std::time::{Duration, Instant};
9// Removed unused import
10
11/// Entry in the similarity search index
12#[derive(Debug, Clone)]
13pub struct PositionEntry {
14    pub vector: Array1<f32>,
15    pub evaluation: f32,
16    pub norm_squared: f32,
17}
18
19/// Result from similarity search (reference-based)
20#[derive(Debug)]
21pub struct SearchResultRef<'a> {
22    pub similarity: f32,
23    pub evaluation: f32,
24    pub vector: &'a Array1<f32>,
25}
26
27/// Result from similarity search (owned)
28#[derive(Debug, Clone)]
29pub struct SearchResult {
30    pub similarity: f32,
31    pub evaluation: f32,
32    pub vector: Array1<f32>,
33}
34
35impl PartialEq for SearchResult {
36    fn eq(&self, other: &Self) -> bool {
37        self.similarity == other.similarity
38    }
39}
40
41impl Eq for SearchResult {}
42
43impl PartialOrd for SearchResult {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49impl Ord for SearchResult {
50    fn cmp(&self, other: &Self) -> Ordering {
51        // Reverse ordering for max-heap behavior in BinaryHeap
52        other
53            .similarity
54            .partial_cmp(&self.similarity)
55            .unwrap_or(Ordering::Equal)
56    }
57}
58
59/// Hierarchical clustering node for improved search performance
60#[derive(Debug, Clone)]
61pub struct ClusterNode {
62    /// Centroid of the cluster
63    pub centroid: Array1<f32>,
64    /// Indices of positions in this cluster
65    pub position_indices: Vec<usize>,
66    /// Child clusters (for hierarchical clustering)
67    pub children: Vec<ClusterNode>,
68    /// Cluster radius (maximum distance from centroid)
69    pub radius: f32,
70    /// Number of positions in this cluster (including children)
71    pub size: usize,
72}
73
74/// Cache entry for similarity search results with TTL
75#[derive(Debug, Clone)]
76pub struct SearchResultCache {
77    pub results: Vec<(Array1<f32>, f32, f32)>,
78    pub timestamp: Instant,
79}
80
81/// Cache statistics for monitoring performance
82#[derive(Debug, Clone)]
83pub struct SimilarityCacheStats {
84    pub result_cache_size: usize,
85    pub similarity_cache_size: usize,
86    pub max_cache_size: usize,
87    pub cache_ttl_secs: u64,
88    pub cache_hits: u64,
89    pub cache_misses: u64,
90    pub hit_ratio: f32,
91}
92
93/// Similarity search engine for chess positions with production-optimized caching
94#[derive(Clone)]
95pub struct SimilaritySearch {
96    /// All stored positions
97    positions: Vec<PositionEntry>,
98    /// Dimension of vectors
99    vector_size: usize,
100    /// Hierarchical clustering tree for fast search
101    cluster_tree: Option<ClusterNode>,
102    /// Cache for frequently accessed similarity results (pairwise similarities)
103    similarity_cache: HashMap<(usize, usize), (f32, Instant)>,
104    /// Cache for complete search results (query_hash -> results)
105    result_cache: HashMap<u64, SearchResultCache>,
106    /// Maximum cache size to prevent memory bloat
107    max_cache_size: usize,
108    /// TTL for cached results
109    cache_ttl: Duration,
110    /// Cache performance metrics
111    cache_hits: u64,
112    cache_misses: u64,
113}
114
115impl SimilaritySearch {
116    /// Create a new similarity search engine with default caching settings
117    pub fn new(vector_size: usize) -> Self {
118        Self {
119            positions: Vec::new(),
120            vector_size,
121            cluster_tree: None,
122            similarity_cache: HashMap::with_capacity(10000),
123            result_cache: HashMap::with_capacity(1000),
124            max_cache_size: 10000,
125            cache_ttl: Duration::from_secs(300), // 5 minutes
126            cache_hits: 0,
127            cache_misses: 0,
128        }
129    }
130
131    /// Create a new similarity search engine with custom cache configuration
132    pub fn with_cache_config(vector_size: usize, max_cache_size: usize, cache_ttl_secs: u64) -> Self {
133        Self {
134            positions: Vec::new(),
135            vector_size,
136            cluster_tree: None,
137            similarity_cache: HashMap::with_capacity(max_cache_size),
138            result_cache: HashMap::with_capacity(max_cache_size / 10),
139            max_cache_size,
140            cache_ttl: Duration::from_secs(cache_ttl_secs),
141            cache_hits: 0,
142            cache_misses: 0,
143        }
144    }
145
146    /// Add a position to the search index
147    pub fn add_position(&mut self, vector: Array1<f32>, evaluation: f32) {
148        assert_eq!(vector.len(), self.vector_size, "Vector size mismatch");
149
150        let norm_squared = SimdVectorOps::squared_norm(&vector);
151
152        self.positions.push(PositionEntry {
153            vector,
154            evaluation,
155            norm_squared,
156        });
157
158        // Invalidate cluster tree when adding new positions
159        self.cluster_tree = None;
160
161        // Evict expired cache entries and manage cache size
162        self.evict_expired_cache_entries();
163        if self.similarity_cache.len() > self.max_cache_size {
164            self.evict_oldest_cache_entries();
165        }
166    }
167
168    /// Search for k most similar positions with references (memory efficient)
169    pub fn search_ref(&self, query: &Array1<f32>, k: usize) -> Vec<(&Array1<f32>, f32, f32)> {
170        // Note: GPU search not supported for reference version due to lifetime constraints
171        // Fall back to CPU-based search methods
172
173        // Use hierarchical clustering for large datasets
174        if self.positions.len() > 1000 {
175            self.hierarchical_search_ref(query, k)
176        } else if self.positions.len() > 100 {
177            self.parallel_search_ref(query, k)
178        } else {
179            self.sequential_search_ref(query, k)
180        }
181    }
182
183    /// Search for k most similar positions with comprehensive caching (automatically chooses best method)
184    pub fn search(&mut self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
185        // Generate cache key from query vector and k
186        let query_hash = self.hash_query(query, k);
187        
188        // Check result cache first
189        if let Some(cached_result) = self.get_cached_result(query_hash) {
190            return cached_result;
191        }
192        
193        // Cache miss - perform actual search
194        let results = self.search_uncached(query, k);
195        
196        // Cache the results for future use
197        self.cache_search_result(query_hash, results.clone());
198        
199        results
200    }
201    
202    /// Internal search method without caching (for cache misses)
203    fn search_uncached(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
204        // Use optimized search as primary method for better performance
205        if self.positions.len() > 50 {
206            return self.search_optimized(query, k);
207        }
208
209        let gpu_accelerator = GPUAccelerator::global();
210
211        // Use GPU acceleration for large datasets if available
212        if gpu_accelerator.is_gpu_enabled() && self.positions.len() > 500 {
213            match self.gpu_accelerated_search(query, k) {
214                Ok(results) => return results,
215                Err(e) => {
216                    println!("GPU search failed ({e}), falling back to CPU");
217                }
218            }
219        }
220
221        // Use hierarchical clustering for large datasets
222        if self.positions.len() > 1000 {
223            self.hierarchical_search(query, k)
224        } else if self.positions.len() > 100 {
225            self.parallel_search(query, k)
226        } else {
227            self.sequential_search(query, k)
228        }
229    }
230
231    /// GPU-accelerated similarity search for large datasets
232    pub fn gpu_accelerated_search(
233        &self,
234        query: &Array1<f32>,
235        k: usize,
236    ) -> Result<Vec<(Array1<f32>, f32, f32)>, Box<dyn std::error::Error>> {
237        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
238
239        if self.positions.is_empty() {
240            return Ok(Vec::new());
241        }
242
243        let gpu_accelerator = GPUAccelerator::global();
244
245        // Prepare vectors matrix for GPU computation
246        let mut vectors_data = Vec::with_capacity(self.positions.len() * self.vector_size);
247        for entry in &self.positions {
248            vectors_data.extend_from_slice(entry.vector.as_slice().unwrap());
249        }
250
251        let vectors_matrix =
252            Array2::from_shape_vec((self.positions.len(), self.vector_size), vectors_data)?;
253
254        // Compute similarities on GPU
255        let similarities = gpu_accelerator.cosine_similarity_batch(query, &vectors_matrix)?;
256
257        // Find top-k results
258        let mut indexed_similarities: Vec<(usize, f32)> = similarities
259            .iter()
260            .enumerate()
261            .map(|(i, &sim)| (i, sim))
262            .collect();
263
264        // Sort by similarity (descending)
265        indexed_similarities
266            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
267
268        // Take top-k and prepare results
269        let mut results = Vec::new();
270        for (idx, similarity) in indexed_similarities.into_iter().take(k) {
271            let entry = &self.positions[idx];
272            results.push((entry.vector.clone(), entry.evaluation, similarity));
273        }
274
275        Ok(results)
276    }
277
278    /// Sequential search implementation with references (memory efficient)
279    pub fn sequential_search_ref(
280        &self,
281        query: &Array1<f32>,
282        k: usize,
283    ) -> Vec<(&Array1<f32>, f32, f32)> {
284        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
285
286        if self.positions.is_empty() {
287            return Vec::new();
288        }
289
290        let query_norm_squared = SimdVectorOps::squared_norm(query);
291
292        // Collect all similarities with indices
293        let mut indexed_similarities: Vec<(usize, f32)> = self
294            .positions
295            .iter()
296            .enumerate()
297            .map(|(idx, entry)| {
298                let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
299                (idx, similarity)
300            })
301            .collect();
302
303        // Sort by similarity (descending)
304        indexed_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
305
306        // Take top k and return references
307        indexed_similarities
308            .into_iter()
309            .take(k)
310            .map(|(idx, similarity)| {
311                let entry = &self.positions[idx];
312                (&entry.vector, entry.evaluation, similarity)
313            })
314            .collect()
315    }
316
317    /// Sequential search implementation (for small datasets)
318    pub fn sequential_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
319        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
320
321        if self.positions.is_empty() {
322            return Vec::new();
323        }
324
325        let query_norm_squared = SimdVectorOps::squared_norm(query);
326
327        // Use a min-heap to keep track of top-k results
328        let mut heap = BinaryHeap::new();
329
330        for entry in &self.positions {
331            let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
332
333            let result = SearchResult {
334                similarity,
335                evaluation: entry.evaluation,
336                vector: entry.vector.clone(),
337            };
338
339            if heap.len() < k {
340                heap.push(result);
341            } else if similarity > heap.peek().unwrap().similarity {
342                heap.pop();
343                heap.push(result);
344            }
345        }
346
347        // Convert heap to sorted vector (highest similarity first)
348        let mut results = Vec::new();
349        while let Some(result) = heap.pop() {
350            results.push((result.vector, result.evaluation, result.similarity));
351        }
352
353        // Reverse to get highest similarity first
354        results.reverse();
355        results
356    }
357
358    /// Parallel search implementation with references (memory efficient)
359    pub fn parallel_search_ref(
360        &self,
361        query: &Array1<f32>,
362        k: usize,
363    ) -> Vec<(&Array1<f32>, f32, f32)> {
364        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
365
366        if self.positions.is_empty() {
367            return Vec::new();
368        }
369
370        let query_norm_squared = SimdVectorOps::squared_norm(query);
371
372        // Calculate similarities in parallel with indices
373        let mut indexed_similarities: Vec<(usize, f32)> = self
374            .positions
375            .par_iter()
376            .enumerate()
377            .map(|(idx, entry)| {
378                let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
379                (idx, similarity)
380            })
381            .collect();
382
383        // Sort by similarity (descending) and take top k
384        indexed_similarities
385            .par_sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
386        indexed_similarities.truncate(k);
387
388        // Return references instead of clones
389        indexed_similarities
390            .into_iter()
391            .map(|(idx, similarity)| {
392                let entry = &self.positions[idx];
393                (&entry.vector, entry.evaluation, similarity)
394            })
395            .collect()
396    }
397
398    /// Parallel search implementation (for larger datasets)
399    pub fn parallel_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
400        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
401
402        if self.positions.is_empty() {
403            return Vec::new();
404        }
405
406        let query_norm_squared = SimdVectorOps::squared_norm(query);
407
408        // Calculate similarities in parallel
409        let mut results: Vec<_> = self
410            .positions
411            .par_iter()
412            .map(|entry| {
413                let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
414                (entry.vector.clone(), entry.evaluation, similarity)
415            })
416            .collect();
417
418        // Sort by similarity (descending) and take top k
419        results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
420        results.truncate(k);
421
422        results
423    }
424
425    /// Brute force search (for small datasets or comparison)
426    pub fn brute_force_search(
427        &self,
428        query: &Array1<f32>,
429        k: usize,
430    ) -> Vec<(Array1<f32>, f32, f32)> {
431        let mut results: Vec<_> = if self.positions.len() > 100 {
432            // Use parallel processing for larger datasets
433            self.positions
434                .par_iter()
435                .map(|entry| {
436                    let similarity = self.cosine_similarity(query, &entry.vector);
437                    (entry.vector.clone(), entry.evaluation, similarity)
438                })
439                .collect()
440        } else {
441            // Use sequential processing for smaller datasets
442            self.positions
443                .iter()
444                .map(|entry| {
445                    let similarity = self.cosine_similarity(query, &entry.vector);
446                    (entry.vector.clone(), entry.evaluation, similarity)
447                })
448                .collect()
449        };
450
451        // Sort by similarity (descending)
452        if results.len() > 1000 {
453            results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
454        } else {
455            results.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
456        }
457
458        // Take top k
459        results.truncate(k);
460        results
461    }
462
463    /// Calculate cosine similarity between query vector and a position entry with caching (SIMD optimized)
464    fn cosine_similarity_fast(
465        &mut self,
466        query: &Array1<f32>,
467        query_norm_squared: f32,
468        entry_index: usize,
469    ) -> f32 {
470        // Check cache for pairwise similarity
471        let now = Instant::now();
472        let cache_key = (0, entry_index); // Using 0 as query index placeholder
473        
474        if let Some((cached_similarity, cached_time)) = self.similarity_cache.get(&cache_key) {
475            if now.duration_since(*cached_time) < self.cache_ttl {
476                self.cache_hits += 1;
477                return *cached_similarity;
478            }
479        }
480        
481        // Cache miss - compute similarity
482        self.cache_misses += 1;
483        let entry = &self.positions[entry_index];
484        
485        // Early termination for zero vectors
486        if query_norm_squared == 0.0 || entry.norm_squared == 0.0 {
487            return 0.0;
488        }
489
490        let dot_product = SimdVectorOps::dot_product(query, &entry.vector);
491        
492        // Pre-computed inverse square roots for better performance
493        let query_norm_inv = 1.0 / query_norm_squared.sqrt();
494        let entry_norm_inv = 1.0 / entry.norm_squared.sqrt();
495        
496        let similarity = dot_product * query_norm_inv * entry_norm_inv;
497        
498        // Cache the result
499        self.similarity_cache.insert(cache_key, (similarity, now));
500        
501        similarity
502    }
503    
504    /// Calculate cosine similarity between query vector and a position entry (uncached version)
505    fn cosine_similarity_fast_uncached(
506        &self,
507        query: &Array1<f32>,
508        query_norm_squared: f32,
509        entry: &PositionEntry,
510    ) -> f32 {
511        // Early termination for zero vectors
512        if query_norm_squared == 0.0 || entry.norm_squared == 0.0 {
513            return 0.0;
514        }
515
516        let dot_product = SimdVectorOps::dot_product(query, &entry.vector);
517        
518        // Pre-computed inverse square roots for better performance
519        let query_norm_inv = 1.0 / query_norm_squared.sqrt();
520        let entry_norm_inv = 1.0 / entry.norm_squared.sqrt();
521        
522        dot_product * query_norm_inv * entry_norm_inv
523    }
524
525    /// Ultra-fast similarity calculation with pre-computed norms (avoids sqrt when possible)
526    fn cosine_similarity_ultra_fast(
527        &self,
528        query: &Array1<f32>,
529        query_norm: f32,
530        entry: &PositionEntry,
531        entry_norm: f32,
532    ) -> f32 {
533        if query_norm == 0.0 || entry_norm == 0.0 {
534            return 0.0;
535        }
536
537        let dot_product = SimdVectorOps::dot_product(query, &entry.vector);
538        dot_product / (query_norm * entry_norm)
539    }
540
541    /// Calculate cosine similarity between two vectors (SIMD optimized)
542    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
543        SimdVectorOps::cosine_similarity(a, b)
544    }
545
546    /// Calculate Euclidean distance between two vectors
547    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
548        (a - b).mapv(|x| x * x).sum().sqrt()
549    }
550
551    /// Search using Euclidean distance (alternative to cosine similarity)
552    pub fn search_by_distance(
553        &self,
554        query: &Array1<f32>,
555        k: usize,
556    ) -> Vec<(Array1<f32>, f32, f32)> {
557        let mut results: Vec<_> = self
558            .positions
559            .iter()
560            .map(|entry| {
561                let distance = self.euclidean_distance(query, &entry.vector);
562                (entry.vector.clone(), entry.evaluation, distance)
563            })
564            .collect();
565
566        // Sort by distance (ascending - smaller distance = more similar)
567        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
568
569        // Take top k
570        results.truncate(k);
571        results
572    }
573
574    /// Get number of positions in the index
575    pub fn size(&self) -> usize {
576        self.positions.len()
577    }
578
579    /// Check if the index is empty
580    pub fn is_empty(&self) -> bool {
581        self.positions.is_empty()
582    }
583
584    /// Clear all positions
585    pub fn clear(&mut self) {
586        self.positions.clear();
587    }
588
589    /// Get statistics about the stored vectors
590    pub fn statistics(&self) -> SimilaritySearchStats {
591        if self.positions.is_empty() {
592            return SimilaritySearchStats {
593                count: 0,
594                avg_evaluation: 0.0,
595                min_evaluation: 0.0,
596                max_evaluation: 0.0,
597            };
598        }
599
600        let evaluations: Vec<f32> = self.positions.iter().map(|p| p.evaluation).collect();
601        let sum: f32 = evaluations.iter().sum();
602        let avg = sum / evaluations.len() as f32;
603        let min = evaluations.iter().fold(f32::INFINITY, |a, &b| a.min(b));
604        let max = evaluations.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
605
606        SimilaritySearchStats {
607            count: self.positions.len(),
608            avg_evaluation: avg,
609            min_evaluation: min,
610            max_evaluation: max,
611        }
612    }
613
614    /// Get all stored positions (for LSH indexing)
615    pub fn get_all_positions(&self) -> Vec<(Array1<f32>, f32)> {
616        self.positions
617            .iter()
618            .map(|entry| (entry.vector.clone(), entry.evaluation))
619            .collect()
620    }
621
622    /// Get position vector by reference to avoid cloning
623    pub fn get_position_ref(&self, index: usize) -> Option<(&Array1<f32>, f32)> {
624        self.positions
625            .get(index)
626            .map(|entry| (&entry.vector, entry.evaluation))
627    }
628
629    /// Get all positions as references (memory efficient iterator)
630    pub fn iter_positions(&self) -> impl Iterator<Item = (&Array1<f32>, f32)> {
631        self.positions
632            .iter()
633            .map(|entry| (&entry.vector, entry.evaluation))
634    }
635
636    /// Build hierarchical clustering tree for improved search performance
637    pub fn build_cluster_tree(&mut self) {
638        if self.positions.is_empty() {
639            self.cluster_tree = None;
640            return;
641        }
642
643        let indices: Vec<usize> = (0..self.positions.len()).collect();
644        self.cluster_tree = Some(self.build_cluster_recursive(indices, 0));
645    }
646
647    /// Recursively build clustering tree using k-means-like approach
648    fn build_cluster_recursive(&self, indices: Vec<usize>, depth: usize) -> ClusterNode {
649        let max_depth = 10;
650        let min_cluster_size = 32;
651
652        if indices.len() <= min_cluster_size || depth >= max_depth {
653            // Leaf node - compute centroid and radius
654            let centroid = self.compute_centroid(&indices);
655            let radius = self.compute_cluster_radius(&centroid, &indices);
656
657            return ClusterNode {
658                centroid,
659                position_indices: indices.clone(),
660                children: Vec::new(),
661                radius,
662                size: indices.len(),
663            };
664        }
665
666        // Use k-means clustering to split into 2 or 4 clusters
667        let k = if indices.len() > 200 { 4 } else { 2 };
668        let clusters = self.k_means_clustering(&indices, k);
669
670        let mut children = Vec::new();
671        let mut all_indices = Vec::new();
672
673        for cluster_indices in clusters {
674            if !cluster_indices.is_empty() {
675                let child = self.build_cluster_recursive(cluster_indices.clone(), depth + 1);
676                all_indices.extend(cluster_indices);
677                children.push(child);
678            }
679        }
680
681        let centroid = self.compute_centroid(&all_indices);
682        let radius = self.compute_cluster_radius(&centroid, &all_indices);
683
684        ClusterNode {
685            centroid,
686            position_indices: all_indices,
687            children,
688            radius,
689            size: indices.len(),
690        }
691    }
692
693    /// Compute centroid of a cluster
694    fn compute_centroid(&self, indices: &[usize]) -> Array1<f32> {
695        if indices.is_empty() {
696            return Array1::zeros(self.vector_size);
697        }
698
699        let mut centroid = Array1::zeros(self.vector_size);
700        for &idx in indices {
701            centroid = SimdVectorOps::add_vectors(&centroid, &self.positions[idx].vector);
702        }
703
704        SimdVectorOps::scale_vector(&centroid, 1.0 / indices.len() as f32)
705    }
706
707    /// Compute radius of a cluster (maximum distance from centroid)
708    fn compute_cluster_radius(&self, centroid: &Array1<f32>, indices: &[usize]) -> f32 {
709        indices
710            .iter()
711            .map(|&idx| 1.0 - self.cosine_similarity_cached(centroid, &self.positions[idx].vector))
712            .fold(0.0, f32::max)
713    }
714
715    /// K-means clustering implementation
716    fn k_means_clustering(&self, indices: &[usize], k: usize) -> Vec<Vec<usize>> {
717        if indices.len() <= k {
718            return indices.iter().map(|&i| vec![i]).collect();
719        }
720
721        // Initialize centroids randomly
722        let mut centroids = Vec::new();
723        let step = indices.len() / k;
724        for i in 0..k {
725            let idx = indices[i * step];
726            centroids.push(self.positions[idx].vector.clone());
727        }
728
729        const MAX_ITERATIONS: usize = 10;
730
731        for _ in 0..MAX_ITERATIONS {
732            // Assign points to clusters
733            let mut clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
734
735            for &idx in indices {
736                let mut best_cluster = 0;
737                let mut best_similarity = -1.0;
738
739                for (cluster_idx, centroid) in centroids.iter().enumerate() {
740                    let similarity =
741                        self.cosine_similarity_cached(centroid, &self.positions[idx].vector);
742                    if similarity > best_similarity {
743                        best_similarity = similarity;
744                        best_cluster = cluster_idx;
745                    }
746                }
747
748                clusters[best_cluster].push(idx);
749            }
750
751            // Update centroids
752            let mut converged = true;
753            for (cluster_idx, cluster) in clusters.iter().enumerate() {
754                if !cluster.is_empty() {
755                    let new_centroid = self.compute_centroid(cluster);
756                    let similarity =
757                        self.cosine_similarity_cached(&centroids[cluster_idx], &new_centroid);
758
759                    if similarity < 0.99 {
760                        converged = false;
761                    }
762
763                    centroids[cluster_idx] = new_centroid;
764                }
765            }
766
767            if converged {
768                break;
769            }
770        }
771
772        // Final assignment
773        let mut final_clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
774        for &idx in indices {
775            let mut best_cluster = 0;
776            let mut best_similarity = -1.0;
777
778            for (cluster_idx, centroid) in centroids.iter().enumerate() {
779                let similarity =
780                    self.cosine_similarity_cached(centroid, &self.positions[idx].vector);
781                if similarity > best_similarity {
782                    best_similarity = similarity;
783                    best_cluster = cluster_idx;
784                }
785            }
786
787            final_clusters[best_cluster].push(idx);
788        }
789
790        final_clusters
791            .into_iter()
792            .filter(|cluster| !cluster.is_empty())
793            .collect()
794    }
795
796    /// Hierarchical search using cluster tree
797    fn hierarchical_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
798        // Build cluster tree if not already built
799        if self.cluster_tree.is_none() {
800            // Can't modify self in this context, fall back to parallel search
801            return self.parallel_search(query, k);
802        }
803
804        let cluster_tree = self.cluster_tree.as_ref().unwrap();
805        let mut candidates = Vec::new();
806
807        // Traverse cluster tree to find promising candidates
808        self.traverse_cluster_tree(query, cluster_tree, &mut candidates, k * 5);
809
810        // Calculate similarities for candidates
811        let mut results: Vec<_> = candidates
812            .into_iter()
813            .map(|idx| {
814                let entry = &self.positions[idx];
815                let similarity = self.cosine_similarity_cached(query, &entry.vector);
816                (entry.vector.clone(), entry.evaluation, similarity)
817            })
818            .collect();
819
820        // Sort and return top k
821        results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
822        results.truncate(k);
823
824        results
825    }
826
827    /// Hierarchical search with references
828    fn hierarchical_search_ref(
829        &self,
830        query: &Array1<f32>,
831        k: usize,
832    ) -> Vec<(&Array1<f32>, f32, f32)> {
833        // Build cluster tree if not already built
834        if self.cluster_tree.is_none() {
835            // Can't modify self in this context, fall back to parallel search
836            return self.parallel_search_ref(query, k);
837        }
838
839        let cluster_tree = self.cluster_tree.as_ref().unwrap();
840        let mut candidates = Vec::new();
841
842        // Traverse cluster tree to find promising candidates
843        self.traverse_cluster_tree(query, cluster_tree, &mut candidates, k * 5);
844
845        // Calculate similarities for candidates
846        let mut results: Vec<_> = candidates
847            .into_iter()
848            .map(|idx| {
849                let entry = &self.positions[idx];
850                let similarity = self.cosine_similarity_cached(query, &entry.vector);
851                (&entry.vector, entry.evaluation, similarity)
852            })
853            .collect();
854
855        // Sort and return top k
856        results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
857        results.truncate(k);
858
859        results
860    }
861
862    /// Traverse cluster tree to find candidate positions
863    fn traverse_cluster_tree(
864        &self,
865        query: &Array1<f32>,
866        node: &ClusterNode,
867        candidates: &mut Vec<usize>,
868        max_candidates: usize,
869    ) {
870        if candidates.len() >= max_candidates {
871            return;
872        }
873
874        // Calculate similarity to cluster centroid
875        let centroid_similarity = self.cosine_similarity_cached(query, &node.centroid);
876
877        // If query is far from this cluster, skip it
878        let distance_threshold = 0.1; // Adjust based on dataset characteristics
879        if centroid_similarity < distance_threshold {
880            return;
881        }
882
883        if node.children.is_empty() {
884            // Leaf node - add all positions
885            for &idx in &node.position_indices {
886                if candidates.len() < max_candidates {
887                    candidates.push(idx);
888                }
889            }
890        } else {
891            // Internal node - sort children by similarity and traverse best ones first
892            let mut child_similarities: Vec<_> = node
893                .children
894                .iter()
895                .enumerate()
896                .map(|(i, child)| {
897                    let similarity = self.cosine_similarity_cached(query, &child.centroid);
898                    (i, similarity)
899                })
900                .collect();
901
902            child_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
903
904            // Traverse children in order of similarity
905            for (child_idx, _) in child_similarities {
906                self.traverse_cluster_tree(
907                    query,
908                    &node.children[child_idx],
909                    candidates,
910                    max_candidates,
911                );
912            }
913        }
914    }
915
916    /// Cached cosine similarity calculation
917    fn cosine_similarity_cached(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
918        // Use optimized SIMD version with early termination checks
919        let a_norm_sq = SimdVectorOps::squared_norm(a);
920        let b_norm_sq = SimdVectorOps::squared_norm(b);
921        
922        if a_norm_sq == 0.0 || b_norm_sq == 0.0 {
923            return 0.0;
924        }
925        
926        let dot_product = SimdVectorOps::dot_product(a, b);
927        dot_product / (a_norm_sq.sqrt() * b_norm_sq.sqrt())
928    }
929
930    /// Force rebuild of cluster tree (useful after adding many positions)
931    pub fn rebuild_cluster_tree(&mut self) {
932        self.cluster_tree = None;
933        self.build_cluster_tree();
934    }
935
936    /// Get cluster tree statistics
937    pub fn cluster_tree_stats(&self) -> Option<ClusterTreeStats> {
938        self.cluster_tree.as_ref().map(|tree| {
939            let mut stats = ClusterTreeStats {
940                total_nodes: 0,
941                leaf_nodes: 0,
942                max_depth: 0,
943                avg_cluster_size: 0.0,
944                max_cluster_size: 0,
945            };
946
947            self.collect_cluster_stats(tree, 0, &mut stats);
948
949            if stats.leaf_nodes > 0 {
950                stats.avg_cluster_size = self.positions.len() as f32 / stats.leaf_nodes as f32;
951            }
952
953            stats
954        })
955    }
956
957    /// Recursively collect cluster statistics
958    fn collect_cluster_stats(
959        &self,
960        node: &ClusterNode,
961        depth: usize,
962        stats: &mut ClusterTreeStats,
963    ) {
964        stats.total_nodes += 1;
965        stats.max_depth = stats.max_depth.max(depth);
966        stats.max_cluster_size = stats.max_cluster_size.max(node.size);
967
968        if node.children.is_empty() {
969            stats.leaf_nodes += 1;
970        } else {
971            for child in &node.children {
972                self.collect_cluster_stats(child, depth + 1, stats);
973            }
974        }
975    }
976    
977    // ============ CACHE MANAGEMENT METHODS ============
978    
979    /// Generate a hash key for a query vector and k value
980    fn hash_query(&self, query: &Array1<f32>, k: usize) -> u64 {
981        use std::collections::hash_map::DefaultHasher;
982        use std::hash::{Hash, Hasher};
983        
984        let mut hasher = DefaultHasher::new();
985        
986        // Hash the query vector (sample key elements for performance)
987        for i in (0..query.len()).step_by(16) { // Sample every 16th element
988            ((query[i] * 1000.0) as i32).hash(&mut hasher);
989        }
990        k.hash(&mut hasher);
991        self.positions.len().hash(&mut hasher); // Include dataset size in hash
992        
993        hasher.finish()
994    }
995    
996    /// Check if we have a cached result for this query
997    fn get_cached_result(&mut self, query_hash: u64) -> Option<Vec<(Array1<f32>, f32, f32)>> {
998        let now = Instant::now();
999        
1000        if let Some(cached_entry) = self.result_cache.get(&query_hash) {
1001            if now.duration_since(cached_entry.timestamp) < self.cache_ttl {
1002                self.cache_hits += 1;
1003                return Some(cached_entry.results.clone());
1004            } else {
1005                // Remove expired entry
1006                self.result_cache.remove(&query_hash);
1007            }
1008        }
1009        
1010        self.cache_misses += 1;
1011        None
1012    }
1013    
1014    /// Cache search results for future lookups
1015    fn cache_search_result(&mut self, query_hash: u64, results: Vec<(Array1<f32>, f32, f32)>) {
1016        let now = Instant::now();
1017        
1018        self.result_cache.insert(query_hash, SearchResultCache {
1019            results,
1020            timestamp: now,
1021        });
1022        
1023        // Maintain cache size
1024        if self.result_cache.len() > self.max_cache_size / 10 {
1025            self.evict_oldest_result_cache_entries();
1026        }
1027    }
1028    
1029    /// Evict expired cache entries to maintain performance
1030    fn evict_expired_cache_entries(&mut self) {
1031        let now = Instant::now();
1032        
1033        // Evict expired similarity cache entries
1034        self.similarity_cache.retain(|_, (_, cached_time)| {
1035            now.duration_since(*cached_time) < self.cache_ttl
1036        });
1037        
1038        // Evict expired result cache entries
1039        self.result_cache.retain(|_, cached_result| {
1040            now.duration_since(cached_result.timestamp) < self.cache_ttl
1041        });
1042    }
1043    
1044    /// Evict oldest cache entries when cache is full (LRU eviction)
1045    fn evict_oldest_cache_entries(&mut self) {
1046        // Remove oldest 25% of similarity cache entries
1047        let entries_to_remove = self.similarity_cache.len() / 4;
1048        if entries_to_remove > 0 {
1049            let mut entries: Vec<_> = self.similarity_cache.iter().map(|(k, v)| (*k, *v)).collect();
1050            entries.sort_by_key(|(_, (_, time))| *time);
1051            
1052            for i in 0..entries_to_remove {
1053                if let Some((key, _)) = entries.get(i) {
1054                    self.similarity_cache.remove(key);
1055                }
1056            }
1057        }
1058    }
1059    
1060    /// Evict oldest result cache entries when cache is full (LRU eviction)
1061    fn evict_oldest_result_cache_entries(&mut self) {
1062        // Remove oldest 25% of result cache entries
1063        let entries_to_remove = self.result_cache.len() / 4;
1064        if entries_to_remove > 0 {
1065            let mut entries: Vec<_> = self.result_cache.iter().map(|(k, v)| (*k, v.timestamp)).collect();
1066            entries.sort_by_key(|(_, timestamp)| *timestamp);
1067            
1068            for i in 0..entries_to_remove {
1069                if let Some((key, _)) = entries.get(i) {
1070                    self.result_cache.remove(key);
1071                }
1072            }
1073        }
1074    }
1075    
1076    /// Get comprehensive cache statistics for performance monitoring
1077    pub fn get_cache_stats(&self) -> SimilarityCacheStats {
1078        let hit_ratio = if self.cache_hits + self.cache_misses > 0 {
1079            self.cache_hits as f32 / (self.cache_hits + self.cache_misses) as f32
1080        } else {
1081            0.0
1082        };
1083        
1084        SimilarityCacheStats {
1085            result_cache_size: self.result_cache.len(),
1086            similarity_cache_size: self.similarity_cache.len(),
1087            max_cache_size: self.max_cache_size,
1088            cache_ttl_secs: self.cache_ttl.as_secs(),
1089            cache_hits: self.cache_hits,
1090            cache_misses: self.cache_misses,
1091            hit_ratio,
1092        }
1093    }
1094    
1095    /// Clear all caches (useful for benchmarking or memory management)
1096    pub fn clear_caches(&mut self) {
1097        self.similarity_cache.clear();
1098        self.result_cache.clear();
1099        self.cache_hits = 0;
1100        self.cache_misses = 0;
1101    }
1102    
1103    /// Reset cache statistics while preserving cached data
1104    pub fn reset_cache_stats(&mut self) {
1105        self.cache_hits = 0;
1106        self.cache_misses = 0;
1107    }
1108}
1109
1110/// Statistics about the similarity search index
1111#[derive(Debug, Clone)]
1112pub struct SimilaritySearchStats {
1113    pub count: usize,
1114    pub avg_evaluation: f32,
1115    pub min_evaluation: f32,
1116    pub max_evaluation: f32,
1117}
1118
1119/// Statistics about the cluster tree
1120#[derive(Debug, Clone)]
1121pub struct ClusterTreeStats {
1122    pub total_nodes: usize,
1123    pub leaf_nodes: usize,
1124    pub max_depth: usize,
1125    pub avg_cluster_size: f32,
1126    pub max_cluster_size: usize,
1127}
1128
1129impl SimilaritySearch {
1130    /// Optimized search with early termination and cached computations
1131    pub fn search_optimized(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
1132        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
1133
1134        if self.positions.is_empty() {
1135            return Vec::new();
1136        }
1137
1138        // Pre-compute query norm once
1139        let query_norm_squared = SimdVectorOps::squared_norm(query);
1140        let query_norm = query_norm_squared.sqrt();
1141
1142        // For small k, use optimized heap management
1143        if k <= 10 && self.positions.len() > k * 10 {
1144            return self.search_with_bounded_heap(query, query_norm_squared, k);
1145        }
1146
1147        // For larger k or smaller datasets, use parallel approach with early termination
1148        self.search_parallel_optimized(query, query_norm, k)
1149    }
1150
1151    /// Search using bounded heap for small k values (memory efficient)
1152    fn search_with_bounded_heap(
1153        &self,
1154        query: &Array1<f32>,
1155        query_norm_squared: f32,
1156        k: usize,
1157    ) -> Vec<(Array1<f32>, f32, f32)> {
1158        let mut heap = BinaryHeap::with_capacity(k + 1);
1159        let mut min_similarity = f32::NEG_INFINITY;
1160
1161        for entry in &self.positions {
1162            // Early termination: skip if impossible to beat current worst
1163            if heap.len() == k && self.can_skip_entry(query, entry, min_similarity) {
1164                continue;
1165            }
1166
1167            let similarity = self.cosine_similarity_fast_uncached(query, query_norm_squared, entry);
1168
1169            let result = SearchResult {
1170                similarity,
1171                evaluation: entry.evaluation,
1172                vector: entry.vector.clone(),
1173            };
1174
1175            if heap.len() < k {
1176                if heap.is_empty() || similarity < min_similarity {
1177                    min_similarity = similarity;
1178                }
1179                heap.push(result);
1180            } else if similarity > min_similarity {
1181                heap.pop();  // Remove worst
1182                heap.push(result);
1183                // Update min_similarity
1184                min_similarity = heap.peek().map(|r| r.similarity).unwrap_or(f32::NEG_INFINITY);
1185            }
1186        }
1187
1188        // Convert to sorted results
1189        self.heap_to_sorted_results(heap)
1190    }
1191
1192    /// Parallel search with optimizations for larger k values
1193    fn search_parallel_optimized(
1194        &self,
1195        query: &Array1<f32>,
1196        query_norm: f32,
1197        k: usize,
1198    ) -> Vec<(Array1<f32>, f32, f32)> {
1199        // Use chunks to reduce memory allocation overhead
1200        let chunk_size = (self.positions.len() / rayon::current_num_threads()).max(1000);
1201        
1202        let mut results: Vec<_> = self
1203            .positions
1204            .par_chunks(chunk_size)
1205            .flat_map(|chunk| {
1206                chunk.par_iter().map(|entry| {
1207                    let entry_norm = entry.norm_squared.sqrt();
1208                    let similarity = self.cosine_similarity_ultra_fast(query, query_norm, entry, entry_norm);
1209                    (entry.vector.clone(), entry.evaluation, similarity)
1210                })
1211            })
1212            .collect();
1213
1214        // Use appropriate sorting strategy based on k vs n ratio
1215        if k * 10 < results.len() {
1216            // For small k, use full sort then truncate (simpler and still efficient)
1217            results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1218            results.truncate(k);
1219        } else {
1220            // Full sort for cases where k is large relative to n
1221            results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1222            results.truncate(k);
1223        }
1224
1225        results
1226    }
1227
1228    /// Early termination heuristic: can we skip this entry?
1229    fn can_skip_entry(&self, _query: &Array1<f32>, _entry: &PositionEntry, min_similarity: f32) -> bool {
1230        // Simple heuristic: if the entry's norm is too different, skip
1231        // This is a conservative approximation - could be made more sophisticated
1232        
1233        // For now, implement a basic check
1234        // In practice, you could use bounds based on vector norms and angles
1235        min_similarity > 0.95  // Only skip if we already have very high similarities
1236    }
1237
1238    /// Convert heap to sorted results vector
1239    fn heap_to_sorted_results(&self, mut heap: BinaryHeap<SearchResult>) -> Vec<(Array1<f32>, f32, f32)> {
1240        let mut results = Vec::with_capacity(heap.len());
1241        while let Some(result) = heap.pop() {
1242            results.push((result.vector, result.evaluation, result.similarity));
1243        }
1244        results.reverse(); // Heap gives us worst-first, we want best-first
1245        results
1246    }
1247
1248    /// Batch search optimization for multiple queries
1249    pub fn batch_search_optimized(
1250        &self,
1251        queries: &[Array1<f32>],
1252        k: usize,
1253    ) -> Vec<Vec<(Array1<f32>, f32, f32)>> {
1254        if queries.is_empty() || self.positions.is_empty() {
1255            return vec![Vec::new(); queries.len()];
1256        }
1257
1258        // Pre-compute norms for all queries
1259        let query_norms: Vec<f32> = queries
1260            .par_iter()
1261            .map(|q| SimdVectorOps::squared_norm(q).sqrt())
1262            .collect();
1263
1264        // Process queries in parallel
1265        queries
1266            .par_iter()
1267            .zip(query_norms.par_iter())
1268            .map(|(query, &query_norm)| {
1269                self.search_parallel_optimized(query, query_norm, k)
1270            })
1271            .collect()
1272    }
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277    use super::*;
1278    use ndarray::Array1;
1279
1280    #[test]
1281    fn test_similarity_search_creation() {
1282        let search = SimilaritySearch::new(100);
1283        assert_eq!(search.size(), 0);
1284        assert!(search.is_empty());
1285    }
1286
1287    #[test]
1288    fn test_add_and_search() {
1289        let mut search = SimilaritySearch::new(3);
1290
1291        // Add some test vectors
1292        let vec1 = Array1::from(vec![1.0, 0.0, 0.0]);
1293        let vec2 = Array1::from(vec![0.0, 1.0, 0.0]);
1294        let vec3 = Array1::from(vec![0.0, 0.0, 1.0]);
1295
1296        search.add_position(vec1.clone(), 1.0);
1297        search.add_position(vec2, 0.5);
1298        search.add_position(vec3, 0.0);
1299
1300        assert_eq!(search.size(), 3);
1301
1302        // Search for similar to vec1
1303        let results = search.search(&vec1, 2);
1304        assert_eq!(results.len(), 2);
1305
1306        // First result should be identical (similarity = 1.0)
1307        assert!((results[0].2 - 1.0).abs() < 1e-6);
1308        assert!((results[0].1 - 1.0).abs() < 1e-6);
1309    }
1310
1311    #[test]
1312    fn test_cosine_similarity() {
1313        let search = SimilaritySearch::new(2);
1314
1315        let vec1 = Array1::from(vec![1.0, 0.0]);
1316        let vec2 = Array1::from(vec![1.0, 0.0]);
1317        let vec3 = Array1::from(vec![0.0, 1.0]);
1318
1319        // Identical vectors
1320        assert!((search.cosine_similarity(&vec1, &vec2) - 1.0).abs() < 1e-6);
1321
1322        // Orthogonal vectors
1323        assert!((search.cosine_similarity(&vec1, &vec3) - 0.0).abs() < 1e-6);
1324    }
1325
1326    #[test]
1327    fn test_statistics() {
1328        let mut search = SimilaritySearch::new(2);
1329
1330        let vec = Array1::from(vec![1.0, 0.0]);
1331        search.add_position(vec.clone(), 1.0);
1332        search.add_position(vec.clone(), 2.0);
1333        search.add_position(vec, 3.0);
1334
1335        let stats = search.statistics();
1336        assert_eq!(stats.count, 3);
1337        assert!((stats.avg_evaluation - 2.0).abs() < 1e-6);
1338        assert!((stats.min_evaluation - 1.0).abs() < 1e-6);
1339        assert!((stats.max_evaluation - 3.0).abs() < 1e-6);
1340    }
1341}