Skip to main content

superbit/
hash.rs

1use ndarray::Array2;
2use rand::Rng;
3use rand_distr::StandardNormal;
4
5/// A random-projection hash family for one hash table.
6///
7/// Uses sign-of-random-projection (SimHash / hyperplane LSH) to map vectors
8/// to bit signatures. Each bit corresponds to the sign of the dot product
9/// with a random Gaussian vector.
10///
11/// Projections are stored as a (num_hashes x dim) matrix so that all dot
12/// products can be computed in a single matrix-vector multiply.
13#[derive(Debug, Clone)]
14#[cfg_attr(
15    feature = "persistence",
16    derive(serde::Serialize, serde::Deserialize)
17)]
18pub struct RandomProjectionHasher {
19    /// (num_hashes x dim) matrix -- each row is one projection vector.
20    projection_matrix: Array2<f32>,
21    num_hashes: usize,
22}
23
24impl RandomProjectionHasher {
25    /// Create a new hasher with `num_hashes` random projection vectors of dimension `dim`.
26    pub fn new(dim: usize, num_hashes: usize, rng: &mut impl Rng) -> Self {
27        let data: Vec<f32> = (0..num_hashes * dim)
28            .map(|_| rng.sample(StandardNormal))
29            .collect();
30        let projection_matrix =
31            Array2::from_shape_vec((num_hashes, dim), data).expect("shape mismatch");
32        Self {
33            projection_matrix,
34            num_hashes,
35        }
36    }
37
38    /// Compute the hash key for a vector, along with margin information for multi-probe.
39    ///
40    /// Returns `(hash_key, margins)` where margins is a vec of `(bit_index, |dot_product|)`
41    /// sorted by ascending margin (most uncertain bits first).
42    pub fn hash_vector(&self, vector: &ndarray::ArrayView1<f32>) -> (u64, Vec<(usize, f32)>) {
43        // Single matrix-vector multiply: dots = projection_matrix * vector
44        let dots = self.projection_matrix.dot(vector);
45
46        let mut hash: u64 = 0;
47        let mut margins: Vec<(usize, f32)> = Vec::with_capacity(self.num_hashes);
48
49        for (i, &dot) in dots.iter().enumerate() {
50            if dot >= 0.0 {
51                hash |= 1u64 << i;
52            }
53            margins.push((i, dot.abs()));
54        }
55
56        margins.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
57        (hash, margins)
58    }
59
60    /// Compute just the hash key (fast path, no margin data).
61    pub fn hash_vector_fast(&self, vector: &ndarray::ArrayView1<f32>) -> u64 {
62        let dots = self.projection_matrix.dot(vector);
63        let mut hash: u64 = 0;
64        for (i, &dot) in dots.iter().enumerate() {
65            if dot >= 0.0 {
66                hash |= 1u64 << i;
67            }
68        }
69        hash
70    }
71
72    /// Number of hash functions (bits in the signature).
73    pub fn num_hashes(&self) -> usize {
74        self.num_hashes
75    }
76}
77
78/// Generate multi-probe hash keys by flipping the most uncertain bits.
79///
80/// Given the base hash and margin info (sorted ascending by uncertainty),
81/// produces the base key plus `num_probes` perturbed keys.
82pub fn multi_probe_keys(
83    base_hash: u64,
84    margins: &[(usize, f32)],
85    num_probes: usize,
86) -> Vec<u64> {
87    let mut keys = Vec::with_capacity(1 + num_probes);
88    keys.push(base_hash);
89
90    for &(bit_idx, _) in margins.iter().take(num_probes) {
91        keys.push(base_hash ^ (1u64 << bit_idx));
92    }
93
94    keys
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use ndarray::array;
101    use rand::SeedableRng;
102    use rand::rngs::StdRng;
103
104    #[test]
105    fn test_deterministic_hash() {
106        let mut rng = StdRng::seed_from_u64(42);
107        let hasher = RandomProjectionHasher::new(4, 8, &mut rng);
108        let v = array![1.0, 2.0, 3.0, 4.0];
109        let h1 = hasher.hash_vector_fast(&v.view());
110        let h2 = hasher.hash_vector_fast(&v.view());
111        assert_eq!(h1, h2);
112    }
113
114    #[test]
115    fn test_similar_vectors_likely_same_hash() {
116        let mut rng = StdRng::seed_from_u64(42);
117        let hasher = RandomProjectionHasher::new(4, 4, &mut rng);
118        let v1 = array![1.0, 2.0, 3.0, 4.0];
119        let v2 = array![1.01, 2.01, 3.01, 4.01];
120        let h1 = hasher.hash_vector_fast(&v1.view());
121        let h2 = hasher.hash_vector_fast(&v2.view());
122        // Very similar vectors should often (but not always) hash the same
123        // With only 4 bits, probability is high
124        assert_eq!(h1, h2);
125    }
126
127    #[test]
128    fn test_multi_probe_keys() {
129        let base = 0b1010u64;
130        let margins = vec![(0, 0.1), (2, 0.5), (1, 0.8), (3, 1.2)];
131        let keys = multi_probe_keys(base, &margins, 2);
132        assert_eq!(keys.len(), 3);
133        assert_eq!(keys[0], 0b1010); // base
134        assert_eq!(keys[1], 0b1011); // flip bit 0
135        assert_eq!(keys[2], 0b1110); // flip bit 2
136    }
137}