chess_vector_engine/
similarity_search.rs

1#![allow(clippy::type_complexity)]
2use crate::gpu_acceleration::GPUAccelerator;
3use ndarray::{Array1, Array2};
4use rayon::prelude::*;
5use std::cmp::Ordering;
6use std::collections::BinaryHeap;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10#[cfg(target_arch = "x86_64")]
11use std::arch::x86_64::*;
12
13/// Entry in the similarity search index
14#[derive(Debug, Clone)]
15pub struct PositionEntry {
16    pub vector: Array1<f32>,
17    pub evaluation: f32,
18    pub norm_squared: f32,
19}
20
21/// Result from similarity search (reference-based)
22#[derive(Debug)]
23pub struct SearchResultRef<'a> {
24    pub similarity: f32,
25    pub evaluation: f32,
26    pub vector: &'a Array1<f32>,
27}
28
29/// Result from similarity search (owned)
30#[derive(Debug, Clone)]
31pub struct SearchResult {
32    pub similarity: f32,
33    pub evaluation: f32,
34    pub vector: Array1<f32>,
35}
36
37impl PartialEq for SearchResult {
38    fn eq(&self, other: &Self) -> bool {
39        self.similarity == other.similarity
40    }
41}
42
43impl Eq for SearchResult {}
44
45impl PartialOrd for SearchResult {
46    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
47        Some(self.cmp(other))
48    }
49}
50
51impl Ord for SearchResult {
52    fn cmp(&self, other: &Self) -> Ordering {
53        // Reverse ordering for max-heap behavior in BinaryHeap
54        other
55            .similarity
56            .partial_cmp(&self.similarity)
57            .unwrap_or(Ordering::Equal)
58    }
59}
60
61/// Similarity search engine for chess positions
62#[derive(Clone)]
63pub struct SimilaritySearch {
64    /// All stored positions
65    positions: Vec<PositionEntry>,
66    /// Dimension of vectors
67    vector_size: usize,
68}
69
70impl SimilaritySearch {
71    /// Create a new similarity search engine
72    pub fn new(vector_size: usize) -> Self {
73        Self {
74            positions: Vec::new(),
75            vector_size,
76        }
77    }
78
79    /// Add a position to the search index
80    pub fn add_position(&mut self, vector: Array1<f32>, evaluation: f32) {
81        assert_eq!(vector.len(), self.vector_size, "Vector size mismatch");
82
83        let norm_squared =
84            self.simd_dot_product(vector.as_slice().unwrap(), vector.as_slice().unwrap());
85
86        self.positions.push(PositionEntry {
87            vector,
88            evaluation,
89            norm_squared,
90        });
91    }
92
93    /// Search for k most similar positions with references (memory efficient)
94    pub fn search_ref(&self, query: &Array1<f32>, k: usize) -> Vec<(&Array1<f32>, f32, f32)> {
95        // Note: GPU search not supported for reference version due to lifetime constraints
96        // Fall back to CPU-based search methods
97
98        // Use parallel CPU search for medium datasets
99        if self.positions.len() > 100 {
100            self.parallel_search_ref(query, k)
101        } else {
102            self.sequential_search_ref(query, k)
103        }
104    }
105
106    /// Search for k most similar positions (automatically chooses best method: GPU > parallel > sequential)
107    pub fn search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
108        let gpu_accelerator = GPUAccelerator::global();
109
110        // Use GPU acceleration for large datasets if available
111        if gpu_accelerator.is_gpu_enabled() && self.positions.len() > 500 {
112            match self.gpu_accelerated_search(query, k) {
113                Ok(results) => return results,
114                Err(e) => {
115                    println!("GPU search failed ({}), falling back to CPU", e);
116                }
117            }
118        }
119
120        // Fall back to parallel CPU search for medium datasets
121        if self.positions.len() > 100 {
122            self.parallel_search(query, k)
123        } else {
124            self.sequential_search(query, k)
125        }
126    }
127
128    /// GPU-accelerated similarity search for large datasets
129    pub fn gpu_accelerated_search(
130        &self,
131        query: &Array1<f32>,
132        k: usize,
133    ) -> Result<Vec<(Array1<f32>, f32, f32)>, Box<dyn std::error::Error>> {
134        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
135
136        if self.positions.is_empty() {
137            return Ok(Vec::new());
138        }
139
140        let gpu_accelerator = GPUAccelerator::global();
141
142        // Prepare vectors matrix for GPU computation
143        let mut vectors_data = Vec::with_capacity(self.positions.len() * self.vector_size);
144        for entry in &self.positions {
145            vectors_data.extend_from_slice(entry.vector.as_slice().unwrap());
146        }
147
148        let vectors_matrix =
149            Array2::from_shape_vec((self.positions.len(), self.vector_size), vectors_data)?;
150
151        // Compute similarities on GPU
152        let similarities = gpu_accelerator.cosine_similarity_batch(query, &vectors_matrix)?;
153
154        // Find top-k results
155        let mut indexed_similarities: Vec<(usize, f32)> = similarities
156            .iter()
157            .enumerate()
158            .map(|(i, &sim)| (i, sim))
159            .collect();
160
161        // Sort by similarity (descending)
162        indexed_similarities
163            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164
165        // Take top-k and prepare results
166        let mut results = Vec::new();
167        for (idx, similarity) in indexed_similarities.into_iter().take(k) {
168            let entry = &self.positions[idx];
169            results.push((entry.vector.clone(), entry.evaluation, similarity));
170        }
171
172        Ok(results)
173    }
174
175    /// Sequential search implementation with references (memory efficient)
176    pub fn sequential_search_ref(
177        &self,
178        query: &Array1<f32>,
179        k: usize,
180    ) -> Vec<(&Array1<f32>, f32, f32)> {
181        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
182
183        if self.positions.is_empty() {
184            return Vec::new();
185        }
186
187        let query_norm_squared = query.dot(query);
188
189        // Collect all similarities with indices
190        let mut indexed_similarities: Vec<(usize, f32)> = self
191            .positions
192            .iter()
193            .enumerate()
194            .map(|(idx, entry)| {
195                let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
196                (idx, similarity)
197            })
198            .collect();
199
200        // Sort by similarity (descending)
201        indexed_similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
202
203        // Take top k and return references
204        indexed_similarities
205            .into_iter()
206            .take(k)
207            .map(|(idx, similarity)| {
208                let entry = &self.positions[idx];
209                (&entry.vector, entry.evaluation, similarity)
210            })
211            .collect()
212    }
213
214    /// Sequential search implementation (for small datasets)
215    pub fn sequential_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
216        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
217
218        if self.positions.is_empty() {
219            return Vec::new();
220        }
221
222        let query_norm_squared = query.dot(query);
223
224        // Use a min-heap to keep track of top-k results
225        let mut heap = BinaryHeap::new();
226
227        for entry in &self.positions {
228            let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
229
230            let result = SearchResult {
231                similarity,
232                evaluation: entry.evaluation,
233                vector: entry.vector.clone(),
234            };
235
236            if heap.len() < k {
237                heap.push(result);
238            } else if similarity > heap.peek().unwrap().similarity {
239                heap.pop();
240                heap.push(result);
241            }
242        }
243
244        // Convert heap to sorted vector (highest similarity first)
245        let mut results = Vec::new();
246        while let Some(result) = heap.pop() {
247            results.push((result.vector, result.evaluation, result.similarity));
248        }
249
250        // Reverse to get highest similarity first
251        results.reverse();
252        results
253    }
254
255    /// Parallel search implementation with references (memory efficient)
256    pub fn parallel_search_ref(
257        &self,
258        query: &Array1<f32>,
259        k: usize,
260    ) -> Vec<(&Array1<f32>, f32, f32)> {
261        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
262
263        if self.positions.is_empty() {
264            return Vec::new();
265        }
266
267        let query_norm_squared = query.dot(query);
268
269        // Calculate similarities in parallel with indices
270        let mut indexed_similarities: Vec<(usize, f32)> = self
271            .positions
272            .par_iter()
273            .enumerate()
274            .map(|(idx, entry)| {
275                let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
276                (idx, similarity)
277            })
278            .collect();
279
280        // Sort by similarity (descending) and take top k
281        indexed_similarities
282            .par_sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
283        indexed_similarities.truncate(k);
284
285        // Return references instead of clones
286        indexed_similarities
287            .into_iter()
288            .map(|(idx, similarity)| {
289                let entry = &self.positions[idx];
290                (&entry.vector, entry.evaluation, similarity)
291            })
292            .collect()
293    }
294
295    /// Parallel search implementation (for larger datasets)
296    pub fn parallel_search(&self, query: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
297        assert_eq!(query.len(), self.vector_size, "Query vector size mismatch");
298
299        if self.positions.is_empty() {
300            return Vec::new();
301        }
302
303        let query_norm_squared = query.dot(query);
304
305        // Calculate similarities in parallel
306        let mut results: Vec<_> = self
307            .positions
308            .par_iter()
309            .map(|entry| {
310                let similarity = self.cosine_similarity_fast(query, query_norm_squared, entry);
311                (entry.vector.clone(), entry.evaluation, similarity)
312            })
313            .collect();
314
315        // Sort by similarity (descending) and take top k
316        results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
317        results.truncate(k);
318
319        results
320    }
321
322    /// Brute force search (for small datasets or comparison)
323    pub fn brute_force_search(
324        &self,
325        query: &Array1<f32>,
326        k: usize,
327    ) -> Vec<(Array1<f32>, f32, f32)> {
328        let mut results: Vec<_> = if self.positions.len() > 100 {
329            // Use parallel processing for larger datasets
330            self.positions
331                .par_iter()
332                .map(|entry| {
333                    let similarity = self.cosine_similarity(query, &entry.vector);
334                    (entry.vector.clone(), entry.evaluation, similarity)
335                })
336                .collect()
337        } else {
338            // Use sequential processing for smaller datasets
339            self.positions
340                .iter()
341                .map(|entry| {
342                    let similarity = self.cosine_similarity(query, &entry.vector);
343                    (entry.vector.clone(), entry.evaluation, similarity)
344                })
345                .collect()
346        };
347
348        // Sort by similarity (descending)
349        if results.len() > 1000 {
350            results.par_sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
351        } else {
352            results.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
353        }
354
355        // Take top k
356        results.truncate(k);
357        results
358    }
359
360    /// Calculate cosine similarity between query vector and a position entry (SIMD optimized)
361    fn cosine_similarity_fast(
362        &self,
363        query: &Array1<f32>,
364        query_norm_squared: f32,
365        entry: &PositionEntry,
366    ) -> f32 {
367        let dot_product =
368            self.simd_dot_product(query.as_slice().unwrap(), entry.vector.as_slice().unwrap());
369
370        if query_norm_squared == 0.0 || entry.norm_squared == 0.0 {
371            0.0
372        } else {
373            dot_product / (query_norm_squared.sqrt() * entry.norm_squared.sqrt())
374        }
375    }
376
377    /// SIMD-optimized dot product calculation
378    #[inline]
379    fn simd_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
380        #[cfg(target_arch = "x86_64")]
381        {
382            if is_x86_feature_detected!("avx2") {
383                return unsafe { self.avx2_dot_product(a, b) };
384            } else if is_x86_feature_detected!("sse4.1") {
385                return unsafe { self.sse_dot_product(a, b) };
386            }
387        }
388
389        #[cfg(target_arch = "aarch64")]
390        {
391            if std::arch::is_aarch64_feature_detected!("neon") {
392                return unsafe { self.neon_dot_product(a, b) };
393            }
394        }
395
396        // Fallback to scalar implementation
397        self.scalar_dot_product(a, b)
398    }
399
400    #[cfg(target_arch = "x86_64")]
401    #[target_feature(enable = "avx2")]
402    unsafe fn avx2_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
403        let len = a.len().min(b.len());
404        let mut sum = _mm256_setzero_ps();
405        let mut i = 0;
406
407        // Process 8 floats at a time with AVX2
408        while i + 8 <= len {
409            let va = _mm256_loadu_ps(a.as_ptr().add(i));
410            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
411            let vmul = _mm256_mul_ps(va, vb);
412            sum = _mm256_add_ps(sum, vmul);
413            i += 8;
414        }
415
416        // Horizontal sum of the AVX2 register
417        let mut result = [0.0f32; 8];
418        _mm256_storeu_ps(result.as_mut_ptr(), sum);
419        let mut final_sum = result.iter().sum::<f32>();
420
421        // Handle remaining elements
422        while i < len {
423            final_sum += a[i] * b[i];
424            i += 1;
425        }
426
427        final_sum
428    }
429
430    #[cfg(target_arch = "x86_64")]
431    #[target_feature(enable = "sse4.1")]
432    unsafe fn sse_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
433        let len = a.len().min(b.len());
434        let mut sum = _mm_setzero_ps();
435        let mut i = 0;
436
437        // Process 4 floats at a time with SSE
438        while i + 4 <= len {
439            let va = _mm_loadu_ps(a.as_ptr().add(i));
440            let vb = _mm_loadu_ps(b.as_ptr().add(i));
441            let vmul = _mm_mul_ps(va, vb);
442            sum = _mm_add_ps(sum, vmul);
443            i += 4;
444        }
445
446        // Horizontal sum of the SSE register
447        let mut result = [0.0f32; 4];
448        _mm_storeu_ps(result.as_mut_ptr(), sum);
449        let mut final_sum = result.iter().sum::<f32>();
450
451        // Handle remaining elements
452        while i < len {
453            final_sum += a[i] * b[i];
454            i += 1;
455        }
456
457        final_sum
458    }
459
460    #[cfg(target_arch = "aarch64")]
461    #[target_feature(enable = "neon")]
462    unsafe fn neon_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
463        let len = a.len().min(b.len());
464        let mut sum = vdupq_n_f32(0.0);
465        let mut i = 0;
466
467        // Process 4 floats at a time with NEON
468        while i + 4 <= len {
469            let va = vld1q_f32(a.as_ptr().add(i));
470            let vb = vld1q_f32(b.as_ptr().add(i));
471            let vmul = vmulq_f32(va, vb);
472            sum = vaddq_f32(sum, vmul);
473            i += 4;
474        }
475
476        // Horizontal sum of the NEON register
477        let mut result = [0.0f32; 4];
478        vst1q_f32(result.as_mut_ptr(), sum);
479        let mut final_sum = result.iter().sum::<f32>();
480
481        // Handle remaining elements
482        while i < len {
483            final_sum += a[i] * b[i];
484            i += 1;
485        }
486
487        final_sum
488    }
489
490    #[inline]
491    fn scalar_dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
492        let len = a.len().min(b.len());
493        let mut sum = 0.0f32;
494
495        // Unroll loop for better performance
496        let mut i = 0;
497        while i + 4 <= len {
498            sum += a[i] * b[i] + a[i + 1] * b[i + 1] + a[i + 2] * b[i + 2] + a[i + 3] * b[i + 3];
499            i += 4;
500        }
501
502        // Handle remaining elements
503        while i < len {
504            sum += a[i] * b[i];
505            i += 1;
506        }
507
508        sum
509    }
510
511    /// Calculate cosine similarity between two vectors (fallback method)
512    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
513        let dot_product = a.dot(b);
514        let norm_a = a.dot(a).sqrt();
515        let norm_b = b.dot(b).sqrt();
516
517        if norm_a == 0.0 || norm_b == 0.0 {
518            0.0
519        } else {
520            dot_product / (norm_a * norm_b)
521        }
522    }
523
524    /// Calculate Euclidean distance between two vectors
525    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
526        (a - b).mapv(|x| x * x).sum().sqrt()
527    }
528
529    /// Search using Euclidean distance (alternative to cosine similarity)
530    pub fn search_by_distance(
531        &self,
532        query: &Array1<f32>,
533        k: usize,
534    ) -> Vec<(Array1<f32>, f32, f32)> {
535        let mut results: Vec<_> = self
536            .positions
537            .iter()
538            .map(|entry| {
539                let distance = self.euclidean_distance(query, &entry.vector);
540                (entry.vector.clone(), entry.evaluation, distance)
541            })
542            .collect();
543
544        // Sort by distance (ascending - smaller distance = more similar)
545        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
546
547        // Take top k
548        results.truncate(k);
549        results
550    }
551
552    /// Get number of positions in the index
553    pub fn size(&self) -> usize {
554        self.positions.len()
555    }
556
557    /// Check if the index is empty
558    pub fn is_empty(&self) -> bool {
559        self.positions.is_empty()
560    }
561
562    /// Clear all positions
563    pub fn clear(&mut self) {
564        self.positions.clear();
565    }
566
567    /// Get statistics about the stored vectors
568    pub fn statistics(&self) -> SimilaritySearchStats {
569        if self.positions.is_empty() {
570            return SimilaritySearchStats {
571                count: 0,
572                avg_evaluation: 0.0,
573                min_evaluation: 0.0,
574                max_evaluation: 0.0,
575            };
576        }
577
578        let evaluations: Vec<f32> = self.positions.iter().map(|p| p.evaluation).collect();
579        let sum: f32 = evaluations.iter().sum();
580        let avg = sum / evaluations.len() as f32;
581        let min = evaluations.iter().fold(f32::INFINITY, |a, &b| a.min(b));
582        let max = evaluations.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
583
584        SimilaritySearchStats {
585            count: self.positions.len(),
586            avg_evaluation: avg,
587            min_evaluation: min,
588            max_evaluation: max,
589        }
590    }
591
592    /// Get all stored positions (for LSH indexing)
593    pub fn get_all_positions(&self) -> Vec<(Array1<f32>, f32)> {
594        self.positions
595            .iter()
596            .map(|entry| (entry.vector.clone(), entry.evaluation))
597            .collect()
598    }
599
600    /// Get position vector by reference to avoid cloning
601    pub fn get_position_ref(&self, index: usize) -> Option<(&Array1<f32>, f32)> {
602        self.positions
603            .get(index)
604            .map(|entry| (&entry.vector, entry.evaluation))
605    }
606
607    /// Get all positions as references (memory efficient iterator)
608    pub fn iter_positions(&self) -> impl Iterator<Item = (&Array1<f32>, f32)> {
609        self.positions
610            .iter()
611            .map(|entry| (&entry.vector, entry.evaluation))
612    }
613}
614
615/// Statistics about the similarity search index
616#[derive(Debug, Clone)]
617pub struct SimilaritySearchStats {
618    pub count: usize,
619    pub avg_evaluation: f32,
620    pub min_evaluation: f32,
621    pub max_evaluation: f32,
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use ndarray::Array1;
628
629    #[test]
630    fn test_similarity_search_creation() {
631        let search = SimilaritySearch::new(100);
632        assert_eq!(search.size(), 0);
633        assert!(search.is_empty());
634    }
635
636    #[test]
637    fn test_add_and_search() {
638        let mut search = SimilaritySearch::new(3);
639
640        // Add some test vectors
641        let vec1 = Array1::from(vec![1.0, 0.0, 0.0]);
642        let vec2 = Array1::from(vec![0.0, 1.0, 0.0]);
643        let vec3 = Array1::from(vec![0.0, 0.0, 1.0]);
644
645        search.add_position(vec1.clone(), 1.0);
646        search.add_position(vec2, 0.5);
647        search.add_position(vec3, 0.0);
648
649        assert_eq!(search.size(), 3);
650
651        // Search for similar to vec1
652        let results = search.search(&vec1, 2);
653        assert_eq!(results.len(), 2);
654
655        // First result should be identical (similarity = 1.0)
656        assert!((results[0].2 - 1.0).abs() < 1e-6);
657        assert!((results[0].1 - 1.0).abs() < 1e-6);
658    }
659
660    #[test]
661    fn test_cosine_similarity() {
662        let search = SimilaritySearch::new(2);
663
664        let vec1 = Array1::from(vec![1.0, 0.0]);
665        let vec2 = Array1::from(vec![1.0, 0.0]);
666        let vec3 = Array1::from(vec![0.0, 1.0]);
667
668        // Identical vectors
669        assert!((search.cosine_similarity(&vec1, &vec2) - 1.0).abs() < 1e-6);
670
671        // Orthogonal vectors
672        assert!((search.cosine_similarity(&vec1, &vec3) - 0.0).abs() < 1e-6);
673    }
674
675    #[test]
676    fn test_statistics() {
677        let mut search = SimilaritySearch::new(2);
678
679        let vec = Array1::from(vec![1.0, 0.0]);
680        search.add_position(vec.clone(), 1.0);
681        search.add_position(vec.clone(), 2.0);
682        search.add_position(vec, 3.0);
683
684        let stats = search.statistics();
685        assert_eq!(stats.count, 3);
686        assert!((stats.avg_evaluation - 2.0).abs() < 1e-6);
687        assert!((stats.min_evaluation - 1.0).abs() < 1e-6);
688        assert!((stats.max_evaluation - 3.0).abs() < 1e-6);
689    }
690}