superbit/simhash/
fast_sim_hash.rs

1// Faithful Rust port of Dynatrace’s  “10× faster SimHash” bit-hack
2// see original post here: https://www.dynatrace.com/engineering/blog/speeding-up-simhash-by-10x-using-a-bit-hack/
3// and Java reference implementation here: https://github.com/dynatrace-oss/hash4j/blob/main/src/main/java/com/dynatrace/hash4j/similarity/FastSimHashPolicy_v1.java
4// unsigned packed counters
5// single Vec<u64> for the temporary lanes
6
7use core::marker::PhantomData;
8use rand_core::{RngCore, SeedableRng};
9use rand_xoshiro::Xoroshiro128PlusPlus;
10use std::hash::Hash;
11
12use super::SimHashBits;
13use super::sim_hasher::SimHasher;
14
15/// number of u64 lanes needed for a signature of length `L`
16const fn tmp_len<const L: usize, const B: usize>() -> usize {
17    (L + (63 >> (6 - B))) >> B
18}
19
20/// repeating …001001 mask (low bit in every packed counter)
21const fn bulk_mask<const B: usize>() -> u64 {
22    let mut m = 1u64;
23    let mut i = 0;
24    while i < B {
25        m |= m << (1 << (5 - i));
26        i += 1;
27    }
28    m
29}
30
31pub struct FastSimHash<H, S, const L: usize, const BULK: usize = 3>
32where
33    H: SimHasher<T = u64>,
34    S: SimHashBits,
35{
36    hasher: H,
37    _p: PhantomData<S>,
38}
39
40impl<H, S, const L: usize, const BULK: usize> FastSimHash<H, S, L, BULK>
41where
42    H: SimHasher<T = u64>,
43    S: SimHashBits,
44{
45    // constants identical to the Java reference
46    // const PER_LANE: usize = 1 << BULK;
47    const BITS_PER_COUNTER: usize = 1 << (6 - BULK); // 8, 4, 2, …
48    const BULK_MASK: u64 = bulk_mask::<BULK>(); // …001001…
49    const TMP_LIMIT: u64 = (1u64 << 1 << (Self::BITS_PER_COUNTER - 1)) - 1; // 255/15/3/1
50
51    #[inline(always)]
52    const fn tmp_len() -> usize {
53        tmp_len::<L, BULK>()
54    }
55
56    pub fn new(hasher: H) -> Self {
57        assert!(
58            L <= S::bit_length(),
59            "signature length too large for container"
60        );
61        Self {
62            hasher,
63            _p: PhantomData,
64        }
65    }
66
67    pub fn create_signature<T, U>(&self, iter: T) -> S
68    where
69        T: IntoIterator<Item = U>,
70        U: Hash,
71    {
72        let mut counts: [i32; L] = [0; L];
73        let mut tmp: Vec<u64> = vec![0; Self::tmp_len()];
74
75        let num_chunks = tmp.len() >> (6 - BULK); // full 64-bit groups
76        let num_tail_lanes = tmp.len() & (0x3f >> BULK);
77
78        let mut processed_block = 0u64;
79        let mut total_elements = 0u32;
80
81        // per-feature loop
82        for feat in iter {
83            total_elements += 1;
84
85            let seed = self.hasher.hash(&feat);
86            let mut rng = Xoroshiro128PlusPlus::seed_from_u64(seed);
87
88            // full 64-bit chunks (update 8 counters when BULK=3)
89            for h in 0..num_chunks {
90                let rnd = rng.next_u64();
91                let off = h << (6 - BULK);
92                for j in 0..Self::BITS_PER_COUNTER {
93                    tmp[off + j] += (rnd >> j) & Self::BULK_MASK; // *add only 1-bits*
94                }
95            }
96            // tail lanes
97            if num_tail_lanes != 0 {
98                let rnd = rng.next_u64();
99                let off = num_chunks << (6 - BULK);
100                for j in 0..num_tail_lanes {
101                    tmp[off + j] += (rnd >> j) & Self::BULK_MASK;
102                }
103            }
104
105            processed_block += 1;
106            if processed_block == Self::TMP_LIMIT {
107                Self::flush::<L, BULK>(&mut counts, &mut tmp);
108                processed_block = 0;
109            }
110        }
111        Self::flush::<L, BULK>(&mut counts, &mut tmp);
112
113        // threshold to signature
114        let limit = (total_elements >> 1) as i32;
115        let mut sig = S::zero();
116        for i in 0..L {
117            if counts[i] + ((i & (!total_elements as usize & 1)) as i32) > limit {
118                sig |= S::one() << i;
119            }
120        }
121        sig
122    }
123
124    // merge packed tmp counters into final signed counts
125    #[inline(always)]
126    fn flush<const LL: usize, const B: usize>(acc: &mut [i32; LL], tmp: &mut [u64]) {
127        let per: usize = 1 << B;
128        let width: usize = 1 << (6 - B);
129        let mask: u64 = (1u64 << width) - 1;
130
131        // full bundles
132        let full = LL >> B;
133        for h in 0..full {
134            let t = core::mem::take(&mut tmp[h]);
135            let off = h << B;
136            for g in 0..per {
137                acc[off + g] += ((t >> (g * width)) & mask) as i32;
138            }
139        }
140        // tail bundle
141        for h in full..tmp.len() {
142            let t = core::mem::take(&mut tmp[h]);
143            let off = h << B;
144            for g in 0..(LL - off) {
145                acc[off + g] += ((t >> (g * width)) & mask) as i32;
146            }
147        }
148    }
149
150    /// Weighted variant (simple path): uses signed float counters.
151    /// This keeps per-feature random bits from Xoroshiro but does not use the
152    /// packed-counter bit-hack. Still quite fast; O(#features * L/64).
153    pub fn create_signature_weighted<T, U, W>(&self, iter: T) -> S
154    where
155        T: IntoIterator<Item = (U, W)>,
156        U: std::hash::Hash,
157        W: Into<f32> + Copy,
158    {
159        let mut counts = [0f32; L];
160
161        for (feat, w0) in iter {
162            let w = w0.into();
163
164            // Seed RNG from the hasher just like the fast path
165            let seed = self.hasher.hash(&feat);
166            let mut rng = Xoroshiro128PlusPlus::seed_from_u64(seed);
167
168            // Consume bits in 64-bit blocks
169            let mut idx = 0usize;
170            while idx < L {
171                let rnd = rng.next_u64();
172                let take = (L - idx).min(64);
173                // For each bit in this 64-bit block
174                for b in 0..take {
175                    if ((rnd >> b) & 1) == 1 {
176                        counts[idx + b] += w; // contribute +w for bit=1
177                    } else {
178                        counts[idx + b] -= w; // contribute -w for bit=0
179                    }
180                }
181                idx += take;
182            }
183        }
184
185        // Threshold by sign
186        let mut sig = S::zero();
187        for i in 0..L {
188            if counts[i] > 0.0 {
189                sig |= S::one() << i;
190            }
191        }
192        sig
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::super::{BitArray, SimHashBits};
199    use super::*;
200    use crate::simhash::sim_hasher::Xxh3Hasher64;
201    use rand::rngs::StdRng;
202    use rand::{Rng, SeedableRng};
203    use std::time::Instant;
204    // cargo test --release fast_simhash_bitarray -- --nocapture
205    #[test]
206    fn fast_simhash_bitarray() {
207        type Bits = BitArray<16>;
208        const L: usize = 1024;
209        const N: usize = 100_000;
210
211        let mut rng = StdRng::seed_from_u64(42);
212        let data1: Vec<u8> = (0..N).map(|_| rng.gen_range(0..=1)).collect();
213        let mut data2 = data1.clone();
214        for i in (0..N).step_by(4) {
215            data2[i] ^= 1;
216        }
217
218        // ground-truth cosine and angle
219        let (mut dot, mut n1, mut n2) = (0f64, 0f64, 0f64);
220        for i in 0..N {
221            let x = data1[i] as f64;
222            let y = data2[i] as f64;
223            dot += x * y;
224            n1 += x * x;
225            n2 += y * y;
226        }
227        let cosine = (dot / (n1.sqrt() * n2.sqrt())).clamp(-1.0, 1.0);
228        let theta = cosine.acos();
229        let p_bit = theta / std::f64::consts::PI; // Charikar: P(bit differs)
230
231        // 1-σ acceptance band for
232        let mean = p_bit * L as f64;
233        let sigma = (L as f64 * p_bit * (1.0 - p_bit)).sqrt();
234        let low = (mean - 2.0 * sigma).round() as usize;
235        let high = (mean + 2.0 * sigma).round() as usize;
236
237        let fsh = FastSimHash::<Xxh3Hasher64, Bits, L>::new(Xxh3Hasher64::new());
238        let t1 = Instant::now();
239        let h1 = fsh.create_signature((0..N).map(|i| (i as u64, data1[i])));
240        let h2 = fsh.create_signature((0..N).map(|i| (i as u64, data2[i])));
241        let hd = h1.hamming_distance(&h2);
242        let dur = t1.elapsed();
243        println!("fast SimHash: {:?}", dur);
244        println!("HD = {hd}, expected ≈ {low}–{high}  (p_bit ≈ {:.3})", p_bit);
245        println!("expected {:.3}", mean);
246        assert!((low..=high).contains(&hd));
247    }
248}