oxify_vector/
lsh.rs

1//! Locality Sensitive Hashing (LSH) for Approximate Nearest Neighbor Search
2//!
3//! LSH is an ANN algorithm that uses hash functions to map similar vectors
4//! to the same buckets with high probability. This implementation uses
5//! random projection LSH for cosine similarity.
6//!
7//! ## Features
8//!
9//! - **Random Projection LSH**: Hash functions based on random hyperplanes
10//! - **Multi-table Hashing**: Multiple hash tables for better recall
11//! - **Multi-probe Search**: Query nearby buckets to improve accuracy
12//! - **Configurable Parameters**: num_tables, num_bits, num_probes
13//!
14//! ## Algorithm
15//!
16//! 1. Generate random projection vectors (hyperplanes)
17//! 2. For each vector, compute hash by checking which side of hyperplanes it's on
18//! 3. Store vectors in hash buckets
19//! 4. At query time, hash the query and retrieve candidates from matching buckets
20//! 5. Optionally probe nearby buckets (flip hash bits) for better recall
21//! 6. Rank candidates by actual similarity
22//!
23//! ## Example
24//!
25//! ```rust
26//! use oxify_vector::lsh::{LshIndex, LshConfig};
27//! use std::collections::HashMap;
28//!
29//! # fn example() -> anyhow::Result<()> {
30//! // Create embeddings
31//! let mut embeddings = HashMap::new();
32//! for i in 0..1000 {
33//!     let vec = vec![i as f32 * 0.01, (i * 2) as f32 * 0.01, (i * 3) as f32 * 0.01];
34//!     embeddings.insert(format!("doc{}", i), vec);
35//! }
36//!
37//! // Build LSH index
38//! let config = LshConfig::default();
39//! let mut index = LshIndex::new(config);
40//! index.build(&embeddings)?;
41//!
42//! // Search
43//! let query = vec![0.5, 1.0, 1.5];
44//! let results = index.search(&query, 10)?;
45//! # Ok(())
46//! # }
47//! ```
48
49use anyhow::{anyhow, Result};
50use rand::rngs::StdRng;
51use rand::{Rng, SeedableRng};
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use tracing::{debug, info};
55
56use crate::simd::cosine_similarity_simd;
57use crate::types::SearchResult;
58
59/// LSH configuration
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct LshConfig {
62    /// Number of hash tables (more tables = better recall, more memory)
63    pub num_tables: usize,
64    /// Number of bits per hash (hash length)
65    pub num_bits: usize,
66    /// Number of probes (how many nearby buckets to check)
67    pub num_probes: usize,
68    /// Random seed for reproducibility
69    pub seed: u64,
70}
71
72impl Default for LshConfig {
73    fn default() -> Self {
74        Self {
75            num_tables: 10,
76            num_bits: 16,
77            num_probes: 3,
78            seed: 42,
79        }
80    }
81}
82
83impl LshConfig {
84    /// Create config optimized for high recall
85    pub fn high_recall() -> Self {
86        Self {
87            num_tables: 20,
88            num_bits: 20,
89            num_probes: 10,
90            seed: 42,
91        }
92    }
93
94    /// Create config optimized for speed (low memory/probes)
95    pub fn fast() -> Self {
96        Self {
97            num_tables: 5,
98            num_bits: 12,
99            num_probes: 1,
100            seed: 42,
101        }
102    }
103
104    /// Create config optimized for memory efficiency
105    pub fn memory_efficient() -> Self {
106        Self {
107            num_tables: 5,
108            num_bits: 10,
109            num_probes: 5,
110            seed: 42,
111        }
112    }
113}
114
115/// Hash value type (bit vector represented as u64)
116type HashValue = u64;
117
118/// A single hash table
119#[derive(Debug, Clone)]
120struct HashTable {
121    /// Random projection vectors (hyperplanes)
122    projections: Vec<Vec<f32>>,
123    /// Buckets: hash_value -> list of vector indices
124    buckets: HashMap<HashValue, Vec<usize>>,
125}
126
127impl HashTable {
128    fn new(num_bits: usize, dimensions: usize, rng: &mut impl Rng) -> Self {
129        // Generate random projection vectors
130        let projections: Vec<Vec<f32>> = (0..num_bits)
131            .map(|_| {
132                (0..dimensions)
133                    .map(|_| rng.random_range(-1.0..1.0))
134                    .collect()
135            })
136            .collect();
137
138        Self {
139            projections,
140            buckets: HashMap::new(),
141        }
142    }
143
144    /// Compute hash value for a vector
145    fn hash(&self, vector: &[f32]) -> HashValue {
146        let mut hash_val: HashValue = 0;
147
148        for (i, projection) in self.projections.iter().enumerate() {
149            // Dot product with projection vector
150            let dot: f32 = vector
151                .iter()
152                .zip(projection.iter())
153                .map(|(v, p)| v * p)
154                .sum();
155
156            // Set bit if dot product is positive
157            if dot > 0.0 {
158                hash_val |= 1u64 << i;
159            }
160        }
161
162        hash_val
163    }
164
165    /// Insert a vector index into the hash table
166    fn insert(&mut self, vector: &[f32], index: usize) {
167        let hash_val = self.hash(vector);
168        self.buckets.entry(hash_val).or_default().push(index);
169    }
170
171    /// Query the hash table and return candidate indices
172    fn query(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
173        let hash_val = self.hash(vector);
174        let mut candidates = Vec::new();
175
176        // Get exact matches
177        if let Some(bucket) = self.buckets.get(&hash_val) {
178            candidates.extend(bucket);
179        }
180
181        // Multi-probe: flip bits to probe nearby buckets
182        if num_probes > 1 {
183            for probe in 1..num_probes.min(self.projections.len()) {
184                // Flip the probe-th bit
185                let flipped_hash = hash_val ^ (1u64 << probe);
186                if let Some(bucket) = self.buckets.get(&flipped_hash) {
187                    candidates.extend(bucket);
188                }
189            }
190        }
191
192        candidates
193    }
194}
195
196/// LSH Index for approximate nearest neighbor search
197#[derive(Debug, Clone)]
198pub struct LshIndex {
199    config: LshConfig,
200    tables: Vec<HashTable>,
201    vectors: Vec<Vec<f32>>,
202    entity_ids: Vec<String>,
203    dimensions: usize,
204    is_built: bool,
205}
206
207impl LshIndex {
208    /// Create a new LSH index
209    pub fn new(config: LshConfig) -> Self {
210        info!(
211            "Initialized LSH index: num_tables={}, num_bits={}, num_probes={}",
212            config.num_tables, config.num_bits, config.num_probes
213        );
214
215        Self {
216            config,
217            tables: Vec::new(),
218            vectors: Vec::new(),
219            entity_ids: Vec::new(),
220            dimensions: 0,
221            is_built: false,
222        }
223    }
224
225    /// Build LSH index from embeddings
226    pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
227        if embeddings.is_empty() {
228            return Err(anyhow!("Cannot build index from empty embeddings"));
229        }
230
231        info!("Building LSH index for {} entities", embeddings.len());
232
233        // Get dimensions from first vector
234        self.dimensions = embeddings.values().next().unwrap().len();
235
236        // Validate all vectors have same dimension
237        for (id, vec) in embeddings {
238            if vec.len() != self.dimensions {
239                return Err(anyhow!(
240                    "Dimension mismatch for entity {}: expected {}, got {}",
241                    id,
242                    self.dimensions,
243                    vec.len()
244                ));
245            }
246        }
247
248        // Store vectors and entity IDs
249        self.vectors.clear();
250        self.entity_ids.clear();
251        for (id, vec) in embeddings {
252            self.vectors.push(vec.clone());
253            self.entity_ids.push(id.clone());
254        }
255
256        // Initialize random number generator with seed
257        let mut rng = StdRng::seed_from_u64(self.config.seed);
258
259        // Create hash tables
260        self.tables.clear();
261        for table_idx in 0..self.config.num_tables {
262            debug!(
263                "Building hash table {}/{}",
264                table_idx + 1,
265                self.config.num_tables
266            );
267
268            let mut table = HashTable::new(self.config.num_bits, self.dimensions, &mut rng);
269
270            // Insert all vectors into this table
271            for (idx, vector) in self.vectors.iter().enumerate() {
272                table.insert(vector, idx);
273            }
274
275            self.tables.push(table);
276        }
277
278        self.is_built = true;
279        info!("LSH index built successfully");
280        Ok(())
281    }
282
283    /// Search for k nearest neighbors
284    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
285        if !self.is_built {
286            return Err(anyhow!("Index not built. Call build() first"));
287        }
288
289        if query.len() != self.dimensions {
290            return Err(anyhow!(
291                "Query dimension mismatch: expected {}, got {}",
292                self.dimensions,
293                query.len()
294            ));
295        }
296
297        debug!("LSH search for k={}", k);
298
299        // Collect candidate indices from all tables
300        let mut candidate_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
301
302        for table in &self.tables {
303            let candidates = table.query(query, self.config.num_probes);
304            candidate_set.extend(candidates);
305        }
306
307        debug!("Found {} unique candidates", candidate_set.len());
308
309        // Compute actual similarities for candidates
310        let mut scored_candidates: Vec<(usize, f32)> = candidate_set
311            .into_iter()
312            .map(|idx| {
313                let similarity = cosine_similarity_simd(query, &self.vectors[idx]);
314                (idx, similarity)
315            })
316            .collect();
317
318        // Sort by similarity (descending)
319        scored_candidates
320            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
321
322        // Take top k and convert to SearchResult
323        let results: Vec<SearchResult> = scored_candidates
324            .into_iter()
325            .take(k)
326            .enumerate()
327            .map(|(rank, (idx, score))| SearchResult {
328                entity_id: self.entity_ids[idx].clone(),
329                score,
330                distance: 1.0 - score,
331                rank: rank + 1,
332            })
333            .collect();
334
335        debug!("Returning {} results", results.len());
336        Ok(results)
337    }
338
339    /// Get number of vectors in index
340    pub fn len(&self) -> usize {
341        self.vectors.len()
342    }
343
344    /// Check if index is empty
345    pub fn is_empty(&self) -> bool {
346        self.vectors.is_empty()
347    }
348
349    /// Get index statistics
350    pub fn stats(&self) -> LshStats {
351        let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
352        let avg_bucket_size: f32 = if total_buckets > 0 {
353            let total_entries: usize = self
354                .tables
355                .iter()
356                .flat_map(|t| t.buckets.values())
357                .map(|b| b.len())
358                .sum();
359            total_entries as f32 / total_buckets as f32
360        } else {
361            0.0
362        };
363
364        let max_bucket_size: usize = self
365            .tables
366            .iter()
367            .flat_map(|t| t.buckets.values())
368            .map(|b| b.len())
369            .max()
370            .unwrap_or(0);
371
372        LshStats {
373            num_vectors: self.vectors.len(),
374            num_tables: self.tables.len(),
375            num_bits: self.config.num_bits,
376            total_buckets,
377            avg_bucket_size,
378            max_bucket_size,
379            dimensions: self.dimensions,
380        }
381    }
382}
383
384/// LSH index statistics
385#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct LshStats {
387    /// Number of vectors in index
388    pub num_vectors: usize,
389    /// Number of hash tables
390    pub num_tables: usize,
391    /// Number of bits per hash
392    pub num_bits: usize,
393    /// Total number of buckets across all tables
394    pub total_buckets: usize,
395    /// Average bucket size
396    pub avg_bucket_size: f32,
397    /// Maximum bucket size
398    pub max_bucket_size: usize,
399    /// Vector dimensions
400    pub dimensions: usize,
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    fn create_test_embeddings(n: usize, dims: usize) -> HashMap<String, Vec<f32>> {
408        let mut embeddings = HashMap::new();
409        for i in 0..n {
410            let vec: Vec<f32> = (0..dims).map(|d| ((i * d) as f32 * 0.01).sin()).collect();
411            embeddings.insert(format!("doc{}", i), vec);
412        }
413        embeddings
414    }
415
416    #[test]
417    fn test_lsh_build() {
418        let embeddings = create_test_embeddings(100, 64);
419        let mut index = LshIndex::new(LshConfig::default());
420        assert!(index.build(&embeddings).is_ok());
421        assert_eq!(index.len(), 100);
422        assert!(index.is_built);
423    }
424
425    #[test]
426    fn test_lsh_search() {
427        let embeddings = create_test_embeddings(100, 64);
428        let mut index = LshIndex::new(LshConfig::default());
429        index.build(&embeddings).unwrap();
430
431        let query: Vec<f32> = (0..64).map(|d| (d as f32 * 0.01).sin()).collect();
432        let results = index.search(&query, 10).unwrap();
433
434        // LSH may not always find exactly k results due to probabilistic nature
435        assert!(!results.is_empty());
436        assert!(results.len() <= 10);
437        // Results should be sorted by score (descending)
438        if results.len() > 1 {
439            assert!(results[0].score >= results[results.len() - 1].score);
440        }
441    }
442
443    #[test]
444    fn test_lsh_empty_embeddings() {
445        let embeddings = HashMap::new();
446        let mut index = LshIndex::new(LshConfig::default());
447        assert!(index.build(&embeddings).is_err());
448    }
449
450    #[test]
451    fn test_lsh_dimension_mismatch() {
452        let mut embeddings = HashMap::new();
453        embeddings.insert("doc1".to_string(), vec![1.0, 2.0, 3.0]);
454        embeddings.insert("doc2".to_string(), vec![1.0, 2.0]); // Wrong dimension
455
456        let mut index = LshIndex::new(LshConfig::default());
457        assert!(index.build(&embeddings).is_err());
458    }
459
460    #[test]
461    fn test_lsh_search_before_build() {
462        let index = LshIndex::new(LshConfig::default());
463        let query = vec![1.0, 2.0, 3.0];
464        assert!(index.search(&query, 10).is_err());
465    }
466
467    #[test]
468    fn test_lsh_query_dimension_mismatch() {
469        let embeddings = create_test_embeddings(100, 64);
470        let mut index = LshIndex::new(LshConfig::default());
471        index.build(&embeddings).unwrap();
472
473        let wrong_query = vec![1.0, 2.0]; // Wrong dimension
474        assert!(index.search(&wrong_query, 10).is_err());
475    }
476
477    #[test]
478    fn test_lsh_stats() {
479        let embeddings = create_test_embeddings(100, 64);
480        let mut index = LshIndex::new(LshConfig::default());
481        index.build(&embeddings).unwrap();
482
483        let stats = index.stats();
484        assert_eq!(stats.num_vectors, 100);
485        assert_eq!(stats.num_tables, 10);
486        assert_eq!(stats.dimensions, 64);
487        assert!(stats.total_buckets > 0);
488        assert!(stats.avg_bucket_size > 0.0);
489    }
490
491    #[test]
492    fn test_lsh_config_presets() {
493        let high_recall = LshConfig::high_recall();
494        assert_eq!(high_recall.num_tables, 20);
495        assert_eq!(high_recall.num_probes, 10);
496
497        let fast = LshConfig::fast();
498        assert_eq!(fast.num_tables, 5);
499        assert_eq!(fast.num_probes, 1);
500
501        let memory = LshConfig::memory_efficient();
502        assert_eq!(memory.num_tables, 5);
503        assert_eq!(memory.num_bits, 10);
504    }
505
506    #[test]
507    fn test_hash_table_hash() {
508        let mut rng = StdRng::seed_from_u64(42);
509        let table = HashTable::new(8, 3, &mut rng);
510
511        let vec1 = vec![1.0, 2.0, 3.0];
512        let vec2 = vec![1.0, 2.0, 3.0];
513        let vec3 = vec![-1.0, -2.0, -3.0];
514
515        // Same vectors should have same hash
516        assert_eq!(table.hash(&vec1), table.hash(&vec2));
517
518        // Different vectors may have different hashes (not guaranteed but likely)
519        // This is a probabilistic test
520        let hash1 = table.hash(&vec1);
521        let hash3 = table.hash(&vec3);
522        // Opposite vectors should likely have different hashes
523        assert_ne!(hash1, hash3);
524    }
525
526    #[test]
527    fn test_multiprobe_increases_candidates() {
528        let embeddings = create_test_embeddings(50, 32);
529
530        // Build with 1 probe
531        let config_1probe = LshConfig {
532            num_tables: 5,
533            num_bits: 10,
534            num_probes: 1,
535            seed: 42,
536        };
537        let mut index_1probe = LshIndex::new(config_1probe);
538        index_1probe.build(&embeddings).unwrap();
539
540        // Build with 5 probes
541        let config_5probe = LshConfig {
542            num_tables: 5,
543            num_bits: 10,
544            num_probes: 5,
545            seed: 42,
546        };
547        let mut index_5probe = LshIndex::new(config_5probe);
548        index_5probe.build(&embeddings).unwrap();
549
550        let query: Vec<f32> = (0..32).map(|d| (d as f32 * 0.02).cos()).collect();
551
552        let results_1probe = index_1probe.search(&query, 20).unwrap();
553        let results_5probe = index_5probe.search(&query, 20).unwrap();
554
555        // More probes should generally find more candidates (and thus return more results if k is large)
556        // This is probabilistic but should hold in most cases
557        assert!(results_5probe.len() >= results_1probe.len());
558    }
559}