Skip to main content

card_est_array/impls/
hyper_log_log.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Matteo Dell'Acqua
3 * SPDX-FileCopyrightText: 2025 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use num_traits::AsPrimitive;
9use std::hash::*;
10use std::{borrow::Borrow, f64::consts::LN_2};
11
12use crate::traits::Word;
13
14#[cfg(target_pointer_width = "16")]
15type DefaultWord = u16;
16#[cfg(target_pointer_width = "32")]
17type DefaultWord = u32;
18#[cfg(target_pointer_width = "64")]
19type DefaultWord = u64;
20
21use crate::traits::{EstimationLogic, MergeEstimationLogic, SliceEstimationLogic};
22
23use super::DefaultEstimator;
24
25/// The type returned by the hash function.
26type HashResult = u64;
27
28/// Estimator logic implementing the HyperLogLog algorithm.
29///
30/// Instances are built using [`HyperLogLogBuilder`], which provides convenient
31/// ways to set the internal parameters.
32///
33/// Note that `T` can be any type satisfying the [`Hash`] trait. The parameter
34/// `H` makes it possible to select a hashing algorithm, and `W` is the unsigned
35/// type used to store backends.
36///
37/// An important constraint is that `W` must be able to represent exactly the
38/// backend of an estimator. While usually `usize` will work (and it is the default
39/// type chosen by [`new`](HyperLogLogBuilder::new)), with odd register sizes
40/// and a small number of registers it might be necessary to select a smaller
41/// type, resulting in slower merges. For example, using 16 5-bit registers one
42/// needs to use `u16`, whereas for 16 6-bit registers `u32` will be sufficient.
43#[derive(Debug, PartialEq)]
44pub struct HyperLogLog<T, H, W> {
45    build_hasher: H,
46    register_size: usize,
47    num_registers_minus_1: HashResult,
48    log_2_num_registers: usize,
49    sentinel_mask: HashResult,
50    num_registers: usize,
51    pub(super) words_per_estimator: usize,
52    alpha_m_m: f64,
53    msb_mask: Box<[W]>,
54    lsb_mask: Box<[W]>,
55    _marker: std::marker::PhantomData<T>,
56}
57
58// We implement Clone manually because we do not want to require that T is
59// Clone.
60impl<T, H: Clone, W: Clone> Clone for HyperLogLog<T, H, W> {
61    fn clone(&self) -> Self {
62        Self {
63            build_hasher: self.build_hasher.clone(),
64            register_size: self.register_size,
65            num_registers_minus_1: self.num_registers_minus_1,
66            log_2_num_registers: self.log_2_num_registers,
67            sentinel_mask: self.sentinel_mask,
68            num_registers: self.num_registers,
69            words_per_estimator: self.words_per_estimator,
70            alpha_m_m: self.alpha_m_m,
71            msb_mask: self.msb_mask.clone(),
72            lsb_mask: self.lsb_mask.clone(),
73            _marker: std::marker::PhantomData,
74        }
75    }
76}
77
78impl<T, H: Clone, W: Word> HyperLogLog<T, H, W> {
79    /// Returns the value contained in a register of a given backend.
80    #[inline(always)]
81    fn get_register_unchecked(&self, backend: impl AsRef<[W]>, index: usize) -> W {
82        let backend = backend.as_ref();
83        let bits = W::BITS as usize;
84        let bit_width = self.register_size;
85        let mask = W::MAX >> (bits - bit_width);
86        let pos = index * bit_width;
87        let word_index = pos / bits;
88        let bit_index = pos % bits;
89
90        if bit_index + bit_width <= bits {
91            (unsafe { *backend.get_unchecked(word_index) } >> bit_index) & mask
92        } else {
93            ((unsafe { *backend.get_unchecked(word_index) } >> bit_index)
94                | (unsafe { *backend.get_unchecked(word_index + 1) } << (bits - bit_index)))
95                & mask
96        }
97    }
98
99    /// Sets the value contained in a register of a given backend.
100    #[inline(always)]
101    fn set_register_unchecked(&self, mut backend: impl AsMut<[W]>, index: usize, new_value: W) {
102        let backend = backend.as_mut();
103        let bits = W::BITS as usize;
104        let bit_width = self.register_size;
105        let mask = W::MAX >> (bits - bit_width);
106        let pos = index * bit_width;
107        let word_index = pos / bits;
108        let bit_index = pos % bits;
109
110        if bit_index + bit_width <= bits {
111            let mut word = unsafe { *backend.get_unchecked_mut(word_index) };
112            word &= !(mask << bit_index);
113            word |= new_value << bit_index;
114            unsafe { *backend.get_unchecked_mut(word_index) = word };
115        } else {
116            let mut word = unsafe { *backend.get_unchecked_mut(word_index) };
117            word &= (W::ONE << bit_index) - W::ONE;
118            word |= new_value << bit_index;
119            unsafe { *backend.get_unchecked_mut(word_index) = word };
120
121            let mut word = unsafe { *backend.get_unchecked_mut(word_index + 1) };
122            word &= !(mask >> (bits - bit_index));
123            word |= new_value >> (bits - bit_index);
124            unsafe { *backend.get_unchecked_mut(word_index + 1) = word };
125        }
126    }
127}
128
129impl<T: Hash, H: BuildHasher + Clone, W: Word + Into<u64>> SliceEstimationLogic<W>
130    for HyperLogLog<T, H, W>
131where
132    u64: AsPrimitive<W>,
133{
134    #[inline(always)]
135    fn backend_len(&self) -> usize {
136        self.words_per_estimator
137    }
138}
139
140impl<T: Hash, H: BuildHasher + Clone, W: Word + Into<u64>> EstimationLogic for HyperLogLog<T, H, W>
141where
142    u64: AsPrimitive<W>,
143{
144    type Item = T;
145    type Backend = [W];
146    type Estimator<'a>
147        = DefaultEstimator<Self, &'a Self, Box<[W]>>
148    where
149        T: 'a,
150        W: 'a,
151        H: 'a;
152
153    fn new_estimator(&self) -> Self::Estimator<'_> {
154        Self::Estimator::new(
155            self,
156            vec![W::ZERO; self.words_per_estimator].into_boxed_slice(),
157        )
158    }
159
160    fn add(&self, backend: &mut Self::Backend, element: impl Borrow<T>) {
161        let x = self.build_hasher.hash_one(element.borrow());
162        let j = x & self.num_registers_minus_1;
163        let r =
164            ((x >> self.log_2_num_registers) | self.sentinel_mask).trailing_zeros() as HashResult;
165        let register = j as usize;
166
167        debug_assert!(r < (1 << self.register_size) - 1);
168        debug_assert!(register < self.num_registers);
169
170        let current_value = self.get_register_unchecked(&*backend, register);
171        let candidate_value = r + 1;
172        let new_value = std::cmp::max(current_value, candidate_value.as_());
173        if current_value != new_value {
174            self.set_register_unchecked(backend, register, new_value);
175        }
176    }
177
178    fn estimate(&self, backend: &[W]) -> f64 {
179        let mut harmonic_mean = 0.0;
180        let mut zeroes = 0;
181
182        for i in 0..self.num_registers {
183            let value: u64 = self.get_register_unchecked(backend, i).into();
184            if value == 0 {
185                zeroes += 1;
186            }
187            harmonic_mean += 1.0 / (1_u64 << value) as f64;
188        }
189
190        let mut estimate = self.alpha_m_m / harmonic_mean;
191        if zeroes != 0 && estimate < 2.5 * self.num_registers as f64 {
192            estimate = self.num_registers as f64 * (self.num_registers as f64 / zeroes as f64).ln();
193        }
194        estimate
195    }
196
197    #[inline(always)]
198    fn clear(&self, backend: &mut [W]) {
199        backend.fill(W::ZERO);
200    }
201
202    #[inline(always)]
203    fn set(&self, dst: &mut [W], src: &[W]) {
204        debug_assert_eq!(dst.len(), src.len());
205        dst.copy_from_slice(src);
206    }
207}
208
209/// Helper for merge operations with [`HyperLogLog`] logic.
210pub struct HyperLogLogHelper<W> {
211    acc: Vec<W>,
212    mask: Vec<W>,
213}
214
215impl<T: Hash, H: BuildHasher + Clone, W: Word + Into<u64>> MergeEstimationLogic
216    for HyperLogLog<T, H, W>
217where
218    u64: AsPrimitive<W>,
219{
220    type Helper = HyperLogLogHelper<W>;
221
222    fn new_helper(&self) -> Self::Helper {
223        HyperLogLogHelper {
224            acc: vec![W::ZERO; self.words_per_estimator],
225            mask: vec![W::ZERO; self.words_per_estimator],
226        }
227    }
228
229    #[inline(always)]
230    fn merge_with_helper(&self, dst: &mut [W], src: &[W], helper: &mut Self::Helper) {
231        merge_hyperloglog_bitwise(
232            dst,
233            src,
234            self.msb_mask.as_ref(),
235            self.lsb_mask.as_ref(),
236            &mut helper.acc,
237            &mut helper.mask,
238            self.register_size,
239        );
240    }
241}
242
243/// Builds a [`HyperLogLog`] cardinality-estimator logic.
244#[derive(Debug, Clone)]
245pub struct HyperLogLogBuilder<H, W = DefaultWord> {
246    build_hasher: H,
247    log_2_num_registers: usize,
248    num_elements: usize,
249    _marker: std::marker::PhantomData<W>,
250}
251
252impl HyperLogLogBuilder<BuildHasherDefault<DefaultHasher>> {
253    /// Creates a new builder for a [`HyperLogLog`] logic with the default word
254    /// type (the fixed-size equivalent of `usize`).
255    ///
256    /// # Panics
257    ///
258    /// If `n` is zero.
259    pub const fn new(num_elements: usize) -> Self {
260        assert!(
261            num_elements > 0,
262            "the upper bound on the number of distinct elements must be positive"
263        );
264        Self {
265            build_hasher: BuildHasherDefault::new(),
266            log_2_num_registers: 4,
267            num_elements,
268            _marker: std::marker::PhantomData,
269        }
270    }
271}
272
273fn min_alignment(bits: usize) -> String {
274    if bits % 128 == 0 {
275        "u128"
276    } else if bits % 64 == 0 {
277        "u64"
278    } else if bits % 32 == 0 {
279        "u32"
280    } else if bits % 16 == 0 {
281        "u16"
282    } else {
283        "u8"
284    }
285    .to_string()
286}
287
288impl HyperLogLog<(), (), ()> {
289    /// Returns the logarithm of the number of registers per estimator that are
290    /// necessary to attain a given relative standard deviation.
291    ///
292    /// # Arguments
293    /// * `rsd`: the relative standard deviation to be attained.
294    pub fn log_2_num_of_registers(rsd: f64) -> usize {
295        ((1.106 / rsd).powi(2)).log2().ceil() as usize
296    }
297
298    /// Returns the relative standard deviation corresponding to a given number
299    /// of registers per estimator.
300    ///
301    /// # Arguments
302    ///
303    /// * `log_2_num_registers`: the logarithm of the number of registers per
304    ///   estimator.
305    pub fn rel_std(log_2_num_registers: usize) -> f64 {
306        let tmp = match log_2_num_registers {
307            4 => 1.106,
308            5 => 1.070,
309            6 => 1.054,
310            7 => 1.046,
311            _ => 1.04,
312        };
313        tmp / ((1 << log_2_num_registers) as f64).sqrt()
314    }
315
316    /// Returns the register size in bits, given an upper bound on the number of
317    /// distinct elements.
318    ///
319    /// # Arguments
320    /// * `num_elements`: an upper bound on the number of distinct elements.
321    pub fn register_size(num_elements: usize) -> usize {
322        std::cmp::max(
323            5,
324            (((num_elements as f64).ln() / LN_2) / LN_2).ln().ceil() as usize,
325        )
326    }
327}
328
329impl<H, W: Word> HyperLogLogBuilder<H, W> {
330    /// Sets the desired relative standard deviation.
331    ///
332    /// ## Note
333    ///
334    /// This is a high-level alternative to [`Self::log_2_num_reg`]. Calling one
335    /// after the other invalidates the work done by the first one.
336    ///
337    /// # Arguments
338    /// * `rsd`: the relative standard deviation to be attained.
339    ///
340    /// # Panics
341    ///
342    /// If the resulting number of registers is less than 16 (i.e., `rsd` is
343    /// too large).
344    pub fn rsd(self, rsd: f64) -> Self {
345        self.log_2_num_reg(HyperLogLog::log_2_num_of_registers(rsd))
346    }
347
348    /// Sets the base-2 logarithm of the number of registers.
349    ///
350    /// ## Note
351    /// This is a low-level alternative to [`Self::rsd`]. Calling one after the
352    /// other invalidates the work done by the first one.
353    ///
354    /// # Arguments
355    /// * `log_2_num_registers`: the logarithm of the number of registers per
356    ///   estimator.
357    ///
358    /// # Panics
359    ///
360    /// If `log_2_num_registers` is less than 4.
361    pub const fn log_2_num_reg(mut self, log_2_num_registers: usize) -> Self {
362        assert!(
363            log_2_num_registers >= 4,
364            "the logarithm of the number of registers per estimator should be at least 4"
365        );
366        self.log_2_num_registers = log_2_num_registers;
367        self
368    }
369
370    /// Sets the type `W` to use to represent backends.
371    ///
372    /// Note that the returned builder will have a different type if `W2` is
373    /// different from `W`.
374    ///
375    /// See the [`logic documentation`](HyperLogLog) for the limitations on the
376    /// choice of `W2`.
377    pub fn word_type<W2>(self) -> HyperLogLogBuilder<H, W2> {
378        HyperLogLogBuilder {
379            num_elements: self.num_elements,
380            build_hasher: self.build_hasher,
381            log_2_num_registers: self.log_2_num_registers,
382            _marker: std::marker::PhantomData,
383        }
384    }
385
386    /// Sets the upper bound on the number of elements.
387    ///
388    /// # Panics
389    ///
390    /// If `n` is zero.
391    pub const fn num_elements(mut self, num_elements: usize) -> Self {
392        assert!(
393            num_elements > 0,
394            "the upper bound on the number of distinct elements must be positive"
395        );
396        self.num_elements = num_elements;
397        self
398    }
399
400    /// Sets the [`BuildHasher`] to use.
401    ///
402    /// Using this method you can select a specific hasher based on one or more
403    /// seeds.
404    pub fn build_hasher<H2>(self, build_hasher: H2) -> HyperLogLogBuilder<H2, W> {
405        HyperLogLogBuilder {
406            num_elements: self.num_elements,
407            log_2_num_registers: self.log_2_num_registers,
408            build_hasher,
409            _marker: std::marker::PhantomData,
410        }
411    }
412
413    /// Builds the logic.
414    ///
415    /// The type of objects the estimators keep track of is defined here by `T`,
416    /// but it is usually inferred by the compiler.
417    ///
418    /// # Panics
419    ///
420    /// If the estimator size in bits is not divisible by the bit width of `W`.
421    pub fn build<T>(self) -> HyperLogLog<T, H, W> {
422        let bits = W::BITS as usize;
423        let log_2_num_registers = self.log_2_num_registers;
424        let num_elements = self.num_elements;
425        let number_of_registers = 1 << log_2_num_registers;
426        let register_size = HyperLogLog::register_size(num_elements);
427        let sentinel_mask = 1 << ((1 << register_size) - 2);
428        let alpha = match log_2_num_registers {
429            4 => 0.673,
430            5 => 0.697,
431            6 => 0.709,
432            _ => 0.7213 / (1.0 + 1.079 / number_of_registers as f64),
433        };
434        let num_registers_minus_1 = (number_of_registers - 1) as HashResult;
435
436        let est_size_in_bits = number_of_registers * register_size;
437
438        // This ensures estimators are always aligned to W
439        assert!(
440            est_size_in_bits % bits == 0,
441            "W should allow estimator backends to be aligned. Use {} or smaller unsigned integer types",
442            min_alignment(est_size_in_bits)
443        );
444        let est_size_in_words = est_size_in_bits / bits;
445
446        let msb_mask = build_register_mask(
447            est_size_in_words,
448            register_size,
449            W::ONE << (register_size - 1),
450        );
451        let lsb_mask = build_register_mask(est_size_in_words, register_size, W::ONE);
452
453        HyperLogLog {
454            num_registers: number_of_registers,
455            num_registers_minus_1,
456            log_2_num_registers,
457            register_size,
458            alpha_m_m: alpha * (number_of_registers as f64).powi(2),
459            sentinel_mask,
460            build_hasher: self.build_hasher,
461            msb_mask,
462            lsb_mask,
463            words_per_estimator: est_size_in_words,
464            _marker: std::marker::PhantomData,
465        }
466    }
467}
468
469impl<T, H, W> std::fmt::Display for HyperLogLog<T, H, W> {
470    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        write!(
472            f,
473            "HyperLogLog with relative standard deviation: {}% ({} registers/estimator, {} bits/register, {} bytes/estimator)",
474            100.0 * HyperLogLog::rel_std(self.log_2_num_registers),
475            self.num_registers,
476            self.register_size,
477            (self.num_registers * self.register_size) / 8
478        )
479    }
480}
481
482/// Builds a mask of `num_words` words by repeating a `register_size`-bit
483/// pattern across all register positions.
484fn build_register_mask<W: Word>(num_words: usize, register_size: usize, pattern: W) -> Box<[W]> {
485    let bits = W::BITS as usize;
486    let total_bits = num_words * bits;
487    let mut result = vec![W::ZERO; num_words];
488    let mut bit_pos = 0;
489    while bit_pos < total_bits {
490        let word_index = bit_pos / bits;
491        let bit_index = bit_pos % bits;
492        result[word_index] |= pattern << bit_index;
493        if bit_index + register_size > bits && word_index + 1 < num_words {
494            result[word_index + 1] |= pattern >> (bits - bit_index);
495        }
496        bit_pos += register_size;
497    }
498    result.into_boxed_slice()
499}
500
501/// Performs a multiple precision subtraction, leaving the result in the first operand.
502/// The operands MUST have the same length.
503///
504/// # Arguments
505/// * `x`: the first operand. This will contain the final result.
506/// * `y`: the second operand that will be subtracted from `x`.
507#[inline(always)]
508pub(super) fn subtract<W: Word>(x: &mut [W], y: &[W]) {
509    debug_assert_eq!(x.len(), y.len());
510    let mut borrow = false;
511
512    for (x_word, &y) in x.iter_mut().zip(y.iter()) {
513        let mut x = *x_word;
514        if !borrow {
515            borrow = x < y;
516        } else if x != W::ZERO {
517            x = x.wrapping_sub(W::ONE);
518            borrow = x < y;
519        } else {
520            x = x.wrapping_sub(W::ONE);
521        }
522        *x_word = x.wrapping_sub(y);
523    }
524}
525
526fn merge_hyperloglog_bitwise<W: Word>(
527    mut x: impl AsMut<[W]>,
528    y: impl AsRef<[W]>,
529    msb_mask: impl AsRef<[W]>,
530    lsb_mask: impl AsRef<[W]>,
531    acc: &mut Vec<W>,
532    mask: &mut Vec<W>,
533    register_size: usize,
534) {
535    let x = x.as_mut();
536    let y = y.as_ref();
537    let msb_mask = msb_mask.as_ref();
538    let lsb_mask = lsb_mask.as_ref();
539
540    debug_assert_eq!(x.len(), y.len());
541    debug_assert_eq!(x.len(), msb_mask.len());
542    debug_assert_eq!(x.len(), lsb_mask.len());
543
544    let register_size_minus_1 = register_size - 1;
545    let num_words_minus_1 = x.len() - 1;
546    let shift_register_size_minus_1 = W::BITS as usize - register_size_minus_1;
547
548    acc.clear();
549    mask.clear();
550
551    /* We work in two phases. Let H_r (msb_mask) be the mask with the
552     * highest bit of each register (of size r) set, and L_r (lsb_mask)
553     * be the mask with the lowest bit of each register set.
554     * We describe the algorithm on a single word.
555     *
556     * In the first phase we perform an unsigned strict register-by-register
557     * comparison of x and y, using the formula
558     *
559     * z = ((((y | H_r) - (x & !H_r)) | (y ^ x)) ^ (y | !x)) & H_r
560     *
561     * Then, we generate a register-by-register mask of all ones or
562     * all zeroes, depending on the result of the comparison, using the
563     * formula
564     *
565     * (((z >> r-1 | H_r) - L_r) | H_r) ^ z
566     *
567     * At that point, it is trivial to select from x and y the right values.
568     */
569
570    // We load y | H_r into the accumulator.
571    acc.extend(
572        y.iter()
573            .zip(msb_mask)
574            .map(|(&y_word, &msb_word)| y_word | msb_word),
575    );
576
577    // We load x & !H_r into mask as temporary storage.
578    mask.extend(
579        x.iter()
580            .zip(msb_mask)
581            .map(|(&x_word, &msb_word)| x_word & !msb_word),
582    );
583
584    // We subtract x & !H_r, using mask as temporary storage
585    subtract(acc, mask);
586
587    // We OR with y ^ x, XOR with (y | !x), and finally AND with H_r.
588    acc.iter_mut()
589        .zip(x.iter())
590        .zip(y.iter())
591        .zip(msb_mask.iter())
592        .for_each(|(((acc_word, &x_word), &y_word), &msb_word)| {
593            *acc_word = ((*acc_word | (y_word ^ x_word)) ^ (y_word | !x_word)) & msb_word
594        });
595
596    // We shift by register_size - 1 places and put the result into mask.
597    {
598        let (mask_last, mask_slice) = mask.split_last_mut().unwrap();
599        let (&msb_last, msb_slice) = msb_mask.split_last().unwrap();
600        mask_slice
601            .iter_mut()
602            .zip(acc[0..num_words_minus_1].iter())
603            .zip(acc[1..].iter())
604            .zip(msb_slice.iter())
605            .rev()
606            .for_each(|(((mask_word, &acc_word), &next_acc_word), &msb_word)| {
607                // W is always unsigned so the shift is always with a 0
608                *mask_word = (acc_word >> register_size_minus_1)
609                    | (next_acc_word << shift_register_size_minus_1)
610                    | msb_word
611            });
612        *mask_last = (acc[num_words_minus_1] >> register_size_minus_1) | msb_last;
613    }
614
615    // We subtract L_r from mask.
616    subtract(mask, lsb_mask);
617
618    // We OR with H_r and XOR with the accumulator.
619    mask.iter_mut()
620        .zip(msb_mask.iter())
621        .zip(acc.iter())
622        .for_each(|((mask_word, &msb_word), &acc_word)| {
623            *mask_word = (*mask_word | msb_word) ^ acc_word
624        });
625
626    // Finally, we use mask to select the right bits from x and y and store the result.
627    x.iter_mut()
628        .zip(y.iter())
629        .zip(mask.iter())
630        .for_each(|((x_word, &y_word), &mask_word)| {
631            *x_word = *x_word ^ ((*x_word ^ y_word) & mask_word);
632        });
633}