Skip to main content

rig/vector_store/
lsh.rs

1use fastrand::Rng;
2use std::collections::HashMap;
3
4#[cfg(test)]
5fn lsh_rng() -> Rng {
6    Rng::with_seed(0x5eed_fade_cafe_beef)
7}
8
9#[cfg(not(test))]
10fn lsh_rng() -> Rng {
11    Rng::new()
12}
13
14/// Locality Sensitive Hashing (LSH) with random projection.
15/// Uses random hyperplanes to hash similar vectors into the same buckets for efficient
16/// approximate nearest neighbor search. See <https://www.pinecone.io/learn/series/faiss/locality-sensitive-hashing-random-projection/>
17/// for details on how LSH works.
18#[derive(Clone, Default)]
19pub struct LSH {
20    hyperplanes: Vec<Vec<f32>>,
21    num_tables: usize,
22    num_hyperplanes: usize,
23}
24
25impl LSH {
26    /// Create a new LSH instance.
27    pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
28        let mut rng = lsh_rng();
29        let mut hyperplanes = Vec::new();
30
31        for _ in 0..(num_tables * num_hyperplanes) {
32            let mut plane = vec![0.0; dim];
33
34            // Generate random values in [-1, 1] to ensure uniform distribution across all directions
35            // before normalization. This guarantees that after normalization to unit vectors, the
36            // hyperplanes are uniformly distributed across the unit sphere, which is essential for
37            // LSH to maintain good locality-sensitive hashing properties.
38            for val in plane.iter_mut() {
39                *val = rng.f32() * 2.0 - 1.0;
40            }
41
42            // Normalize to unit vector so the dot product reflects only direction, ensuring
43            // the hash correctly identifies which side of the hyperplane each point lies on.
44            let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
45            if norm > 0.0 {
46                for val in plane.iter_mut() {
47                    *val /= norm;
48                }
49            }
50
51            hyperplanes.push(plane);
52        }
53
54        Self {
55            hyperplanes,
56            num_tables,
57            num_hyperplanes,
58        }
59    }
60
61    /// Compute hash for a vector in a specific table
62    pub fn hash(&self, vector: &[f64], table_idx: usize) -> u64 {
63        let mut hash = 0u64;
64        let start = table_idx * self.num_hyperplanes;
65
66        for (i, hyperplane) in self
67            .hyperplanes
68            .get(start..start + self.num_hyperplanes)
69            .unwrap_or(&[])
70            .iter()
71            .enumerate()
72        {
73            // Dot product (convert f64 to f32)
74            let dot: f32 = vector
75                .iter()
76                .zip(hyperplane.iter())
77                .map(|(v, h)| (*v as f32) * h)
78                .sum();
79
80            // Set bit if positive
81            if dot >= 0.0 {
82                hash |= 1u64 << i;
83            }
84        }
85
86        hash
87    }
88}
89
90/// LSH Index for document IDs.
91/// Stores document IDs in a hashmap of hash values to document IDs.
92/// This allows for efficient lookup of document IDs by hash value.
93#[derive(Clone, Default)]
94pub struct LSHIndex {
95    lsh: LSH,
96    tables: Vec<HashMap<u64, Vec<String>>>, // Hash -> document IDs
97}
98
99impl LSHIndex {
100    /// Create a new LSHIndex.
101    pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
102        let lsh = LSH::new(dim, num_tables, num_hyperplanes);
103        let tables = vec![HashMap::new(); num_tables];
104
105        Self { lsh, tables }
106    }
107
108    /// Insert a document ID with its embedding
109    pub fn insert(&mut self, id: String, embedding: &[f64]) {
110        for table_idx in 0..self.lsh.num_tables {
111            let hash = self.lsh.hash(embedding, table_idx);
112            if let Some(table) = self.tables.get_mut(table_idx) {
113                table.entry(hash).or_default().push(id.clone());
114            }
115        }
116    }
117
118    /// Query for candidate document IDs
119    pub fn query(&self, embedding: &[f64]) -> Vec<String> {
120        use std::collections::HashSet;
121
122        let mut candidates = HashSet::new();
123
124        // Collect candidates from all tables
125        for table_idx in 0..self.lsh.num_tables {
126            let hash = self.lsh.hash(embedding, table_idx);
127
128            if let Some(ids) = self
129                .tables
130                .get(table_idx)
131                .and_then(|table| table.get(&hash))
132            {
133                candidates.extend(ids.iter().cloned());
134            }
135        }
136
137        candidates.into_iter().collect()
138    }
139
140    /// Clear all tables
141    pub fn clear(&mut self) {
142        for table in self.tables.iter_mut() {
143            table.clear();
144        }
145    }
146}