1use fastrand::Rng;
2use std::collections::HashMap;
3
4#[derive(Clone, Default)]
9pub struct LSH {
10 hyperplanes: Vec<Vec<f32>>,
11 num_tables: usize,
12 num_hyperplanes: usize,
13}
14
15impl LSH {
16 pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
18 let mut rng = Rng::new();
19 let mut hyperplanes = Vec::new();
20
21 for _ in 0..(num_tables * num_hyperplanes) {
22 let mut plane = vec![0.0; dim];
23
24 for val in plane.iter_mut() {
29 *val = rng.f32() * 2.0 - 1.0;
30 }
31
32 let norm: f32 = plane.iter().map(|x| x * x).sum::<f32>().sqrt();
35 if norm > 0.0 {
36 for val in plane.iter_mut() {
37 *val /= norm;
38 }
39 }
40
41 hyperplanes.push(plane);
42 }
43
44 Self {
45 hyperplanes,
46 num_tables,
47 num_hyperplanes,
48 }
49 }
50
51 pub fn hash(&self, vector: &[f64], table_idx: usize) -> u64 {
53 let mut hash = 0u64;
54 let start = table_idx * self.num_hyperplanes;
55
56 for (i, hyperplane) in self.hyperplanes[start..start + self.num_hyperplanes]
57 .iter()
58 .enumerate()
59 {
60 let dot: f32 = vector
62 .iter()
63 .zip(hyperplane.iter())
64 .map(|(v, h)| (*v as f32) * h)
65 .sum();
66
67 if dot >= 0.0 {
69 hash |= 1u64 << i;
70 }
71 }
72
73 hash
74 }
75}
76
77#[derive(Clone, Default)]
81pub struct LSHIndex {
82 lsh: LSH,
83 tables: Vec<HashMap<u64, Vec<String>>>, }
85
86impl LSHIndex {
87 pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self {
89 let lsh = LSH::new(dim, num_tables, num_hyperplanes);
90 let tables = vec![HashMap::new(); num_tables];
91
92 Self { lsh, tables }
93 }
94
95 pub fn insert(&mut self, id: String, embedding: &[f64]) {
97 for table_idx in 0..self.lsh.num_tables {
98 let hash = self.lsh.hash(embedding, table_idx);
99 self.tables[table_idx]
100 .entry(hash)
101 .or_default()
102 .push(id.clone());
103 }
104 }
105
106 pub fn query(&self, embedding: &[f64]) -> Vec<String> {
108 use std::collections::HashSet;
109
110 let mut candidates = HashSet::new();
111
112 for table_idx in 0..self.lsh.num_tables {
114 let hash = self.lsh.hash(embedding, table_idx);
115
116 if let Some(ids) = self.tables[table_idx].get(&hash) {
117 candidates.extend(ids.iter().cloned());
118 }
119 }
120
121 candidates.into_iter().collect()
122 }
123
124 pub fn clear(&mut self) {
126 for table in self.tables.iter_mut() {
127 table.clear();
128 }
129 }
130}