1use ndarray::{Array1, Array2};
2use rand::Rng;
3use rayon::prelude::*;
4use std::collections::HashMap;
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8#[cfg(target_arch = "x86_64")]
9use std::arch::x86_64::*;
10
11#[derive(Clone)]
13pub struct LSH {
14 num_tables: usize,
16 hash_size: usize,
18 #[allow(dead_code)]
20 vector_dim: usize,
21 hyperplanes: Vec<Array2<f32>>,
23 hash_tables: Vec<HashMap<Vec<bool>, Vec<usize>>>,
25 stored_vectors: Vec<Array1<f32>>,
27 stored_data: Vec<f32>,
29}
30
31impl LSH {
32 pub fn new(vector_dim: usize, num_tables: usize, hash_size: usize) -> Self {
34 Self::with_expected_size(vector_dim, num_tables, hash_size, 1000)
35 }
36
37 pub fn with_expected_size(
39 vector_dim: usize,
40 num_tables: usize,
41 hash_size: usize,
42 expected_size: usize,
43 ) -> Self {
44 let mut rng = rand::thread_rng();
45 let mut hyperplanes = Vec::new();
46
47 for _ in 0..num_tables {
49 let mut table_hyperplanes = Array2::zeros((hash_size, vector_dim));
50 for i in 0..hash_size {
51 for j in 0..vector_dim {
52 table_hyperplanes[[i, j]] = rng.gen_range(-1.0..1.0);
53 }
54 }
55 hyperplanes.push(table_hyperplanes);
56 }
57
58 let bucket_capacity = (expected_size as f32 / 5.0).ceil() as usize;
61 let optimal_capacity = bucket_capacity.max(100);
62
63 let hash_tables = vec![HashMap::with_capacity(optimal_capacity); num_tables];
64
65 Self {
66 num_tables,
67 hash_size,
68 vector_dim,
69 hyperplanes,
70 hash_tables,
71 stored_vectors: Vec::with_capacity(expected_size),
72 stored_data: Vec::with_capacity(expected_size),
73 }
74 }
75
76 pub fn add_vector(&mut self, vector: Array1<f32>, data: f32) {
78 let index = self.stored_vectors.len();
79
80 let current_load =
82 self.stored_vectors.len() as f32 / (self.hash_tables[0].capacity() as f32 * 0.2);
83 if current_load > 0.75 {
84 self.resize_hash_tables();
85 }
86
87 let mut hashes = Vec::with_capacity(self.num_tables);
89 for table_idx in 0..self.num_tables {
90 hashes.push(self.hash_vector(&vector, table_idx));
91 }
92
93 self.stored_vectors.push(vector);
95 self.stored_data.push(data);
96
97 for (table_idx, hash) in hashes.into_iter().enumerate() {
99 self.hash_tables[table_idx]
100 .entry(hash)
101 .or_insert_with(|| Vec::with_capacity(8)) .push(index);
103 }
104 }
105
106 fn resize_hash_tables(&mut self) {
108 let new_capacity = (self.hash_tables[0].capacity() * 2).max(self.stored_vectors.len());
109
110 for table in &mut self.hash_tables {
111 table.reserve(new_capacity - table.capacity());
113 }
114 }
115
116 pub fn query(&self, query_vector: &Array1<f32>, k: usize) -> Vec<(Array1<f32>, f32, f32)> {
118 let mut candidates = std::collections::HashSet::new();
119 let max_candidates = (self.stored_vectors.len() / 4)
120 .max(k * 10)
121 .min(self.stored_vectors.len());
122
123 if self.num_tables > 4 {
125 let candidate_sets: Vec<Vec<usize>> = (0..self.num_tables)
127 .into_par_iter()
128 .map(|table_idx| {
129 let hash = self.hash_vector(query_vector, table_idx);
130 if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
131 bucket.clone()
132 } else {
133 Vec::new()
134 }
135 })
136 .collect();
137
138 for candidate_set in candidate_sets {
140 for idx in candidate_set {
141 candidates.insert(idx);
142 if candidates.len() >= max_candidates {
143 break;
144 }
145 }
146 if candidates.len() >= max_candidates {
147 break;
148 }
149 }
150 } else {
151 for table_idx in 0..self.num_tables {
153 if candidates.len() >= max_candidates {
154 break;
155 }
156
157 let hash = self.hash_vector(query_vector, table_idx);
158 if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
159 for &idx in bucket {
160 candidates.insert(idx);
161 if candidates.len() >= max_candidates {
162 break;
163 }
164 }
165 }
166 }
167 }
168
169 if candidates.len() < k * 3 && self.stored_vectors.len() > k * 3 {
171 let needed = k * 5;
173 for idx in 0..needed.min(self.stored_vectors.len()) {
174 candidates.insert(idx);
175 if candidates.len() >= needed {
176 break;
177 }
178 }
179 }
180
181 let mut results = if candidates.len() > 50 {
183 candidates
184 .par_iter()
185 .map(|&idx| {
186 let stored_vector = &self.stored_vectors[idx];
187 let similarity = cosine_similarity(query_vector, stored_vector);
188 (stored_vector.clone(), self.stored_data[idx], similarity)
189 })
190 .collect()
191 } else {
192 let mut results = Vec::with_capacity(candidates.len());
194 for &idx in &candidates {
195 let stored_vector = &self.stored_vectors[idx];
196 let similarity = cosine_similarity(query_vector, stored_vector);
197 results.push((stored_vector.clone(), self.stored_data[idx], similarity));
198 }
199 results
200 };
201
202 if results.len() > 100 {
204 results.par_sort_unstable_by(|a, b| {
205 b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
206 });
207 } else {
208 results.sort_unstable_by(|a, b| {
209 b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
210 });
211 }
212 results.truncate(k);
213
214 results
215 }
216
217 fn hash_vector(&self, vector: &Array1<f32>, table_idx: usize) -> Vec<bool> {
219 let hyperplanes = &self.hyperplanes[table_idx];
220
221 let dot_products = hyperplanes.dot(vector);
223
224 dot_products.iter().map(|&x| x >= 0.0).collect()
226 }
227
228 pub fn stats(&self) -> LSHStats {
230 let mut bucket_sizes = Vec::new();
231 let mut total_buckets = 0;
232 let mut non_empty_buckets = 0;
233
234 for table in &self.hash_tables {
235 let buckets_per_table = 1_usize
237 .checked_shl(self.hash_size as u32)
238 .unwrap_or(usize::MAX);
239 total_buckets += buckets_per_table;
240 non_empty_buckets += table.len();
241
242 for bucket in table.values() {
243 bucket_sizes.push(bucket.len());
244 }
245 }
246
247 bucket_sizes.sort();
248 let median_bucket_size = if bucket_sizes.is_empty() {
249 0.0
250 } else {
251 bucket_sizes[bucket_sizes.len() / 2] as f32
252 };
253
254 let avg_bucket_size = if bucket_sizes.is_empty() {
255 0.0
256 } else {
257 bucket_sizes.iter().sum::<usize>() as f32 / bucket_sizes.len() as f32
258 };
259
260 LSHStats {
261 num_vectors: self.stored_vectors.len(),
262 num_tables: self.num_tables,
263 hash_size: self.hash_size,
264 total_buckets,
265 non_empty_buckets,
266 avg_bucket_size,
267 median_bucket_size,
268 max_bucket_size: bucket_sizes.last().copied().unwrap_or(0),
269 }
270 }
271
272 pub fn save_to_database(
274 &self,
275 db: &crate::persistence::Database,
276 ) -> Result<(), Box<dyn std::error::Error>> {
277 use crate::persistence::{LSHHashFunction, LSHTableData};
278
279 let mut hash_functions = Vec::new();
281 for hyperplane_matrix in &self.hyperplanes {
282 for row in hyperplane_matrix.rows() {
283 hash_functions.push(LSHHashFunction {
284 random_vector: row.to_vec().iter().map(|&x| x as f64).collect(),
285 threshold: 0.0, });
287 }
288 }
289
290 let config = LSHTableData {
291 num_tables: self.num_tables,
292 num_hash_functions: self.hash_size,
293 vector_dim: self.vector_dim,
294 hash_functions,
295 };
296
297 db.save_lsh_config(&config)?;
298
299 db.clear_lsh_buckets()?;
301
302 for (table_idx, table) in self.hash_tables.iter().enumerate() {
304 for (hash_bits, indices) in table {
305 let hash_string = hash_bits
306 .iter()
307 .map(|&b| if b { '1' } else { '0' })
308 .collect::<String>();
309
310 for &position_idx in indices {
311 db.save_lsh_bucket(table_idx, &hash_string, position_idx as i64)?;
312 }
313 }
314 }
315
316 Ok(())
317 }
318
319 pub fn load_from_database(
321 db: &crate::persistence::Database,
322 positions: &[(Array1<f32>, f32)],
323 ) -> Result<Option<Self>, Box<dyn std::error::Error>> {
324 let config = match db.load_lsh_config()? {
325 Some(config) => config,
326 None => return Ok(None),
327 };
328
329 let mut hyperplanes = Vec::new();
331 let functions_per_table = config.num_hash_functions;
332
333 for table_idx in 0..config.num_tables {
334 let start_idx = table_idx * functions_per_table;
335 let end_idx = start_idx + functions_per_table;
336
337 if end_idx <= config.hash_functions.len() {
338 let mut table_hyperplanes = Array2::zeros((functions_per_table, config.vector_dim));
339
340 for (func_idx, hash_func) in
341 config.hash_functions[start_idx..end_idx].iter().enumerate()
342 {
343 for (dim_idx, &value) in hash_func.random_vector.iter().enumerate() {
344 if dim_idx < config.vector_dim {
345 table_hyperplanes[[func_idx, dim_idx]] = value as f32;
346 }
347 }
348 }
349
350 hyperplanes.push(table_hyperplanes);
351 }
352 }
353
354 let mut lsh = Self {
356 num_tables: config.num_tables,
357 hash_size: config.num_hash_functions,
358 vector_dim: config.vector_dim,
359 hyperplanes,
360 hash_tables: vec![HashMap::with_capacity(positions.len().max(100)); config.num_tables],
361 stored_vectors: Vec::new(),
362 stored_data: Vec::new(),
363 };
364
365 for (vector, evaluation) in positions {
367 lsh.add_vector(vector.clone(), *evaluation);
368 }
369
370 Ok(Some(lsh))
371 }
372
373 pub fn from_database_or_new(
375 db: &crate::persistence::Database,
376 positions: &[(Array1<f32>, f32)],
377 vector_dim: usize,
378 num_tables: usize,
379 hash_size: usize,
380 ) -> Result<Self, Box<dyn std::error::Error>> {
381 match Self::load_from_database(db, positions)? {
382 Some(lsh) => {
383 println!(
384 "Loaded LSH configuration from database with {} vectors",
385 lsh.stored_vectors.len()
386 );
387 Ok(lsh)
388 }
389 None => {
390 println!("No saved LSH configuration found, creating new LSH index");
391 let mut lsh =
392 Self::with_expected_size(vector_dim, num_tables, hash_size, positions.len());
393 for (vector, evaluation) in positions {
394 lsh.add_vector(vector.clone(), *evaluation);
395 }
396 Ok(lsh)
397 }
398 }
399 }
400}
401
402#[derive(Debug)]
404pub struct LSHStats {
405 pub num_vectors: usize,
406 pub num_tables: usize,
407 pub hash_size: usize,
408 pub total_buckets: usize,
409 pub non_empty_buckets: usize,
410 pub avg_bucket_size: f32,
411 pub median_bucket_size: f32,
412 pub max_bucket_size: usize,
413}
414
415fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
417 let a_slice = a.as_slice().unwrap();
418 let b_slice = b.as_slice().unwrap();
419
420 let dot_product = simd_dot_product(a_slice, b_slice);
421 let norm_a_sq = simd_dot_product(a_slice, a_slice);
422 let norm_b_sq = simd_dot_product(b_slice, b_slice);
423
424 if norm_a_sq == 0.0 || norm_b_sq == 0.0 {
425 0.0
426 } else {
427 dot_product / (norm_a_sq * norm_b_sq).sqrt()
428 }
429}
430
431#[inline]
433fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
434 #[cfg(target_arch = "x86_64")]
435 {
436 if is_x86_feature_detected!("avx2") {
437 return unsafe { avx2_dot_product(a, b) };
438 } else if is_x86_feature_detected!("sse4.1") {
439 return unsafe { sse_dot_product(a, b) };
440 }
441 }
442
443 #[cfg(target_arch = "aarch64")]
444 {
445 if std::arch::is_aarch64_feature_detected!("neon") {
446 return unsafe { neon_dot_product(a, b) };
447 }
448 }
449
450 scalar_dot_product(a, b)
452}
453
454#[cfg(target_arch = "x86_64")]
455#[target_feature(enable = "avx2")]
456unsafe fn avx2_dot_product(a: &[f32], b: &[f32]) -> f32 {
457 let len = a.len().min(b.len());
458 let mut sum = _mm256_setzero_ps();
459 let mut i = 0;
460
461 while i + 8 <= len {
462 let va = _mm256_loadu_ps(a.as_ptr().add(i));
463 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
464 let vmul = _mm256_mul_ps(va, vb);
465 sum = _mm256_add_ps(sum, vmul);
466 i += 8;
467 }
468
469 let mut result = [0.0f32; 8];
470 _mm256_storeu_ps(result.as_mut_ptr(), sum);
471 let mut final_sum = result.iter().sum::<f32>();
472
473 while i < len {
474 final_sum += a[i] * b[i];
475 i += 1;
476 }
477
478 final_sum
479}
480
481#[cfg(target_arch = "x86_64")]
482#[target_feature(enable = "sse4.1")]
483unsafe fn sse_dot_product(a: &[f32], b: &[f32]) -> f32 {
484 let len = a.len().min(b.len());
485 let mut sum = _mm_setzero_ps();
486 let mut i = 0;
487
488 while i + 4 <= len {
489 let va = _mm_loadu_ps(a.as_ptr().add(i));
490 let vb = _mm_loadu_ps(b.as_ptr().add(i));
491 let vmul = _mm_mul_ps(va, vb);
492 sum = _mm_add_ps(sum, vmul);
493 i += 4;
494 }
495
496 let mut result = [0.0f32; 4];
497 _mm_storeu_ps(result.as_mut_ptr(), sum);
498 let mut final_sum = result.iter().sum::<f32>();
499
500 while i < len {
501 final_sum += a[i] * b[i];
502 i += 1;
503 }
504
505 final_sum
506}
507
508#[cfg(target_arch = "aarch64")]
509#[target_feature(enable = "neon")]
510unsafe fn neon_dot_product(a: &[f32], b: &[f32]) -> f32 {
511 let len = a.len().min(b.len());
512 let mut sum = vdupq_n_f32(0.0);
513 let mut i = 0;
514
515 while i + 4 <= len {
516 let va = vld1q_f32(a.as_ptr().add(i));
517 let vb = vld1q_f32(b.as_ptr().add(i));
518 let vmul = vmulq_f32(va, vb);
519 sum = vaddq_f32(sum, vmul);
520 i += 4;
521 }
522
523 let mut result = [0.0f32; 4];
524 vst1q_f32(result.as_mut_ptr(), sum);
525 let mut final_sum = result.iter().sum::<f32>();
526
527 while i < len {
528 final_sum += a[i] * b[i];
529 i += 1;
530 }
531
532 final_sum
533}
534
535#[inline]
536fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
537 let len = a.len().min(b.len());
538 let mut sum = 0.0f32;
539 let mut i = 0;
540
541 while i + 4 <= len {
542 sum += a[i] * b[i] + a[i + 1] * b[i + 1] + a[i + 2] * b[i + 2] + a[i + 3] * b[i + 3];
543 i += 4;
544 }
545
546 while i < len {
547 sum += a[i] * b[i];
548 i += 1;
549 }
550
551 sum
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use ndarray::Array1;
558
559 #[test]
560 fn test_lsh_creation() {
561 let lsh = LSH::new(128, 4, 8);
562 assert_eq!(lsh.num_tables, 4);
563 assert_eq!(lsh.hash_size, 8);
564 assert_eq!(lsh.vector_dim, 128);
565 }
566
567 #[test]
568 fn test_lsh_add_and_query() {
569 let mut lsh = LSH::new(4, 2, 4);
570
571 let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
573 let vec2 = Array1::from(vec![0.0, 1.0, 0.0, 0.0]);
574 let vec3 = Array1::from(vec![1.0, 0.1, 0.0, 0.0]); lsh.add_vector(vec1.clone(), 1.0);
577 lsh.add_vector(vec2, 2.0);
578 lsh.add_vector(vec3, 1.1);
579
580 let results = lsh.query(&vec1, 2);
582 assert!(!results.is_empty());
583
584 assert!(results[0].2 > 0.8); }
587
588 #[test]
589 fn test_cosine_similarity() {
590 let a = Array1::from(vec![1.0, 0.0, 0.0]);
591 let b = Array1::from(vec![1.0, 0.0, 0.0]);
592 let c = Array1::from(vec![0.0, 1.0, 0.0]);
593
594 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
595 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
596 }
597}