atomic_hyperloglog/
lib.rs

1//! # Atomic HyperLogLog
2//! 
3//! a concurrent, super fast, pretty-well-tested and fully safe hyperloglog for rust with no dependencies.
4//! 
5//! ```
6//!# use std::hash::BuildHasherDefault;
7//!# use seahash::SeaHasher;
8//! use atomic_hyperloglog::HyperLogLog;
9//! let h = HyperLogLog::new(BuildHasherDefault::<SeaHasher>::default(), 12);
10//! for n in 0..10_000 {
11//!     h.add(n);
12//! }
13//! let est = h.cardinality();
14//! assert!(10_000.0 * 0.95 <= est && est <= 10_000.0 * 1.05, "{est}");
15//! ```
16
17use std::{
18    hash::{BuildHasher, Hash, Hasher},
19    sync::atomic::{AtomicU32, AtomicU64, Ordering},
20};
21
22struct Registers {
23    words: Box<[AtomicU32]>,
24    int_size: u32,
25}
26
27/// the bit length of registers
28const REGISTER_SIZE: u8 = 6;
29
30impl Registers {
31    pub fn new(len: usize, int_size: u32) -> Self {
32        let ints_per_word = u32::BITS / int_size;
33        let words = (len + ints_per_word as usize - 1) / ints_per_word as usize;
34        Self {
35            words: Vec::from_iter(std::iter::repeat_with(|| AtomicU32::new(0)).take(words))
36                .into_boxed_slice(),
37            int_size,
38        }
39    }
40
41    /// Increment the relevant register, given j and p (terms used in HLL paper).
42    /// 
43    /// Params:
44    /// - j, the index of the register
45    /// - p, `1 + leading zeros`
46    pub fn incr(&self, j: u64, p: u32) -> Option<(u32, u32)> {
47        let ints_per_word = (u32::BITS / self.int_size) as u64;
48        let word = (j / ints_per_word) as usize;
49        let offset = (j % ints_per_word) as u32 * self.int_size;
50
51        let mask = (1 << self.int_size) - 1;
52        let val = p & mask;
53
54        let mut old_word = self.words[word].load(Ordering::Relaxed);
55
56        loop {
57            let old_val = (old_word >> offset) & mask;
58            if old_val >= val {
59                return None;
60            }
61
62            let new_word = (old_word & !(mask << offset)) | (val << offset);
63
64            match self.words[word].compare_exchange(
65                old_word,
66                new_word,
67                Ordering::Relaxed,
68                Ordering::Relaxed,
69            ) {
70                Ok(_) => return Some((old_val, val)),
71                Err(val) => old_word = val,
72            };
73        }
74    }
75
76    /// merge two register sets, modifying self, updating counters as our data changes
77    pub fn merge(&self, other: &Self, counters: &Counters) {
78        assert_eq!(self.int_size, other.int_size);
79        assert_eq!(self.words.len(), other.words.len());
80
81        let ints_per_word = (u32::BITS / self.int_size) as u64;
82        let mask = (1 << self.int_size) - 1;
83
84
85        for w_idx in 0..self.words.len() {
86            let mut old_word = self.words[w_idx].load(Ordering::Relaxed);
87            let (mut reciprocal_adj, mut zero_count_adj);
88            loop {
89                reciprocal_adj = 0;
90                zero_count_adj = 0;
91                let mut their_word = other.words[w_idx].load(Ordering::Relaxed);
92                let mut our_word = old_word;
93                let mut new_word = 0;
94                for i in 0..ints_per_word {
95                    let their_val = their_word & mask;
96                    let our_val = our_word & mask;
97
98                    let new_val = if their_val > our_val {
99                        let old_recip = 1u64 << RECIP_PRECISION.saturating_sub(our_val);
100                        let new_recip = 1u64 << RECIP_PRECISION.saturating_sub(their_val);
101                        reciprocal_adj += old_recip - new_recip;
102                        zero_count_adj += (our_val == 0) as u64;
103                        their_val
104                    } else {
105                        our_val
106                    };
107
108                    new_word |= new_val << i * self.int_size as u64;
109                    their_word = their_word >> self.int_size;
110                    our_word = our_word >> self.int_size;
111                }
112                match self.words[w_idx].compare_exchange(old_word, new_word, Ordering::Relaxed, Ordering::Relaxed) {
113                    Ok(_) => break,
114                    Err(word) => {old_word = word}
115                }
116            }
117            counters.reciprical_sum.fetch_sub(reciprocal_adj, Ordering::Relaxed);
118            counters.zero_count.fetch_sub(zero_count_adj, Ordering::Relaxed);
119        }
120    }
121}
122
123/// fixed-point precision bits used for the reciprocal sum
124const RECIP_PRECISION: u32 = 47;
125
126/// A hyperloglog data structure, allowing count-distinct with limited memory overhead.
127/// Fully concurrent with relaxed-only ordering and zero-unsafe code.
128pub struct HyperLogLog<H: BuildHasher> {
129    registers: Registers,
130    counters: Counters,
131    b: u8,
132    hasher: H,
133}
134
135struct Counters {
136    reciprical_sum: AtomicU64,
137    zero_count: AtomicU64,
138}
139
140impl<H> HyperLogLog<H>
141where
142    H: BuildHasher,
143{
144    /// Create a new hyperloglog data structure
145    /// 
146    /// parameters: hasher: hash function, b = log_2{number of bins}
147    pub fn new(hasher: H, b: u8) -> Self {
148        assert!(4 <= b && b <= 16);
149
150        let m = 1 << b;
151        let registers = Registers::new(
152            m,
153            REGISTER_SIZE as u32
154        );
155
156        Self {
157            hasher,
158            registers,
159            counters: Counters {
160                reciprical_sum: AtomicU64::new((1u64 << RECIP_PRECISION) * m as u64),
161                zero_count: AtomicU64::new(m as u64),
162            },
163            b,
164        }
165    }
166
167    /// calculates the standard relative error for the given `b` parameter
168    pub fn stderr(&self) -> f64 {
169        let m = 1 << self.b;
170        1.04 / (m as f64).sqrt()
171    }
172
173    /// Add a value to the count
174    pub fn add<T: Hash>(&self, val: T) {
175        let mut hasher = self.hasher.build_hasher();
176        val.hash(&mut hasher);
177        let x = hasher.finish();
178
179        let j = x & ((1 << self.b) - 1);
180        let p = 1 + x.leading_zeros();
181
182        if let Some((old, new)) = self.registers.incr(j, p) {
183            let old_recip = 1u64 << RECIP_PRECISION.saturating_sub(old);
184            let new_recip = 1u64 << RECIP_PRECISION.saturating_sub(new);
185
186            self.counters.reciprical_sum.fetch_sub(old_recip - new_recip, Ordering::Relaxed);
187            if old == 0 {
188                self.counters.zero_count.fetch_sub(1, Ordering::Relaxed);
189            }
190        }
191    }
192
193    /// merge other's count into self
194    pub fn merge(&self, other: &Self) {
195        assert_eq!(self.b, other.b);
196        self.registers.merge(&other.registers, &self.counters);
197    }
198
199    /// Get the cardinality estimate
200    pub fn cardinality(&self) -> f64 {
201        fn inner(reciprical_sum: u64, zero_count: u64, b: u8) -> f64 {
202            let max = 2f64.powi(RECIP_PRECISION as i32 + b as i32);
203            let m = 1 << b;
204            let m_f64 = m as f64;
205
206            let z_recip = fixed_point_to_floating_point(reciprical_sum, RECIP_PRECISION as i32);
207            let a = match m {
208                16 => 0.673,
209                32 => 0.697,
210                64 => 0.709,
211                _ => 0.7213 / (1f64 + 1.079 / m_f64),
212            };
213            let e_unscaled = a / z_recip;
214            let e = e_unscaled * m_f64.powi(2); // the “raw” HyperLogLog estimate
215
216            if e_unscaled * m_f64 <= 2.5f64 {
217                // small range correction
218                if zero_count != 0 {
219                    let u: f64 = (b as f64) - (zero_count as f64).log2(); // u = log(m / V)
220                    return m_f64 * u;
221                }
222            } else if e / max > 30.0 {
223                // large range correction
224                return -max * (1f64 - (e / max)).log2();
225            }
226
227            e
228        }
229
230        inner(
231            self.counters.reciprical_sum.load(Ordering::Relaxed),
232            self.counters.zero_count.load(Ordering::Relaxed),
233            self.b,
234        )
235    }
236}
237
238/// Convert a fixed point number to a floating point number
239fn fixed_point_to_floating_point(fixed: u64, ones_place: i32) -> f64 {
240    const MANTISSA_BITS: i32 = f64::MANTISSA_DIGITS as i32 - 1;
241    const MANTISSA_MASK: u64 = 0x000f_ffff_ffff_ffff;
242
243    // now convert to IEEE floating point:
244    if fixed == 0 {
245        return 0.0;
246    }
247
248    // determine the mantissa... we only get to keep the 53 most significant digits.
249    // we can assume that the floating point number is not zero
250    // align the first 1 bit to be the hidden bit
251    let shift = (u64::BITS - f64::MANTISSA_DIGITS) as i32 - fixed.leading_zeros() as i32;
252    let mantissa = if shift > 0 {
253        fixed >> shift
254    } else {
255        fixed << -shift
256    } as u64
257        & MANTISSA_MASK;
258
259    // reconstruct the floating point
260    let exp = MANTISSA_BITS - ones_place + shift as i32;
261    let e_biased = (exp + 1023) as u64;
262    f64::from_bits(e_biased << MANTISSA_BITS | mantissa)
263}
264
265#[cfg(test)]
266mod tests {
267    use seahash::SeaHasher;
268
269    use super::*;
270
271    #[test]
272    fn fixed_to_float() {
273        for n in [
274            0u64,
275            1,
276            0x000f_ffff_ffff_ffff,
277            0xffff_ffff_ffff_ffff,
278            0x1000_0000_0000,
279            0x1000_0000_0001,
280            0x1000_1000_0001,
281            0xffff_ffff_ffff,
282            0xabcd_ef12_abcd_ef45,
283        ] {
284            let actual = fixed_point_to_floating_point(n, 64);
285            let expected = n as f64 / (2.0f64).powi(64);
286            assert!(actual - expected < 0.001, "{actual} ≠ {expected}")
287        }
288    }
289
290    struct BuildHasherClone<H: Hasher + Clone>(H);
291    impl<H: Hasher + Clone> BuildHasher for BuildHasherClone<H> {
292        type Hasher = H;
293
294        fn build_hasher(&self) -> Self::Hasher {
295            self.0.clone()
296        }
297    }
298
299    #[test]
300    fn ten_thousand() {
301        let b = 4;
302        let m = 1 << b;
303        let sterr = 1.04 / (m as f64).sqrt();
304        let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
305        assert_eq!(hll.cardinality(), 0f64);
306
307        for n in 1..=10000 {
308            hll.add(n);
309
310            if n % 10 == 1 {
311                let c = hll.cardinality();
312                let rel_error = (c / n as f64) - 1.;
313                let z = rel_error / sterr;
314
315                assert!(
316                    z.abs() <= 3.0,
317                    "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
318                );
319            }
320        }
321    }
322
323    #[test]
324    fn million() {
325        let b = 4;
326        let m = 1 << b;
327        let sterr = 1.04 / (m as f64).sqrt();
328        let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
329        assert_eq!(hll.cardinality(), 0f64);
330
331        for n in 1..=1_000_000 {
332            hll.add(n);
333
334            if n % 100_000 == 0 {
335                let c = hll.cardinality();
336                let rel_error = (c / n as f64) - 1.;
337                let z = rel_error / sterr;
338
339                assert!(
340                    z.abs() <= 3.0,
341                    "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
342                );
343            }
344        }
345    }
346
347    #[test]
348    fn merging_small() {
349        let b = 8;
350        let m = 1 << b;
351        let sterr = 1.04 / (m as f64).sqrt();
352
353        let hll1 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
354        let hll2 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
355        let hll3 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
356
357        hll1.add(1);
358        hll2.add(2);
359        hll3.add(3);
360
361        hll1.merge(&hll2);
362        hll2.merge(&hll3);
363
364        assert_eq!(hll1.cardinality(), hll2.cardinality());
365        assert_ne!(hll2.cardinality(), hll3.cardinality());
366    }
367
368    #[test]
369    fn merging() {
370        let b = 8;
371        let m = 1 << b;
372        let sterr = 1.04 / (m as f64).sqrt();
373
374        let hll1 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
375        let hll2 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
376        let hll3 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
377
378        assert_eq!(hll1.cardinality(), 0f64);
379
380        for n in 1..=1_000_000 {
381            hll1.add(n);
382            hll2.add(n);
383            hll3.add(!n);
384        }
385
386        assert_eq!(hll1.cardinality(), hll2.cardinality());
387        assert_ne!(hll1.cardinality(), hll3.cardinality());
388
389        for n in 1_000_000..=2_000_000 {
390            hll2.add(n);
391        }
392
393        assert_ne!(hll1.cardinality(), hll2.cardinality());
394
395        hll1.merge(&hll2);
396
397        assert_eq!(hll1.cardinality(), hll2.cardinality());
398
399        let expected = hll2.cardinality() + hll3.cardinality();
400        hll2.merge(&hll3);
401        let error = (hll2.cardinality() - expected) / expected;
402        let z = error / (sterr.powi(2) * 2.0).sqrt();
403        assert!(z <= 1.0, "should be within 1 margin of error of difference after merging");
404    }
405
406    #[test]
407    fn million_b8() {
408        let b = 8;
409        let m = 1 << b;
410        let sterr = 1.04 / (m as f64).sqrt();
411        let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
412        assert_eq!(hll.cardinality(), 0f64);
413
414        for n in 1..=1_000_000 {
415            hll.add(n);
416
417            if n % 100_000 == 0 {
418                let c = hll.cardinality();
419                let rel_error = (c / n as f64) - 1.;
420                let z = rel_error / sterr;
421
422                assert!(
423                    z.abs() <= 3.0,
424                    "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
425                );
426            }
427        }
428    }
429
430    #[test]
431    fn million_b16() {
432        let b = 16;
433        let m = 1 << b;
434        let sterr = 1.04 / (m as f64).sqrt();
435        let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
436        assert_eq!(hll.cardinality(), 0f64);
437
438        for n in 1..=1_000_000 {
439            hll.add(n);
440
441            if n % 250_000 == 0 {
442                let c = hll.cardinality();
443                let rel_error = (c / n as f64) - 1.;
444                let z = rel_error / sterr;
445
446                assert!(
447                    z.abs() <= 4.0,
448                    "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
449                );
450            }
451        }
452    }
453}