chess_vector_engine/
lsh.rs

1use ndarray::{Array1, Array2};
2use rand::Rng;
3use rayon::prelude::*;
4use std::collections::HashMap;
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8#[cfg(target_arch = "x86_64")]
9use std::arch::x86_64::*;
10
11/// Locality Sensitive Hashing for approximate nearest neighbor search
12#[derive(Clone)]
13pub struct LSH {
14    /// Number of hash tables
15    num_tables: usize,
16    /// Number of hash functions per table
17    hash_size: usize,
18    /// Vector dimension
19    #[allow(dead_code)]
20    vector_dim: usize,
21    /// Random hyperplanes for each hash table
22    hyperplanes: Vec<Array2<f32>>,
23    /// Hash tables storing (hash, vector_index) pairs
24    hash_tables: Vec<HashMap<Vec<bool>, Vec<usize>>>,
25    /// Stored vectors for retrieval
26    stored_vectors: Vec<Array1<f32>>,
27    /// Associated data (evaluations)
28    stored_data: Vec<f32>,
29}
30
31impl LSH {
32    /// Create a new LSH index with dynamic sizing
33    pub fn new(vector_dim: usize, num_tables: usize, hash_size: usize) -> Self {
34        Self::with_expected_size(vector_dim, num_tables, hash_size, 1000)
35    }
36
37    /// Create a new LSH index with expected dataset size for optimal memory allocation
38    pub fn with_expected_size(
39        vector_dim: usize,
40        num_tables: usize,
41        hash_size: usize,
42        expected_size: usize,
43    ) -> Self {
44        let mut rng = rand::thread_rng();
45        let mut hyperplanes = Vec::new();
46
47        // Generate random hyperplanes for each table
48        for _ in 0..num_tables {
49            let mut table_hyperplanes = Array2::zeros((hash_size, vector_dim));
50            for i in 0..hash_size {
51                for j in 0..vector_dim {
52                    table_hyperplanes[[i, j]] = rng.gen_range(-1.0..1.0);
53                }
54            }
55            hyperplanes.push(table_hyperplanes);
56        }
57
58        // Calculate optimal hash table capacity based on expected size
59        // Assume roughly 20% occupancy for good performance
60        let bucket_capacity = (expected_size as f32 / 5.0).ceil() as usize;
61        let optimal_capacity = bucket_capacity.max(100);
62
63        let hash_tables = vec![HashMap::with_capacity(optimal_capacity); num_tables];
64
65        Self {
66            num_tables,
67            hash_size,
68            vector_dim,
69            hyperplanes,
70            hash_tables,
71            stored_vectors: Vec::with_capacity(expected_size),
72            stored_data: Vec::with_capacity(expected_size),
73        }
74    }
75
76    /// Add a vector to the LSH index with dynamic resizing
77    pub fn add_vector(&mut self, vector: Array1<f32>, data: f32) {
78        let index = self.stored_vectors.len();
79
80        // Check if we need to resize hash tables (when load factor > 0.75)
81        let current_load =
82            self.stored_vectors.len() as f32 / (self.hash_tables[0].capacity() as f32 * 0.2);
83        if current_load > 0.75 {
84            self.resize_hash_tables();
85        }
86
87        // Hash the vector in each table before storing (to avoid clone)
88        let mut hashes = Vec::with_capacity(self.num_tables);
89        for table_idx in 0..self.num_tables {
90            hashes.push(self.hash_vector(&vector, table_idx));
91        }
92
93        // Now store the vector and data
94        self.stored_vectors.push(vector);
95        self.stored_data.push(data);
96
97        // Insert into hash tables using pre-computed hashes
98        for (table_idx, hash) in hashes.into_iter().enumerate() {
99            self.hash_tables[table_idx]
100                .entry(hash)
101                .or_insert_with(|| Vec::with_capacity(8)) // Increased pre-allocation
102                .push(index);
103        }
104    }
105
106    /// Resize hash tables when they become too full
107    fn resize_hash_tables(&mut self) {
108        let new_capacity = (self.hash_tables[0].capacity() * 2).max(self.stored_vectors.len());
109
110        for table in &mut self.hash_tables {
111            // Reserve additional capacity to avoid frequent rehashing
112            table.reserve(new_capacity - table.capacity());
113        }
114    }
115
116    /// Find approximate nearest neighbors
117    pub fn query(&self, query_vector: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
118        let mut candidates = std::collections::HashSet::new();
119        let max_candidates = (self.stored_vectors.len() / 4)
120            .max(k * 10)
121            .min(self.stored_vectors.len());
122
123        // Parallelize hash table queries when we have many tables
124        if self.num_tables > 4 {
125            // Collect candidates from hash tables in parallel
126            let candidate_sets: Vec<Vec<usize>> = (0..self.num_tables)
127                .into_par_iter()
128                .map(|table_idx| {
129                    let hash = self.hash_vector(query_vector, table_idx);
130                    if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
131                        bucket.clone()
132                    } else {
133                        Vec::new()
134                    }
135                })
136                .collect();
137
138            // Merge candidate sets sequentially (avoiding race conditions on HashSet)
139            for candidate_set in candidate_sets {
140                for idx in candidate_set {
141                    candidates.insert(idx);
142                    if candidates.len() >= max_candidates {
143                        break;
144                    }
145                }
146                if candidates.len() >= max_candidates {
147                    break;
148                }
149            }
150        } else {
151            // Sequential collection for smaller numbers of tables
152            for table_idx in 0..self.num_tables {
153                if candidates.len() >= max_candidates {
154                    break;
155                }
156
157                let hash = self.hash_vector(query_vector, table_idx);
158                if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
159                    for &idx in bucket {
160                        candidates.insert(idx);
161                        if candidates.len() >= max_candidates {
162                            break;
163                        }
164                    }
165                }
166            }
167        }
168
169        // If we have too few candidates, use a more efficient approach
170        if candidates.len() < k * 3 && self.stored_vectors.len() > k * 3 {
171            // Instead of random sampling, just take the first few indices
172            let needed = k * 5;
173            for idx in 0..needed.min(self.stored_vectors.len()) {
174                candidates.insert(idx);
175                if candidates.len() >= needed {
176                    break;
177                }
178            }
179        }
180
181        // Calculate similarities for candidates in parallel for large candidate sets
182        let mut results = if candidates.len() > 50 {
183            candidates
184                .par_iter()
185                .map(|&idx| {
186                    let stored_vector = &self.stored_vectors[idx];
187                    let similarity = cosine_similarity(query_vector, stored_vector);
188                    (stored_vector.clone(), self.stored_data[idx], similarity)
189                })
190                .collect()
191        } else {
192            // Sequential for smaller candidate sets
193            let mut results = Vec::with_capacity(candidates.len());
194            for &idx in &candidates {
195                let stored_vector = &self.stored_vectors[idx];
196                let similarity = cosine_similarity(query_vector, stored_vector);
197                results.push((stored_vector.clone(), self.stored_data[idx], similarity));
198            }
199            results
200        };
201
202        // Sort by similarity (descending) and return top k
203        if results.len() > 100 {
204            results.par_sort_unstable_by(|a, b| {
205                b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
206            });
207        } else {
208            results.sort_unstable_by(|a, b| {
209                b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
210            });
211        }
212        results.truncate(k);
213
214        results
215    }
216
217    /// Hash a vector using hyperplanes for a specific table
218    fn hash_vector(&self, vector: &Array1<f32>, table_idx: usize) -> Vec<bool> {
219        let hyperplanes = &self.hyperplanes[table_idx];
220
221        // Use ndarray's optimized matrix-vector multiplication
222        let dot_products = hyperplanes.dot(vector);
223
224        // Convert to hash bits in one pass
225        dot_products.iter().map(|&x| x >= 0.0).collect()
226    }
227
228    /// Get statistics about the index
229    pub fn stats(&self) -> LSHStats {
230        let mut bucket_sizes = Vec::new();
231        let mut total_buckets = 0;
232        let mut non_empty_buckets = 0;
233
234        for table in &self.hash_tables {
235            // Use checked_shl to avoid overflow - if hash_size is too large, use max value
236            let buckets_per_table = 1_usize
237                .checked_shl(self.hash_size as u32)
238                .unwrap_or(usize::MAX);
239            total_buckets += buckets_per_table;
240            non_empty_buckets += table.len();
241
242            for bucket in table.values() {
243                bucket_sizes.push(bucket.len());
244            }
245        }
246
247        bucket_sizes.sort();
248        let median_bucket_size = if bucket_sizes.is_empty() {
249            0.0
250        } else {
251            bucket_sizes[bucket_sizes.len() / 2] as f32
252        };
253
254        let avg_bucket_size = if bucket_sizes.is_empty() {
255            0.0
256        } else {
257            bucket_sizes.iter().sum::<usize>() as f32 / bucket_sizes.len() as f32
258        };
259
260        LSHStats {
261            num_vectors: self.stored_vectors.len(),
262            num_tables: self.num_tables,
263            hash_size: self.hash_size,
264            total_buckets,
265            non_empty_buckets,
266            avg_bucket_size,
267            median_bucket_size,
268            max_bucket_size: bucket_sizes.last().copied().unwrap_or(0),
269        }
270    }
271
272    /// Save LSH configuration and hash functions to database
273    pub fn save_to_database(
274        &self,
275        db: &crate::persistence::Database,
276    ) -> Result<(), Box<dyn std::error::Error>> {
277        use crate::persistence::{LSHHashFunction, LSHTableData};
278
279        // Convert hyperplanes to serializable format
280        let mut hash_functions = Vec::new();
281        for hyperplane_matrix in &self.hyperplanes {
282            for row in hyperplane_matrix.rows() {
283                hash_functions.push(LSHHashFunction {
284                    random_vector: row.to_vec().iter().map(|&x| x as f64).collect(),
285                    threshold: 0.0, // We use zero threshold for hyperplane hashing
286                });
287            }
288        }
289
290        let config = LSHTableData {
291            num_tables: self.num_tables,
292            num_hash_functions: self.hash_size,
293            vector_dim: self.vector_dim,
294            hash_functions,
295        };
296
297        db.save_lsh_config(&config)?;
298
299        // Clear existing bucket data and save new buckets
300        db.clear_lsh_buckets()?;
301
302        // Save hash bucket assignments (this maps positions to buckets)
303        for (table_idx, table) in self.hash_tables.iter().enumerate() {
304            for (hash_bits, indices) in table {
305                let hash_string = hash_bits
306                    .iter()
307                    .map(|&b| if b { '1' } else { '0' })
308                    .collect::<String>();
309
310                for &position_idx in indices {
311                    db.save_lsh_bucket(table_idx, &hash_string, position_idx as i64)?;
312                }
313            }
314        }
315
316        Ok(())
317    }
318
319    /// Load LSH configuration from database and rebuild hash tables
320    pub fn load_from_database(
321        db: &crate::persistence::Database,
322        positions: &[(Array1<f32>, f32)],
323    ) -> Result<Option<Self>, Box<dyn std::error::Error>> {
324        let config = match db.load_lsh_config()? {
325            Some(config) => config,
326            None => return Ok(None),
327        };
328
329        // Reconstruct hyperplanes from saved hash functions
330        let mut hyperplanes = Vec::new();
331        let functions_per_table = config.num_hash_functions;
332
333        for table_idx in 0..config.num_tables {
334            let start_idx = table_idx * functions_per_table;
335            let end_idx = start_idx + functions_per_table;
336
337            if end_idx <= config.hash_functions.len() {
338                let mut table_hyperplanes = Array2::zeros((functions_per_table, config.vector_dim));
339
340                for (func_idx, hash_func) in
341                    config.hash_functions[start_idx..end_idx].iter().enumerate()
342                {
343                    for (dim_idx, &value) in hash_func.random_vector.iter().enumerate() {
344                        if dim_idx < config.vector_dim {
345                            table_hyperplanes[[func_idx, dim_idx]] = value as f32;
346                        }
347                    }
348                }
349
350                hyperplanes.push(table_hyperplanes);
351            }
352        }
353
354        // Create LSH with loaded configuration
355        let mut lsh = Self {
356            num_tables: config.num_tables,
357            hash_size: config.num_hash_functions,
358            vector_dim: config.vector_dim,
359            hyperplanes,
360            hash_tables: vec![HashMap::with_capacity(positions.len().max(100)); config.num_tables],
361            stored_vectors: Vec::new(),
362            stored_data: Vec::new(),
363        };
364
365        // Rebuild the index with provided positions
366        for (vector, evaluation) in positions {
367            lsh.add_vector(vector.clone(), *evaluation);
368        }
369
370        Ok(Some(lsh))
371    }
372
373    /// Create LSH from database or return None if no saved configuration exists
374    pub fn from_database_or_new(
375        db: &crate::persistence::Database,
376        positions: &[(Array1<f32>, f32)],
377        vector_dim: usize,
378        num_tables: usize,
379        hash_size: usize,
380    ) -> Result<Self, Box<dyn std::error::Error>> {
381        match Self::load_from_database(db, positions)? {
382            Some(lsh) => {
383                println!(
384                    "Loaded LSH configuration from database with {} vectors",
385                    lsh.stored_vectors.len()
386                );
387                Ok(lsh)
388            }
389            None => {
390                println!("No saved LSH configuration found, creating new LSH index");
391                let mut lsh =
392                    Self::with_expected_size(vector_dim, num_tables, hash_size, positions.len());
393                for (vector, evaluation) in positions {
394                    lsh.add_vector(vector.clone(), *evaluation);
395                }
396                Ok(lsh)
397            }
398        }
399    }
400}
401
402/// LSH performance statistics
403#[derive(Debug)]
404pub struct LSHStats {
405    pub num_vectors: usize,
406    pub num_tables: usize,
407    pub hash_size: usize,
408    pub total_buckets: usize,
409    pub non_empty_buckets: usize,
410    pub avg_bucket_size: f32,
411    pub median_bucket_size: f32,
412    pub max_bucket_size: usize,
413}
414
415/// Calculate cosine similarity between two vectors (SIMD optimized)
416fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
417    let a_slice = a.as_slice().unwrap();
418    let b_slice = b.as_slice().unwrap();
419
420    let dot_product = simd_dot_product(a_slice, b_slice);
421    let norm_a_sq = simd_dot_product(a_slice, a_slice);
422    let norm_b_sq = simd_dot_product(b_slice, b_slice);
423
424    if norm_a_sq == 0.0 || norm_b_sq == 0.0 {
425        0.0
426    } else {
427        dot_product / (norm_a_sq * norm_b_sq).sqrt()
428    }
429}
430
431/// SIMD-optimized dot product calculation
432#[inline]
433fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
434    #[cfg(target_arch = "x86_64")]
435    {
436        if is_x86_feature_detected!("avx2") {
437            return unsafe { avx2_dot_product(a, b) };
438        } else if is_x86_feature_detected!("sse4.1") {
439            return unsafe { sse_dot_product(a, b) };
440        }
441    }
442
443    #[cfg(target_arch = "aarch64")]
444    {
445        if std::arch::is_aarch64_feature_detected!("neon") {
446            return unsafe { neon_dot_product(a, b) };
447        }
448    }
449
450    // Fallback to scalar implementation
451    scalar_dot_product(a, b)
452}
453
454#[cfg(target_arch = "x86_64")]
455#[target_feature(enable = "avx2")]
456unsafe fn avx2_dot_product(a: &[f32], b: &[f32]) -> f32 {
457    let len = a.len().min(b.len());
458    let mut sum = _mm256_setzero_ps();
459    let mut i = 0;
460
461    while i + 8 <= len {
462        let va = _mm256_loadu_ps(a.as_ptr().add(i));
463        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
464        let vmul = _mm256_mul_ps(va, vb);
465        sum = _mm256_add_ps(sum, vmul);
466        i += 8;
467    }
468
469    let mut result = [0.0f32; 8];
470    _mm256_storeu_ps(result.as_mut_ptr(), sum);
471    let mut final_sum = result.iter().sum::<f32>();
472
473    while i < len {
474        final_sum += a[i] * b[i];
475        i += 1;
476    }
477
478    final_sum
479}
480
481#[cfg(target_arch = "x86_64")]
482#[target_feature(enable = "sse4.1")]
483unsafe fn sse_dot_product(a: &[f32], b: &[f32]) -> f32 {
484    let len = a.len().min(b.len());
485    let mut sum = _mm_setzero_ps();
486    let mut i = 0;
487
488    while i + 4 <= len {
489        let va = _mm_loadu_ps(a.as_ptr().add(i));
490        let vb = _mm_loadu_ps(b.as_ptr().add(i));
491        let vmul = _mm_mul_ps(va, vb);
492        sum = _mm_add_ps(sum, vmul);
493        i += 4;
494    }
495
496    let mut result = [0.0f32; 4];
497    _mm_storeu_ps(result.as_mut_ptr(), sum);
498    let mut final_sum = result.iter().sum::<f32>();
499
500    while i < len {
501        final_sum += a[i] * b[i];
502        i += 1;
503    }
504
505    final_sum
506}
507
508#[cfg(target_arch = "aarch64")]
509#[target_feature(enable = "neon")]
510unsafe fn neon_dot_product(a: &[f32], b: &[f32]) -> f32 {
511    let len = a.len().min(b.len());
512    let mut sum = vdupq_n_f32(0.0);
513    let mut i = 0;
514
515    while i + 4 <= len {
516        let va = vld1q_f32(a.as_ptr().add(i));
517        let vb = vld1q_f32(b.as_ptr().add(i));
518        let vmul = vmulq_f32(va, vb);
519        sum = vaddq_f32(sum, vmul);
520        i += 4;
521    }
522
523    let mut result = [0.0f32; 4];
524    vst1q_f32(result.as_mut_ptr(), sum);
525    let mut final_sum = result.iter().sum::<f32>();
526
527    while i < len {
528        final_sum += a[i] * b[i];
529        i += 1;
530    }
531
532    final_sum
533}
534
535#[inline]
536fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
537    let len = a.len().min(b.len());
538    let mut sum = 0.0f32;
539    let mut i = 0;
540
541    while i + 4 <= len {
542        sum += a[i] * b[i] + a[i + 1] * b[i + 1] + a[i + 2] * b[i + 2] + a[i + 3] * b[i + 3];
543        i += 4;
544    }
545
546    while i < len {
547        sum += a[i] * b[i];
548        i += 1;
549    }
550
551    sum
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use ndarray::Array1;
558
559    #[test]
560    fn test_lsh_creation() {
561        let lsh = LSH::new(128, 4, 8);
562        assert_eq!(lsh.num_tables, 4);
563        assert_eq!(lsh.hash_size, 8);
564        assert_eq!(lsh.vector_dim, 128);
565    }
566
567    #[test]
568    fn test_lsh_add_and_query() {
569        let mut lsh = LSH::new(4, 2, 4);
570
571        // Add some test vectors
572        let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
573        let vec2 = Array1::from(vec![0.0, 1.0, 0.0, 0.0]);
574        let vec3 = Array1::from(vec![1.0, 0.1, 0.0, 0.0]); // Similar to vec1
575
576        lsh.add_vector(vec1.clone(), 1.0);
577        lsh.add_vector(vec2, 2.0);
578        lsh.add_vector(vec3, 1.1);
579
580        // Query with vec1 should find similar vectors
581        let results = lsh.query(&vec1, 2);
582        assert!(!results.is_empty());
583
584        // The most similar should be vec1 itself or vec3
585        assert!(results[0].2 > 0.8); // High similarity
586    }
587
588    #[test]
589    fn test_cosine_similarity() {
590        let a = Array1::from(vec![1.0, 0.0, 0.0]);
591        let b = Array1::from(vec![1.0, 0.0, 0.0]);
592        let c = Array1::from(vec![0.0, 1.0, 0.0]);
593
594        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
595        assert!((cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
596    }
597}