use core::{hash::BuildHasher, marker::PhantomData};
use alloc::{vec, vec::Vec};
use crate::{hash::DefaultHashBuilder, Error};
const MIN_PRECISION: u8 = 4;
const MAX_PRECISION: u8 = 18;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct HyperLogLog<T: ?Sized, S = DefaultHashBuilder> {
registers: Vec<u8>,
precision: u8,
#[cfg_attr(feature = "serde", serde(skip))]
hasher: S,
#[cfg_attr(feature = "serde", serde(skip))]
_marker: PhantomData<fn(&T)>,
}
impl<T: ?Sized> HyperLogLog<T, DefaultHashBuilder> {
pub fn new(precision: u8) -> Result<Self, Error> {
Self::with_hasher(precision, DefaultHashBuilder)
}
pub fn with_error_rate(error: f64) -> Result<Self, Error> {
if !(error.is_finite() && error > 0.0 && error < 1.0) {
return Err(Error::InvalidParameter {
param: "error",
reason: "must be a finite value in the open interval (0.0, 1.0)",
});
}
let registers = libm::pow(1.04 / error, 2.0);
let raw = libm::ceil(libm::log2(registers)) as i64;
let precision = raw.clamp(i64::from(MIN_PRECISION), i64::from(MAX_PRECISION)) as u8;
Self::new(precision)
}
}
impl<T: ?Sized, S: BuildHasher> HyperLogLog<T, S> {
pub fn with_hasher(precision: u8, hasher: S) -> Result<Self, Error> {
if !(MIN_PRECISION..=MAX_PRECISION).contains(&precision) {
return Err(Error::InvalidParameter {
param: "precision",
reason: "must be in the range 4..=18",
});
}
let num_registers = 1usize << precision;
Ok(Self {
registers: vec![0u8; num_registers],
precision,
hasher,
_marker: PhantomData,
})
}
pub fn insert(&mut self, item: &T)
where
T: core::hash::Hash,
{
let hash = self.hasher.hash_one(item);
let p = u32::from(self.precision);
let index = (hash >> (64 - p)) as usize;
let remainder = (hash << p) | ((1u64 << p) - 1);
let rank = (remainder.leading_zeros() + 1) as u8;
if rank > self.registers[index] {
self.registers[index] = rank;
}
}
#[must_use]
pub fn count(&self) -> u64 {
let m = self.registers.len() as f64;
let mut sum = 0.0f64;
let mut zeros = 0u64;
for ®ister in &self.registers {
sum += libm::exp2(-f64::from(register));
if register == 0 {
zeros += 1;
}
}
let raw = alpha(self.registers.len()) * m * m / sum;
if raw <= 2.5 * m && zeros > 0 {
let linear = m * libm::log(m / zeros as f64);
return libm::round(linear) as u64;
}
libm::round(raw) as u64
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.registers.iter().all(|®ister| register == 0)
}
#[inline]
#[must_use]
pub fn precision(&self) -> u8 {
self.precision
}
pub fn clear(&mut self) {
self.registers.iter_mut().for_each(|register| *register = 0);
}
pub fn merge(&mut self, other: &Self) -> Result<(), Error> {
if self.precision != other.precision {
return Err(Error::IncompatibleParameters);
}
for (dst, src) in self.registers.iter_mut().zip(other.registers.iter()) {
*dst = (*dst).max(*src);
}
Ok(())
}
}
fn alpha(m: usize) -> f64 {
match m {
16 => 0.673,
32 => 0.697,
64 => 0.709,
_ => 0.7213 / (1.0 + 1.079 / m as f64),
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_new_rejects_out_of_range_precision() {
assert!(matches!(
HyperLogLog::<&str>::new(3),
Err(Error::InvalidParameter { .. })
));
assert!(matches!(
HyperLogLog::<&str>::new(19),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_with_error_rate_picks_precision() {
let hll = HyperLogLog::<&str>::with_error_rate(0.01).unwrap();
assert_eq!(hll.precision(), 14);
let tight = HyperLogLog::<&str>::with_error_rate(0.0001).unwrap();
assert_eq!(tight.precision(), MAX_PRECISION);
}
#[test]
fn test_empty_counts_zero() {
let hll = HyperLogLog::<u32>::new(14).unwrap();
assert!(hll.is_empty());
assert_eq!(hll.count(), 0);
}
#[test]
fn test_small_cardinality_is_exact_ish() {
let mut hll = HyperLogLog::new(14).unwrap();
for i in 0..10u32 {
hll.insert(&i);
}
let estimate = hll.count();
assert!(
(9..=11).contains(&estimate),
"estimate {estimate} off for n=10"
);
}
#[test]
fn test_large_cardinality_within_error() {
let mut hll = HyperLogLog::new(14).unwrap();
let n = 100_000u32;
for i in 0..n {
hll.insert(&i);
}
let estimate = hll.count();
let error = (estimate as f64 - f64::from(n)).abs() / f64::from(n);
assert!(
error < 0.03,
"relative error {error} too high (est {estimate})"
);
}
#[test]
fn test_idempotent_inserts() {
let mut hll = HyperLogLog::new(14).unwrap();
for _ in 0..1_000 {
hll.insert("same");
}
assert_eq!(hll.count(), 1);
}
#[test]
fn test_clear() {
let mut hll = HyperLogLog::new(14).unwrap();
for i in 0..100u32 {
hll.insert(&i);
}
hll.clear();
assert!(hll.is_empty());
assert_eq!(hll.count(), 0);
}
#[test]
fn test_merge_estimates_union() {
let mut a = HyperLogLog::new(14).unwrap();
let mut b = HyperLogLog::new(14).unwrap();
for i in 0..1_000u32 {
a.insert(&i);
}
for i in 500..1_500u32 {
b.insert(&i);
}
a.merge(&b).unwrap();
let estimate = a.count();
assert!(
(1_400..=1_600).contains(&estimate),
"union estimate {estimate}"
);
}
#[test]
fn test_merge_rejects_incompatible() {
let mut a = HyperLogLog::<u32>::new(14).unwrap();
let b = HyperLogLog::<u32>::new(12).unwrap();
assert_eq!(a.merge(&b), Err(Error::IncompatibleParameters));
}
}