rig/vector_store/
lsh.rs

1use fastrand::Rng;
2use std::collections::HashMap;
3
4/// Locality Sensitive Hashing (LSH) with random projection.
5/// Uses random hyperplanes to hash similar vectors into the same buckets for efficient
6/// approximate nearest neighbor search. See <https://www.pinecone.io/learn/series/faiss/locality-sensitive-hashing-random-projection/>
7/// for details on how LSH works.
8#[derive(Clone, Default)]
9pub struct LSH {
10    hyperplanes: Vec<Vec<f32>>,
11    num_tables: usize,
12    num_hyperplanes: usize,
13}
14
15impl LSH {
16    /// Create a new LSH instance.
17    pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
18        let mut rng = Rng::new();
19        let mut hyperplanes = Vec::new();
20
21        for _ in 0..(num_tables * num_hyperplanes) {
22            let mut plane = vec![0.0; dim];
23
24            // Generate random values in [-1, 1] to ensure uniform distribution across all directions
25            // before normalization. This guarantees that after normalization to unit vectors, the
26            // hyperplanes are uniformly distributed across the unit sphere, which is essential for
27            // LSH to maintain good locality-sensitive hashing properties.
28            for val in plane.iter_mut() {
29                *val = rng.f32() * 2.0 - 1.0;
30            }
31
32            // Normalize to unit vector so the dot product reflects only direction, ensuring
33            // the hash correctly identifies which side of the hyperplane each point lies on.
34            let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
35            if norm > 0.0 {
36                for val in plane.iter_mut() {
37                    *val /= norm;
38                }
39            }
40
41            hyperplanes.push(plane);
42        }
43
44        Self {
45            hyperplanes,
46            num_tables,
47            num_hyperplanes,
48        }
49    }
50
51    /// Compute hash for a vector in a specific table
52    pub fn hash(&self, vector: &[f64], table_idx: usize) -> u64 {
53        let mut hash = 0u64;
54        let start = table_idx * self.num_hyperplanes;
55
56        for (i, hyperplane) in self.hyperplanes[start..start + self.num_hyperplanes]
57            .iter()
58            .enumerate()
59        {
60            // Dot product (convert f64 to f32)
61            let dot: f32 = vector
62                .iter()
63                .zip(hyperplane.iter())
64                .map(|(v, h)| (*v as f32) * h)
65                .sum();
66
67            // Set bit if positive
68            if dot >= 0.0 {
69                hash |= 1u64 << i;
70            }
71        }
72
73        hash
74    }
75}
76
77/// LSH Index for document IDs.
78/// Stores document IDs in a hashmap of hash values to document IDs.
79/// This allows for efficient lookup of document IDs by hash value.
80#[derive(Clone, Default)]
81pub struct LSHIndex {
82    lsh: LSH,
83    tables: Vec<HashMap<u64, Vec<String>>>, // Hash -> document IDs
84}
85
86impl LSHIndex {
87    /// Create a new LSHIndex.
88    pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
89        let lsh = LSH::new(dim, num_tables, num_hyperplanes);
90        let tables = vec![HashMap::new(); num_tables];
91
92        Self { lsh, tables }
93    }
94
95    /// Insert a document ID with its embedding
96    pub fn insert(&mut self, id: String, embedding: &[f64]) {
97        for table_idx in 0..self.lsh.num_tables {
98            let hash = self.lsh.hash(embedding, table_idx);
99            self.tables[table_idx]
100                .entry(hash)
101                .or_default()
102                .push(id.clone());
103        }
104    }
105
106    /// Query for candidate document IDs
107    pub fn query(&self, embedding: &[f64]) -> Vec<String> {
108        use std::collections::HashSet;
109
110        let mut candidates = HashSet::new();
111
112        // Collect candidates from all tables
113        for table_idx in 0..self.lsh.num_tables {
114            let hash = self.lsh.hash(embedding, table_idx);
115
116            if let Some(ids) = self.tables[table_idx].get(&hash) {
117                candidates.extend(ids.iter().cloned());
118            }
119        }
120
121        candidates.into_iter().collect()
122    }
123
124    /// Clear all tables
125    pub fn clear(&mut self) {
126        for table in self.tables.iter_mut() {
127            table.clear();
128        }
129    }
130}