Skip to main content

card_est_array/impls/
hyper_log_log.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Matteo Dell'Acqua
3 * SPDX-FileCopyrightText: 2025 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use anyhow::{Result, ensure};
9use common_traits::{CastableFrom, CastableInto, Number, UpcastableInto};
10use std::hash::*;
11use std::{borrow::Borrow, f64::consts::LN_2};
12use sux::{bits::BitFieldVec, traits::Word};
13use value_traits::slices::{SliceByValue, SliceByValueMut};
14
15use crate::traits::{EstimationLogic, MergeEstimationLogic, SliceEstimationLogic};
16
17use super::DefaultEstimator;
18
19/// The type returned by the hash function.
20type HashResult = u64;
21
22/// Estimator logic implementing the HyperLogLog algorithm.
23///
24/// Instances are built using [`HyperLogLogBuilder`], which provides convenient
25/// ways to set the internal parameters.
26///
27/// Note that `T` can be any type satisfying the [`Hash`] trait. The parameter
28/// `H` makes it possible to select a hashing algorithm, and `W` is the unsigned
29/// type used to store backends.
30///
31/// An important constraint is that `W` must be able to represent exactly the
32/// backend of an estimator. While usually `usize` will work (and it is the default
33/// type chosen by [`new`](HyperLogLogBuilder::new)), with odd register sizes
34/// and a small number of registers it might be necessary to select a smaller
35/// type, resulting in slower merges. For example, using 16 5-bit registers one
36/// needs to use `u16`, whereas for 16 6-bit registers `u32` will be sufficient.
37#[derive(Debug, PartialEq)]
38pub struct HyperLogLog<T, H, W> {
39    build_hasher: H,
40    register_size: usize,
41    num_registers_minus_1: HashResult,
42    log_2_num_registers: usize,
43    sentinel_mask: HashResult,
44    num_registers: usize,
45    pub(super) words_per_estimator: usize,
46    alpha_m_m: f64,
47    msb_mask: Box<[W]>,
48    lsb_mask: Box<[W]>,
49    _marker: std::marker::PhantomData<T>,
50}
51
52// We implement Clone manually because we do not want to require that T is
53// Clone.
54impl<T, H: Clone, W: Clone> Clone for HyperLogLog<T, H, W> {
55    fn clone(&self) -> Self {
56        Self {
57            build_hasher: self.build_hasher.clone(),
58            register_size: self.register_size,
59            num_registers_minus_1: self.num_registers_minus_1,
60            log_2_num_registers: self.log_2_num_registers,
61            sentinel_mask: self.sentinel_mask,
62            num_registers: self.num_registers,
63            words_per_estimator: self.words_per_estimator,
64            alpha_m_m: self.alpha_m_m,
65            msb_mask: self.msb_mask.clone(),
66            lsb_mask: self.lsb_mask.clone(),
67            _marker: std::marker::PhantomData,
68        }
69    }
70}
71
72impl<T, H: Clone, W: Word> HyperLogLog<T, H, W> {
73    /// Returns the value contained in a register of a given backend.
74    #[inline(always)]
75    fn get_register_unchecked(&self, backend: impl AsRef<[W]>, index: usize) -> W {
76        let backend = backend.as_ref();
77        let bit_width = self.register_size;
78        let mask = W::MAX >> (W::BITS - bit_width);
79        let pos = index * bit_width;
80        let word_index = pos / W::BITS;
81        let bit_index = pos % W::BITS;
82
83        if bit_index + bit_width <= W::BITS {
84            (unsafe { *backend.get_unchecked(word_index) } >> bit_index) & mask
85        } else {
86            ((unsafe { *backend.get_unchecked(word_index) } >> bit_index)
87                | (unsafe { *backend.get_unchecked(word_index + 1) } << (W::BITS - bit_index)))
88                & mask
89        }
90    }
91
92    /// Sets the value contained in a register of a given backend.
93    #[inline(always)]
94    fn set_register_unchecked(&self, mut backend: impl AsMut<[W]>, index: usize, new_value: W) {
95        let backend = backend.as_mut();
96        let bit_width = self.register_size;
97        let mask = W::MAX >> (W::BITS - bit_width);
98        let pos = index * bit_width;
99        let word_index = pos / W::BITS;
100        let bit_index = pos % W::BITS;
101
102        if bit_index + bit_width <= W::BITS {
103            let mut word = unsafe { *backend.get_unchecked_mut(word_index) };
104            word &= !(mask << bit_index);
105            word |= new_value << bit_index;
106            unsafe { *backend.get_unchecked_mut(word_index) = word };
107        } else {
108            let mut word = unsafe { *backend.get_unchecked_mut(word_index) };
109            word &= (W::ONE << bit_index) - W::ONE;
110            word |= new_value << bit_index;
111            unsafe { *backend.get_unchecked_mut(word_index) = word };
112
113            let mut word = unsafe { *backend.get_unchecked_mut(word_index + 1) };
114            word &= !(mask >> (W::BITS - bit_index));
115            word |= new_value >> (W::BITS - bit_index);
116            unsafe { *backend.get_unchecked_mut(word_index + 1) = word };
117        }
118    }
119}
120
121impl<
122    T: Hash,
123    H: BuildHasher + Clone,
124    W: Word + UpcastableInto<HashResult> + CastableFrom<HashResult>,
125> SliceEstimationLogic<W> for HyperLogLog<T, H, W>
126{
127    #[inline(always)]
128    fn backend_len(&self) -> usize {
129        self.words_per_estimator
130    }
131}
132
133impl<
134    T: Hash,
135    H: BuildHasher + Clone,
136    W: Word + UpcastableInto<HashResult> + CastableFrom<HashResult>,
137> EstimationLogic for HyperLogLog<T, H, W>
138{
139    type Item = T;
140    type Backend = [W];
141    type Estimator<'a>
142        = DefaultEstimator<Self, &'a Self, Box<[W]>>
143    where
144        T: 'a,
145        W: 'a,
146        H: 'a;
147
148    fn new_estimator(&self) -> Self::Estimator<'_> {
149        Self::Estimator::new(
150            self,
151            vec![W::ZERO; self.words_per_estimator].into_boxed_slice(),
152        )
153    }
154
155    fn add(&self, backend: &mut Self::Backend, element: impl Borrow<T>) {
156        let x = self.build_hasher.hash_one(element.borrow());
157        let j = x & self.num_registers_minus_1;
158        let r =
159            ((x >> self.log_2_num_registers) | self.sentinel_mask).trailing_zeros() as HashResult;
160        let register = j as usize;
161
162        debug_assert!(r < (1 << self.register_size) - 1);
163        debug_assert!(register < self.num_registers);
164
165        let current_value = self.get_register_unchecked(&*backend, register);
166        let candidate_value = r + 1;
167        let new_value = std::cmp::max(current_value, candidate_value.cast());
168        if current_value != new_value {
169            self.set_register_unchecked(backend, register, new_value);
170        }
171    }
172
173    fn estimate(&self, backend: &[W]) -> f64 {
174        let mut harmonic_mean = 0.0;
175        let mut zeroes = 0;
176
177        for i in 0..self.num_registers {
178            let value: u64 = self.get_register_unchecked(backend, i).upcast();
179            if value == 0 {
180                zeroes += 1;
181            }
182            harmonic_mean += 1.0 / (1_u64 << value) as f64;
183        }
184
185        let mut estimate = self.alpha_m_m / harmonic_mean;
186        if zeroes != 0 && estimate < 2.5 * self.num_registers as f64 {
187            estimate = self.num_registers as f64 * (self.num_registers as f64 / zeroes as f64).ln();
188        }
189        estimate
190    }
191
192    #[inline(always)]
193    fn clear(&self, backend: &mut [W]) {
194        backend.fill(W::ZERO);
195    }
196
197    #[inline(always)]
198    fn set(&self, dst: &mut [W], src: &[W]) {
199        debug_assert_eq!(dst.len(), src.len());
200        dst.copy_from_slice(src);
201    }
202}
203
204/// Helper for merge operations with [`HyperLogLog`] logic.
205pub struct HyperLogLogHelper<W> {
206    acc: Vec<W>,
207    mask: Vec<W>,
208}
209
210impl<
211    T: Hash,
212    H: BuildHasher + Clone,
213    W: Word + UpcastableInto<HashResult> + CastableFrom<HashResult>,
214> MergeEstimationLogic for HyperLogLog<T, H, W>
215{
216    type Helper = HyperLogLogHelper<W>;
217
218    fn new_helper(&self) -> Self::Helper {
219        HyperLogLogHelper {
220            acc: vec![W::ZERO; self.words_per_estimator],
221            mask: vec![W::ZERO; self.words_per_estimator],
222        }
223    }
224
225    #[inline(always)]
226    fn merge_with_helper(&self, dst: &mut [W], src: &[W], helper: &mut Self::Helper) {
227        merge_hyperloglog_bitwise(
228            dst,
229            src,
230            self.msb_mask.as_ref(),
231            self.lsb_mask.as_ref(),
232            &mut helper.acc,
233            &mut helper.mask,
234            self.register_size,
235        );
236    }
237}
238
239/// Builds a [`HyperLogLog`] cardinality-estimator logic.
240#[derive(Debug, Clone)]
241pub struct HyperLogLogBuilder<H, W = usize> {
242    build_hasher: H,
243    log_2_num_registers: usize,
244    n: usize,
245    _marker: std::marker::PhantomData<W>,
246}
247
248impl HyperLogLogBuilder<BuildHasherDefault<DefaultHasher>, usize> {
249    /// Creates a new builder for a [`HyperLogLog`] logic with a default word
250    /// type of `usize`.
251    pub fn new(n: usize) -> Self {
252        Self {
253            build_hasher: BuildHasherDefault::default(),
254            log_2_num_registers: 4,
255            n,
256            _marker: std::marker::PhantomData,
257        }
258    }
259}
260
261fn min_alignment(bits: usize) -> String {
262    if bits % 128 == 0 {
263        "u128"
264    } else if bits % 64 == 0 {
265        "u64"
266    } else if bits % 32 == 0 {
267        "u32"
268    } else if bits % 16 == 0 {
269        "u16"
270    } else {
271        "u8"
272    }
273    .to_string()
274}
275
276impl HyperLogLog<(), (), ()> {
277    /// Returns the logarithm of the number of registers per estimator that are
278    /// necessary to attain a given relative standard deviation.
279    ///
280    /// # Arguments
281    /// * `rsd`: the relative standard deviation to be attained.
282    pub fn log_2_num_of_registers(rsd: f64) -> usize {
283        ((1.106 / rsd).pow(2.0)).log2().ceil() as usize
284    }
285
286    /// Returns the relative standard deviation corresponding to a given number
287    /// of registers per estimator.
288    ///
289    /// # Arguments
290    ///
291    /// * `log_2_num_registers`: the logarithm of the number of registers per
292    ///   estimator.
293    pub fn rel_std(log_2_num_registers: usize) -> f64 {
294        let tmp = match log_2_num_registers {
295            4 => 1.106,
296            5 => 1.070,
297            6 => 1.054,
298            7 => 1.046,
299            _ => 1.04,
300        };
301        tmp / ((1 << log_2_num_registers) as f64).sqrt()
302    }
303
304    /// Returns the register size in bits, given an upper bound on the number of
305    /// distinct elements.
306    ///
307    /// # Arguments
308    /// * `n`: an upper bound on the number of distinct elements.
309    pub fn register_size(n: usize) -> usize {
310        std::cmp::max(5, (((n as f64).ln() / LN_2) / LN_2).ln().ceil() as usize)
311    }
312}
313
314impl<H, W: Word> HyperLogLogBuilder<H, W> {
315    /// Sets the desired relative standard deviation.
316    ///
317    /// ## Note
318    ///
319    /// This is a high-level alternative to [`Self::log_2_num_reg`]. Calling one
320    /// after the other invalidates the work done by the first one.
321    ///
322    /// # Arguments
323    /// * `rsd`: the relative standard deviation to be attained.
324    pub fn rsd(self, rsd: f64) -> Self {
325        self.log_2_num_reg(HyperLogLog::log_2_num_of_registers(rsd))
326    }
327
328    /// Sets the base-2 logarithm of the number of registers.
329    ///
330    /// ## Note
331    /// This is a low-level alternative to [`Self::rsd`]. Calling one after the
332    /// other invalidates the work done by the first one.
333    ///
334    /// # Arguments
335    /// * `log_2_num_registers`: the logarithm of the number of registers per
336    ///   estimator.
337    pub fn log_2_num_reg(mut self, log_2_num_registers: usize) -> Self {
338        self.log_2_num_registers = log_2_num_registers;
339        self
340    }
341
342    /// Sets the type `W` to use to represent backends.
343    ///
344    /// See the [`logic documentation`](HyperLogLog) for the limitations on the
345    /// choice of `W2`.
346    pub fn word_type<W2>(self) -> HyperLogLogBuilder<H, W2> {
347        HyperLogLogBuilder {
348            n: self.n,
349            build_hasher: self.build_hasher,
350            log_2_num_registers: self.log_2_num_registers,
351            _marker: std::marker::PhantomData,
352        }
353    }
354
355    /// Sets the upper bound on the number of elements.
356    pub fn num_elements(mut self, n: usize) -> Self {
357        self.n = n;
358        self
359    }
360
361    /// Sets the [`BuildHasher`] to use.
362    ///
363    /// Using this method you can select a specific hasher based on one or more
364    /// seeds.
365    pub fn build_hasher<H2>(self, build_hasher: H2) -> HyperLogLogBuilder<H2, W> {
366        HyperLogLogBuilder {
367            n: self.n,
368            log_2_num_registers: self.log_2_num_registers,
369            build_hasher,
370            _marker: std::marker::PhantomData,
371        }
372    }
373
374    /// Builds the logic.
375    ///
376    /// The type of objects the estimators keep track of is defined here by `T`,
377    /// but it is usually inferred by the compiler.
378    ///
379    /// # Errors
380    ///
381    /// Errors will be caused by consistency checks (at least 16 registers per
382    /// estimator, backend bits exactly divisible by `W::BITS`)
383    pub fn build<T>(self) -> Result<HyperLogLog<T, H, W>> {
384        let log_2_num_registers = self.log_2_num_registers;
385        let num_elements = self.n;
386
387        ensure!(
388            num_elements > 0,
389            "the upper bound on the number of distinct elements must be positive"
390        );
391
392        // This ensures estimators are at least 16-bit-aligned.
393        ensure!(
394            log_2_num_registers >= 4,
395            "the logarithm of the number of registers per estimator should be at least 4; got {}",
396            log_2_num_registers
397        );
398
399        let number_of_registers = 1 << log_2_num_registers;
400        let register_size = HyperLogLog::register_size(num_elements);
401        let sentinel_mask = 1 << ((1 << register_size) - 2);
402        let alpha = match log_2_num_registers {
403            4 => 0.673,
404            5 => 0.697,
405            6 => 0.709,
406            _ => 0.7213 / (1.0 + 1.079 / number_of_registers as f64),
407        };
408        let num_registers_minus_1 = (number_of_registers - 1) as HashResult;
409
410        let est_size_in_bits = number_of_registers * register_size;
411
412        // This ensures estimators are always aligned to W
413        ensure!(
414            est_size_in_bits % W::BITS == 0,
415            "W should allow estimator backends to be aligned. Use {} or smaller unsigned integer types",
416            min_alignment(est_size_in_bits)
417        );
418        let est_size_in_words = est_size_in_bits / W::BITS;
419
420        let mut msb = BitFieldVec::new(register_size, number_of_registers);
421        let mut lsb = BitFieldVec::new(register_size, number_of_registers);
422        let msb_w = W::ONE << (register_size - 1);
423        let lsb_w = W::ONE;
424        for i in 0..number_of_registers {
425            msb.set_value(i, msb_w);
426            lsb.set_value(i, lsb_w);
427        }
428
429        Ok(HyperLogLog {
430            num_registers: number_of_registers,
431            num_registers_minus_1,
432            log_2_num_registers,
433            register_size,
434            alpha_m_m: alpha * (number_of_registers as f64).powi(2),
435            sentinel_mask,
436            build_hasher: self.build_hasher,
437            msb_mask: msb.as_slice().into(),
438            lsb_mask: lsb.as_slice().into(),
439            words_per_estimator: est_size_in_words,
440            _marker: std::marker::PhantomData,
441        })
442    }
443}
444
445impl<T, H, W> std::fmt::Display for HyperLogLog<T, H, W> {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        write!(
448            f,
449            "HyperLogLog with relative standard deviation: {}% ({} registers/estimator, {} bits/register, {} bytes/estimator)",
450            100.0 * HyperLogLog::rel_std(self.log_2_num_registers),
451            self.num_registers,
452            self.register_size,
453            (self.num_registers * self.register_size) / 8
454        )
455    }
456}
457
458/// Performs a multiple precision subtraction, leaving the result in the first operand.
459/// The operands MUST have the same length.
460///
461/// # Arguments
462/// * `x`: the first operand. This will contain the final result.
463/// * `y`: the second operand that will be subtracted from `x`.
464#[inline(always)]
465pub(super) fn subtract<W: Word>(x: &mut [W], y: &[W]) {
466    debug_assert_eq!(x.len(), y.len());
467    let mut borrow = false;
468
469    for (x_word, &y) in x.iter_mut().zip(y.iter()) {
470        let mut x = *x_word;
471        if !borrow {
472            borrow = x < y;
473        } else if x != W::ZERO {
474            x = x.wrapping_sub(W::ONE);
475            borrow = x < y;
476        } else {
477            x = x.wrapping_sub(W::ONE);
478        }
479        *x_word = x.wrapping_sub(y);
480    }
481}
482
483fn merge_hyperloglog_bitwise<W: Word>(
484    mut x: impl AsMut<[W]>,
485    y: impl AsRef<[W]>,
486    msb_mask: impl AsRef<[W]>,
487    lsb_mask: impl AsRef<[W]>,
488    acc: &mut Vec<W>,
489    mask: &mut Vec<W>,
490    register_size: usize,
491) {
492    let x = x.as_mut();
493    let y = y.as_ref();
494    let msb_mask = msb_mask.as_ref();
495    let lsb_mask = lsb_mask.as_ref();
496
497    debug_assert_eq!(x.len(), y.len());
498    debug_assert_eq!(x.len(), msb_mask.len());
499    debug_assert_eq!(x.len(), lsb_mask.len());
500
501    let register_size_minus_1 = register_size - 1;
502    let num_words_minus_1 = x.len() - 1;
503    let shift_register_size_minus_1 = W::BITS - register_size_minus_1;
504
505    acc.clear();
506    mask.clear();
507
508    /* We work in two phases. Let H_r (msb_mask) be the mask with the
509     * highest bit of each register (of size r) set, and L_r (lsb_mask)
510     * be the mask with the lowest bit of each register set.
511     * We describe the algorithm on a single word.
512     *
513     * In the first phase we perform an unsigned strict register-by-register
514     * comparison of x and y, using the formula
515     *
516     * z = ((((y | H_r) - (x & !H_r)) | (y ^ x)) ^ (y | !x)) & H_r
517     *
518     * Then, we generate a register-by-register mask of all ones or
519     * all zeroes, depending on the result of the comparison, using the
520     * formula
521     *
522     * (((z >> r-1 | H_r) - L_r) | H_r) ^ z
523     *
524     * At that point, it is trivial to select from x and y the right values.
525     */
526
527    // We load y | H_r into the accumulator.
528    acc.extend(
529        y.iter()
530            .zip(msb_mask)
531            .map(|(&y_word, &msb_word)| y_word | msb_word),
532    );
533
534    // We load x & !H_r into mask as temporary storage.
535    mask.extend(
536        x.iter()
537            .zip(msb_mask)
538            .map(|(&x_word, &msb_word)| x_word & !msb_word),
539    );
540
541    // We subtract x & !H_r, using mask as temporary storage
542    subtract(acc, mask);
543
544    // We OR with y ^ x, XOR with (y | !x), and finally AND with H_r.
545    acc.iter_mut()
546        .zip(x.iter())
547        .zip(y.iter())
548        .zip(msb_mask.iter())
549        .for_each(|(((acc_word, &x_word), &y_word), &msb_word)| {
550            *acc_word = ((*acc_word | (y_word ^ x_word)) ^ (y_word | !x_word)) & msb_word
551        });
552
553    // We shift by register_size - 1 places and put the result into mask.
554    {
555        let (mask_last, mask_slice) = mask.split_last_mut().unwrap();
556        let (&msb_last, msb_slice) = msb_mask.split_last().unwrap();
557        mask_slice
558            .iter_mut()
559            .zip(acc[0..num_words_minus_1].iter())
560            .zip(acc[1..].iter())
561            .zip(msb_slice.iter())
562            .rev()
563            .for_each(|(((mask_word, &acc_word), &next_acc_word), &msb_word)| {
564                // W is always unsigned so the shift is always with a 0
565                *mask_word = (acc_word >> register_size_minus_1)
566                    | (next_acc_word << shift_register_size_minus_1)
567                    | msb_word
568            });
569        *mask_last = (acc[num_words_minus_1] >> register_size_minus_1) | msb_last;
570    }
571
572    // We subtract L_r from mask.
573    subtract(mask, lsb_mask);
574
575    // We OR with H_r and XOR with the accumulator.
576    mask.iter_mut()
577        .zip(msb_mask.iter())
578        .zip(acc.iter())
579        .for_each(|((mask_word, &msb_word), &acc_word)| {
580            *mask_word = (*mask_word | msb_word) ^ acc_word
581        });
582
583    // Finally, we use mask to select the right bits from x and y and store the result.
584    x.iter_mut()
585        .zip(y.iter())
586        .zip(mask.iter())
587        .for_each(|((x_word, &y_word), &mask_word)| {
588            *x_word = *x_word ^ ((*x_word ^ y_word) & mask_word);
589        });
590}