superbit/simhash/
superbit.rs

1// Super-Bit:
2use super::SimHashBits;
3use super::sim_hasher::SimHasher;
4use core::marker::PhantomData;
5use std::hash::Hash;
6use std::hash::Hasher;
7use xxhash_rust::xxh3::Xxh3;
8
9pub struct SuperBitSimHash<H, S, const L: usize>
10where
11    H: SimHasher<T = u64>,
12    S: SimHashBits,
13{
14    hasher: H,
15    r: usize,                // superbit depth (block size)
16    m: usize,                // number of blocks = L / r
17    q_blocks: Vec<Vec<f32>>, // each is r×r row-major orthonormal
18    seed: u64,
19    _phantom: PhantomData<S>, // keep S "used" at the type level
20}
21
22impl<H, S, const L: usize> SuperBitSimHash<H, S, L>
23where
24    H: SimHasher<T = u64>,
25    S: SimHashBits,
26{
27    pub fn new(hasher: H, r: usize, seed: u64) -> Self {
28        assert!(r > 0 && L % r == 0, "r must divide L");
29        let m = L / r;
30        let mut q_blocks = Vec::with_capacity(m);
31        for b in 0..m {
32            q_blocks.push(Self::make_orthonormal_block(
33                r,
34                seed ^ ((b as u64) * 0x9E37_79B9),
35            ));
36        }
37        Self {
38            hasher,
39            r,
40            m,
41            q_blocks,
42            seed,
43            _phantom: PhantomData,
44        }
45    }
46
47    // Classical (stable enough for small r) Gram–Schmidt on a random r×r
48    fn make_orthonormal_block(r: usize, seed: u64) -> Vec<f32> {
49        let mut mat = vec![0f32; r * r];
50        // fill with pseudo-random N(0,1)-ish via hashed Box–Muller
51        let mut k = 0u64;
52        for i in 0..r {
53            for j in 0..r {
54                let u1 = Self::u01(seed, k ^ ((i as u64) << 16) ^ (j as u64));
55                let u2 = Self::u01(seed, k.wrapping_mul(0x9E37_79B97F4A7C15) ^ 0xBF58_476D);
56                k = k.wrapping_add(1);
57                let r2 = (-2.0f32 * u1.ln()).sqrt();
58                let th = 2.0f32 * std::f32::consts::PI * u2;
59                mat[i * r + j] = r2 * th.cos(); // one normal; good enough for r<=32
60            }
61        }
62        // Gram–Schmidt orthogonal
63        for j in 0..r {
64            // subtract projections
65            for p in 0..j {
66                let mut dot = 0f32;
67                for i in 0..r {
68                    dot += mat[i * r + j] * mat[i * r + p];
69                }
70                for i in 0..r {
71                    mat[i * r + j] -= dot * mat[i * r + p];
72                }
73            }
74            // normalize
75            let mut n = 0f32;
76            for i in 0..r {
77                n += mat[i * r + j] * mat[i * r + j];
78            }
79            let n = n.sqrt().max(1e-12);
80            for i in 0..r {
81                mat[i * r + j] /= n;
82            }
83        }
84        mat
85    }
86
87    fn u01(seed: u64, x: u64) -> f32 {
88        // xxhash-rust API: use `with_seed`
89        let mut h = Xxh3::with_seed(seed);
90        h.update(&x.to_le_bytes());
91        let v = h.finish();
92        // map 53-bit mantissa to (0,1); avoid exact 0/1
93        ((v >> 11) as f32 + 0.5) * (1.0 / ((1u64 << 53) as f32))
94    }
95
96    /// Unweighted items: treat each item as weight 1.0 (see classic SimHash).
97    pub fn create_signature<U>(&self, iter: impl Iterator<Item = U>) -> S
98    where
99        U: Hash,
100    {
101        self.create_signature_weighted(iter.map(|u| (u, 1.0f32)))
102    }
103
104    /// Weighted variant (useful for TF/IDF or MS intensities after L2 norm).
105    /// Weighted variant (useful for TF/IDF or MS intensities after L2 norm).
106    pub fn create_signature_weighted<U>(&self, iter: impl Iterator<Item = (U, f32)>) -> S
107    where
108        U: Hash,
109    {
110        let mut counts = vec![0f32; L];
111        // Reuse a single buffer for g to avoid per-block allocations.
112        let mut g = vec![0f32; self.r];
113
114        for (item, w) in iter {
115            if w == 0.0 {
116                continue;
117            }
118
119            // Stable per-item 64-bit base using the generic hasher H
120            let base: u64 = self.hasher.hash(&item);
121
122            // for each block, build g in {+1,-1}^r and accumulate Q_b * g
123            for b in 0..self.m {
124                let qb = &self.q_blocks[b]; // row-major r×r
125
126                // Seed a tiny PRNG (SplitMix64) once per (item, block)
127                let mut s = self.seed ^ base ^ ((b as u64) << 32) ^ 0x9E37_79B9_7F4A_7C15;
128
129                // Fill g[j] ∈ {+1,-1} using SplitMix64
130                for j in 0..self.r {
131                    s = Self::splitmix64(s);
132                    g[j] = if (s >> 63) == 0 { 1.0 } else { -1.0 };
133                }
134
135                // u = Q_b * g (r dot-products, row-major is cache-friendly here)
136                let off = b * self.r;
137                for row in 0..self.r {
138                    let mut acc = 0f32;
139                    let row_off = row * self.r;
140                    // dot(row, g)
141                    for col in 0..self.r {
142                        acc += qb[row_off + col] * g[col];
143                    }
144                    counts[off + row] += w * acc;
145                }
146            }
147        }
148
149        // threshold -> bits
150        let mut out = S::zero();
151        for i in 0..L {
152            if counts[i] > 0.0 {
153                out |= S::one() << i;
154            }
155        }
156        out
157    }
158
159    fn splitmix64(mut x: u64) -> u64 {
160        x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
161        let mut z = x;
162        z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
163        z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
164        z ^ (z >> 31)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::super::{BitArray, SimHashBits};
171    use super::*;
172    use crate::simhash::sim_hasher::Xxh3Hasher64;
173    use rand::Rng;
174    use rand::SeedableRng;
175    use rand::rngs::StdRng;
176
177    // cargo test --release superbit_simhash_bitarray -- --nocapture
178    #[test]
179    fn superbit_simhash_bitarray() {
180        type Bits = BitArray<16>; // 16×64 = 1024 bits
181        const L: usize = 1024;
182        const R: usize = 32; // block size; L % R == 0
183        const N: usize = 100_000;
184
185        let mut rng = StdRng::seed_from_u64(12345);
186        let data1: Vec<u8> = (0..N).map(|_| rng.gen_range(0..=1)).collect();
187        let mut data2 = data1.clone();
188        for i in (0..N).step_by(4) {
189            data2[i] ^= 1;
190        }
191
192        // Ground-truth cosine/angle from the two bit-vectors treated as {0,1}^N
193        let (mut dot, mut n1, mut n2) = (0f64, 0f64, 0f64);
194        for i in 0..N {
195            let x = data1[i] as f64;
196            let y = data2[i] as f64;
197            dot += x * y;
198            n1 += x * x;
199            n2 += y * y;
200        }
201        let cosine = (dot / (n1.sqrt() * n2.sqrt())).clamp(-1.0, 1.0);
202        let theta = cosine.acos();
203        let p_bit = theta / std::f64::consts::PI; // P(bit differs) for SRP
204
205        let mean = p_bit * L as f64;
206        let sigma = (L as f64 * p_bit * (1.0 - p_bit)).sqrt();
207        let low = (mean - 3.0 * sigma).floor().max(0.0) as usize; // 3σ band for robustness
208        let high = (mean + 3.0 * sigma).ceil().min(L as f64) as usize;
209
210        let sb = SuperBitSimHash::<Xxh3Hasher64, Bits, L>::new(Xxh3Hasher64::new(), R, 0xDEAD_BEEF);
211        let h1 = sb.create_signature_weighted((0..N).map(|i| (i as u64, data1[i] as f32)));
212        let h2 = sb.create_signature_weighted((0..N).map(|i| (i as u64, data2[i] as f32)));
213        let hd = h1.hamming_distance(&h2);
214
215        eprintln!(
216            "SuperBit SimHash (L={L}, R={R}, N={N}): HD = {}, expected ≈ {}–{} (p_bit ≈ {:.3}, mean {:.1}, σ {:.1})",
217            hd, low, high, p_bit, mean, sigma
218        );
219        assert!((low..=high).contains(&hd));
220    }
221    // cargo test --release superbit_vs_classic_weighted_accuracy -- --nocapture
222    #[test]
223    fn superbit_vs_classic_weighted_accuracy() {
224        use crate::simhash::sim_hash::SimHash;
225        use crate::simhash::sim_hasher::{Xxh3Hasher64, Xxh3Hasher128};
226        use rand::rngs::StdRng;
227        use rand::{Rng, SeedableRng};
228
229        // Signature + block sizes
230        const L: usize = 256; // signature length
231        const R: usize = 16; // SuperBit block size, must divide L
232        const D: usize = 4096; // dimensionality (many bins/features)
233        const REPS: usize = 5; // average across a few draws for stability
234
235        let mut rng = StdRng::seed_from_u64(2025_11_12);
236
237        // SuperBit (uses 64-bit item hasher) over u128 signature
238        let sb = SuperBitSimHash::<Xxh3Hasher64, u128, L>::new(Xxh3Hasher64::new(), R, 0xD00D_F00D);
239        // Classic SimHash (weighted) must hash to u128 for L=256 bits
240        let sh = SimHash::<Xxh3Hasher128, u128, L>::new(Xxh3Hasher128::new());
241
242        let mut mae_sb = 0.0f64;
243        let mut mae_sh = 0.0f64;
244
245        for _ in 0..REPS {
246            // Construct two *positive-weighted* dense vectors (MS-like):
247            // a ~ Exp(1), b = a with ~15% multiplicative jitter + sparse spikes
248            let mut a = vec![0f32; D];
249            let mut b = vec![0f32; D];
250            for i in 0..D {
251                let u1: f32 = rng.r#gen::<f32>().clamp(1e-7, 1.0 - 1e-7);
252                let u2: f32 = rng.r#gen::<f32>().clamp(1e-7, 1.0 - 1e-7);
253                let exp_a = -u1.ln(); // exponential(1)
254                let jitter = 1.0 + 0.15 * (rng.r#gen::<f32>() - 0.5);
255                let spike = if rng.r#gen::<f32>() < 0.01 {
256                    0.5 * (-u2.ln())
257                } else {
258                    0.0
259                };
260                a[i] = exp_a;
261                b[i] = (exp_a * jitter + spike).max(0.0);
262            }
263
264            // Ground-truth cosine / angle
265            let (mut dot, mut na, mut nb) = (0f64, 0f64, 0f64);
266            for i in 0..D {
267                let (ai, bi) = (a[i] as f64, b[i] as f64);
268                dot += ai * bi;
269                na += ai * ai;
270                nb += bi * bi;
271            }
272            let cosine = (dot / (na.sqrt() * nb.sqrt())).clamp(-1.0, 1.0);
273            let theta = cosine.acos();
274
275            // SuperBit (weighted)
276            let sig_sb_a = sb.create_signature_weighted((0..D).map(|i| (i as u64, a[i])));
277            let sig_sb_b = sb.create_signature_weighted((0..D).map(|i| (i as u64, b[i])));
278            let hd_sb = sig_sb_a.hamming_distance(&sig_sb_b);
279            let th_sb = std::f64::consts::PI * (hd_sb as f64) / (L as f64);
280            mae_sb += (th_sb - theta).abs();
281
282            // Classic SimHash (weighted) — relies on your added weighted API
283            let sig_sh_a = sh.create_signature_weighted((0..D).map(|i| (i as u64, a[i])));
284            let sig_sh_b = sh.create_signature_weighted((0..D).map(|i| (i as u64, b[i])));
285            let hd_sh = sig_sh_a.hamming_distance(&sig_sh_b);
286            let th_sh = std::f64::consts::PI * (hd_sh as f64) / (L as f64);
287            mae_sh += (th_sh - theta).abs();
288        }
289
290        mae_sb /= REPS as f64;
291        mae_sh /= REPS as f64;
292
293        eprintln!(
294            "Weighted accuracy (angle MAE): SuperBit={:.6}, Classic={:.6}",
295            mae_sb, mae_sh
296        );
297
298        // Expect SuperBit's orthogonal hyperplanes to reduce variance => lower error.
299        // Allow a small tolerance for determinism; SuperBit should be strictly better here.
300        assert!(
301            mae_sb <= mae_sh * 0.9 + 1e-6,
302            "SuperBit did not improve enough: SB={:.6}, SH={:.6}",
303            mae_sb,
304            mae_sh
305        );
306    }
307}