chess_vector_engine/
lsh.rs

1use crate::utils::simd::SimdVectorOps;
2use ndarray::{Array1, Array2};
3use rand::Rng;
4use rayon::prelude::*;
5use std::collections::HashMap;
6
7/// Locality Sensitive Hashing for approximate nearest neighbor search
8#[derive(Clone)]
9pub struct LSH {
10    /// Number of hash tables
11    num_tables: usize,
12    /// Number of hash functions per table
13    hash_size: usize,
14    /// Vector dimension
15    #[allow(dead_code)]
16    vector_dim: usize,
17    /// Random hyperplanes for each hash table
18    hyperplanes: Vec<Array2<f32>>,
19    /// Hash tables storing (hash, vector_index) pairs
20    hash_tables: Vec<HashMap<Vec<bool>, Vec<usize>>>,
21    /// Stored vectors for retrieval
22    stored_vectors: Vec<Array1<f32>>,
23    /// Associated data (evaluations)
24    stored_data: Vec<f32>,
25}
26
27impl LSH {
28    /// Create a new LSH index with dynamic sizing
29    pub fn new(vector_dim: usize, num_tables: usize, hash_size: usize) -> Self {
30        Self::with_expected_size(vector_dim, num_tables, hash_size, 1000)
31    }
32
33    /// Create a new LSH index with expected dataset size for optimal memory allocation
34    pub fn with_expected_size(
35        vector_dim: usize,
36        num_tables: usize,
37        hash_size: usize,
38        expected_size: usize,
39    ) -> Self {
40        let mut rng = rand::thread_rng();
41        let mut hyperplanes = Vec::new();
42
43        // Generate random hyperplanes for each table
44        for _ in 0..num_tables {
45            let mut table_hyperplanes = Array2::zeros((hash_size, vector_dim));
46            for i in 0..hash_size {
47                for j in 0..vector_dim {
48                    table_hyperplanes[[i, j]] = rng.gen_range(-1.0..1.0);
49                }
50            }
51            hyperplanes.push(table_hyperplanes);
52        }
53
54        // Calculate optimal hash table capacity based on expected size
55        // Assume roughly 20% occupancy for good performance
56        let bucket_capacity = (expected_size as f32 / 5.0).ceil() as usize;
57        let optimal_capacity = bucket_capacity.max(100);
58
59        let hash_tables = vec![HashMap::with_capacity(optimal_capacity); num_tables];
60
61        Self {
62            num_tables,
63            hash_size,
64            vector_dim,
65            hyperplanes,
66            hash_tables,
67            stored_vectors: Vec::with_capacity(expected_size),
68            stored_data: Vec::with_capacity(expected_size),
69        }
70    }
71
72    /// Add a vector to the LSH index with dynamic resizing
73    pub fn add_vector(&mut self, vector: Array1<f32>, data: f32) {
74        let index = self.stored_vectors.len();
75
76        // Check if we need to resize hash tables (when load factor > 0.75)
77        let current_load =
78            self.stored_vectors.len() as f32 / (self.hash_tables[0].capacity() as f32 * 0.2);
79        if current_load > 0.75 {
80            self.resize_hash_tables();
81        }
82
83        // Hash the vector in each table before storing (to avoid clone)
84        let mut hashes = Vec::with_capacity(self.num_tables);
85        for table_idx in 0..self.num_tables {
86            hashes.push(self.hash_vector(&vector, table_idx));
87        }
88
89        // Now store the vector and data
90        self.stored_vectors.push(vector);
91        self.stored_data.push(data);
92
93        // Insert into hash tables using pre-computed hashes
94        for (table_idx, hash) in hashes.into_iter().enumerate() {
95            self.hash_tables[table_idx]
96                .entry(hash)
97                .or_insert_with(|| Vec::with_capacity(8)) // Increased pre-allocation
98                .push(index);
99        }
100    }
101
102    /// Resize hash tables when they become too full
103    fn resize_hash_tables(&mut self) {
104        let new_capacity = (self.hash_tables[0].capacity() * 2).max(self.stored_vectors.len());
105
106        for table in &mut self.hash_tables {
107            // Reserve additional capacity to avoid frequent rehashing
108            table.reserve(new_capacity - table.capacity());
109        }
110    }
111
112    /// Find approximate nearest neighbors
113    pub fn query(&self, query_vector: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
114        let mut candidates = std::collections::HashSet::new();
115        let max_candidates = (self.stored_vectors.len() / 4)
116            .max(k * 10)
117            .min(self.stored_vectors.len());
118
119        // Parallelize hash table queries when we have many tables
120        if self.num_tables > 4 {
121            // Collect candidates from hash tables in parallel
122            let candidate_sets: Vec<Vec<usize>> = (0..self.num_tables)
123                .into_par_iter()
124                .map(|table_idx| {
125                    let hash = self.hash_vector(query_vector, table_idx);
126                    if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
127                        bucket.clone()
128                    } else {
129                        Vec::new()
130                    }
131                })
132                .collect();
133
134            // Merge candidate sets sequentially (avoiding race conditions on HashSet)
135            for candidate_set in candidate_sets {
136                for idx in candidate_set {
137                    candidates.insert(idx);
138                    if candidates.len() >= max_candidates {
139                        break;
140                    }
141                }
142                if candidates.len() >= max_candidates {
143                    break;
144                }
145            }
146        } else {
147            // Sequential collection for smaller numbers of tables
148            for table_idx in 0..self.num_tables {
149                if candidates.len() >= max_candidates {
150                    break;
151                }
152
153                let hash = self.hash_vector(query_vector, table_idx);
154                if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
155                    for &idx in bucket {
156                        candidates.insert(idx);
157                        if candidates.len() >= max_candidates {
158                            break;
159                        }
160                    }
161                }
162            }
163        }
164
165        // If we have too few candidates, use a more efficient approach
166        if candidates.len() < k * 3 && self.stored_vectors.len() > k * 3 {
167            // Instead of random sampling, just take the first few indices
168            let needed = k * 5;
169            for idx in 0..needed.min(self.stored_vectors.len()) {
170                candidates.insert(idx);
171                if candidates.len() >= needed {
172                    break;
173                }
174            }
175        }
176
177        // Calculate similarities for candidates in parallel for large candidate sets
178        let mut results = if candidates.len() > 50 {
179            candidates
180                .par_iter()
181                .map(|&idx| {
182                    let stored_vector = &self.stored_vectors[idx];
183                    let similarity = cosine_similarity(query_vector, stored_vector);
184                    (stored_vector.clone(), self.stored_data[idx], similarity)
185                })
186                .collect()
187        } else {
188            // Sequential for smaller candidate sets
189            let mut results = Vec::with_capacity(candidates.len());
190            for &idx in &candidates {
191                let stored_vector = &self.stored_vectors[idx];
192                let similarity = cosine_similarity(query_vector, stored_vector);
193                results.push((stored_vector.clone(), self.stored_data[idx], similarity));
194            }
195            results
196        };
197
198        // Sort by similarity (descending) and return top k
199        if results.len() > 100 {
200            results.par_sort_unstable_by(|a, b| {
201                b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
202            });
203        } else {
204            results.sort_unstable_by(|a, b| {
205                b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
206            });
207        }
208        results.truncate(k);
209
210        results
211    }
212
213    /// Hash a vector using hyperplanes for a specific table
214    fn hash_vector(&self, vector: &Array1<f32>, table_idx: usize) -> Vec<bool> {
215        let hyperplanes = &self.hyperplanes[table_idx];
216
217        // Use ndarray's optimized matrix-vector multiplication
218        let dot_products = hyperplanes.dot(vector);
219
220        // Convert to hash bits in one pass
221        dot_products.iter().map(|&x| x >= 0.0).collect()
222    }
223
224    /// Get statistics about the index
225    pub fn stats(&self) -> LSHStats {
226        let mut bucket_sizes = Vec::new();
227        let mut total_buckets = 0;
228        let mut non_empty_buckets = 0;
229
230        for table in &self.hash_tables {
231            // Use checked_shl to avoid overflow - if hash_size is too large, use max value
232            let buckets_per_table = 1_usize
233                .checked_shl(self.hash_size as u32)
234                .unwrap_or(usize::MAX);
235            total_buckets += buckets_per_table;
236            non_empty_buckets += table.len();
237
238            for bucket in table.values() {
239                bucket_sizes.push(bucket.len());
240            }
241        }
242
243        bucket_sizes.sort();
244        let median_bucket_size = if bucket_sizes.is_empty() {
245            0.0
246        } else {
247            bucket_sizes[bucket_sizes.len() / 2] as f32
248        };
249
250        let avg_bucket_size = if bucket_sizes.is_empty() {
251            0.0
252        } else {
253            bucket_sizes.iter().sum::<usize>() as f32 / bucket_sizes.len() as f32
254        };
255
256        LSHStats {
257            num_vectors: self.stored_vectors.len(),
258            num_tables: self.num_tables,
259            hash_size: self.hash_size,
260            total_buckets,
261            non_empty_buckets,
262            avg_bucket_size,
263            median_bucket_size,
264            max_bucket_size: bucket_sizes.last().copied().unwrap_or(0),
265        }
266    }
267
268    /// Save LSH configuration and hash functions to database
269    pub fn save_to_database(
270        &self,
271        db: &crate::persistence::Database,
272    ) -> Result<(), Box<dyn std::error::Error>> {
273        use crate::persistence::{LSHHashFunction, LSHTableData};
274
275        // Convert hyperplanes to serializable format
276        let mut hash_functions = Vec::new();
277        for hyperplane_matrix in &self.hyperplanes {
278            for row in hyperplane_matrix.rows() {
279                hash_functions.push(LSHHashFunction {
280                    random_vector: row.to_vec().iter().map(|&x| x as f64).collect(),
281                    threshold: 0.0, // We use zero threshold for hyperplane hashing
282                });
283            }
284        }
285
286        let config = LSHTableData {
287            num_tables: self.num_tables,
288            num_hash_functions: self.hash_size,
289            vector_dim: self.vector_dim,
290            hash_functions,
291        };
292
293        db.save_lsh_config(&config)?;
294
295        // Clear existing bucket data and save new buckets
296        db.clear_lsh_buckets()?;
297
298        // Save hash bucket assignments (this maps positions to buckets)
299        for (table_idx, table) in self.hash_tables.iter().enumerate() {
300            for (hash_bits, indices) in table {
301                let hash_string = hash_bits
302                    .iter()
303                    .map(|&b| if b { '1' } else { '0' })
304                    .collect::<String>();
305
306                for &position_idx in indices {
307                    db.save_lsh_bucket(table_idx, &hash_string, position_idx as i64)?;
308                }
309            }
310        }
311
312        Ok(())
313    }
314
315    /// Load LSH configuration from database and rebuild hash tables
316    pub fn load_from_database(
317        db: &crate::persistence::Database,
318        positions: &[(Array1<f32>, f32)],
319    ) -> Result<Option<Self>, Box<dyn std::error::Error>> {
320        let config = match db.load_lsh_config()? {
321            Some(config) => config,
322            None => return Ok(None),
323        };
324
325        // Reconstruct hyperplanes from saved hash functions
326        let mut hyperplanes = Vec::new();
327        let functions_per_table = config.num_hash_functions;
328
329        for table_idx in 0..config.num_tables {
330            let start_idx = table_idx * functions_per_table;
331            let end_idx = start_idx + functions_per_table;
332
333            if end_idx <= config.hash_functions.len() {
334                let mut table_hyperplanes = Array2::zeros((functions_per_table, config.vector_dim));
335
336                for (func_idx, hash_func) in
337                    config.hash_functions[start_idx..end_idx].iter().enumerate()
338                {
339                    for (dim_idx, &value) in hash_func.random_vector.iter().enumerate() {
340                        if dim_idx < config.vector_dim {
341                            table_hyperplanes[[func_idx, dim_idx]] = value as f32;
342                        }
343                    }
344                }
345
346                hyperplanes.push(table_hyperplanes);
347            }
348        }
349
350        // Create LSH with loaded configuration
351        let mut lsh = Self {
352            num_tables: config.num_tables,
353            hash_size: config.num_hash_functions,
354            vector_dim: config.vector_dim,
355            hyperplanes,
356            hash_tables: vec![HashMap::with_capacity(positions.len().max(100)); config.num_tables],
357            stored_vectors: Vec::new(),
358            stored_data: Vec::new(),
359        };
360
361        // Rebuild the index with provided positions
362        for (vector, evaluation) in positions {
363            lsh.add_vector(vector.clone(), *evaluation);
364        }
365
366        Ok(Some(lsh))
367    }
368
369    /// Create LSH from database or return None if no saved configuration exists
370    pub fn from_database_or_new(
371        db: &crate::persistence::Database,
372        positions: &[(Array1<f32>, f32)],
373        vector_dim: usize,
374        num_tables: usize,
375        hash_size: usize,
376    ) -> Result<Self, Box<dyn std::error::Error>> {
377        match Self::load_from_database(db, positions)? {
378            Some(lsh) => {
379                println!(
380                    "Loaded LSH configuration from database with {} vectors",
381                    lsh.stored_vectors.len()
382                );
383                Ok(lsh)
384            }
385            None => {
386                println!("No saved LSH configuration found, creating new LSH index");
387                let mut lsh =
388                    Self::with_expected_size(vector_dim, num_tables, hash_size, positions.len());
389                for (vector, evaluation) in positions {
390                    lsh.add_vector(vector.clone(), *evaluation);
391                }
392                Ok(lsh)
393            }
394        }
395    }
396}
397
398/// LSH performance statistics
399#[derive(Debug)]
400pub struct LSHStats {
401    pub num_vectors: usize,
402    pub num_tables: usize,
403    pub hash_size: usize,
404    pub total_buckets: usize,
405    pub non_empty_buckets: usize,
406    pub avg_bucket_size: f32,
407    pub median_bucket_size: f32,
408    pub max_bucket_size: usize,
409}
410
411/// Calculate cosine similarity between two vectors (SIMD optimized)
412fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
413    SimdVectorOps::cosine_similarity(a, b)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use ndarray::Array1;
420
421    #[test]
422    fn test_lsh_creation() {
423        let lsh = LSH::new(128, 4, 8);
424        assert_eq!(lsh.num_tables, 4);
425        assert_eq!(lsh.hash_size, 8);
426        assert_eq!(lsh.vector_dim, 128);
427    }
428
429    #[test]
430    fn test_lsh_add_and_query() {
431        let mut lsh = LSH::new(4, 2, 4);
432
433        // Add some test vectors
434        let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
435        let vec2 = Array1::from(vec![0.0, 1.0, 0.0, 0.0]);
436        let vec3 = Array1::from(vec![1.0, 0.1, 0.0, 0.0]); // Similar to vec1
437
438        lsh.add_vector(vec1.clone(), 1.0);
439        lsh.add_vector(vec2, 2.0);
440        lsh.add_vector(vec3, 1.1);
441
442        // Query with vec1 should find similar vectors
443        let results = lsh.query(&vec1, 2);
444        assert!(!results.is_empty());
445
446        // The most similar should be vec1 itself or vec3
447        assert!(results[0].2 > 0.8); // High similarity
448    }
449
450    #[test]
451    fn test_cosine_similarity() {
452        let a = Array1::from(vec![1.0, 0.0, 0.0]);
453        let b = Array1::from(vec![1.0, 0.0, 0.0]);
454        let c = Array1::from(vec![0.0, 1.0, 0.0]);
455
456        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
457        assert!((cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
458    }
459}