superbit/simhash/
superbit.rs

1// Super-Bit:
2use super::SimHashBits;
3use super::sim_hasher::SimHasher;
4use core::marker::PhantomData;
5use std::hash::Hash;
6use std::hash::Hasher;
7use xxhash_rust::xxh3::Xxh3;
8
9pub struct SuperBitSimHash<H, S, const L: usize>
10where
11    H: SimHasher<T = u64>,
12    S: SimHashBits,
13{
14    hasher: H,
15    r: usize,                // superbit depth (block size)
16    m: usize,                // number of blocks = L / r
17    q_blocks: Vec<Vec<f32>>, // each is r×r row-major orthonormal
18    seed: u64,
19    _phantom: PhantomData<S>, // keep S "used" at the type level
20}
21
22impl<H, S, const L: usize> SuperBitSimHash<H, S, L>
23where
24    H: SimHasher<T = u64>,
25    S: SimHashBits,
26{
27    pub fn new(hasher: H, r: usize, seed: u64) -> Self {
28        assert!(r > 0 && L % r == 0, "r must divide L");
29        let m = L / r;
30        let mut q_blocks = Vec::with_capacity(m);
31        for b in 0..m {
32            q_blocks.push(Self::make_orthonormal_block(
33                r,
34                seed ^ ((b as u64) * 0x9E37_79B9),
35            ));
36        }
37        Self {
38            hasher,
39            r,
40            m,
41            q_blocks,
42            seed,
43            _phantom: PhantomData,
44        }
45    }
46
47    // Classical (stable enough for small r) Gram–Schmidt on a random r×r
48    fn make_orthonormal_block(r: usize, seed: u64) -> Vec<f32> {
49        let mut mat = vec![0f32; r * r];
50        // fill with pseudo-random N(0,1)-ish via hashed Box–Muller
51        let mut k = 0u64;
52        for i in 0..r {
53            for j in 0..r {
54                let u1 = Self::u01(seed, k ^ ((i as u64) << 16) ^ (j as u64));
55                let u2 = Self::u01(seed, k.wrapping_mul(0x9E37_79B97F4A7C15) ^ 0xBF58_476D);
56                k = k.wrapping_add(1);
57                let r2 = (-2.0f32 * u1.ln()).sqrt();
58                let th = 2.0f32 * std::f32::consts::PI * u2;
59                mat[i * r + j] = r2 * th.cos(); // one normal; good enough for r<=32
60            }
61        }
62        // Gram–Schmidt orthogonal
63        for j in 0..r {
64            // subtract projections
65            for p in 0..j {
66                let mut dot = 0f32;
67                for i in 0..r {
68                    dot += mat[i * r + j] * mat[i * r + p];
69                }
70                for i in 0..r {
71                    mat[i * r + j] -= dot * mat[i * r + p];
72                }
73            }
74            // normalize
75            let mut n = 0f32;
76            for i in 0..r {
77                n += mat[i * r + j] * mat[i * r + j];
78            }
79            let n = n.sqrt().max(1e-12);
80            for i in 0..r {
81                mat[i * r + j] /= n;
82            }
83        }
84        mat
85    }
86
87    #[inline]
88    fn u01(seed: u64, x: u64) -> f32 {
89        // xxhash-rust API: use `with_seed`
90        let mut h = Xxh3::with_seed(seed);
91        h.update(&x.to_le_bytes());
92        let v = h.finish();
93        // map 53-bit mantissa to (0,1); avoid exact 0/1
94        ((v >> 11) as f32 + 0.5) * (1.0 / ((1u64 << 53) as f32))
95    }
96
97    /// Unweighted items: treat each item as weight 1.0 (like classic SimHash).
98    pub fn create_signature<U>(&self, iter: impl Iterator<Item = U>) -> S
99    where
100        U: Hash,
101    {
102        self.create_signature_weighted(iter.map(|u| (u, 1.0f32)))
103    }
104
105    /// Weighted variant (useful for TF/IDF or MS intensities after L2 norm).
106    /// Weighted variant (useful for TF/IDF or MS intensities after L2 norm).
107    pub fn create_signature_weighted<U>(&self, iter: impl Iterator<Item = (U, f32)>) -> S
108    where
109        U: Hash,
110    {
111        let mut counts = vec![0f32; L];
112        // Reuse a single buffer for g to avoid per-block allocations.
113        let mut g = vec![0f32; self.r];
114
115        for (item, w) in iter {
116            if w == 0.0 {
117                continue;
118            }
119
120            // Stable per-item 64-bit base using the generic hasher H
121            let base: u64 = self.hasher.hash(&item);
122
123            // for each block, build g in {+1,-1}^r and accumulate Q_b * g
124            for b in 0..self.m {
125                let qb = &self.q_blocks[b]; // row-major r×r
126
127                // Seed a tiny PRNG (SplitMix64) once per (item, block)
128                let mut s = self.seed ^ base ^ ((b as u64) << 32) ^ 0x9E37_79B9_7F4A_7C15;
129
130                // Fill g[j] ∈ {+1,-1} using SplitMix64
131                for j in 0..self.r {
132                    s = Self::splitmix64(s);
133                    g[j] = if (s >> 63) == 0 { 1.0 } else { -1.0 };
134                }
135
136                // u = Q_b * g (r dot-products, row-major is cache-friendly here)
137                let off = b * self.r;
138                for row in 0..self.r {
139                    let mut acc = 0f32;
140                    let row_off = row * self.r;
141                    // dot(row, g)
142                    for col in 0..self.r {
143                        acc += qb[row_off + col] * g[col];
144                    }
145                    counts[off + row] += w * acc;
146                }
147            }
148        }
149
150        // threshold -> bits
151        let mut out = S::zero();
152        for i in 0..L {
153            if counts[i] > 0.0 {
154                out |= S::one() << i;
155            }
156        }
157        out
158    }
159
160    fn splitmix64(mut x: u64) -> u64 {
161        x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
162        let mut z = x;
163        z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
164        z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
165        z ^ (z >> 31)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::super::{BitArray, SimHashBits};
172    use super::*;
173    use crate::simhash::sim_hasher::Xxh3Hasher64;
174    use rand::Rng;
175    use rand::SeedableRng;
176    use rand::rngs::StdRng;
177
178    // cargo test --release superbit_simhash_bitarray -- --nocapture
179    #[test]
180    fn superbit_simhash_bitarray() {
181        type Bits = BitArray<16>; // 16×64 = 1024 bits
182        const L: usize = 1024;
183        const R: usize = 32; // block size; L % R == 0
184        const N: usize = 100_000;
185
186        let mut rng = StdRng::seed_from_u64(12345);
187        let data1: Vec<u8> = (0..N).map(|_| rng.gen_range(0..=1)).collect();
188        let mut data2 = data1.clone();
189        for i in (0..N).step_by(4) {
190            data2[i] ^= 1;
191        }
192
193        // Ground-truth cosine/angle from the two bit-vectors treated as {0,1}^N
194        let (mut dot, mut n1, mut n2) = (0f64, 0f64, 0f64);
195        for i in 0..N {
196            let x = data1[i] as f64;
197            let y = data2[i] as f64;
198            dot += x * y;
199            n1 += x * x;
200            n2 += y * y;
201        }
202        let cosine = (dot / (n1.sqrt() * n2.sqrt())).clamp(-1.0, 1.0);
203        let theta = cosine.acos();
204        let p_bit = theta / std::f64::consts::PI; // P(bit differs) for SRP
205
206        let mean = p_bit * L as f64;
207        let sigma = (L as f64 * p_bit * (1.0 - p_bit)).sqrt();
208        let low = (mean - 3.0 * sigma).floor().max(0.0) as usize; // 3σ band for robustness
209        let high = (mean + 3.0 * sigma).ceil().min(L as f64) as usize;
210
211        let sb = SuperBitSimHash::<Xxh3Hasher64, Bits, L>::new(Xxh3Hasher64::new(), R, 0xDEAD_BEEF);
212        let h1 = sb.create_signature_weighted((0..N).map(|i| (i as u64, data1[i] as f32)));
213        let h2 = sb.create_signature_weighted((0..N).map(|i| (i as u64, data2[i] as f32)));
214        let hd = h1.hamming_distance(&h2);
215
216        eprintln!(
217            "SuperBit SimHash (L={L}, R={R}, N={N}): HD = {}, expected ≈ {}–{} (p_bit ≈ {:.3}, mean {:.1}, σ {:.1})",
218            hd, low, high, p_bit, mean, sigma
219        );
220        assert!((low..=high).contains(&hd));
221    }
222    // cargo test --release superbit_vs_classic_weighted_accuracy -- --nocapture
223    #[test]
224    fn superbit_vs_classic_weighted_accuracy() {
225        use crate::simhash::sim_hash::SimHash;
226        use crate::simhash::sim_hasher::{Xxh3Hasher64, Xxh3Hasher128};
227        use rand::rngs::StdRng;
228        use rand::{Rng, SeedableRng};
229
230        // Signature + block sizes
231        const L: usize = 256; // signature length
232        const R: usize = 16; // SuperBit block size, must divide L
233        const D: usize = 4096; // dimensionality (many bins/features)
234        const REPS: usize = 5; // average across a few draws for stability
235
236        let mut rng = StdRng::seed_from_u64(2025_11_12);
237
238        // SuperBit (uses 64-bit item hasher) over u128 signature
239        let sb = SuperBitSimHash::<Xxh3Hasher64, u128, L>::new(Xxh3Hasher64::new(), R, 0xD00D_F00D);
240        // Classic SimHash (weighted) must hash to u128 for L=256 bits
241        let sh = SimHash::<Xxh3Hasher128, u128, L>::new(Xxh3Hasher128::new());
242
243        let mut mae_sb = 0.0f64;
244        let mut mae_sh = 0.0f64;
245
246        for _ in 0..REPS {
247            // Construct two *positive-weighted* dense vectors (MS-like):
248            // a ~ Exp(1), b = a with ~15% multiplicative jitter + sparse spikes
249            let mut a = vec![0f32; D];
250            let mut b = vec![0f32; D];
251            for i in 0..D {
252                let u1: f32 = rng.r#gen::<f32>().clamp(1e-7, 1.0 - 1e-7);
253                let u2: f32 = rng.r#gen::<f32>().clamp(1e-7, 1.0 - 1e-7);
254                let exp_a = -u1.ln(); // exponential(1)
255                let jitter = 1.0 + 0.15 * (rng.r#gen::<f32>() - 0.5);
256                let spike = if rng.r#gen::<f32>() < 0.01 {
257                    0.5 * (-u2.ln())
258                } else {
259                    0.0
260                };
261                a[i] = exp_a;
262                b[i] = (exp_a * jitter + spike).max(0.0);
263            }
264
265            // Ground-truth cosine / angle
266            let (mut dot, mut na, mut nb) = (0f64, 0f64, 0f64);
267            for i in 0..D {
268                let (ai, bi) = (a[i] as f64, b[i] as f64);
269                dot += ai * bi;
270                na += ai * ai;
271                nb += bi * bi;
272            }
273            let cosine = (dot / (na.sqrt() * nb.sqrt())).clamp(-1.0, 1.0);
274            let theta = cosine.acos();
275
276            // SuperBit (weighted)
277            let sig_sb_a = sb.create_signature_weighted((0..D).map(|i| (i as u64, a[i])));
278            let sig_sb_b = sb.create_signature_weighted((0..D).map(|i| (i as u64, b[i])));
279            let hd_sb = sig_sb_a.hamming_distance(&sig_sb_b);
280            let th_sb = std::f64::consts::PI * (hd_sb as f64) / (L as f64);
281            mae_sb += (th_sb - theta).abs();
282
283            // Classic SimHash (weighted) — relies on your added weighted API
284            let sig_sh_a = sh.create_signature_weighted((0..D).map(|i| (i as u64, a[i])));
285            let sig_sh_b = sh.create_signature_weighted((0..D).map(|i| (i as u64, b[i])));
286            let hd_sh = sig_sh_a.hamming_distance(&sig_sh_b);
287            let th_sh = std::f64::consts::PI * (hd_sh as f64) / (L as f64);
288            mae_sh += (th_sh - theta).abs();
289        }
290
291        mae_sb /= REPS as f64;
292        mae_sh /= REPS as f64;
293
294        eprintln!(
295            "Weighted accuracy (angle MAE): SuperBit={:.6}, Classic={:.6}",
296            mae_sb, mae_sh
297        );
298
299        // Expect SuperBit's orthogonal hyperplanes to reduce variance => lower error.
300        // Allow a small tolerance for determinism; SuperBit should be strictly better here.
301        assert!(
302            mae_sb <= mae_sh * 0.9 + 1e-6,
303            "SuperBit did not improve enough: SB={:.6}, SH={:.6}",
304            mae_sb,
305            mae_sh
306        );
307    }
308}