1use crate::random_utils::NormalSampler as Normal;
10use crate::{Vector, VectorIndex};
11use anyhow::{anyhow, Result};
12#[allow(unused_imports)]
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LshConfig {
20 pub num_tables: usize,
22 pub num_hash_functions: usize,
24 pub lsh_family: LshFamily,
26 pub seed: u64,
28 pub multi_probe: bool,
30 pub num_probes: usize,
32}
33
34impl Default for LshConfig {
35 fn default() -> Self {
36 Self {
37 num_tables: 10,
38 num_hash_functions: 8,
39 lsh_family: LshFamily::RandomProjection,
40 seed: 42,
41 multi_probe: true,
42 num_probes: 3,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
49pub enum LshFamily {
50 RandomProjection,
52 MinHash,
54 SimHash,
56 PStable(f32), }
59
60trait HashFunction: Send + Sync {
62 fn hash(&self, vector: &[f32]) -> u64;
64
65 fn hash_multi(&self, vector: &[f32], num_hashes: usize) -> Vec<u64> {
67 (0..num_hashes).map(|_| self.hash(vector)).collect()
68 }
69}
70
71struct RandomProjectionHash {
73 projections: Vec<Vec<f32>>,
74 dimensions: usize,
75}
76
77impl RandomProjectionHash {
78 fn new(dimensions: usize, num_projections: usize, seed: u64) -> Self {
79 let mut rng = Random::seed(seed);
80 let normal =
81 Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
82
83 let mut projections = Vec::with_capacity(num_projections);
84 for _ in 0..num_projections {
85 let projection: Vec<f32> = (0..dimensions).map(|_| normal.sample(&mut rng)).collect();
86 projections.push(projection);
87 }
88
89 Self {
90 projections,
91 dimensions,
92 }
93 }
94}
95
96impl HashFunction for RandomProjectionHash {
97 fn hash(&self, vector: &[f32]) -> u64 {
98 let mut hash_value = 0u64;
99
100 for (i, projection) in self.projections.iter().enumerate() {
101 use oxirs_core::simd::SimdOps;
103 let dot_product = f32::dot(vector, projection);
104
105 if dot_product > 0.0 {
107 hash_value |= 1 << (i % 64);
108 }
109 }
110
111 hash_value
112 }
113}
114
115struct MinHashFunction {
117 a: Vec<u64>,
118 b: Vec<u64>,
119 prime: u64,
120}
121
122impl MinHashFunction {
123 fn new(num_hashes: usize, seed: u64) -> Self {
124 let mut rng = Random::seed(seed);
125 let prime = 4294967311u64; let a: Vec<u64> = (0..num_hashes)
128 .map(|_| rng.random_range(1..prime))
129 .collect();
130 let b: Vec<u64> = (0..num_hashes)
131 .map(|_| rng.random_range(0..prime))
132 .collect();
133
134 Self { a, b, prime }
135 }
136
137 fn minhash_signature(&self, set_elements: &[u32]) -> Vec<u64> {
138 let mut signature = vec![u64::MAX; self.a.len()];
139
140 for &element in set_elements {
141 for (i, sig_val) in signature.iter_mut().enumerate().take(self.a.len()) {
142 let hash = (self.a[i] * element as u64 + self.b[i]) % self.prime;
143 *sig_val = (*sig_val).min(hash);
144 }
145 }
146
147 signature
148 }
149}
150
151impl HashFunction for MinHashFunction {
152 fn hash(&self, vector: &[f32]) -> u64 {
153 let threshold = 0.0;
155 let set_elements: Vec<u32> = vector
156 .iter()
157 .enumerate()
158 .filter(|&(_, &v)| v > threshold)
159 .map(|(i, _)| i as u32)
160 .collect();
161
162 let signature = self.minhash_signature(&set_elements);
163
164 let mut hash = 0u64;
166 for (i, &sig) in signature.iter().enumerate() {
167 hash ^= sig.rotate_left((i * 7) as u32);
168 }
169
170 hash
171 }
172}
173
174struct SimHashFunction {
176 random_vectors: Vec<Vec<f32>>,
177}
178
179impl SimHashFunction {
180 fn new(dimensions: usize, seed: u64) -> Self {
181 let mut rng = Random::seed(seed);
182 let normal =
183 Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
184
185 let random_vectors: Vec<Vec<f32>> = (0..64)
186 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
187 .collect();
188
189 Self { random_vectors }
190 }
191}
192
193impl HashFunction for SimHashFunction {
194 fn hash(&self, vector: &[f32]) -> u64 {
195 let mut hash = 0u64;
196
197 for (i, random_vec) in self.random_vectors.iter().enumerate() {
198 let mut sum = 0.0;
200 for (j, &v) in vector.iter().enumerate() {
201 if j < random_vec.len() {
202 sum += v * random_vec[j];
203 }
204 }
205
206 if sum > 0.0 {
207 hash |= 1 << i;
208 }
209 }
210
211 hash
212 }
213}
214
215struct PStableHash {
217 projections: Vec<Vec<f32>>,
218 offsets: Vec<f32>,
219 width: f32,
220 p: f32,
221}
222
223impl PStableHash {
224 fn new(dimensions: usize, num_projections: usize, width: f32, p: f32, seed: u64) -> Self {
225 let mut rng = Random::seed(seed);
226
227 let projections: Vec<Vec<f32>> = if (p - 1.0).abs() < 0.1 {
229 (0..num_projections)
231 .map(|_| {
232 (0..dimensions)
233 .map(|_| {
234 let u: f32 = rng
235 .gen_range(-std::f32::consts::PI / 2.0..std::f32::consts::PI / 2.0);
236 u.tan()
237 })
238 .collect()
239 })
240 .collect()
241 } else if (p - 2.0).abs() < 0.1 {
242 let normal =
244 Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
245 (0..num_projections)
246 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
247 .collect()
248 } else {
249 let normal =
251 Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
252 (0..num_projections)
253 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
254 .collect()
255 };
256
257 let offsets: Vec<f32> = (0..num_projections)
258 .map(|_| rng.gen_range(0.0..width))
259 .collect();
260
261 Self {
262 projections,
263 offsets,
264 width,
265 p,
266 }
267 }
268}
269
270impl HashFunction for PStableHash {
271 fn hash(&self, vector: &[f32]) -> u64 {
272 let mut hash = 0u64;
273
274 for (i, (projection, &offset)) in self.projections.iter().zip(&self.offsets).enumerate() {
275 use oxirs_core::simd::SimdOps;
276 let dot_product = f32::dot(vector, projection);
277 let bucket = ((dot_product + offset) / self.width).floor() as i32;
278
279 if bucket > 0 {
281 hash |= 1 << (i % 64);
282 }
283 }
284
285 hash
286 }
287}
288
289struct LshTable {
291 buckets: HashMap<u64, Vec<usize>>,
292 hash_function: Box<dyn HashFunction>,
293}
294
295impl LshTable {
296 fn new(hash_function: Box<dyn HashFunction>) -> Self {
297 Self {
298 buckets: HashMap::new(),
299 hash_function,
300 }
301 }
302
303 fn insert(&mut self, id: usize, vector: &[f32]) {
304 let hash = self.hash_function.hash(vector);
305 self.buckets.entry(hash).or_default().push(id);
306 }
307
308 fn query(&self, vector: &[f32]) -> Vec<usize> {
309 let hash = self.hash_function.hash(vector);
310 self.buckets.get(&hash).cloned().unwrap_or_default()
311 }
312
313 fn query_multi_probe(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
314 let main_hash = self.hash_function.hash(vector);
315 let mut candidates = HashSet::new();
316
317 if let Some(ids) = self.buckets.get(&main_hash) {
319 candidates.extend(ids);
320 }
321
322 for probe in 1..=num_probes {
324 for bit in 0..64 {
325 let probed_hash = main_hash ^ (1 << bit);
326 if let Some(ids) = self.buckets.get(&probed_hash) {
327 candidates.extend(ids);
328 }
329
330 if candidates.len() >= probe * 10 {
332 break;
333 }
334 }
335 }
336
337 candidates.into_iter().collect()
338 }
339}
340
341pub struct LshIndex {
343 config: LshConfig,
344 tables: Vec<LshTable>,
345 vectors: Vec<(String, Vector)>,
346 uri_to_id: HashMap<String, usize>,
347 dimensions: Option<usize>,
348}
349
350impl LshIndex {
351 pub fn new(config: LshConfig) -> Self {
353 let tables = Self::create_tables(&config, 0);
354
355 Self {
356 config,
357 tables,
358 vectors: Vec::new(),
359 uri_to_id: HashMap::new(),
360 dimensions: None,
361 }
362 }
363
364 fn create_tables(config: &LshConfig, dimensions: usize) -> Vec<LshTable> {
366 let mut tables = Vec::with_capacity(config.num_tables);
367
368 for table_idx in 0..config.num_tables {
369 let seed = config.seed.wrapping_add(table_idx as u64);
370
371 let hash_function: Box<dyn HashFunction> = match config.lsh_family {
372 LshFamily::RandomProjection => Box::new(RandomProjectionHash::new(
373 dimensions,
374 config.num_hash_functions,
375 seed,
376 )),
377 LshFamily::MinHash => {
378 Box::new(MinHashFunction::new(config.num_hash_functions, seed))
379 }
380 LshFamily::SimHash => Box::new(SimHashFunction::new(dimensions, seed)),
381 LshFamily::PStable(p) => {
382 Box::new(PStableHash::new(
383 dimensions,
384 config.num_hash_functions,
385 1.0, p,
387 seed,
388 ))
389 }
390 };
391
392 tables.push(LshTable::new(hash_function));
393 }
394
395 tables
396 }
397
398 fn rebuild_tables(&mut self) {
400 if let Some(dims) = self.dimensions {
401 self.tables = Self::create_tables(&self.config, dims);
402
403 for (id, (_, vector)) in self.vectors.iter().enumerate() {
405 let vector_f32 = vector.as_f32();
406 for table in &mut self.tables {
407 table.insert(id, &vector_f32);
408 }
409 }
410 }
411 }
412
413 fn query_candidates(&self, vector: &[f32]) -> Vec<usize> {
415 let mut candidates = HashSet::new();
416
417 if self.config.multi_probe {
418 for table in &self.tables {
420 let table_candidates = table.query_multi_probe(vector, self.config.num_probes);
421 candidates.extend(table_candidates);
422 }
423 } else {
424 for table in &self.tables {
426 let table_candidates = table.query(vector);
427 candidates.extend(table_candidates);
428 }
429 }
430
431 candidates.into_iter().collect()
432 }
433
434 pub fn stats(&self) -> LshStats {
436 let avg_bucket_size = if self.tables.is_empty() {
437 0.0
438 } else {
439 let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
440 let total_items: usize = self
441 .tables
442 .iter()
443 .map(|t| t.buckets.values().map(|v| v.len()).sum::<usize>())
444 .sum();
445
446 if total_buckets > 0 {
447 total_items as f64 / total_buckets as f64
448 } else {
449 0.0
450 }
451 };
452
453 LshStats {
454 num_vectors: self.vectors.len(),
455 num_tables: self.tables.len(),
456 avg_bucket_size,
457 memory_usage: self.estimate_memory_usage(),
458 }
459 }
460
461 fn estimate_memory_usage(&self) -> usize {
462 let vector_memory =
463 self.vectors.len() * (std::mem::size_of::<String>() + std::mem::size_of::<Vector>());
464
465 let table_memory: usize = self
466 .tables
467 .iter()
468 .map(|t| {
469 t.buckets.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<Vec<usize>>())
470 + t.buckets
471 .values()
472 .map(|v| v.len() * std::mem::size_of::<usize>())
473 .sum::<usize>()
474 })
475 .sum();
476
477 vector_memory + table_memory
478 }
479}
480
481impl VectorIndex for LshIndex {
482 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
483 if self.dimensions.is_none() {
485 self.dimensions = Some(vector.dimensions);
486 self.rebuild_tables();
487 } else if Some(vector.dimensions) != self.dimensions {
488 return Err(anyhow!(
489 "Vector dimensions ({}) don't match index dimensions ({:?})",
490 vector.dimensions,
491 self.dimensions
492 ));
493 }
494
495 let id = self.vectors.len();
496 let vector_f32 = vector.as_f32();
497
498 for table in &mut self.tables {
500 table.insert(id, &vector_f32);
501 }
502
503 self.uri_to_id.insert(uri.clone(), id);
504 self.vectors.push((uri, vector));
505
506 Ok(())
507 }
508
509 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
510 if self.vectors.is_empty() {
511 return Ok(Vec::new());
512 }
513
514 let query_f32 = query.as_f32();
515 let candidates = self.query_candidates(&query_f32);
516
517 let mut results: Vec<(usize, f32)> = candidates
519 .into_iter()
520 .filter_map(|id| {
521 self.vectors.get(id).map(|(_, vec)| {
522 let vec_f32 = vec.as_f32();
523 let distance = match self.config.lsh_family {
524 LshFamily::RandomProjection | LshFamily::SimHash => {
525 use oxirs_core::simd::SimdOps;
527 f32::cosine_distance(&query_f32, &vec_f32)
528 }
529 LshFamily::MinHash => {
530 let threshold = 0.0;
532 let set1: HashSet<usize> = query_f32
533 .iter()
534 .enumerate()
535 .filter(|&(_, &v)| v > threshold)
536 .map(|(i, _)| i)
537 .collect();
538 let set2: HashSet<usize> = vec_f32
539 .iter()
540 .enumerate()
541 .filter(|&(_, &v)| v > threshold)
542 .map(|(i, _)| i)
543 .collect();
544
545 let intersection = set1.intersection(&set2).count();
546 let union = set1.union(&set2).count();
547
548 if union > 0 {
549 1.0 - (intersection as f32 / union as f32)
550 } else {
551 1.0
552 }
553 }
554 LshFamily::PStable(p) => {
555 use oxirs_core::simd::SimdOps;
557 if (p - 1.0).abs() < 0.1 {
558 f32::manhattan_distance(&query_f32, &vec_f32)
559 } else if (p - 2.0).abs() < 0.1 {
560 f32::euclidean_distance(&query_f32, &vec_f32)
561 } else {
562 query_f32
564 .iter()
565 .zip(&vec_f32)
566 .map(|(a, b)| (a - b).abs().powf(p))
567 .sum::<f32>()
568 .powf(1.0 / p)
569 }
570 }
571 };
572 (id, distance)
573 })
574 })
575 .collect();
576
577 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
579 results.truncate(k);
580
581 Ok(results
583 .into_iter()
584 .map(|(id, dist)| (self.vectors[id].0.clone(), dist))
585 .collect())
586 }
587
588 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
589 if self.vectors.is_empty() {
590 return Ok(Vec::new());
591 }
592
593 let query_f32 = query.as_f32();
594 let candidates = self.query_candidates(&query_f32);
595
596 let mut results: Vec<(String, f32)> = candidates
598 .into_iter()
599 .filter_map(|id| {
600 self.vectors.get(id).and_then(|(uri, vec)| {
601 let vec_f32 = vec.as_f32();
602 let distance = match self.config.lsh_family {
603 LshFamily::RandomProjection | LshFamily::SimHash => {
604 use oxirs_core::simd::SimdOps;
605 f32::cosine_distance(&query_f32, &vec_f32)
606 }
607 LshFamily::MinHash => {
608 let threshold_val = 0.0;
610 let set1: HashSet<usize> = query_f32
611 .iter()
612 .enumerate()
613 .filter(|&(_, &v)| v > threshold_val)
614 .map(|(i, _)| i)
615 .collect();
616 let set2: HashSet<usize> = vec_f32
617 .iter()
618 .enumerate()
619 .filter(|&(_, &v)| v > threshold_val)
620 .map(|(i, _)| i)
621 .collect();
622
623 let intersection = set1.intersection(&set2).count();
624 let union = set1.union(&set2).count();
625
626 if union > 0 {
627 1.0 - (intersection as f32 / union as f32)
628 } else {
629 1.0
630 }
631 }
632 LshFamily::PStable(p) => {
633 use oxirs_core::simd::SimdOps;
634 if (p - 1.0).abs() < 0.1 {
635 f32::manhattan_distance(&query_f32, &vec_f32)
636 } else if (p - 2.0).abs() < 0.1 {
637 f32::euclidean_distance(&query_f32, &vec_f32)
638 } else {
639 query_f32
640 .iter()
641 .zip(&vec_f32)
642 .map(|(a, b)| (a - b).abs().powf(p))
643 .sum::<f32>()
644 .powf(1.0 / p)
645 }
646 }
647 };
648
649 if distance <= threshold {
650 Some((uri.clone(), distance))
651 } else {
652 None
653 }
654 })
655 })
656 .collect();
657
658 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
659 Ok(results)
660 }
661
662 fn get_vector(&self, uri: &str) -> Option<&Vector> {
663 self.uri_to_id
664 .get(uri)
665 .and_then(|&id| self.vectors.get(id))
666 .map(|(_, v)| v)
667 }
668}
669
670#[derive(Debug, Clone)]
672pub struct LshStats {
673 pub num_vectors: usize,
674 pub num_tables: usize,
675 pub avg_bucket_size: f64,
676 pub memory_usage: usize,
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn test_random_projection_lsh() {
685 let config = LshConfig {
686 num_tables: 5,
687 num_hash_functions: 4,
688 lsh_family: LshFamily::RandomProjection,
689 seed: 42,
690 multi_probe: false,
691 num_probes: 0,
692 };
693
694 let mut index = LshIndex::new(config);
695
696 let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
698 let v2 = Vector::new(vec![0.0, 1.0, 0.0]);
699 let v3 = Vector::new(vec![0.0, 0.0, 1.0]);
700 let v_similar = Vector::new(vec![0.9, 0.1, 0.0]); index.insert("v1".to_string(), v1.clone()).unwrap();
703 index.insert("v2".to_string(), v2.clone()).unwrap();
704 index.insert("v3".to_string(), v3.clone()).unwrap();
705 index
706 .insert("v_similar".to_string(), v_similar.clone())
707 .unwrap();
708
709 let results = index.search_knn(&v1, 2).unwrap();
711
712 assert!(results.len() <= 2);
713 assert!(results
715 .iter()
716 .any(|(uri, _)| uri == "v1" || uri == "v_similar"));
717 }
718
719 #[test]
720 fn test_minhash_lsh() {
721 let config = LshConfig {
722 num_tables: 3,
723 num_hash_functions: 64,
724 lsh_family: LshFamily::MinHash,
725 seed: 42,
726 multi_probe: false,
727 num_probes: 0,
728 };
729
730 let mut index = LshIndex::new(config);
731
732 let mut v1 = vec![0.0; 100];
734 v1[0] = 1.0;
735 v1[10] = 1.0;
736 v1[20] = 1.0;
737
738 let mut v2 = vec![0.0; 100];
739 v2[0] = 1.0;
740 v2[10] = 1.0;
741 v2[30] = 1.0; let mut v3 = vec![0.0; 100];
744 v3[50] = 1.0;
745 v3[60] = 1.0;
746 v3[70] = 1.0; index
749 .insert("v1".to_string(), Vector::new(v1.clone()))
750 .unwrap();
751 index.insert("v2".to_string(), Vector::new(v2)).unwrap();
752 index.insert("v3".to_string(), Vector::new(v3)).unwrap();
753
754 let results = index.search_knn(&Vector::new(v1), 2).unwrap();
756
757 assert!(!results.is_empty());
759 assert_eq!(results[0].0, "v1");
760 if results.len() > 1 {
761 assert_eq!(results[1].0, "v2");
762 }
763 }
764
765 #[test]
766 fn test_multi_probe_lsh() {
767 let config = LshConfig {
768 num_tables: 3,
769 num_hash_functions: 4,
770 lsh_family: LshFamily::RandomProjection,
771 seed: 42,
772 multi_probe: true,
773 num_probes: 2,
774 };
775
776 let mut index = LshIndex::new(config);
777
778 for i in 0..50 {
780 let angle = i as f32 * std::f32::consts::PI / 25.0;
781 let vec = Vector::new(vec![angle.cos(), angle.sin(), 0.0]);
782 index.insert(format!("v{i}"), vec).unwrap();
783 }
784
785 let query = Vector::new(vec![1.0, 0.0, 0.0]);
787 let results = index.search_knn(&query, 5).unwrap();
788
789 assert_eq!(results.len(), 5);
790 for i in 1..results.len() {
792 assert!(results[i - 1].1 <= results[i].1);
793 }
794 }
795}