1use ndarray::Array2;
2use rand::Rng;
3use rand_distr::StandardNormal;
4
5#[derive(Debug, Clone)]
14#[cfg_attr(
15 feature = "persistence",
16 derive(serde::Serialize, serde::Deserialize)
17)]
18pub struct RandomProjectionHasher {
19 projection_matrix: Array2<f32>,
21 num_hashes: usize,
22}
23
24impl RandomProjectionHasher {
25 pub fn new(dim: usize, num_hashes: usize, rng: &mut impl Rng) -> Self {
27 let data: Vec<f32> = (0..num_hashes * dim)
28 .map(|_| rng.sample(StandardNormal))
29 .collect();
30 let projection_matrix =
31 Array2::from_shape_vec((num_hashes, dim), data).expect("shape mismatch");
32 Self {
33 projection_matrix,
34 num_hashes,
35 }
36 }
37
38 pub fn hash_vector(&self, vector: &ndarray::ArrayView1<f32>) -> (u64, Vec<(usize, f32)>) {
43 let dots = self.projection_matrix.dot(vector);
45
46 let mut hash: u64 = 0;
47 let mut margins: Vec<(usize, f32)> = Vec::with_capacity(self.num_hashes);
48
49 for (i, &dot) in dots.iter().enumerate() {
50 if dot >= 0.0 {
51 hash |= 1u64 << i;
52 }
53 margins.push((i, dot.abs()));
54 }
55
56 margins.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
57 (hash, margins)
58 }
59
60 pub fn hash_vector_fast(&self, vector: &ndarray::ArrayView1<f32>) -> u64 {
62 let dots = self.projection_matrix.dot(vector);
63 let mut hash: u64 = 0;
64 for (i, &dot) in dots.iter().enumerate() {
65 if dot >= 0.0 {
66 hash |= 1u64 << i;
67 }
68 }
69 hash
70 }
71
72 pub fn num_hashes(&self) -> usize {
74 self.num_hashes
75 }
76}
77
78pub fn multi_probe_keys(
83 base_hash: u64,
84 margins: &[(usize, f32)],
85 num_probes: usize,
86) -> Vec<u64> {
87 let mut keys = Vec::with_capacity(1 + num_probes);
88 keys.push(base_hash);
89
90 for &(bit_idx, _) in margins.iter().take(num_probes) {
91 keys.push(base_hash ^ (1u64 << bit_idx));
92 }
93
94 keys
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use ndarray::array;
101 use rand::SeedableRng;
102 use rand::rngs::StdRng;
103
104 #[test]
105 fn test_deterministic_hash() {
106 let mut rng = StdRng::seed_from_u64(42);
107 let hasher = RandomProjectionHasher::new(4, 8, &mut rng);
108 let v = array![1.0, 2.0, 3.0, 4.0];
109 let h1 = hasher.hash_vector_fast(&v.view());
110 let h2 = hasher.hash_vector_fast(&v.view());
111 assert_eq!(h1, h2);
112 }
113
114 #[test]
115 fn test_similar_vectors_likely_same_hash() {
116 let mut rng = StdRng::seed_from_u64(42);
117 let hasher = RandomProjectionHasher::new(4, 4, &mut rng);
118 let v1 = array![1.0, 2.0, 3.0, 4.0];
119 let v2 = array![1.01, 2.01, 3.01, 4.01];
120 let h1 = hasher.hash_vector_fast(&v1.view());
121 let h2 = hasher.hash_vector_fast(&v2.view());
122 assert_eq!(h1, h2);
125 }
126
127 #[test]
128 fn test_multi_probe_keys() {
129 let base = 0b1010u64;
130 let margins = vec![(0, 0.1), (2, 0.5), (1, 0.8), (3, 1.2)];
131 let keys = multi_probe_keys(base, &margins, 2);
132 assert_eq!(keys.len(), 3);
133 assert_eq!(keys[0], 0b1010); assert_eq!(keys[1], 0b1011); assert_eq!(keys[2], 0b1110); }
137}