Skip to main content

oxirs_vec/
lsh_index.rs

1//! Locality-Sensitive Hashing (LSH) index for approximate nearest neighbour search.
2//!
3//! Uses random hyperplane projection with multiple hash tables to efficiently
4//! find approximate nearest neighbours in high-dimensional vector spaces.
5//! All randomness comes from a deterministic XorShift-64 PRNG.
6
7use std::collections::HashMap;
8
9// -------------------------------------------------------------------------
10// XorShift-64 deterministic PRNG
11// -------------------------------------------------------------------------
12
13/// Deterministic XorShift-64 pseudo-random number generator.
14struct XorShift64 {
15    state: u64,
16}
17
18impl XorShift64 {
19    fn new(seed: u64) -> Self {
20        // Ensure the seed is non-zero
21        Self {
22            state: if seed == 0 { 1 } else { seed },
23        }
24    }
25
26    /// Generate the next pseudo-random u64.
27    fn next(&mut self) -> u64 {
28        let mut x = self.state;
29        x ^= x << 13;
30        x ^= x >> 7;
31        x ^= x << 17;
32        self.state = x;
33        x
34    }
35
36    /// Generate a pseudo-random f64 in [-1.0, 1.0).
37    fn next_f64_signed(&mut self) -> f64 {
38        let bits = self.next();
39        // Map to [0.0, 1.0) then scale to [-1.0, 1.0)
40        let pos = (bits as f64) / (u64::MAX as f64);
41        pos * 2.0 - 1.0
42    }
43}
44
45// -------------------------------------------------------------------------
46// LshHasher
47// -------------------------------------------------------------------------
48
49/// A single LSH hasher using random hyperplane projection.
50///
51/// Each bit of the hash corresponds to one random hyperplane: a bit is 1 if
52/// the dot product with the random vector is non-negative, 0 otherwise.
53#[derive(Debug, Clone)]
54pub struct LshHasher {
55    /// Unit random vectors (one per hash bit).
56    pub random_vectors: Vec<Vec<f64>>,
57    /// Dimensionality of the input vectors.
58    pub dim: usize,
59}
60
61impl LshHasher {
62    /// Create a new hasher with `num_hashes` random unit vectors of dimension `dim`.
63    ///
64    /// Uses the provided XorShift state for deterministic generation.
65    fn new_with_rng(dim: usize, num_hashes: usize, rng: &mut XorShift64) -> Self {
66        let mut random_vectors = Vec::with_capacity(num_hashes);
67        for _ in 0..num_hashes {
68            let mut v: Vec<f64> = (0..dim).map(|_| rng.next_f64_signed()).collect();
69            normalize_vec(&mut v);
70            random_vectors.push(v);
71        }
72        Self {
73            random_vectors,
74            dim,
75        }
76    }
77
78    /// Hash a vector to a `u64` using sign bits of dot products.
79    pub fn hash(&self, v: &[f64]) -> u64 {
80        let mut h: u64 = 0;
81        for (bit, rv) in self.random_vectors.iter().enumerate() {
82            if bit >= 64 {
83                break;
84            }
85            let dot: f64 = v.iter().zip(rv.iter()).map(|(a, b)| a * b).sum();
86            if dot >= 0.0 {
87                h |= 1u64 << bit;
88            }
89        }
90        h
91    }
92}
93
94/// A hash table bucket mapping hash values to vector indices.
95pub type LshBucket = HashMap<u64, Vec<usize>>;
96
97// -------------------------------------------------------------------------
98// LshIndex
99// -------------------------------------------------------------------------
100
101/// Approximate nearest-neighbour index using Locality-Sensitive Hashing.
102pub struct LshIndex {
103    /// All inserted vectors (indexed by position in this Vec).
104    pub vectors: Vec<Vec<f64>>,
105    /// One bucket table per LSH table.
106    pub buckets: Vec<LshBucket>,
107    /// One hasher per LSH table.
108    pub hashers: Vec<LshHasher>,
109    /// Dimensionality of the vectors.
110    pub dim: usize,
111    /// Number of hash tables.
112    pub num_tables: usize,
113    /// Number of hash bits per table.
114    pub num_hashes: usize,
115}
116
117impl LshIndex {
118    /// Create a new LSH index.
119    ///
120    /// * `dim` — vector dimensionality
121    /// * `num_tables` — number of independent hash tables (more → higher recall)
122    /// * `num_hashes` — number of hash bits per table (more → higher precision)
123    /// * `seed` — XorShift-64 seed for reproducibility
124    pub fn new(dim: usize, num_tables: usize, num_hashes: usize, seed: u64) -> Self {
125        let mut rng = XorShift64::new(seed);
126        let mut hashers = Vec::with_capacity(num_tables);
127        let mut buckets = Vec::with_capacity(num_tables);
128        for _ in 0..num_tables {
129            hashers.push(LshHasher::new_with_rng(dim, num_hashes, &mut rng));
130            buckets.push(LshBucket::new());
131        }
132        Self {
133            vectors: Vec::new(),
134            buckets,
135            hashers,
136            dim,
137            num_tables,
138            num_hashes,
139        }
140    }
141
142    /// Insert a vector with the given id into all hash tables.
143    pub fn insert(&mut self, id: usize, vector: &[f64]) {
144        // Ensure storage is large enough
145        while self.vectors.len() <= id {
146            self.vectors.push(vec![]);
147        }
148        self.vectors[id] = vector.to_vec();
149
150        for (table_idx, hasher) in self.hashers.iter().enumerate() {
151            let h = hasher.hash(vector);
152            self.buckets[table_idx].entry(h).or_default().push(id);
153        }
154    }
155
156    /// Compute cosine similarity between two vectors.
157    ///
158    /// Returns 0.0 if either vector has zero magnitude.
159    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
160        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161        let mag_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
162        let mag_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
163        if mag_a < f64::EPSILON || mag_b < f64::EPSILON {
164            return 0.0;
165        }
166        dot / (mag_a * mag_b)
167    }
168
169    /// Search for the top-k approximate nearest neighbours of `query`.
170    ///
171    /// Collects candidates from all tables, deduplicates them, computes cosine
172    /// similarity, and returns the top-k sorted by descending similarity.
173    /// If the query itself was indexed (same vector), it may appear in results.
174    pub fn search(&self, query: &[f64], k: usize) -> Vec<(usize, f64)> {
175        let mut candidate_set = std::collections::HashSet::new();
176
177        for (table_idx, hasher) in self.hashers.iter().enumerate() {
178            let h = hasher.hash(query);
179            if let Some(ids) = self.buckets[table_idx].get(&h) {
180                for &id in ids {
181                    candidate_set.insert(id);
182                }
183            }
184        }
185
186        let mut scored: Vec<(usize, f64)> = candidate_set
187            .into_iter()
188            .filter_map(|id| {
189                let v = self.vectors.get(id)?;
190                if v.is_empty() {
191                    return None;
192                }
193                Some((id, Self::cosine_similarity(query, v)))
194            })
195            .collect();
196
197        // Sort by similarity descending, then by id ascending for determinism
198        scored.sort_by(|a, b| {
199            b.1.partial_cmp(&a.1)
200                .unwrap_or(std::cmp::Ordering::Equal)
201                .then_with(|| a.0.cmp(&b.0))
202        });
203
204        scored.truncate(k);
205        scored
206    }
207
208    /// Return the number of indexed vectors.
209    pub fn len(&self) -> usize {
210        self.vectors.iter().filter(|v| !v.is_empty()).count()
211    }
212
213    /// Return true if no vectors have been indexed.
214    pub fn is_empty(&self) -> bool {
215        self.len() == 0
216    }
217
218    /// Remove all indexed vectors and clear all hash tables.
219    pub fn clear(&mut self) {
220        self.vectors.clear();
221        for bucket in &mut self.buckets {
222            bucket.clear();
223        }
224    }
225}
226
227// -------------------------------------------------------------------------
228// Helper: normalize a vector to unit length
229// -------------------------------------------------------------------------
230
231fn normalize_vec(v: &mut [f64]) {
232    let mag: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
233    if mag > f64::EPSILON {
234        for x in v.iter_mut() {
235            *x /= mag;
236        }
237    }
238}
239
240// -------------------------------------------------------------------------
241// Tests
242// -------------------------------------------------------------------------
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    fn unit_vec(dim: usize, axis: usize) -> Vec<f64> {
249        let mut v = vec![0.0_f64; dim];
250        v[axis] = 1.0;
251        v
252    }
253
254    fn new_index() -> LshIndex {
255        LshIndex::new(4, 4, 8, 42)
256    }
257
258    // ------ XorShift64 ------
259
260    #[test]
261    fn test_xorshift64_deterministic() {
262        let mut rng1 = XorShift64::new(123);
263        let mut rng2 = XorShift64::new(123);
264        for _ in 0..100 {
265            assert_eq!(rng1.next(), rng2.next());
266        }
267    }
268
269    #[test]
270    fn test_xorshift64_nonzero_seed() {
271        let mut rng = XorShift64::new(0); // Should be initialised to 1 internally
272        let v = rng.next();
273        assert_ne!(v, 0);
274    }
275
276    #[test]
277    fn test_xorshift64_different_seeds() {
278        let mut rng1 = XorShift64::new(1);
279        let mut rng2 = XorShift64::new(2);
280        let v1 = rng1.next();
281        let v2 = rng2.next();
282        assert_ne!(v1, v2);
283    }
284
285    // ------ normalize_vec ------
286
287    #[test]
288    fn test_normalize_vec_unit_length() {
289        let mut v = vec![3.0_f64, 4.0_f64];
290        normalize_vec(&mut v);
291        let mag: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
292        assert!((mag - 1.0).abs() < 1e-9);
293    }
294
295    #[test]
296    fn test_normalize_zero_vec_safe() {
297        let mut v = vec![0.0_f64; 4];
298        normalize_vec(&mut v); // should not panic
299    }
300
301    // ------ LshHasher ------
302
303    #[test]
304    fn test_hasher_deterministic() {
305        let mut rng = XorShift64::new(42);
306        let h1 = LshHasher::new_with_rng(4, 8, &mut rng);
307        let v = vec![1.0_f64, 0.0, 0.0, 0.0];
308        let hash1 = h1.hash(&v);
309
310        // Same seed → same hash
311        let mut rng2 = XorShift64::new(42);
312        let h2 = LshHasher::new_with_rng(4, 8, &mut rng2);
313        let hash2 = h2.hash(&v);
314
315        assert_eq!(hash1, hash2);
316    }
317
318    #[test]
319    fn test_hasher_similar_vectors_same_bucket() {
320        let mut rng = XorShift64::new(42);
321        let h = LshHasher::new_with_rng(4, 4, &mut rng);
322        let v1 = vec![1.0_f64, 0.001, 0.001, 0.001];
323        let v2 = vec![1.0_f64, 0.001, 0.001, 0.002];
324        // Very similar vectors — likely same bucket (not guaranteed but often true)
325        let hash1 = h.hash(&v1);
326        let hash2 = h.hash(&v2);
327        // We can't assert equality (probabilistic), but we can assert both are u64
328        let _ = (hash1, hash2);
329    }
330
331    #[test]
332    fn test_hasher_opposite_vectors_different_bits() {
333        let mut rng = XorShift64::new(99);
334        let h = LshHasher::new_with_rng(4, 8, &mut rng);
335        let v = vec![1.0_f64, 0.0, 0.0, 0.0];
336        let neg_v = vec![-1.0_f64, 0.0, 0.0, 0.0];
337        let h1 = h.hash(&v);
338        let h2 = h.hash(&neg_v);
339        // Opposite vectors should hash differently
340        assert_ne!(h1, h2);
341    }
342
343    // ------ cosine_similarity ------
344
345    #[test]
346    fn test_cosine_identical_vectors() {
347        let v = vec![1.0_f64, 2.0, 3.0];
348        let sim = LshIndex::cosine_similarity(&v, &v);
349        assert!((sim - 1.0).abs() < 1e-9);
350    }
351
352    #[test]
353    fn test_cosine_orthogonal_vectors() {
354        let v1 = vec![1.0_f64, 0.0, 0.0];
355        let v2 = vec![0.0_f64, 1.0, 0.0];
356        let sim = LshIndex::cosine_similarity(&v1, &v2);
357        assert!(sim.abs() < 1e-9);
358    }
359
360    #[test]
361    fn test_cosine_opposite_vectors() {
362        let v1 = vec![1.0_f64, 0.0];
363        let v2 = vec![-1.0_f64, 0.0];
364        let sim = LshIndex::cosine_similarity(&v1, &v2);
365        assert!((sim + 1.0).abs() < 1e-9);
366    }
367
368    #[test]
369    fn test_cosine_zero_vector() {
370        let v1 = vec![0.0_f64, 0.0];
371        let v2 = vec![1.0_f64, 0.0];
372        let sim = LshIndex::cosine_similarity(&v1, &v2);
373        assert!((sim).abs() < 1e-9);
374    }
375
376    // ------ LshIndex construction ------
377
378    #[test]
379    fn test_index_new_dimensions() {
380        let idx = LshIndex::new(8, 4, 16, 1);
381        assert_eq!(idx.dim, 8);
382        assert_eq!(idx.num_tables, 4);
383        assert_eq!(idx.num_hashes, 16);
384        assert_eq!(idx.hashers.len(), 4);
385        assert_eq!(idx.buckets.len(), 4);
386    }
387
388    #[test]
389    fn test_index_empty() {
390        let idx = new_index();
391        assert!(idx.is_empty());
392        assert_eq!(idx.len(), 0);
393    }
394
395    // ------ insert / len ------
396
397    #[test]
398    fn test_insert_single_vector() {
399        let mut idx = new_index();
400        idx.insert(0, &[1.0, 0.0, 0.0, 0.0]);
401        assert_eq!(idx.len(), 1);
402    }
403
404    #[test]
405    fn test_insert_multiple_vectors() {
406        let mut idx = new_index();
407        for i in 0..10 {
408            idx.insert(i, &unit_vec(4, i % 4));
409        }
410        assert_eq!(idx.len(), 10);
411    }
412
413    // ------ search ------
414
415    #[test]
416    fn test_search_empty_index() {
417        let idx = new_index();
418        let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 5);
419        assert!(results.is_empty());
420    }
421
422    #[test]
423    fn test_search_exact_match() {
424        let mut idx = LshIndex::new(4, 8, 16, 42);
425        let v = vec![1.0_f64, 0.0, 0.0, 0.0];
426        idx.insert(0, &v);
427        let results = idx.search(&v, 1);
428        assert!(!results.is_empty());
429        assert_eq!(results[0].0, 0);
430        assert!((results[0].1 - 1.0).abs() < 1e-6);
431    }
432
433    #[test]
434    fn test_search_k_limits_results() {
435        let mut idx = LshIndex::new(4, 8, 4, 77);
436        let v = vec![1.0_f64, 0.0, 0.0, 0.0];
437        for i in 0..5 {
438            // All very similar vectors
439            let mut vv = v.clone();
440            vv[0] = 1.0 - i as f64 * 0.01;
441            idx.insert(i, &vv);
442        }
443        let results = idx.search(&v, 2);
444        assert!(results.len() <= 2);
445    }
446
447    #[test]
448    fn test_search_returns_closer_vector() {
449        let mut idx = LshIndex::new(2, 8, 16, 1);
450        // v1 is close to query [1, 0]
451        idx.insert(0, &[1.0_f64, 0.01]);
452        // v2 is far from query
453        idx.insert(1, &[0.0_f64, 1.0]);
454
455        let results = idx.search(&[1.0_f64, 0.0], 2);
456        // When both are found, v1 should rank higher
457        if results.len() >= 2 {
458            assert!(results[0].1 >= results[1].1);
459        }
460    }
461
462    #[test]
463    fn test_search_sorted_descending() {
464        let mut idx = LshIndex::new(4, 8, 16, 7);
465        idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
466        idx.insert(1, &[0.9_f64, 0.1, 0.0, 0.0]);
467        idx.insert(2, &[0.5_f64, 0.5, 0.0, 0.0]);
468
469        let query = [1.0_f64, 0.0, 0.0, 0.0];
470        let results = idx.search(&query, 3);
471        for w in results.windows(2) {
472            assert!(w[0].1 >= w[1].1, "Results not sorted descending");
473        }
474    }
475
476    #[test]
477    fn test_search_k_greater_than_num_vectors() {
478        let mut idx = new_index();
479        idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
480        idx.insert(1, &[0.0_f64, 1.0, 0.0, 0.0]);
481        let results = idx.search(&[1.0_f64, 0.0, 0.0, 0.0], 100);
482        assert!(results.len() <= 2);
483    }
484
485    // ------ clear ------
486
487    #[test]
488    fn test_clear() {
489        let mut idx = new_index();
490        idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
491        idx.clear();
492        assert!(idx.is_empty());
493    }
494
495    #[test]
496    fn test_clear_then_insert() {
497        let mut idx = new_index();
498        idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
499        idx.clear();
500        idx.insert(0, &[0.0_f64, 1.0, 0.0, 0.0]);
501        assert_eq!(idx.len(), 1);
502    }
503
504    // ------ multi-table redundancy ------
505
506    #[test]
507    fn test_multi_table_improves_recall() {
508        // With many tables, the target vector should almost certainly be found
509        let mut idx = LshIndex::new(4, 16, 8, 2024);
510        let target = vec![1.0_f64, 0.0, 0.0, 0.0];
511        idx.insert(42, &target);
512
513        // Add some other vectors
514        for i in 0..20 {
515            let mut v = vec![0.0_f64; 4];
516            v[i % 4] = 1.0;
517            v[(i + 1) % 4] = 0.1;
518            idx.insert(i, &v);
519        }
520
521        let results = idx.search(&target, 5);
522        let found = results.iter().any(|(id, _)| *id == 42);
523        assert!(found, "Target vector should be found with 16 tables");
524    }
525
526    #[test]
527    fn test_high_dimensional_search() {
528        let dim = 64;
529        let mut idx = LshIndex::new(dim, 8, 16, 99);
530        let target: Vec<f64> = (0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
531        idx.insert(0, &target);
532        let results = idx.search(&target, 1);
533        if !results.is_empty() {
534            assert!((results[0].1 - 1.0).abs() < 1e-6);
535        }
536    }
537
538    #[test]
539    fn test_is_empty_after_inserts() {
540        let mut idx = new_index();
541        assert!(idx.is_empty());
542        idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
543        assert!(!idx.is_empty());
544    }
545
546    #[test]
547    fn test_results_contain_similarity() {
548        let mut idx = LshIndex::new(4, 8, 16, 55);
549        idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
550        let results = idx.search(&[1.0_f64, 0.0, 0.0, 0.0], 1);
551        if !results.is_empty() {
552            assert!(results[0].1 >= 0.0 && results[0].1 <= 1.0 + 1e-9);
553        }
554    }
555}