1use crate::utils::simd::SimdVectorOps;
2use ndarray::{Array1, Array2};
3use rand::Rng;
4use rayon::prelude::*;
5use std::collections::HashMap;
6
7#[derive(Clone)]
9pub struct LSH {
10 num_tables: usize,
12 hash_size: usize,
14 #[allow(dead_code)]
16 vector_dim: usize,
17 hyperplanes: Vec<Array2<f32>>,
19 hash_tables: Vec<HashMap<Vec<bool>, Vec<usize>>>,
21 stored_vectors: Vec<Array1<f32>>,
23 stored_data: Vec<f32>,
25}
26
27impl LSH {
28 pub fn new(vector_dim: usize, num_tables: usize, hash_size: usize) -> Self {
30 Self::with_expected_size(vector_dim, num_tables, hash_size, 1000)
31 }
32
33 pub fn with_expected_size(
35 vector_dim: usize,
36 num_tables: usize,
37 hash_size: usize,
38 expected_size: usize,
39 ) -> Self {
40 let mut rng = rand::thread_rng();
41 let mut hyperplanes = Vec::new();
42
43 for _ in 0..num_tables {
45 let mut table_hyperplanes = Array2::zeros((hash_size, vector_dim));
46 for i in 0..hash_size {
47 for j in 0..vector_dim {
48 table_hyperplanes[[i, j]] = rng.gen_range(-1.0..1.0);
49 }
50 }
51 hyperplanes.push(table_hyperplanes);
52 }
53
54 let bucket_capacity = (expected_size as f32 / 5.0).ceil() as usize;
57 let optimal_capacity = bucket_capacity.max(100);
58
59 let hash_tables = vec![HashMap::with_capacity(optimal_capacity); num_tables];
60
61 Self {
62 num_tables,
63 hash_size,
64 vector_dim,
65 hyperplanes,
66 hash_tables,
67 stored_vectors: Vec::with_capacity(expected_size),
68 stored_data: Vec::with_capacity(expected_size),
69 }
70 }
71
72 pub fn add_vector(&mut self, vector: Array1<f32>, data: f32) {
74 let index = self.stored_vectors.len();
75
76 let current_load =
78 self.stored_vectors.len() as f32 / (self.hash_tables[0].capacity() as f32 * 0.2);
79 if current_load > 0.75 {
80 self.resize_hash_tables();
81 }
82
83 let mut hashes = Vec::with_capacity(self.num_tables);
85 for table_idx in 0..self.num_tables {
86 hashes.push(self.hash_vector(&vector, table_idx));
87 }
88
89 self.stored_vectors.push(vector);
91 self.stored_data.push(data);
92
93 for (table_idx, hash) in hashes.into_iter().enumerate() {
95 self.hash_tables[table_idx]
96 .entry(hash)
97 .or_insert_with(|| Vec::with_capacity(8)) .push(index);
99 }
100 }
101
102 fn resize_hash_tables(&mut self) {
104 let new_capacity = (self.hash_tables[0].capacity() * 2).max(self.stored_vectors.len());
105
106 for table in &mut self.hash_tables {
107 table.reserve(new_capacity - table.capacity());
109 }
110 }
111
112 pub fn query(&self, query_vector: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
114 let mut candidates = std::collections::HashSet::new();
115 let max_candidates = (self.stored_vectors.len() / 4)
116 .max(k * 10)
117 .min(self.stored_vectors.len());
118
119 if self.num_tables > 4 {
121 let candidate_sets: Vec<Vec<usize>> = (0..self.num_tables)
123 .into_par_iter()
124 .map(|table_idx| {
125 let hash = self.hash_vector(query_vector, table_idx);
126 if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
127 bucket.clone()
128 } else {
129 Vec::new()
130 }
131 })
132 .collect();
133
134 for candidate_set in candidate_sets {
136 for idx in candidate_set {
137 candidates.insert(idx);
138 if candidates.len() >= max_candidates {
139 break;
140 }
141 }
142 if candidates.len() >= max_candidates {
143 break;
144 }
145 }
146 } else {
147 for table_idx in 0..self.num_tables {
149 if candidates.len() >= max_candidates {
150 break;
151 }
152
153 let hash = self.hash_vector(query_vector, table_idx);
154 if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
155 for &idx in bucket {
156 candidates.insert(idx);
157 if candidates.len() >= max_candidates {
158 break;
159 }
160 }
161 }
162 }
163 }
164
165 if candidates.len() < k * 3 && self.stored_vectors.len() > k * 3 {
167 let needed = k * 5;
169 for idx in 0..needed.min(self.stored_vectors.len()) {
170 candidates.insert(idx);
171 if candidates.len() >= needed {
172 break;
173 }
174 }
175 }
176
177 let mut results = if candidates.len() > 50 {
179 candidates
180 .par_iter()
181 .map(|&idx| {
182 let stored_vector = &self.stored_vectors[idx];
183 let similarity = cosine_similarity(query_vector, stored_vector);
184 (stored_vector.clone(), self.stored_data[idx], similarity)
185 })
186 .collect()
187 } else {
188 let mut results = Vec::with_capacity(candidates.len());
190 for &idx in &candidates {
191 let stored_vector = &self.stored_vectors[idx];
192 let similarity = cosine_similarity(query_vector, stored_vector);
193 results.push((stored_vector.clone(), self.stored_data[idx], similarity));
194 }
195 results
196 };
197
198 if results.len() > 100 {
200 results.par_sort_unstable_by(|a, b| {
201 b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
202 });
203 } else {
204 results.sort_unstable_by(|a, b| {
205 b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
206 });
207 }
208 results.truncate(k);
209
210 results
211 }
212
213 fn hash_vector(&self, vector: &Array1<f32>, table_idx: usize) -> Vec<bool> {
215 let hyperplanes = &self.hyperplanes[table_idx];
216
217 let dot_products = hyperplanes.dot(vector);
219
220 dot_products.iter().map(|&x| x >= 0.0).collect()
222 }
223
224 pub fn stats(&self) -> LSHStats {
226 let mut bucket_sizes = Vec::new();
227 let mut total_buckets = 0;
228 let mut non_empty_buckets = 0;
229
230 for table in &self.hash_tables {
231 let buckets_per_table = 1_usize
233 .checked_shl(self.hash_size as u32)
234 .unwrap_or(usize::MAX);
235 total_buckets += buckets_per_table;
236 non_empty_buckets += table.len();
237
238 for bucket in table.values() {
239 bucket_sizes.push(bucket.len());
240 }
241 }
242
243 bucket_sizes.sort();
244 let median_bucket_size = if bucket_sizes.is_empty() {
245 0.0
246 } else {
247 bucket_sizes[bucket_sizes.len() / 2] as f32
248 };
249
250 let avg_bucket_size = if bucket_sizes.is_empty() {
251 0.0
252 } else {
253 bucket_sizes.iter().sum::<usize>() as f32 / bucket_sizes.len() as f32
254 };
255
256 LSHStats {
257 num_vectors: self.stored_vectors.len(),
258 num_tables: self.num_tables,
259 hash_size: self.hash_size,
260 total_buckets,
261 non_empty_buckets,
262 avg_bucket_size,
263 median_bucket_size,
264 max_bucket_size: bucket_sizes.last().copied().unwrap_or(0),
265 }
266 }
267
268 pub fn save_to_database(
270 &self,
271 db: &crate::persistence::Database,
272 ) -> Result<(), Box<dyn std::error::Error>> {
273 use crate::persistence::{LSHHashFunction, LSHTableData};
274
275 let mut hash_functions = Vec::new();
277 for hyperplane_matrix in &self.hyperplanes {
278 for row in hyperplane_matrix.rows() {
279 hash_functions.push(LSHHashFunction {
280 random_vector: row.to_vec().iter().map(|&x| x as f64).collect(),
281 threshold: 0.0, });
283 }
284 }
285
286 let config = LSHTableData {
287 num_tables: self.num_tables,
288 num_hash_functions: self.hash_size,
289 vector_dim: self.vector_dim,
290 hash_functions,
291 };
292
293 db.save_lsh_config(&config)?;
294
295 db.clear_lsh_buckets()?;
297
298 for (table_idx, table) in self.hash_tables.iter().enumerate() {
300 for (hash_bits, indices) in table {
301 let hash_string = hash_bits
302 .iter()
303 .map(|&b| if b { '1' } else { '0' })
304 .collect::<String>();
305
306 for &position_idx in indices {
307 db.save_lsh_bucket(table_idx, &hash_string, position_idx as i64)?;
308 }
309 }
310 }
311
312 Ok(())
313 }
314
315 pub fn load_from_database(
317 db: &crate::persistence::Database,
318 positions: &[(Array1<f32>, f32)],
319 ) -> Result<Option<Self>, Box<dyn std::error::Error>> {
320 let config = match db.load_lsh_config()? {
321 Some(config) => config,
322 None => return Ok(None),
323 };
324
325 let mut hyperplanes = Vec::new();
327 let functions_per_table = config.num_hash_functions;
328
329 for table_idx in 0..config.num_tables {
330 let start_idx = table_idx * functions_per_table;
331 let end_idx = start_idx + functions_per_table;
332
333 if end_idx <= config.hash_functions.len() {
334 let mut table_hyperplanes = Array2::zeros((functions_per_table, config.vector_dim));
335
336 for (func_idx, hash_func) in
337 config.hash_functions[start_idx..end_idx].iter().enumerate()
338 {
339 for (dim_idx, &value) in hash_func.random_vector.iter().enumerate() {
340 if dim_idx < config.vector_dim {
341 table_hyperplanes[[func_idx, dim_idx]] = value as f32;
342 }
343 }
344 }
345
346 hyperplanes.push(table_hyperplanes);
347 }
348 }
349
350 let mut lsh = Self {
352 num_tables: config.num_tables,
353 hash_size: config.num_hash_functions,
354 vector_dim: config.vector_dim,
355 hyperplanes,
356 hash_tables: vec![HashMap::with_capacity(positions.len().max(100)); config.num_tables],
357 stored_vectors: Vec::new(),
358 stored_data: Vec::new(),
359 };
360
361 for (vector, evaluation) in positions {
363 lsh.add_vector(vector.clone(), *evaluation);
364 }
365
366 Ok(Some(lsh))
367 }
368
369 pub fn from_database_or_new(
371 db: &crate::persistence::Database,
372 positions: &[(Array1<f32>, f32)],
373 vector_dim: usize,
374 num_tables: usize,
375 hash_size: usize,
376 ) -> Result<Self, Box<dyn std::error::Error>> {
377 match Self::load_from_database(db, positions)? {
378 Some(lsh) => {
379 println!(
380 "Loaded LSH configuration from database with {} vectors",
381 lsh.stored_vectors.len()
382 );
383 Ok(lsh)
384 }
385 None => {
386 println!("No saved LSH configuration found, creating new LSH index");
387 let mut lsh =
388 Self::with_expected_size(vector_dim, num_tables, hash_size, positions.len());
389 for (vector, evaluation) in positions {
390 lsh.add_vector(vector.clone(), *evaluation);
391 }
392 Ok(lsh)
393 }
394 }
395 }
396}
397
398#[derive(Debug)]
400pub struct LSHStats {
401 pub num_vectors: usize,
402 pub num_tables: usize,
403 pub hash_size: usize,
404 pub total_buckets: usize,
405 pub non_empty_buckets: usize,
406 pub avg_bucket_size: f32,
407 pub median_bucket_size: f32,
408 pub max_bucket_size: usize,
409}
410
411fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
413 SimdVectorOps::cosine_similarity(a, b)
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use ndarray::Array1;
420
421 #[test]
422 fn test_lsh_creation() {
423 let lsh = LSH::new(128, 4, 8);
424 assert_eq!(lsh.num_tables, 4);
425 assert_eq!(lsh.hash_size, 8);
426 assert_eq!(lsh.vector_dim, 128);
427 }
428
429 #[test]
430 fn test_lsh_add_and_query() {
431 let mut lsh = LSH::new(4, 2, 4);
432
433 let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
435 let vec2 = Array1::from(vec![0.0, 1.0, 0.0, 0.0]);
436 let vec3 = Array1::from(vec![1.0, 0.1, 0.0, 0.0]); lsh.add_vector(vec1.clone(), 1.0);
439 lsh.add_vector(vec2, 2.0);
440 lsh.add_vector(vec3, 1.1);
441
442 let results = lsh.query(&vec1, 2);
444 assert!(!results.is_empty());
445
446 assert!(results[0].2 > 0.8); }
449
450 #[test]
451 fn test_cosine_similarity() {
452 let a = Array1::from(vec![1.0, 0.0, 0.0]);
453 let b = Array1::from(vec![1.0, 0.0, 0.0]);
454 let c = Array1::from(vec![0.0, 1.0, 0.0]);
455
456 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
457 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
458 }
459}