1use fastrand::Rng;
2use std::collections::HashMap;
3
4#[cfg(test)]
5fn lsh_rng() -> Rng {
6 Rng::with_seed(0x5eed_fade_cafe_beef)
7}
8
9#[cfg(not(test))]
10fn lsh_rng() -> Rng {
11 Rng::new()
12}
13
14#[derive(Clone, Default)]
19pub struct LSH {
20 hyperplanes: Vec<Vec<f32>>,
21 num_tables: usize,
22 num_hyperplanes: usize,
23}
24
25impl LSH {
26 pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
28 let mut rng = lsh_rng();
29 let mut hyperplanes = Vec::new();
30
31 for _ in 0..(num_tables * num_hyperplanes) {
32 let mut plane = vec![0.0; dim];
33
34 for val in plane.iter_mut() {
39 *val = rng.f32() * 2.0 - 1.0;
40 }
41
42 let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
45 if norm > 0.0 {
46 for val in plane.iter_mut() {
47 *val /= norm;
48 }
49 }
50
51 hyperplanes.push(plane);
52 }
53
54 Self {
55 hyperplanes,
56 num_tables,
57 num_hyperplanes,
58 }
59 }
60
61 pub fn hash(&self, vector: &[f64], table_idx: usize) -> u64 {
63 let mut hash = 0u64;
64 let start = table_idx * self.num_hyperplanes;
65
66 for (i, hyperplane) in self
67 .hyperplanes
68 .get(start..start + self.num_hyperplanes)
69 .unwrap_or(&[])
70 .iter()
71 .enumerate()
72 {
73 let dot: f32 = vector
75 .iter()
76 .zip(hyperplane.iter())
77 .map(|(v, h)| (*v as f32) * h)
78 .sum();
79
80 if dot >= 0.0 {
82 hash |= 1u64 << i;
83 }
84 }
85
86 hash
87 }
88}
89
90#[derive(Clone, Default)]
94pub struct LSHIndex {
95 lsh: LSH,
96 tables: Vec<HashMap<u64, Vec<String>>>, }
98
99impl LSHIndex {
100 pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
102 let lsh = LSH::new(dim, num_tables, num_hyperplanes);
103 let tables = vec![HashMap::new(); num_tables];
104
105 Self { lsh, tables }
106 }
107
108 pub fn insert(&mut self, id: String, embedding: &[f64]) {
110 for table_idx in 0..self.lsh.num_tables {
111 let hash = self.lsh.hash(embedding, table_idx);
112 if let Some(table) = self.tables.get_mut(table_idx) {
113 table.entry(hash).or_default().push(id.clone());
114 }
115 }
116 }
117
118 pub fn query(&self, embedding: &[f64]) -> Vec<String> {
120 use std::collections::HashSet;
121
122 let mut candidates = HashSet::new();
123
124 for table_idx in 0..self.lsh.num_tables {
126 let hash = self.lsh.hash(embedding, table_idx);
127
128 if let Some(ids) = self
129 .tables
130 .get(table_idx)
131 .and_then(|table| table.get(&hash))
132 {
133 candidates.extend(ids.iter().cloned());
134 }
135 }
136
137 candidates.into_iter().collect()
138 }
139
140 pub fn clear(&mut self) {
142 for table in self.tables.iter_mut() {
143 table.clear();
144 }
145 }
146}