oxirs_vec/
lsh.rs

1//! Locality Sensitive Hashing (LSH) for approximate nearest neighbor search
2//!
3//! This module implements various LSH families including:
4//! - Random projection LSH for cosine similarity
5//! - MinHash for Jaccard similarity
6//! - SimHash for binary vectors
7//! - Multi-probe LSH for improved recall
8
9use crate::{Vector, VectorIndex};
10use anyhow::{anyhow, Result};
11use crate::random_utils::NormalSampler as Normal;
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::{HashMap, HashSet};
15
16/// Configuration for LSH index
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LshConfig {
19    /// Number of hash tables (L parameter)
20    pub num_tables: usize,
21    /// Number of hash functions per table (K parameter)
22    pub num_hash_functions: usize,
23    /// LSH family to use
24    pub lsh_family: LshFamily,
25    /// Random seed for reproducibility
26    pub seed: u64,
27    /// Enable multi-probe LSH
28    pub multi_probe: bool,
29    /// Number of probes for multi-probe LSH
30    pub num_probes: usize,
31}
32
33impl Default for LshConfig {
34    fn default() -> Self {
35        Self {
36            num_tables: 10,
37            num_hash_functions: 8,
38            lsh_family: LshFamily::RandomProjection,
39            seed: 42,
40            multi_probe: true,
41            num_probes: 3,
42        }
43    }
44}
45
46/// LSH family types
47#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
48pub enum LshFamily {
49    /// Random projection for cosine similarity
50    RandomProjection,
51    /// MinHash for Jaccard similarity
52    MinHash,
53    /// SimHash for binary similarity
54    SimHash,
55    /// P-stable distributions for Lp distance
56    PStable(f32), // p value
57}
58
59/// Hash function trait
60trait HashFunction: Send + Sync {
61    /// Compute hash value for a vector
62    fn hash(&self, vector: &[f32]) -> u64;
63
64    /// Compute multiple hash values
65    fn hash_multi(&self, vector: &[f32], num_hashes: usize) -> Vec<u64> {
66        (0..num_hashes).map(|_| self.hash(vector)).collect()
67    }
68}
69
70/// Random projection hash function for cosine similarity
71struct RandomProjectionHash {
72    projections: Vec<Vec<f32>>,
73    dimensions: usize,
74}
75
76impl RandomProjectionHash {
77    fn new(dimensions: usize, num_projections: usize, seed: u64) -> Self {
78        let mut rng = Random::seed(seed);
79        let normal = Normal::new(0.0, 1.0).unwrap();
80
81        let mut projections = Vec::with_capacity(num_projections);
82        for _ in 0..num_projections {
83            let projection: Vec<f32> = (0..dimensions).map(|_| normal.sample(&mut rng)).collect();
84            projections.push(projection);
85        }
86
87        Self {
88            projections,
89            dimensions,
90        }
91    }
92}
93
94impl HashFunction for RandomProjectionHash {
95    fn hash(&self, vector: &[f32]) -> u64 {
96        let mut hash_value = 0u64;
97
98        for (i, projection) in self.projections.iter().enumerate() {
99            // Compute dot product
100            use oxirs_core::simd::SimdOps;
101            let dot_product = f32::dot(vector, projection);
102
103            // Set bit if positive
104            if dot_product > 0.0 {
105                hash_value |= 1 << (i % 64);
106            }
107        }
108
109        hash_value
110    }
111}
112
113/// MinHash for Jaccard similarity
114struct MinHashFunction {
115    a: Vec<u64>,
116    b: Vec<u64>,
117    prime: u64,
118}
119
120impl MinHashFunction {
121    fn new(num_hashes: usize, seed: u64) -> Self {
122        let mut rng = Random::seed(seed);
123        let prime = 4294967311u64; // Large prime
124
125        let a: Vec<u64> = (0..num_hashes).map(|_| rng.gen_range(1..prime)).collect();
126        let b: Vec<u64> = (0..num_hashes).map(|_| rng.gen_range(0..prime)).collect();
127
128        Self { a, b, prime }
129    }
130
131    fn minhash_signature(&self, set_elements: &[u32]) -> Vec<u64> {
132        let mut signature = vec![u64::MAX; self.a.len()];
133
134        for &element in set_elements {
135            for (i, sig_val) in signature.iter_mut().enumerate().take(self.a.len()) {
136                let hash = (self.a[i] * element as u64 + self.b[i]) % self.prime;
137                *sig_val = (*sig_val).min(hash);
138            }
139        }
140
141        signature
142    }
143}
144
145impl HashFunction for MinHashFunction {
146    fn hash(&self, vector: &[f32]) -> u64 {
147        // Convert vector to set of indices where value > threshold
148        let threshold = 0.0;
149        let set_elements: Vec<u32> = vector
150            .iter()
151            .enumerate()
152            .filter(|&(_, &v)| v > threshold)
153            .map(|(i, _)| i as u32)
154            .collect();
155
156        let signature = self.minhash_signature(&set_elements);
157
158        // Combine signature into single hash
159        let mut hash = 0u64;
160        for (i, &sig) in signature.iter().enumerate() {
161            hash ^= sig.rotate_left((i * 7) as u32);
162        }
163
164        hash
165    }
166}
167
168/// SimHash for binary similarity
169struct SimHashFunction {
170    random_vectors: Vec<Vec<f32>>,
171}
172
173impl SimHashFunction {
174    fn new(dimensions: usize, seed: u64) -> Self {
175        let mut rng = Random::seed(seed);
176        let normal = Normal::new(0.0, 1.0).unwrap();
177
178        let random_vectors: Vec<Vec<f32>> = (0..64)
179            .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
180            .collect();
181
182        Self { random_vectors }
183    }
184}
185
186impl HashFunction for SimHashFunction {
187    fn hash(&self, vector: &[f32]) -> u64 {
188        let mut hash = 0u64;
189
190        for (i, random_vec) in self.random_vectors.iter().enumerate() {
191            // Weighted sum
192            let mut sum = 0.0;
193            for (j, &v) in vector.iter().enumerate() {
194                if j < random_vec.len() {
195                    sum += v * random_vec[j];
196                }
197            }
198
199            if sum > 0.0 {
200                hash |= 1 << i;
201            }
202        }
203
204        hash
205    }
206}
207
208/// P-stable LSH for Lp distance
209struct PStableHash {
210    projections: Vec<Vec<f32>>,
211    offsets: Vec<f32>,
212    width: f32,
213    p: f32,
214}
215
216impl PStableHash {
217    fn new(dimensions: usize, num_projections: usize, width: f32, p: f32, seed: u64) -> Self {
218        let mut rng = Random::seed(seed);
219
220        // Use Cauchy distribution for L1, Normal for L2
221        let projections: Vec<Vec<f32>> = if (p - 1.0).abs() < 0.1 {
222            // L1 distance - use Cauchy distribution
223            (0..num_projections)
224                .map(|_| {
225                    (0..dimensions)
226                        .map(|_| {
227                            let u: f32 = rng
228                                .gen_range(-std::f32::consts::PI / 2.0..std::f32::consts::PI / 2.0);
229                            u.tan()
230                        })
231                        .collect()
232                })
233                .collect()
234        } else if (p - 2.0).abs() < 0.1 {
235            // L2 distance - use Normal distribution
236            let normal = Normal::new(0.0, 1.0).unwrap();
237            (0..num_projections)
238                .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
239                .collect()
240        } else {
241            // General case - approximate with Normal
242            let normal = Normal::new(0.0, 1.0).unwrap();
243            (0..num_projections)
244                .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
245                .collect()
246        };
247
248        let offsets: Vec<f32> = (0..num_projections)
249            .map(|_| rng.gen_range(0.0..width))
250            .collect();
251
252        Self {
253            projections,
254            offsets,
255            width,
256            p,
257        }
258    }
259}
260
261impl HashFunction for PStableHash {
262    fn hash(&self, vector: &[f32]) -> u64 {
263        let mut hash = 0u64;
264
265        for (i, (projection, &offset)) in self.projections.iter().zip(&self.offsets).enumerate() {
266            use oxirs_core::simd::SimdOps;
267            let dot_product = f32::dot(vector, projection);
268            let bucket = ((dot_product + offset) / self.width).floor() as i32;
269
270            // Map bucket to bit position
271            if bucket > 0 {
272                hash |= 1 << (i % 64);
273            }
274        }
275
276        hash
277    }
278}
279
280/// LSH table storing hash buckets
281struct LshTable {
282    buckets: HashMap<u64, Vec<usize>>,
283    hash_function: Box<dyn HashFunction>,
284}
285
286impl LshTable {
287    fn new(hash_function: Box<dyn HashFunction>) -> Self {
288        Self {
289            buckets: HashMap::new(),
290            hash_function,
291        }
292    }
293
294    fn insert(&mut self, id: usize, vector: &[f32]) {
295        let hash = self.hash_function.hash(vector);
296        self.buckets.entry(hash).or_default().push(id);
297    }
298
299    fn query(&self, vector: &[f32]) -> Vec<usize> {
300        let hash = self.hash_function.hash(vector);
301        self.buckets.get(&hash).cloned().unwrap_or_default()
302    }
303
304    fn query_multi_probe(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
305        let main_hash = self.hash_function.hash(vector);
306        let mut candidates = HashSet::new();
307
308        // Add exact match
309        if let Some(ids) = self.buckets.get(&main_hash) {
310            candidates.extend(ids);
311        }
312
313        // Probe nearby buckets by flipping bits
314        for probe in 1..=num_probes {
315            for bit in 0..64 {
316                let probed_hash = main_hash ^ (1 << bit);
317                if let Some(ids) = self.buckets.get(&probed_hash) {
318                    candidates.extend(ids);
319                }
320
321                // Stop if we have enough probes
322                if candidates.len() >= probe * 10 {
323                    break;
324                }
325            }
326        }
327
328        candidates.into_iter().collect()
329    }
330}
331
332/// LSH index implementation
333pub struct LshIndex {
334    config: LshConfig,
335    tables: Vec<LshTable>,
336    vectors: Vec<(String, Vector)>,
337    uri_to_id: HashMap<String, usize>,
338    dimensions: Option<usize>,
339}
340
341impl LshIndex {
342    /// Create a new LSH index
343    pub fn new(config: LshConfig) -> Self {
344        let tables = Self::create_tables(&config, 0);
345
346        Self {
347            config,
348            tables,
349            vectors: Vec::new(),
350            uri_to_id: HashMap::new(),
351            dimensions: None,
352        }
353    }
354
355    /// Create hash tables based on configuration
356    fn create_tables(config: &LshConfig, dimensions: usize) -> Vec<LshTable> {
357        let mut tables = Vec::with_capacity(config.num_tables);
358
359        for table_idx in 0..config.num_tables {
360            let seed = config.seed.wrapping_add(table_idx as u64);
361
362            let hash_function: Box<dyn HashFunction> = match config.lsh_family {
363                LshFamily::RandomProjection => Box::new(RandomProjectionHash::new(
364                    dimensions,
365                    config.num_hash_functions,
366                    seed,
367                )),
368                LshFamily::MinHash => {
369                    Box::new(MinHashFunction::new(config.num_hash_functions, seed))
370                }
371                LshFamily::SimHash => Box::new(SimHashFunction::new(dimensions, seed)),
372                LshFamily::PStable(p) => {
373                    Box::new(PStableHash::new(
374                        dimensions,
375                        config.num_hash_functions,
376                        1.0, // Default width
377                        p,
378                        seed,
379                    ))
380                }
381            };
382
383            tables.push(LshTable::new(hash_function));
384        }
385
386        tables
387    }
388
389    /// Rebuild tables with known dimensions
390    fn rebuild_tables(&mut self) {
391        if let Some(dims) = self.dimensions {
392            self.tables = Self::create_tables(&self.config, dims);
393
394            // Re-insert all vectors
395            for (id, (_, vector)) in self.vectors.iter().enumerate() {
396                let vector_f32 = vector.as_f32();
397                for table in &mut self.tables {
398                    table.insert(id, &vector_f32);
399                }
400            }
401        }
402    }
403
404    /// Query for approximate nearest neighbors
405    fn query_candidates(&self, vector: &[f32]) -> Vec<usize> {
406        let mut candidates = HashSet::new();
407
408        if self.config.multi_probe {
409            // Multi-probe LSH
410            for table in &self.tables {
411                let table_candidates = table.query_multi_probe(vector, self.config.num_probes);
412                candidates.extend(table_candidates);
413            }
414        } else {
415            // Standard LSH
416            for table in &self.tables {
417                let table_candidates = table.query(vector);
418                candidates.extend(table_candidates);
419            }
420        }
421
422        candidates.into_iter().collect()
423    }
424
425    /// Get statistics about the index
426    pub fn stats(&self) -> LshStats {
427        let avg_bucket_size = if self.tables.is_empty() {
428            0.0
429        } else {
430            let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
431            let total_items: usize = self
432                .tables
433                .iter()
434                .map(|t| t.buckets.values().map(|v| v.len()).sum::<usize>())
435                .sum();
436
437            if total_buckets > 0 {
438                total_items as f64 / total_buckets as f64
439            } else {
440                0.0
441            }
442        };
443
444        LshStats {
445            num_vectors: self.vectors.len(),
446            num_tables: self.tables.len(),
447            avg_bucket_size,
448            memory_usage: self.estimate_memory_usage(),
449        }
450    }
451
452    fn estimate_memory_usage(&self) -> usize {
453        let vector_memory =
454            self.vectors.len() * (std::mem::size_of::<String>() + std::mem::size_of::<Vector>());
455
456        let table_memory: usize = self
457            .tables
458            .iter()
459            .map(|t| {
460                t.buckets.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<Vec<usize>>())
461                    + t.buckets
462                        .values()
463                        .map(|v| v.len() * std::mem::size_of::<usize>())
464                        .sum::<usize>()
465            })
466            .sum();
467
468        vector_memory + table_memory
469    }
470}
471
472impl VectorIndex for LshIndex {
473    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
474        // Initialize dimensions if needed
475        if self.dimensions.is_none() {
476            self.dimensions = Some(vector.dimensions);
477            self.rebuild_tables();
478        } else if Some(vector.dimensions) != self.dimensions {
479            return Err(anyhow!(
480                "Vector dimensions ({}) don't match index dimensions ({:?})",
481                vector.dimensions,
482                self.dimensions
483            ));
484        }
485
486        let id = self.vectors.len();
487        let vector_f32 = vector.as_f32();
488
489        // Insert into all tables
490        for table in &mut self.tables {
491            table.insert(id, &vector_f32);
492        }
493
494        self.uri_to_id.insert(uri.clone(), id);
495        self.vectors.push((uri, vector));
496
497        Ok(())
498    }
499
500    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
501        if self.vectors.is_empty() {
502            return Ok(Vec::new());
503        }
504
505        let query_f32 = query.as_f32();
506        let candidates = self.query_candidates(&query_f32);
507
508        // Compute exact distances for candidates
509        let mut results: Vec<(usize, f32)> = candidates
510            .into_iter()
511            .filter_map(|id| {
512                self.vectors.get(id).map(|(_, vec)| {
513                    let vec_f32 = vec.as_f32();
514                    let distance = match self.config.lsh_family {
515                        LshFamily::RandomProjection | LshFamily::SimHash => {
516                            // Cosine distance
517                            use oxirs_core::simd::SimdOps;
518                            f32::cosine_distance(&query_f32, &vec_f32)
519                        }
520                        LshFamily::MinHash => {
521                            // Jaccard distance
522                            let threshold = 0.0;
523                            let set1: HashSet<usize> = query_f32
524                                .iter()
525                                .enumerate()
526                                .filter(|&(_, &v)| v > threshold)
527                                .map(|(i, _)| i)
528                                .collect();
529                            let set2: HashSet<usize> = vec_f32
530                                .iter()
531                                .enumerate()
532                                .filter(|&(_, &v)| v > threshold)
533                                .map(|(i, _)| i)
534                                .collect();
535
536                            let intersection = set1.intersection(&set2).count();
537                            let union = set1.union(&set2).count();
538
539                            if union > 0 {
540                                1.0 - (intersection as f32 / union as f32)
541                            } else {
542                                1.0
543                            }
544                        }
545                        LshFamily::PStable(p) => {
546                            // Lp distance
547                            use oxirs_core::simd::SimdOps;
548                            if (p - 1.0).abs() < 0.1 {
549                                f32::manhattan_distance(&query_f32, &vec_f32)
550                            } else if (p - 2.0).abs() < 0.1 {
551                                f32::euclidean_distance(&query_f32, &vec_f32)
552                            } else {
553                                // General Minkowski distance
554                                query_f32
555                                    .iter()
556                                    .zip(&vec_f32)
557                                    .map(|(a, b)| (a - b).abs().powf(p))
558                                    .sum::<f32>()
559                                    .powf(1.0 / p)
560                            }
561                        }
562                    };
563                    (id, distance)
564                })
565            })
566            .collect();
567
568        // Sort by distance and take top k
569        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
570        results.truncate(k);
571
572        // Convert to final result format
573        Ok(results
574            .into_iter()
575            .map(|(id, dist)| (self.vectors[id].0.clone(), dist))
576            .collect())
577    }
578
579    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
580        if self.vectors.is_empty() {
581            return Ok(Vec::new());
582        }
583
584        let query_f32 = query.as_f32();
585        let candidates = self.query_candidates(&query_f32);
586
587        // Filter candidates by threshold
588        let mut results: Vec<(String, f32)> = candidates
589            .into_iter()
590            .filter_map(|id| {
591                self.vectors.get(id).and_then(|(uri, vec)| {
592                    let vec_f32 = vec.as_f32();
593                    let distance = match self.config.lsh_family {
594                        LshFamily::RandomProjection | LshFamily::SimHash => {
595                            use oxirs_core::simd::SimdOps;
596                            f32::cosine_distance(&query_f32, &vec_f32)
597                        }
598                        LshFamily::MinHash => {
599                            // Jaccard distance
600                            let threshold_val = 0.0;
601                            let set1: HashSet<usize> = query_f32
602                                .iter()
603                                .enumerate()
604                                .filter(|&(_, &v)| v > threshold_val)
605                                .map(|(i, _)| i)
606                                .collect();
607                            let set2: HashSet<usize> = vec_f32
608                                .iter()
609                                .enumerate()
610                                .filter(|&(_, &v)| v > threshold_val)
611                                .map(|(i, _)| i)
612                                .collect();
613
614                            let intersection = set1.intersection(&set2).count();
615                            let union = set1.union(&set2).count();
616
617                            if union > 0 {
618                                1.0 - (intersection as f32 / union as f32)
619                            } else {
620                                1.0
621                            }
622                        }
623                        LshFamily::PStable(p) => {
624                            use oxirs_core::simd::SimdOps;
625                            if (p - 1.0).abs() < 0.1 {
626                                f32::manhattan_distance(&query_f32, &vec_f32)
627                            } else if (p - 2.0).abs() < 0.1 {
628                                f32::euclidean_distance(&query_f32, &vec_f32)
629                            } else {
630                                query_f32
631                                    .iter()
632                                    .zip(&vec_f32)
633                                    .map(|(a, b)| (a - b).abs().powf(p))
634                                    .sum::<f32>()
635                                    .powf(1.0 / p)
636                            }
637                        }
638                    };
639
640                    if distance <= threshold {
641                        Some((uri.clone(), distance))
642                    } else {
643                        None
644                    }
645                })
646            })
647            .collect();
648
649        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
650        Ok(results)
651    }
652
653    fn get_vector(&self, uri: &str) -> Option<&Vector> {
654        self.uri_to_id
655            .get(uri)
656            .and_then(|&id| self.vectors.get(id))
657            .map(|(_, v)| v)
658    }
659}
660
661/// LSH index statistics
662#[derive(Debug, Clone)]
663pub struct LshStats {
664    pub num_vectors: usize,
665    pub num_tables: usize,
666    pub avg_bucket_size: f64,
667    pub memory_usage: usize,
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[test]
675    fn test_random_projection_lsh() {
676        let config = LshConfig {
677            num_tables: 5,
678            num_hash_functions: 4,
679            lsh_family: LshFamily::RandomProjection,
680            seed: 42,
681            multi_probe: false,
682            num_probes: 0,
683        };
684
685        let mut index = LshIndex::new(config);
686
687        // Insert vectors
688        let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
689        let v2 = Vector::new(vec![0.0, 1.0, 0.0]);
690        let v3 = Vector::new(vec![0.0, 0.0, 1.0]);
691        let v_similar = Vector::new(vec![0.9, 0.1, 0.0]); // Similar to v1
692
693        index.insert("v1".to_string(), v1.clone()).unwrap();
694        index.insert("v2".to_string(), v2.clone()).unwrap();
695        index.insert("v3".to_string(), v3.clone()).unwrap();
696        index
697            .insert("v_similar".to_string(), v_similar.clone())
698            .unwrap();
699
700        // Search for similar vectors
701        let results = index.search_knn(&v1, 2).unwrap();
702
703        assert!(results.len() <= 2);
704        // v1 and v_similar should be the closest
705        assert!(results
706            .iter()
707            .any(|(uri, _)| uri == "v1" || uri == "v_similar"));
708    }
709
710    #[test]
711    fn test_minhash_lsh() {
712        let config = LshConfig {
713            num_tables: 3,
714            num_hash_functions: 64,
715            lsh_family: LshFamily::MinHash,
716            seed: 42,
717            multi_probe: false,
718            num_probes: 0,
719        };
720
721        let mut index = LshIndex::new(config);
722
723        // Create sparse binary vectors
724        let mut v1 = vec![0.0; 100];
725        v1[0] = 1.0;
726        v1[10] = 1.0;
727        v1[20] = 1.0;
728
729        let mut v2 = vec![0.0; 100];
730        v2[0] = 1.0;
731        v2[10] = 1.0;
732        v2[30] = 1.0; // 2/4 overlap with v1
733
734        let mut v3 = vec![0.0; 100];
735        v3[50] = 1.0;
736        v3[60] = 1.0;
737        v3[70] = 1.0; // No overlap with v1
738
739        index
740            .insert("v1".to_string(), Vector::new(v1.clone()))
741            .unwrap();
742        index.insert("v2".to_string(), Vector::new(v2)).unwrap();
743        index.insert("v3".to_string(), Vector::new(v3)).unwrap();
744
745        // Search for similar vectors
746        let results = index.search_knn(&Vector::new(v1), 2).unwrap();
747
748        // v1 should be first, v2 should be second (due to overlap)
749        assert!(!results.is_empty());
750        assert_eq!(results[0].0, "v1");
751        if results.len() > 1 {
752            assert_eq!(results[1].0, "v2");
753        }
754    }
755
756    #[test]
757    fn test_multi_probe_lsh() {
758        let config = LshConfig {
759            num_tables: 3,
760            num_hash_functions: 4,
761            lsh_family: LshFamily::RandomProjection,
762            seed: 42,
763            multi_probe: true,
764            num_probes: 2,
765        };
766
767        let mut index = LshIndex::new(config);
768
769        // Insert many vectors
770        for i in 0..50 {
771            let angle = i as f32 * std::f32::consts::PI / 25.0;
772            let vec = Vector::new(vec![angle.cos(), angle.sin(), 0.0]);
773            index.insert(format!("v{i}"), vec).unwrap();
774        }
775
776        // Search with multi-probe should find more candidates
777        let query = Vector::new(vec![1.0, 0.0, 0.0]);
778        let results = index.search_knn(&query, 5).unwrap();
779
780        assert_eq!(results.len(), 5);
781        // Results should be ordered by distance
782        for i in 1..results.len() {
783            assert!(results[i - 1].1 <= results[i].1);
784        }
785    }
786}