1use crate::random_utils::NormalSampler as Normal;
10use crate::{Vector, VectorIndex};
11use anyhow::{anyhow, Result};
12#[allow(unused_imports)]
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LshConfig {
20 pub num_tables: usize,
22 pub num_hash_functions: usize,
24 pub lsh_family: LshFamily,
26 pub seed: u64,
28 pub multi_probe: bool,
30 pub num_probes: usize,
32}
33
34impl Default for LshConfig {
35 fn default() -> Self {
36 Self {
37 num_tables: 10,
38 num_hash_functions: 8,
39 lsh_family: LshFamily::RandomProjection,
40 seed: 42,
41 multi_probe: true,
42 num_probes: 3,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
49pub enum LshFamily {
50 RandomProjection,
52 MinHash,
54 SimHash,
56 PStable(f32), }
59
60trait HashFunction: Send + Sync {
62 fn hash(&self, vector: &[f32]) -> u64;
64
65 fn hash_multi(&self, vector: &[f32], num_hashes: usize) -> Vec<u64> {
67 (0..num_hashes).map(|_| self.hash(vector)).collect()
68 }
69}
70
71struct RandomProjectionHash {
73 projections: Vec<Vec<f32>>,
74 dimensions: usize,
75}
76
77impl RandomProjectionHash {
78 fn new(dimensions: usize, num_projections: usize, seed: u64) -> Self {
79 let mut rng = Random::seed(seed);
80 let normal = Normal::new(0.0, 1.0).unwrap();
81
82 let mut projections = Vec::with_capacity(num_projections);
83 for _ in 0..num_projections {
84 let projection: Vec<f32> = (0..dimensions).map(|_| normal.sample(&mut rng)).collect();
85 projections.push(projection);
86 }
87
88 Self {
89 projections,
90 dimensions,
91 }
92 }
93}
94
95impl HashFunction for RandomProjectionHash {
96 fn hash(&self, vector: &[f32]) -> u64 {
97 let mut hash_value = 0u64;
98
99 for (i, projection) in self.projections.iter().enumerate() {
100 use oxirs_core::simd::SimdOps;
102 let dot_product = f32::dot(vector, projection);
103
104 if dot_product > 0.0 {
106 hash_value |= 1 << (i % 64);
107 }
108 }
109
110 hash_value
111 }
112}
113
114struct MinHashFunction {
116 a: Vec<u64>,
117 b: Vec<u64>,
118 prime: u64,
119}
120
121impl MinHashFunction {
122 fn new(num_hashes: usize, seed: u64) -> Self {
123 let mut rng = Random::seed(seed);
124 let prime = 4294967311u64; let a: Vec<u64> = (0..num_hashes)
127 .map(|_| rng.random_range(1, prime))
128 .collect();
129 let b: Vec<u64> = (0..num_hashes)
130 .map(|_| rng.random_range(0, prime))
131 .collect();
132
133 Self { a, b, prime }
134 }
135
136 fn minhash_signature(&self, set_elements: &[u32]) -> Vec<u64> {
137 let mut signature = vec![u64::MAX; self.a.len()];
138
139 for &element in set_elements {
140 for (i, sig_val) in signature.iter_mut().enumerate().take(self.a.len()) {
141 let hash = (self.a[i] * element as u64 + self.b[i]) % self.prime;
142 *sig_val = (*sig_val).min(hash);
143 }
144 }
145
146 signature
147 }
148}
149
150impl HashFunction for MinHashFunction {
151 fn hash(&self, vector: &[f32]) -> u64 {
152 let threshold = 0.0;
154 let set_elements: Vec<u32> = vector
155 .iter()
156 .enumerate()
157 .filter(|&(_, &v)| v > threshold)
158 .map(|(i, _)| i as u32)
159 .collect();
160
161 let signature = self.minhash_signature(&set_elements);
162
163 let mut hash = 0u64;
165 for (i, &sig) in signature.iter().enumerate() {
166 hash ^= sig.rotate_left((i * 7) as u32);
167 }
168
169 hash
170 }
171}
172
173struct SimHashFunction {
175 random_vectors: Vec<Vec<f32>>,
176}
177
178impl SimHashFunction {
179 fn new(dimensions: usize, seed: u64) -> Self {
180 let mut rng = Random::seed(seed);
181 let normal = Normal::new(0.0, 1.0).unwrap();
182
183 let random_vectors: Vec<Vec<f32>> = (0..64)
184 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
185 .collect();
186
187 Self { random_vectors }
188 }
189}
190
191impl HashFunction for SimHashFunction {
192 fn hash(&self, vector: &[f32]) -> u64 {
193 let mut hash = 0u64;
194
195 for (i, random_vec) in self.random_vectors.iter().enumerate() {
196 let mut sum = 0.0;
198 for (j, &v) in vector.iter().enumerate() {
199 if j < random_vec.len() {
200 sum += v * random_vec[j];
201 }
202 }
203
204 if sum > 0.0 {
205 hash |= 1 << i;
206 }
207 }
208
209 hash
210 }
211}
212
213struct PStableHash {
215 projections: Vec<Vec<f32>>,
216 offsets: Vec<f32>,
217 width: f32,
218 p: f32,
219}
220
221impl PStableHash {
222 fn new(dimensions: usize, num_projections: usize, width: f32, p: f32, seed: u64) -> Self {
223 let mut rng = Random::seed(seed);
224
225 let projections: Vec<Vec<f32>> = if (p - 1.0).abs() < 0.1 {
227 (0..num_projections)
229 .map(|_| {
230 (0..dimensions)
231 .map(|_| {
232 let u: f32 = rng
233 .gen_range(-std::f32::consts::PI / 2.0..std::f32::consts::PI / 2.0);
234 u.tan()
235 })
236 .collect()
237 })
238 .collect()
239 } else if (p - 2.0).abs() < 0.1 {
240 let normal = Normal::new(0.0, 1.0).unwrap();
242 (0..num_projections)
243 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
244 .collect()
245 } else {
246 let normal = Normal::new(0.0, 1.0).unwrap();
248 (0..num_projections)
249 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
250 .collect()
251 };
252
253 let offsets: Vec<f32> = (0..num_projections)
254 .map(|_| rng.gen_range(0.0..width))
255 .collect();
256
257 Self {
258 projections,
259 offsets,
260 width,
261 p,
262 }
263 }
264}
265
266impl HashFunction for PStableHash {
267 fn hash(&self, vector: &[f32]) -> u64 {
268 let mut hash = 0u64;
269
270 for (i, (projection, &offset)) in self.projections.iter().zip(&self.offsets).enumerate() {
271 use oxirs_core::simd::SimdOps;
272 let dot_product = f32::dot(vector, projection);
273 let bucket = ((dot_product + offset) / self.width).floor() as i32;
274
275 if bucket > 0 {
277 hash |= 1 << (i % 64);
278 }
279 }
280
281 hash
282 }
283}
284
285struct LshTable {
287 buckets: HashMap<u64, Vec<usize>>,
288 hash_function: Box<dyn HashFunction>,
289}
290
291impl LshTable {
292 fn new(hash_function: Box<dyn HashFunction>) -> Self {
293 Self {
294 buckets: HashMap::new(),
295 hash_function,
296 }
297 }
298
299 fn insert(&mut self, id: usize, vector: &[f32]) {
300 let hash = self.hash_function.hash(vector);
301 self.buckets.entry(hash).or_default().push(id);
302 }
303
304 fn query(&self, vector: &[f32]) -> Vec<usize> {
305 let hash = self.hash_function.hash(vector);
306 self.buckets.get(&hash).cloned().unwrap_or_default()
307 }
308
309 fn query_multi_probe(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
310 let main_hash = self.hash_function.hash(vector);
311 let mut candidates = HashSet::new();
312
313 if let Some(ids) = self.buckets.get(&main_hash) {
315 candidates.extend(ids);
316 }
317
318 for probe in 1..=num_probes {
320 for bit in 0..64 {
321 let probed_hash = main_hash ^ (1 << bit);
322 if let Some(ids) = self.buckets.get(&probed_hash) {
323 candidates.extend(ids);
324 }
325
326 if candidates.len() >= probe * 10 {
328 break;
329 }
330 }
331 }
332
333 candidates.into_iter().collect()
334 }
335}
336
337pub struct LshIndex {
339 config: LshConfig,
340 tables: Vec<LshTable>,
341 vectors: Vec<(String, Vector)>,
342 uri_to_id: HashMap<String, usize>,
343 dimensions: Option<usize>,
344}
345
346impl LshIndex {
347 pub fn new(config: LshConfig) -> Self {
349 let tables = Self::create_tables(&config, 0);
350
351 Self {
352 config,
353 tables,
354 vectors: Vec::new(),
355 uri_to_id: HashMap::new(),
356 dimensions: None,
357 }
358 }
359
360 fn create_tables(config: &LshConfig, dimensions: usize) -> Vec<LshTable> {
362 let mut tables = Vec::with_capacity(config.num_tables);
363
364 for table_idx in 0..config.num_tables {
365 let seed = config.seed.wrapping_add(table_idx as u64);
366
367 let hash_function: Box<dyn HashFunction> = match config.lsh_family {
368 LshFamily::RandomProjection => Box::new(RandomProjectionHash::new(
369 dimensions,
370 config.num_hash_functions,
371 seed,
372 )),
373 LshFamily::MinHash => {
374 Box::new(MinHashFunction::new(config.num_hash_functions, seed))
375 }
376 LshFamily::SimHash => Box::new(SimHashFunction::new(dimensions, seed)),
377 LshFamily::PStable(p) => {
378 Box::new(PStableHash::new(
379 dimensions,
380 config.num_hash_functions,
381 1.0, p,
383 seed,
384 ))
385 }
386 };
387
388 tables.push(LshTable::new(hash_function));
389 }
390
391 tables
392 }
393
394 fn rebuild_tables(&mut self) {
396 if let Some(dims) = self.dimensions {
397 self.tables = Self::create_tables(&self.config, dims);
398
399 for (id, (_, vector)) in self.vectors.iter().enumerate() {
401 let vector_f32 = vector.as_f32();
402 for table in &mut self.tables {
403 table.insert(id, &vector_f32);
404 }
405 }
406 }
407 }
408
409 fn query_candidates(&self, vector: &[f32]) -> Vec<usize> {
411 let mut candidates = HashSet::new();
412
413 if self.config.multi_probe {
414 for table in &self.tables {
416 let table_candidates = table.query_multi_probe(vector, self.config.num_probes);
417 candidates.extend(table_candidates);
418 }
419 } else {
420 for table in &self.tables {
422 let table_candidates = table.query(vector);
423 candidates.extend(table_candidates);
424 }
425 }
426
427 candidates.into_iter().collect()
428 }
429
430 pub fn stats(&self) -> LshStats {
432 let avg_bucket_size = if self.tables.is_empty() {
433 0.0
434 } else {
435 let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
436 let total_items: usize = self
437 .tables
438 .iter()
439 .map(|t| t.buckets.values().map(|v| v.len()).sum::<usize>())
440 .sum();
441
442 if total_buckets > 0 {
443 total_items as f64 / total_buckets as f64
444 } else {
445 0.0
446 }
447 };
448
449 LshStats {
450 num_vectors: self.vectors.len(),
451 num_tables: self.tables.len(),
452 avg_bucket_size,
453 memory_usage: self.estimate_memory_usage(),
454 }
455 }
456
457 fn estimate_memory_usage(&self) -> usize {
458 let vector_memory =
459 self.vectors.len() * (std::mem::size_of::<String>() + std::mem::size_of::<Vector>());
460
461 let table_memory: usize = self
462 .tables
463 .iter()
464 .map(|t| {
465 t.buckets.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<Vec<usize>>())
466 + t.buckets
467 .values()
468 .map(|v| v.len() * std::mem::size_of::<usize>())
469 .sum::<usize>()
470 })
471 .sum();
472
473 vector_memory + table_memory
474 }
475}
476
477impl VectorIndex for LshIndex {
478 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
479 if self.dimensions.is_none() {
481 self.dimensions = Some(vector.dimensions);
482 self.rebuild_tables();
483 } else if Some(vector.dimensions) != self.dimensions {
484 return Err(anyhow!(
485 "Vector dimensions ({}) don't match index dimensions ({:?})",
486 vector.dimensions,
487 self.dimensions
488 ));
489 }
490
491 let id = self.vectors.len();
492 let vector_f32 = vector.as_f32();
493
494 for table in &mut self.tables {
496 table.insert(id, &vector_f32);
497 }
498
499 self.uri_to_id.insert(uri.clone(), id);
500 self.vectors.push((uri, vector));
501
502 Ok(())
503 }
504
505 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
506 if self.vectors.is_empty() {
507 return Ok(Vec::new());
508 }
509
510 let query_f32 = query.as_f32();
511 let candidates = self.query_candidates(&query_f32);
512
513 let mut results: Vec<(usize, f32)> = candidates
515 .into_iter()
516 .filter_map(|id| {
517 self.vectors.get(id).map(|(_, vec)| {
518 let vec_f32 = vec.as_f32();
519 let distance = match self.config.lsh_family {
520 LshFamily::RandomProjection | LshFamily::SimHash => {
521 use oxirs_core::simd::SimdOps;
523 f32::cosine_distance(&query_f32, &vec_f32)
524 }
525 LshFamily::MinHash => {
526 let threshold = 0.0;
528 let set1: HashSet<usize> = query_f32
529 .iter()
530 .enumerate()
531 .filter(|&(_, &v)| v > threshold)
532 .map(|(i, _)| i)
533 .collect();
534 let set2: HashSet<usize> = vec_f32
535 .iter()
536 .enumerate()
537 .filter(|&(_, &v)| v > threshold)
538 .map(|(i, _)| i)
539 .collect();
540
541 let intersection = set1.intersection(&set2).count();
542 let union = set1.union(&set2).count();
543
544 if union > 0 {
545 1.0 - (intersection as f32 / union as f32)
546 } else {
547 1.0
548 }
549 }
550 LshFamily::PStable(p) => {
551 use oxirs_core::simd::SimdOps;
553 if (p - 1.0).abs() < 0.1 {
554 f32::manhattan_distance(&query_f32, &vec_f32)
555 } else if (p - 2.0).abs() < 0.1 {
556 f32::euclidean_distance(&query_f32, &vec_f32)
557 } else {
558 query_f32
560 .iter()
561 .zip(&vec_f32)
562 .map(|(a, b)| (a - b).abs().powf(p))
563 .sum::<f32>()
564 .powf(1.0 / p)
565 }
566 }
567 };
568 (id, distance)
569 })
570 })
571 .collect();
572
573 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
575 results.truncate(k);
576
577 Ok(results
579 .into_iter()
580 .map(|(id, dist)| (self.vectors[id].0.clone(), dist))
581 .collect())
582 }
583
584 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
585 if self.vectors.is_empty() {
586 return Ok(Vec::new());
587 }
588
589 let query_f32 = query.as_f32();
590 let candidates = self.query_candidates(&query_f32);
591
592 let mut results: Vec<(String, f32)> = candidates
594 .into_iter()
595 .filter_map(|id| {
596 self.vectors.get(id).and_then(|(uri, vec)| {
597 let vec_f32 = vec.as_f32();
598 let distance = match self.config.lsh_family {
599 LshFamily::RandomProjection | LshFamily::SimHash => {
600 use oxirs_core::simd::SimdOps;
601 f32::cosine_distance(&query_f32, &vec_f32)
602 }
603 LshFamily::MinHash => {
604 let threshold_val = 0.0;
606 let set1: HashSet<usize> = query_f32
607 .iter()
608 .enumerate()
609 .filter(|&(_, &v)| v > threshold_val)
610 .map(|(i, _)| i)
611 .collect();
612 let set2: HashSet<usize> = vec_f32
613 .iter()
614 .enumerate()
615 .filter(|&(_, &v)| v > threshold_val)
616 .map(|(i, _)| i)
617 .collect();
618
619 let intersection = set1.intersection(&set2).count();
620 let union = set1.union(&set2).count();
621
622 if union > 0 {
623 1.0 - (intersection as f32 / union as f32)
624 } else {
625 1.0
626 }
627 }
628 LshFamily::PStable(p) => {
629 use oxirs_core::simd::SimdOps;
630 if (p - 1.0).abs() < 0.1 {
631 f32::manhattan_distance(&query_f32, &vec_f32)
632 } else if (p - 2.0).abs() < 0.1 {
633 f32::euclidean_distance(&query_f32, &vec_f32)
634 } else {
635 query_f32
636 .iter()
637 .zip(&vec_f32)
638 .map(|(a, b)| (a - b).abs().powf(p))
639 .sum::<f32>()
640 .powf(1.0 / p)
641 }
642 }
643 };
644
645 if distance <= threshold {
646 Some((uri.clone(), distance))
647 } else {
648 None
649 }
650 })
651 })
652 .collect();
653
654 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
655 Ok(results)
656 }
657
658 fn get_vector(&self, uri: &str) -> Option<&Vector> {
659 self.uri_to_id
660 .get(uri)
661 .and_then(|&id| self.vectors.get(id))
662 .map(|(_, v)| v)
663 }
664}
665
666#[derive(Debug, Clone)]
668pub struct LshStats {
669 pub num_vectors: usize,
670 pub num_tables: usize,
671 pub avg_bucket_size: f64,
672 pub memory_usage: usize,
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678
679 #[test]
680 fn test_random_projection_lsh() {
681 let config = LshConfig {
682 num_tables: 5,
683 num_hash_functions: 4,
684 lsh_family: LshFamily::RandomProjection,
685 seed: 42,
686 multi_probe: false,
687 num_probes: 0,
688 };
689
690 let mut index = LshIndex::new(config);
691
692 let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
694 let v2 = Vector::new(vec![0.0, 1.0, 0.0]);
695 let v3 = Vector::new(vec![0.0, 0.0, 1.0]);
696 let v_similar = Vector::new(vec![0.9, 0.1, 0.0]); index.insert("v1".to_string(), v1.clone()).unwrap();
699 index.insert("v2".to_string(), v2.clone()).unwrap();
700 index.insert("v3".to_string(), v3.clone()).unwrap();
701 index
702 .insert("v_similar".to_string(), v_similar.clone())
703 .unwrap();
704
705 let results = index.search_knn(&v1, 2).unwrap();
707
708 assert!(results.len() <= 2);
709 assert!(results
711 .iter()
712 .any(|(uri, _)| uri == "v1" || uri == "v_similar"));
713 }
714
715 #[test]
716 fn test_minhash_lsh() {
717 let config = LshConfig {
718 num_tables: 3,
719 num_hash_functions: 64,
720 lsh_family: LshFamily::MinHash,
721 seed: 42,
722 multi_probe: false,
723 num_probes: 0,
724 };
725
726 let mut index = LshIndex::new(config);
727
728 let mut v1 = vec![0.0; 100];
730 v1[0] = 1.0;
731 v1[10] = 1.0;
732 v1[20] = 1.0;
733
734 let mut v2 = vec![0.0; 100];
735 v2[0] = 1.0;
736 v2[10] = 1.0;
737 v2[30] = 1.0; let mut v3 = vec![0.0; 100];
740 v3[50] = 1.0;
741 v3[60] = 1.0;
742 v3[70] = 1.0; index
745 .insert("v1".to_string(), Vector::new(v1.clone()))
746 .unwrap();
747 index.insert("v2".to_string(), Vector::new(v2)).unwrap();
748 index.insert("v3".to_string(), Vector::new(v3)).unwrap();
749
750 let results = index.search_knn(&Vector::new(v1), 2).unwrap();
752
753 assert!(!results.is_empty());
755 assert_eq!(results[0].0, "v1");
756 if results.len() > 1 {
757 assert_eq!(results[1].0, "v2");
758 }
759 }
760
761 #[test]
762 fn test_multi_probe_lsh() {
763 let config = LshConfig {
764 num_tables: 3,
765 num_hash_functions: 4,
766 lsh_family: LshFamily::RandomProjection,
767 seed: 42,
768 multi_probe: true,
769 num_probes: 2,
770 };
771
772 let mut index = LshIndex::new(config);
773
774 for i in 0..50 {
776 let angle = i as f32 * std::f32::consts::PI / 25.0;
777 let vec = Vector::new(vec![angle.cos(), angle.sin(), 0.0]);
778 index.insert(format!("v{i}"), vec).unwrap();
779 }
780
781 let query = Vector::new(vec![1.0, 0.0, 0.0]);
783 let results = index.search_knn(&query, 5).unwrap();
784
785 assert_eq!(results.len(), 5);
786 for i in 1..results.len() {
788 assert!(results[i - 1].1 <= results[i].1);
789 }
790 }
791}