1use crate::{Vector, VectorIndex};
10use anyhow::{anyhow, Result};
11use crate::random_utils::NormalSampler as Normal;
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LshConfig {
19 pub num_tables: usize,
21 pub num_hash_functions: usize,
23 pub lsh_family: LshFamily,
25 pub seed: u64,
27 pub multi_probe: bool,
29 pub num_probes: usize,
31}
32
33impl Default for LshConfig {
34 fn default() -> Self {
35 Self {
36 num_tables: 10,
37 num_hash_functions: 8,
38 lsh_family: LshFamily::RandomProjection,
39 seed: 42,
40 multi_probe: true,
41 num_probes: 3,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
48pub enum LshFamily {
49 RandomProjection,
51 MinHash,
53 SimHash,
55 PStable(f32), }
58
59trait HashFunction: Send + Sync {
61 fn hash(&self, vector: &[f32]) -> u64;
63
64 fn hash_multi(&self, vector: &[f32], num_hashes: usize) -> Vec<u64> {
66 (0..num_hashes).map(|_| self.hash(vector)).collect()
67 }
68}
69
70struct RandomProjectionHash {
72 projections: Vec<Vec<f32>>,
73 dimensions: usize,
74}
75
76impl RandomProjectionHash {
77 fn new(dimensions: usize, num_projections: usize, seed: u64) -> Self {
78 let mut rng = Random::seed(seed);
79 let normal = Normal::new(0.0, 1.0).unwrap();
80
81 let mut projections = Vec::with_capacity(num_projections);
82 for _ in 0..num_projections {
83 let projection: Vec<f32> = (0..dimensions).map(|_| normal.sample(&mut rng)).collect();
84 projections.push(projection);
85 }
86
87 Self {
88 projections,
89 dimensions,
90 }
91 }
92}
93
94impl HashFunction for RandomProjectionHash {
95 fn hash(&self, vector: &[f32]) -> u64 {
96 let mut hash_value = 0u64;
97
98 for (i, projection) in self.projections.iter().enumerate() {
99 use oxirs_core::simd::SimdOps;
101 let dot_product = f32::dot(vector, projection);
102
103 if dot_product > 0.0 {
105 hash_value |= 1 << (i % 64);
106 }
107 }
108
109 hash_value
110 }
111}
112
113struct MinHashFunction {
115 a: Vec<u64>,
116 b: Vec<u64>,
117 prime: u64,
118}
119
120impl MinHashFunction {
121 fn new(num_hashes: usize, seed: u64) -> Self {
122 let mut rng = Random::seed(seed);
123 let prime = 4294967311u64; let a: Vec<u64> = (0..num_hashes).map(|_| rng.gen_range(1..prime)).collect();
126 let b: Vec<u64> = (0..num_hashes).map(|_| rng.gen_range(0..prime)).collect();
127
128 Self { a, b, prime }
129 }
130
131 fn minhash_signature(&self, set_elements: &[u32]) -> Vec<u64> {
132 let mut signature = vec![u64::MAX; self.a.len()];
133
134 for &element in set_elements {
135 for (i, sig_val) in signature.iter_mut().enumerate().take(self.a.len()) {
136 let hash = (self.a[i] * element as u64 + self.b[i]) % self.prime;
137 *sig_val = (*sig_val).min(hash);
138 }
139 }
140
141 signature
142 }
143}
144
145impl HashFunction for MinHashFunction {
146 fn hash(&self, vector: &[f32]) -> u64 {
147 let threshold = 0.0;
149 let set_elements: Vec<u32> = vector
150 .iter()
151 .enumerate()
152 .filter(|&(_, &v)| v > threshold)
153 .map(|(i, _)| i as u32)
154 .collect();
155
156 let signature = self.minhash_signature(&set_elements);
157
158 let mut hash = 0u64;
160 for (i, &sig) in signature.iter().enumerate() {
161 hash ^= sig.rotate_left((i * 7) as u32);
162 }
163
164 hash
165 }
166}
167
168struct SimHashFunction {
170 random_vectors: Vec<Vec<f32>>,
171}
172
173impl SimHashFunction {
174 fn new(dimensions: usize, seed: u64) -> Self {
175 let mut rng = Random::seed(seed);
176 let normal = Normal::new(0.0, 1.0).unwrap();
177
178 let random_vectors: Vec<Vec<f32>> = (0..64)
179 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
180 .collect();
181
182 Self { random_vectors }
183 }
184}
185
186impl HashFunction for SimHashFunction {
187 fn hash(&self, vector: &[f32]) -> u64 {
188 let mut hash = 0u64;
189
190 for (i, random_vec) in self.random_vectors.iter().enumerate() {
191 let mut sum = 0.0;
193 for (j, &v) in vector.iter().enumerate() {
194 if j < random_vec.len() {
195 sum += v * random_vec[j];
196 }
197 }
198
199 if sum > 0.0 {
200 hash |= 1 << i;
201 }
202 }
203
204 hash
205 }
206}
207
208struct PStableHash {
210 projections: Vec<Vec<f32>>,
211 offsets: Vec<f32>,
212 width: f32,
213 p: f32,
214}
215
216impl PStableHash {
217 fn new(dimensions: usize, num_projections: usize, width: f32, p: f32, seed: u64) -> Self {
218 let mut rng = Random::seed(seed);
219
220 let projections: Vec<Vec<f32>> = if (p - 1.0).abs() < 0.1 {
222 (0..num_projections)
224 .map(|_| {
225 (0..dimensions)
226 .map(|_| {
227 let u: f32 = rng
228 .gen_range(-std::f32::consts::PI / 2.0..std::f32::consts::PI / 2.0);
229 u.tan()
230 })
231 .collect()
232 })
233 .collect()
234 } else if (p - 2.0).abs() < 0.1 {
235 let normal = Normal::new(0.0, 1.0).unwrap();
237 (0..num_projections)
238 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
239 .collect()
240 } else {
241 let normal = Normal::new(0.0, 1.0).unwrap();
243 (0..num_projections)
244 .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
245 .collect()
246 };
247
248 let offsets: Vec<f32> = (0..num_projections)
249 .map(|_| rng.gen_range(0.0..width))
250 .collect();
251
252 Self {
253 projections,
254 offsets,
255 width,
256 p,
257 }
258 }
259}
260
261impl HashFunction for PStableHash {
262 fn hash(&self, vector: &[f32]) -> u64 {
263 let mut hash = 0u64;
264
265 for (i, (projection, &offset)) in self.projections.iter().zip(&self.offsets).enumerate() {
266 use oxirs_core::simd::SimdOps;
267 let dot_product = f32::dot(vector, projection);
268 let bucket = ((dot_product + offset) / self.width).floor() as i32;
269
270 if bucket > 0 {
272 hash |= 1 << (i % 64);
273 }
274 }
275
276 hash
277 }
278}
279
280struct LshTable {
282 buckets: HashMap<u64, Vec<usize>>,
283 hash_function: Box<dyn HashFunction>,
284}
285
286impl LshTable {
287 fn new(hash_function: Box<dyn HashFunction>) -> Self {
288 Self {
289 buckets: HashMap::new(),
290 hash_function,
291 }
292 }
293
294 fn insert(&mut self, id: usize, vector: &[f32]) {
295 let hash = self.hash_function.hash(vector);
296 self.buckets.entry(hash).or_default().push(id);
297 }
298
299 fn query(&self, vector: &[f32]) -> Vec<usize> {
300 let hash = self.hash_function.hash(vector);
301 self.buckets.get(&hash).cloned().unwrap_or_default()
302 }
303
304 fn query_multi_probe(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
305 let main_hash = self.hash_function.hash(vector);
306 let mut candidates = HashSet::new();
307
308 if let Some(ids) = self.buckets.get(&main_hash) {
310 candidates.extend(ids);
311 }
312
313 for probe in 1..=num_probes {
315 for bit in 0..64 {
316 let probed_hash = main_hash ^ (1 << bit);
317 if let Some(ids) = self.buckets.get(&probed_hash) {
318 candidates.extend(ids);
319 }
320
321 if candidates.len() >= probe * 10 {
323 break;
324 }
325 }
326 }
327
328 candidates.into_iter().collect()
329 }
330}
331
332pub struct LshIndex {
334 config: LshConfig,
335 tables: Vec<LshTable>,
336 vectors: Vec<(String, Vector)>,
337 uri_to_id: HashMap<String, usize>,
338 dimensions: Option<usize>,
339}
340
341impl LshIndex {
342 pub fn new(config: LshConfig) -> Self {
344 let tables = Self::create_tables(&config, 0);
345
346 Self {
347 config,
348 tables,
349 vectors: Vec::new(),
350 uri_to_id: HashMap::new(),
351 dimensions: None,
352 }
353 }
354
355 fn create_tables(config: &LshConfig, dimensions: usize) -> Vec<LshTable> {
357 let mut tables = Vec::with_capacity(config.num_tables);
358
359 for table_idx in 0..config.num_tables {
360 let seed = config.seed.wrapping_add(table_idx as u64);
361
362 let hash_function: Box<dyn HashFunction> = match config.lsh_family {
363 LshFamily::RandomProjection => Box::new(RandomProjectionHash::new(
364 dimensions,
365 config.num_hash_functions,
366 seed,
367 )),
368 LshFamily::MinHash => {
369 Box::new(MinHashFunction::new(config.num_hash_functions, seed))
370 }
371 LshFamily::SimHash => Box::new(SimHashFunction::new(dimensions, seed)),
372 LshFamily::PStable(p) => {
373 Box::new(PStableHash::new(
374 dimensions,
375 config.num_hash_functions,
376 1.0, p,
378 seed,
379 ))
380 }
381 };
382
383 tables.push(LshTable::new(hash_function));
384 }
385
386 tables
387 }
388
389 fn rebuild_tables(&mut self) {
391 if let Some(dims) = self.dimensions {
392 self.tables = Self::create_tables(&self.config, dims);
393
394 for (id, (_, vector)) in self.vectors.iter().enumerate() {
396 let vector_f32 = vector.as_f32();
397 for table in &mut self.tables {
398 table.insert(id, &vector_f32);
399 }
400 }
401 }
402 }
403
404 fn query_candidates(&self, vector: &[f32]) -> Vec<usize> {
406 let mut candidates = HashSet::new();
407
408 if self.config.multi_probe {
409 for table in &self.tables {
411 let table_candidates = table.query_multi_probe(vector, self.config.num_probes);
412 candidates.extend(table_candidates);
413 }
414 } else {
415 for table in &self.tables {
417 let table_candidates = table.query(vector);
418 candidates.extend(table_candidates);
419 }
420 }
421
422 candidates.into_iter().collect()
423 }
424
425 pub fn stats(&self) -> LshStats {
427 let avg_bucket_size = if self.tables.is_empty() {
428 0.0
429 } else {
430 let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
431 let total_items: usize = self
432 .tables
433 .iter()
434 .map(|t| t.buckets.values().map(|v| v.len()).sum::<usize>())
435 .sum();
436
437 if total_buckets > 0 {
438 total_items as f64 / total_buckets as f64
439 } else {
440 0.0
441 }
442 };
443
444 LshStats {
445 num_vectors: self.vectors.len(),
446 num_tables: self.tables.len(),
447 avg_bucket_size,
448 memory_usage: self.estimate_memory_usage(),
449 }
450 }
451
452 fn estimate_memory_usage(&self) -> usize {
453 let vector_memory =
454 self.vectors.len() * (std::mem::size_of::<String>() + std::mem::size_of::<Vector>());
455
456 let table_memory: usize = self
457 .tables
458 .iter()
459 .map(|t| {
460 t.buckets.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<Vec<usize>>())
461 + t.buckets
462 .values()
463 .map(|v| v.len() * std::mem::size_of::<usize>())
464 .sum::<usize>()
465 })
466 .sum();
467
468 vector_memory + table_memory
469 }
470}
471
472impl VectorIndex for LshIndex {
473 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
474 if self.dimensions.is_none() {
476 self.dimensions = Some(vector.dimensions);
477 self.rebuild_tables();
478 } else if Some(vector.dimensions) != self.dimensions {
479 return Err(anyhow!(
480 "Vector dimensions ({}) don't match index dimensions ({:?})",
481 vector.dimensions,
482 self.dimensions
483 ));
484 }
485
486 let id = self.vectors.len();
487 let vector_f32 = vector.as_f32();
488
489 for table in &mut self.tables {
491 table.insert(id, &vector_f32);
492 }
493
494 self.uri_to_id.insert(uri.clone(), id);
495 self.vectors.push((uri, vector));
496
497 Ok(())
498 }
499
500 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
501 if self.vectors.is_empty() {
502 return Ok(Vec::new());
503 }
504
505 let query_f32 = query.as_f32();
506 let candidates = self.query_candidates(&query_f32);
507
508 let mut results: Vec<(usize, f32)> = candidates
510 .into_iter()
511 .filter_map(|id| {
512 self.vectors.get(id).map(|(_, vec)| {
513 let vec_f32 = vec.as_f32();
514 let distance = match self.config.lsh_family {
515 LshFamily::RandomProjection | LshFamily::SimHash => {
516 use oxirs_core::simd::SimdOps;
518 f32::cosine_distance(&query_f32, &vec_f32)
519 }
520 LshFamily::MinHash => {
521 let threshold = 0.0;
523 let set1: HashSet<usize> = query_f32
524 .iter()
525 .enumerate()
526 .filter(|&(_, &v)| v > threshold)
527 .map(|(i, _)| i)
528 .collect();
529 let set2: HashSet<usize> = vec_f32
530 .iter()
531 .enumerate()
532 .filter(|&(_, &v)| v > threshold)
533 .map(|(i, _)| i)
534 .collect();
535
536 let intersection = set1.intersection(&set2).count();
537 let union = set1.union(&set2).count();
538
539 if union > 0 {
540 1.0 - (intersection as f32 / union as f32)
541 } else {
542 1.0
543 }
544 }
545 LshFamily::PStable(p) => {
546 use oxirs_core::simd::SimdOps;
548 if (p - 1.0).abs() < 0.1 {
549 f32::manhattan_distance(&query_f32, &vec_f32)
550 } else if (p - 2.0).abs() < 0.1 {
551 f32::euclidean_distance(&query_f32, &vec_f32)
552 } else {
553 query_f32
555 .iter()
556 .zip(&vec_f32)
557 .map(|(a, b)| (a - b).abs().powf(p))
558 .sum::<f32>()
559 .powf(1.0 / p)
560 }
561 }
562 };
563 (id, distance)
564 })
565 })
566 .collect();
567
568 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
570 results.truncate(k);
571
572 Ok(results
574 .into_iter()
575 .map(|(id, dist)| (self.vectors[id].0.clone(), dist))
576 .collect())
577 }
578
579 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
580 if self.vectors.is_empty() {
581 return Ok(Vec::new());
582 }
583
584 let query_f32 = query.as_f32();
585 let candidates = self.query_candidates(&query_f32);
586
587 let mut results: Vec<(String, f32)> = candidates
589 .into_iter()
590 .filter_map(|id| {
591 self.vectors.get(id).and_then(|(uri, vec)| {
592 let vec_f32 = vec.as_f32();
593 let distance = match self.config.lsh_family {
594 LshFamily::RandomProjection | LshFamily::SimHash => {
595 use oxirs_core::simd::SimdOps;
596 f32::cosine_distance(&query_f32, &vec_f32)
597 }
598 LshFamily::MinHash => {
599 let threshold_val = 0.0;
601 let set1: HashSet<usize> = query_f32
602 .iter()
603 .enumerate()
604 .filter(|&(_, &v)| v > threshold_val)
605 .map(|(i, _)| i)
606 .collect();
607 let set2: HashSet<usize> = vec_f32
608 .iter()
609 .enumerate()
610 .filter(|&(_, &v)| v > threshold_val)
611 .map(|(i, _)| i)
612 .collect();
613
614 let intersection = set1.intersection(&set2).count();
615 let union = set1.union(&set2).count();
616
617 if union > 0 {
618 1.0 - (intersection as f32 / union as f32)
619 } else {
620 1.0
621 }
622 }
623 LshFamily::PStable(p) => {
624 use oxirs_core::simd::SimdOps;
625 if (p - 1.0).abs() < 0.1 {
626 f32::manhattan_distance(&query_f32, &vec_f32)
627 } else if (p - 2.0).abs() < 0.1 {
628 f32::euclidean_distance(&query_f32, &vec_f32)
629 } else {
630 query_f32
631 .iter()
632 .zip(&vec_f32)
633 .map(|(a, b)| (a - b).abs().powf(p))
634 .sum::<f32>()
635 .powf(1.0 / p)
636 }
637 }
638 };
639
640 if distance <= threshold {
641 Some((uri.clone(), distance))
642 } else {
643 None
644 }
645 })
646 })
647 .collect();
648
649 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
650 Ok(results)
651 }
652
653 fn get_vector(&self, uri: &str) -> Option<&Vector> {
654 self.uri_to_id
655 .get(uri)
656 .and_then(|&id| self.vectors.get(id))
657 .map(|(_, v)| v)
658 }
659}
660
661#[derive(Debug, Clone)]
663pub struct LshStats {
664 pub num_vectors: usize,
665 pub num_tables: usize,
666 pub avg_bucket_size: f64,
667 pub memory_usage: usize,
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn test_random_projection_lsh() {
676 let config = LshConfig {
677 num_tables: 5,
678 num_hash_functions: 4,
679 lsh_family: LshFamily::RandomProjection,
680 seed: 42,
681 multi_probe: false,
682 num_probes: 0,
683 };
684
685 let mut index = LshIndex::new(config);
686
687 let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
689 let v2 = Vector::new(vec![0.0, 1.0, 0.0]);
690 let v3 = Vector::new(vec![0.0, 0.0, 1.0]);
691 let v_similar = Vector::new(vec![0.9, 0.1, 0.0]); index.insert("v1".to_string(), v1.clone()).unwrap();
694 index.insert("v2".to_string(), v2.clone()).unwrap();
695 index.insert("v3".to_string(), v3.clone()).unwrap();
696 index
697 .insert("v_similar".to_string(), v_similar.clone())
698 .unwrap();
699
700 let results = index.search_knn(&v1, 2).unwrap();
702
703 assert!(results.len() <= 2);
704 assert!(results
706 .iter()
707 .any(|(uri, _)| uri == "v1" || uri == "v_similar"));
708 }
709
710 #[test]
711 fn test_minhash_lsh() {
712 let config = LshConfig {
713 num_tables: 3,
714 num_hash_functions: 64,
715 lsh_family: LshFamily::MinHash,
716 seed: 42,
717 multi_probe: false,
718 num_probes: 0,
719 };
720
721 let mut index = LshIndex::new(config);
722
723 let mut v1 = vec![0.0; 100];
725 v1[0] = 1.0;
726 v1[10] = 1.0;
727 v1[20] = 1.0;
728
729 let mut v2 = vec![0.0; 100];
730 v2[0] = 1.0;
731 v2[10] = 1.0;
732 v2[30] = 1.0; let mut v3 = vec![0.0; 100];
735 v3[50] = 1.0;
736 v3[60] = 1.0;
737 v3[70] = 1.0; index
740 .insert("v1".to_string(), Vector::new(v1.clone()))
741 .unwrap();
742 index.insert("v2".to_string(), Vector::new(v2)).unwrap();
743 index.insert("v3".to_string(), Vector::new(v3)).unwrap();
744
745 let results = index.search_knn(&Vector::new(v1), 2).unwrap();
747
748 assert!(!results.is_empty());
750 assert_eq!(results[0].0, "v1");
751 if results.len() > 1 {
752 assert_eq!(results[1].0, "v2");
753 }
754 }
755
756 #[test]
757 fn test_multi_probe_lsh() {
758 let config = LshConfig {
759 num_tables: 3,
760 num_hash_functions: 4,
761 lsh_family: LshFamily::RandomProjection,
762 seed: 42,
763 multi_probe: true,
764 num_probes: 2,
765 };
766
767 let mut index = LshIndex::new(config);
768
769 for i in 0..50 {
771 let angle = i as f32 * std::f32::consts::PI / 25.0;
772 let vec = Vector::new(vec![angle.cos(), angle.sin(), 0.0]);
773 index.insert(format!("v{i}"), vec).unwrap();
774 }
775
776 let query = Vector::new(vec![1.0, 0.0, 0.0]);
778 let results = index.search_knn(&query, 5).unwrap();
779
780 assert_eq!(results.len(), 5);
781 for i in 1..results.len() {
783 assert!(results[i - 1].1 <= results[i].1);
784 }
785 }
786}