chess_vector_engine/
ann.rs

1use ndarray::Array1;
2use std::cmp::Ordering;
3
4/// Approximate Nearest Neighbor search using multiple strategies
5pub struct ANNIndex {
6    /// All stored vectors
7    vectors: Vec<Array1<f32>>,
8    /// Associated data (evaluations)
9    data: Vec<f32>,
10    /// LSH index for fast approximate search
11    lsh: Option<crate::lsh::LSH>,
12    /// Use random projections for dimensionality reduction
13    use_random_projections: bool,
14    /// Random projection matrix (if enabled)
15    projection_matrix: Option<Array2<f32>>,
16    /// Projected dimension
17    projected_dim: usize,
18    /// Original vector dimension
19    vector_dim: usize,
20}
21
22/// Search result with similarity score
23#[derive(Debug, Clone)]
24pub struct ANNResult {
25    pub vector: Array1<f32>,
26    pub data: f32,
27    pub similarity: f32,
28}
29
30impl PartialEq for ANNResult {
31    fn eq(&self, other: &Self) -> bool {
32        self.similarity == other.similarity
33    }
34}
35
36impl Eq for ANNResult {}
37
38impl PartialOrd for ANNResult {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl Ord for ANNResult {
45    fn cmp(&self, other: &Self) -> Ordering {
46        // Reverse for max-heap
47        other
48            .similarity
49            .partial_cmp(&self.similarity)
50            .unwrap_or(Ordering::Equal)
51    }
52}
53
54impl ANNIndex {
55    /// Create a new ANN index
56    pub fn new(vector_dim: usize) -> Self {
57        Self {
58            vectors: Vec::new(),
59            data: Vec::new(),
60            lsh: None,
61            use_random_projections: false,
62            projection_matrix: None,
63            projected_dim: vector_dim / 4, // Default to 1/4 of original dimension
64            vector_dim,
65        }
66    }
67
68    /// Enable LSH indexing
69    pub fn with_lsh(mut self, num_tables: usize, hash_size: usize) -> Self {
70        self.lsh = Some(crate::lsh::LSH::new(self.vector_dim, num_tables, hash_size));
71        self
72    }
73
74    /// Enable random projections for dimensionality reduction
75    pub fn with_random_projections(mut self, projected_dim: usize) -> Self {
76        self.use_random_projections = true;
77        self.projected_dim = projected_dim;
78        self
79    }
80
81    /// Add a vector to the index
82    pub fn add_vector(&mut self, vector: Array1<f32>, data: f32) {
83        // Initialize random projection matrix if needed
84        if self.use_random_projections && self.projection_matrix.is_none() {
85            self.init_random_projections(vector.len());
86        }
87
88        self.vectors.push(vector.clone());
89        self.data.push(data);
90
91        // Add to LSH if enabled
92        if let Some(ref mut lsh) = self.lsh {
93            lsh.add_vector(vector, data);
94        }
95    }
96
97    /// Search for approximate nearest neighbors
98    pub fn search(
99        &self,
100        query: &Array1<f32>,
101        k: usize,
102        strategy: SearchStrategy,
103    ) -> Vec<ANNResult> {
104        match strategy {
105            SearchStrategy::LSH => self.search_lsh(query, k),
106            SearchStrategy::RandomProjection => self.search_random_projection(query, k),
107            SearchStrategy::Hybrid => self.search_hybrid(query, k),
108            SearchStrategy::Exact => self.search_exact(query, k),
109        }
110    }
111
112    /// LSH-based search
113    fn search_lsh(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
114        if let Some(ref lsh) = self.lsh {
115            lsh.query(query, k)
116                .into_iter()
117                .map(|(vec, data, sim)| ANNResult {
118                    vector: vec,
119                    data,
120                    similarity: sim,
121                })
122                .collect()
123        } else {
124            self.search_exact(query, k)
125        }
126    }
127
128    /// Random projection-based search
129    fn search_random_projection(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
130        if !self.use_random_projections || self.projection_matrix.is_none() {
131            return self.search_exact(query, k);
132        }
133
134        let proj_matrix = self.projection_matrix.as_ref().unwrap();
135        let proj_query = self.project_vector(query, proj_matrix);
136
137        let mut results: Vec<_> = self
138            .vectors
139            .iter()
140            .zip(self.data.iter())
141            .map(|(vec, &data)| {
142                let proj_vec = self.project_vector(vec, proj_matrix);
143                let similarity = cosine_similarity(&proj_query, &proj_vec);
144                ANNResult {
145                    vector: vec.clone(),
146                    data,
147                    similarity,
148                }
149            })
150            .collect();
151
152        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
153        results.truncate(k);
154        results
155    }
156
157    /// Hybrid search combining multiple strategies
158    fn search_hybrid(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
159        let mut candidate_indices = std::collections::HashSet::new();
160        let mut results = Vec::new();
161
162        // Get candidates from LSH
163        if let Some(ref lsh) = self.lsh {
164            let lsh_results = lsh.query(query, k * 2);
165            for (vec, _data, _) in lsh_results {
166                // Find the index of this vector in our stored vectors
167                for (idx, stored_vec) in self.vectors.iter().enumerate() {
168                    if vectors_approximately_equal(&vec, stored_vec) {
169                        candidate_indices.insert(idx);
170                        break;
171                    }
172                }
173            }
174        }
175
176        // Get candidates from random projection
177        if self.use_random_projections {
178            let rp_results = self.search_random_projection(query, k * 2);
179            for result in rp_results {
180                // Find the index of this vector
181                for (idx, stored_vec) in self.vectors.iter().enumerate() {
182                    if vectors_approximately_equal(&result.vector, stored_vec) {
183                        candidate_indices.insert(idx);
184                        break;
185                    }
186                }
187            }
188        }
189
190        // If we don't have enough candidates, add some random ones
191        if candidate_indices.len() < k * 3 {
192            for idx in 0..(k * 3).min(self.vectors.len()) {
193                candidate_indices.insert(idx);
194            }
195        }
196
197        // Re-rank candidates using exact similarity
198        for &idx in &candidate_indices {
199            let vec = &self.vectors[idx];
200            let data = self.data[idx];
201            let similarity = cosine_similarity(query, vec);
202            results.push(ANNResult {
203                vector: vec.clone(),
204                data,
205                similarity,
206            });
207        }
208
209        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
210        results.truncate(k);
211        results
212    }
213
214    /// Exact search (brute force)
215    fn search_exact(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
216        let mut results: Vec<_> = self
217            .vectors
218            .iter()
219            .zip(self.data.iter())
220            .map(|(vec, &data)| {
221                let similarity = cosine_similarity(query, vec);
222                ANNResult {
223                    vector: vec.clone(),
224                    data,
225                    similarity,
226                }
227            })
228            .collect();
229
230        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
231        results.truncate(k);
232        results
233    }
234
235    /// Initialize random projection matrix
236    fn init_random_projections(&mut self, input_dim: usize) {
237        use rand::Rng;
238        let mut rng = rand::thread_rng();
239
240        // Use the provided input_dim (should match self.vector_dim)
241        assert_eq!(
242            input_dim, self.vector_dim,
243            "Input dimension should match vector dimension"
244        );
245
246        let mut matrix_data = Vec::with_capacity(self.projected_dim * input_dim);
247        for _ in 0..(self.projected_dim * input_dim) {
248            matrix_data.push(rng.gen_range(-1.0..1.0));
249        }
250
251        self.projection_matrix = Some(
252            Array2::from_shape_vec((self.projected_dim, input_dim), matrix_data)
253                .expect("Failed to create projection matrix"),
254        );
255    }
256
257    /// Project a vector to lower dimension
258    fn project_vector(&self, vector: &Array1<f32>, proj_matrix: &Array2<f32>) -> Array1<f32> {
259        let mut result = Array1::zeros(self.projected_dim);
260        for i in 0..self.projected_dim {
261            let dot_product: f32 = vector
262                .iter()
263                .zip(proj_matrix.row(i).iter())
264                .map(|(v, p)| v * p)
265                .sum();
266            result[i] = dot_product;
267        }
268        result
269    }
270
271    /// Get statistics about the index
272    pub fn stats(&self) -> ANNStats {
273        ANNStats {
274            num_vectors: self.vectors.len(),
275            vector_dim: if self.vectors.is_empty() {
276                0
277            } else {
278                self.vectors[0].len()
279            },
280            has_lsh: self.lsh.is_some(),
281            has_random_projections: self.use_random_projections,
282            projected_dim: if self.use_random_projections {
283                Some(self.projected_dim)
284            } else {
285                None
286            },
287        }
288    }
289}
290
291/// Search strategies for ANN
292#[derive(Debug, Clone, Copy)]
293pub enum SearchStrategy {
294    /// Use LSH for approximate search
295    LSH,
296    /// Use random projections
297    RandomProjection,
298    /// Combine multiple strategies
299    Hybrid,
300    /// Exact search (for comparison)
301    Exact,
302}
303
304/// Statistics about the ANN index
305#[derive(Debug)]
306pub struct ANNStats {
307    pub num_vectors: usize,
308    pub vector_dim: usize,
309    pub has_lsh: bool,
310    pub has_random_projections: bool,
311    pub projected_dim: Option<usize>,
312}
313
314/// Calculate cosine similarity between two vectors
315fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
316    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
317    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
318    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
319
320    if norm_a == 0.0 || norm_b == 0.0 {
321        0.0
322    } else {
323        dot_product / (norm_a * norm_b)
324    }
325}
326
327/// Check if two vectors are approximately equal
328fn vectors_approximately_equal(a: &Array1<f32>, b: &Array1<f32>) -> bool {
329    if a.len() != b.len() {
330        return false;
331    }
332
333    let threshold = 1e-6;
334    for (x, y) in a.iter().zip(b.iter()) {
335        if (x - y).abs() > threshold {
336            return false;
337        }
338    }
339    true
340}
341
342use ndarray::Array2;
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use ndarray::Array1;
348
349    #[test]
350    fn test_ann_index_creation() {
351        let index = ANNIndex::new(128);
352        assert_eq!(index.vectors.len(), 0);
353        assert!(!index.use_random_projections);
354        assert!(index.lsh.is_none());
355    }
356
357    #[test]
358    fn test_ann_with_lsh() {
359        let index = ANNIndex::new(128).with_lsh(4, 8);
360        assert!(index.lsh.is_some());
361    }
362
363    #[test]
364    fn test_ann_with_random_projections() {
365        let index = ANNIndex::new(128).with_random_projections(32);
366        assert!(index.use_random_projections);
367        assert_eq!(index.projected_dim, 32);
368    }
369
370    #[test]
371    fn test_add_and_search() {
372        let mut index = ANNIndex::new(4);
373
374        let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
375        let vec2 = Array1::from(vec![0.0, 1.0, 0.0, 0.0]);
376        let vec3 = Array1::from(vec![1.0, 0.1, 0.0, 0.0]);
377
378        index.add_vector(vec1.clone(), 1.0);
379        index.add_vector(vec2, 2.0);
380        index.add_vector(vec3, 1.1);
381
382        let results = index.search(&vec1, 2, SearchStrategy::Exact);
383        assert_eq!(results.len(), 2);
384        assert!(results[0].similarity > 0.9); // Should find itself first
385    }
386
387    #[test]
388    fn test_search_strategies() {
389        let mut index = ANNIndex::new(4).with_lsh(2, 4).with_random_projections(2);
390
391        let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
392        index.add_vector(vec1.clone(), 1.0);
393
394        // Test all search strategies
395        let exact = index.search(&vec1, 1, SearchStrategy::Exact);
396        let lsh = index.search(&vec1, 1, SearchStrategy::LSH);
397        let rp = index.search(&vec1, 1, SearchStrategy::RandomProjection);
398        let hybrid = index.search(&vec1, 1, SearchStrategy::Hybrid);
399
400        assert!(!exact.is_empty());
401        assert!(!lsh.is_empty());
402        assert!(!rp.is_empty());
403        assert!(!hybrid.is_empty());
404    }
405}