1use std::{
18 hash::{BuildHasher, Hash, Hasher},
19 sync::atomic::{AtomicU32, AtomicU64, Ordering},
20};
21
22struct Registers {
23 words: Box<[AtomicU32]>,
24 int_size: u32,
25}
26
27const REGISTER_SIZE: u8 = 6;
29
30impl Registers {
31 pub fn new(len: usize, int_size: u32) -> Self {
32 let ints_per_word = u32::BITS / int_size;
33 let words = (len + ints_per_word as usize - 1) / ints_per_word as usize;
34 Self {
35 words: Vec::from_iter(std::iter::repeat_with(|| AtomicU32::new(0)).take(words))
36 .into_boxed_slice(),
37 int_size,
38 }
39 }
40
41 pub fn incr(&self, j: u64, p: u32) -> Option<(u32, u32)> {
47 let ints_per_word = (u32::BITS / self.int_size) as u64;
48 let word = (j / ints_per_word) as usize;
49 let offset = (j % ints_per_word) as u32 * self.int_size;
50
51 let mask = (1 << self.int_size) - 1;
52 let val = p & mask;
53
54 let mut old_word = self.words[word].load(Ordering::Relaxed);
55
56 loop {
57 let old_val = (old_word >> offset) & mask;
58 if old_val >= val {
59 return None;
60 }
61
62 let new_word = (old_word & !(mask << offset)) | (val << offset);
63
64 match self.words[word].compare_exchange(
65 old_word,
66 new_word,
67 Ordering::Relaxed,
68 Ordering::Relaxed,
69 ) {
70 Ok(_) => return Some((old_val, val)),
71 Err(val) => old_word = val,
72 };
73 }
74 }
75
76 pub fn merge(&self, other: &Self, counters: &Counters) {
78 assert_eq!(self.int_size, other.int_size);
79 assert_eq!(self.words.len(), other.words.len());
80
81 let ints_per_word = (u32::BITS / self.int_size) as u64;
82 let mask = (1 << self.int_size) - 1;
83
84
85 for w_idx in 0..self.words.len() {
86 let mut old_word = self.words[w_idx].load(Ordering::Relaxed);
87 let (mut reciprocal_adj, mut zero_count_adj);
88 loop {
89 reciprocal_adj = 0;
90 zero_count_adj = 0;
91 let mut their_word = other.words[w_idx].load(Ordering::Relaxed);
92 let mut our_word = old_word;
93 let mut new_word = 0;
94 for i in 0..ints_per_word {
95 let their_val = their_word & mask;
96 let our_val = our_word & mask;
97
98 let new_val = if their_val > our_val {
99 let old_recip = 1u64 << RECIP_PRECISION.saturating_sub(our_val);
100 let new_recip = 1u64 << RECIP_PRECISION.saturating_sub(their_val);
101 reciprocal_adj += old_recip - new_recip;
102 zero_count_adj += (our_val == 0) as u64;
103 their_val
104 } else {
105 our_val
106 };
107
108 new_word |= new_val << i * self.int_size as u64;
109 their_word = their_word >> self.int_size;
110 our_word = our_word >> self.int_size;
111 }
112 match self.words[w_idx].compare_exchange(old_word, new_word, Ordering::Relaxed, Ordering::Relaxed) {
113 Ok(_) => break,
114 Err(word) => {old_word = word}
115 }
116 }
117 counters.reciprical_sum.fetch_sub(reciprocal_adj, Ordering::Relaxed);
118 counters.zero_count.fetch_sub(zero_count_adj, Ordering::Relaxed);
119 }
120 }
121}
122
123const RECIP_PRECISION: u32 = 47;
125
126pub struct HyperLogLog<H: BuildHasher> {
129 registers: Registers,
130 counters: Counters,
131 b: u8,
132 hasher: H,
133}
134
135struct Counters {
136 reciprical_sum: AtomicU64,
137 zero_count: AtomicU64,
138}
139
140impl<H> HyperLogLog<H>
141where
142 H: BuildHasher,
143{
144 pub fn new(hasher: H, b: u8) -> Self {
148 assert!(4 <= b && b <= 16);
149
150 let m = 1 << b;
151 let registers = Registers::new(
152 m,
153 REGISTER_SIZE as u32
154 );
155
156 Self {
157 hasher,
158 registers,
159 counters: Counters {
160 reciprical_sum: AtomicU64::new((1u64 << RECIP_PRECISION) * m as u64),
161 zero_count: AtomicU64::new(m as u64),
162 },
163 b,
164 }
165 }
166
167 pub fn stderr(&self) -> f64 {
169 let m = 1 << self.b;
170 1.04 / (m as f64).sqrt()
171 }
172
173 pub fn add<T: Hash>(&self, val: T) {
175 let mut hasher = self.hasher.build_hasher();
176 val.hash(&mut hasher);
177 let x = hasher.finish();
178
179 let j = x & ((1 << self.b) - 1);
180 let p = 1 + x.leading_zeros();
181
182 if let Some((old, new)) = self.registers.incr(j, p) {
183 let old_recip = 1u64 << RECIP_PRECISION.saturating_sub(old);
184 let new_recip = 1u64 << RECIP_PRECISION.saturating_sub(new);
185
186 self.counters.reciprical_sum.fetch_sub(old_recip - new_recip, Ordering::Relaxed);
187 if old == 0 {
188 self.counters.zero_count.fetch_sub(1, Ordering::Relaxed);
189 }
190 }
191 }
192
193 pub fn merge(&self, other: &Self) {
195 assert_eq!(self.b, other.b);
196 self.registers.merge(&other.registers, &self.counters);
197 }
198
199 pub fn cardinality(&self) -> f64 {
201 fn inner(reciprical_sum: u64, zero_count: u64, b: u8) -> f64 {
202 let max = 2f64.powi(RECIP_PRECISION as i32 + b as i32);
203 let m = 1 << b;
204 let m_f64 = m as f64;
205
206 let z_recip = fixed_point_to_floating_point(reciprical_sum, RECIP_PRECISION as i32);
207 let a = match m {
208 16 => 0.673,
209 32 => 0.697,
210 64 => 0.709,
211 _ => 0.7213 / (1f64 + 1.079 / m_f64),
212 };
213 let e_unscaled = a / z_recip;
214 let e = e_unscaled * m_f64.powi(2); if e_unscaled * m_f64 <= 2.5f64 {
217 if zero_count != 0 {
219 let u: f64 = (b as f64) - (zero_count as f64).log2(); return m_f64 * u;
221 }
222 } else if e / max > 30.0 {
223 return -max * (1f64 - (e / max)).log2();
225 }
226
227 e
228 }
229
230 inner(
231 self.counters.reciprical_sum.load(Ordering::Relaxed),
232 self.counters.zero_count.load(Ordering::Relaxed),
233 self.b,
234 )
235 }
236}
237
238fn fixed_point_to_floating_point(fixed: u64, ones_place: i32) -> f64 {
240 const MANTISSA_BITS: i32 = f64::MANTISSA_DIGITS as i32 - 1;
241 const MANTISSA_MASK: u64 = 0x000f_ffff_ffff_ffff;
242
243 if fixed == 0 {
245 return 0.0;
246 }
247
248 let shift = (u64::BITS - f64::MANTISSA_DIGITS) as i32 - fixed.leading_zeros() as i32;
252 let mantissa = if shift > 0 {
253 fixed >> shift
254 } else {
255 fixed << -shift
256 } as u64
257 & MANTISSA_MASK;
258
259 let exp = MANTISSA_BITS - ones_place + shift as i32;
261 let e_biased = (exp + 1023) as u64;
262 f64::from_bits(e_biased << MANTISSA_BITS | mantissa)
263}
264
265#[cfg(test)]
266mod tests {
267 use seahash::SeaHasher;
268
269 use super::*;
270
271 #[test]
272 fn fixed_to_float() {
273 for n in [
274 0u64,
275 1,
276 0x000f_ffff_ffff_ffff,
277 0xffff_ffff_ffff_ffff,
278 0x1000_0000_0000,
279 0x1000_0000_0001,
280 0x1000_1000_0001,
281 0xffff_ffff_ffff,
282 0xabcd_ef12_abcd_ef45,
283 ] {
284 let actual = fixed_point_to_floating_point(n, 64);
285 let expected = n as f64 / (2.0f64).powi(64);
286 assert!(actual - expected < 0.001, "{actual} ≠ {expected}")
287 }
288 }
289
290 struct BuildHasherClone<H: Hasher + Clone>(H);
291 impl<H: Hasher + Clone> BuildHasher for BuildHasherClone<H> {
292 type Hasher = H;
293
294 fn build_hasher(&self) -> Self::Hasher {
295 self.0.clone()
296 }
297 }
298
299 #[test]
300 fn ten_thousand() {
301 let b = 4;
302 let m = 1 << b;
303 let sterr = 1.04 / (m as f64).sqrt();
304 let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
305 assert_eq!(hll.cardinality(), 0f64);
306
307 for n in 1..=10000 {
308 hll.add(n);
309
310 if n % 10 == 1 {
311 let c = hll.cardinality();
312 let rel_error = (c / n as f64) - 1.;
313 let z = rel_error / sterr;
314
315 assert!(
316 z.abs() <= 3.0,
317 "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
318 );
319 }
320 }
321 }
322
323 #[test]
324 fn million() {
325 let b = 4;
326 let m = 1 << b;
327 let sterr = 1.04 / (m as f64).sqrt();
328 let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
329 assert_eq!(hll.cardinality(), 0f64);
330
331 for n in 1..=1_000_000 {
332 hll.add(n);
333
334 if n % 100_000 == 0 {
335 let c = hll.cardinality();
336 let rel_error = (c / n as f64) - 1.;
337 let z = rel_error / sterr;
338
339 assert!(
340 z.abs() <= 3.0,
341 "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
342 );
343 }
344 }
345 }
346
347 #[test]
348 fn merging_small() {
349 let b = 8;
350 let m = 1 << b;
351 let sterr = 1.04 / (m as f64).sqrt();
352
353 let hll1 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
354 let hll2 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
355 let hll3 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
356
357 hll1.add(1);
358 hll2.add(2);
359 hll3.add(3);
360
361 hll1.merge(&hll2);
362 hll2.merge(&hll3);
363
364 assert_eq!(hll1.cardinality(), hll2.cardinality());
365 assert_ne!(hll2.cardinality(), hll3.cardinality());
366 }
367
368 #[test]
369 fn merging() {
370 let b = 8;
371 let m = 1 << b;
372 let sterr = 1.04 / (m as f64).sqrt();
373
374 let hll1 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
375 let hll2 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
376 let hll3 = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
377
378 assert_eq!(hll1.cardinality(), 0f64);
379
380 for n in 1..=1_000_000 {
381 hll1.add(n);
382 hll2.add(n);
383 hll3.add(!n);
384 }
385
386 assert_eq!(hll1.cardinality(), hll2.cardinality());
387 assert_ne!(hll1.cardinality(), hll3.cardinality());
388
389 for n in 1_000_000..=2_000_000 {
390 hll2.add(n);
391 }
392
393 assert_ne!(hll1.cardinality(), hll2.cardinality());
394
395 hll1.merge(&hll2);
396
397 assert_eq!(hll1.cardinality(), hll2.cardinality());
398
399 let expected = hll2.cardinality() + hll3.cardinality();
400 hll2.merge(&hll3);
401 let error = (hll2.cardinality() - expected) / expected;
402 let z = error / (sterr.powi(2) * 2.0).sqrt();
403 assert!(z <= 1.0, "should be within 1 margin of error of difference after merging");
404 }
405
406 #[test]
407 fn million_b8() {
408 let b = 8;
409 let m = 1 << b;
410 let sterr = 1.04 / (m as f64).sqrt();
411 let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
412 assert_eq!(hll.cardinality(), 0f64);
413
414 for n in 1..=1_000_000 {
415 hll.add(n);
416
417 if n % 100_000 == 0 {
418 let c = hll.cardinality();
419 let rel_error = (c / n as f64) - 1.;
420 let z = rel_error / sterr;
421
422 assert!(
423 z.abs() <= 3.0,
424 "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
425 );
426 }
427 }
428 }
429
430 #[test]
431 fn million_b16() {
432 let b = 16;
433 let m = 1 << b;
434 let sterr = 1.04 / (m as f64).sqrt();
435 let hll = HyperLogLog::new(BuildHasherClone(SeaHasher::new()), b);
436 assert_eq!(hll.cardinality(), 0f64);
437
438 for n in 1..=1_000_000 {
439 hll.add(n);
440
441 if n % 250_000 == 0 {
442 let c = hll.cardinality();
443 let rel_error = (c / n as f64) - 1.;
444 let z = rel_error / sterr;
445
446 assert!(
447 z.abs() <= 4.0,
448 "z was {z}, c: {c}, n: {n}, rel_er: {rel_error}"
449 );
450 }
451 }
452 }
453}