use crate::math;
use crate::traits::{CardinalitySketch, ErrorBounds, MergeError, Sketch};
use xxhash_rust::xxh3::xxh3_64;
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Clone, Debug)]
pub struct HyperLogLog {
precision: u8,
registers: Vec<u8>,
num_inserts: u64,
}
impl HyperLogLog {
pub fn new(precision: u8) -> Self {
assert!(
(4..=18).contains(&precision),
"precision must be between 4 and 18"
);
let m = 1usize << precision;
Self {
precision,
registers: vec![0u8; m],
num_inserts: 0,
}
}
pub fn with_error(target_error: f64) -> Self {
let precision = super::precision_for_error(target_error);
Self::new(precision)
}
pub fn precision(&self) -> u8 {
self.precision
}
pub fn num_registers(&self) -> usize {
self.registers.len()
}
pub fn insert(&mut self, item: &str) {
self.insert_bytes(item.as_bytes());
}
pub fn insert_bytes(&mut self, bytes: &[u8]) {
let hash = xxh3_64(bytes);
self.insert_hash(hash);
}
pub fn insert_hash(&mut self, hash: u64) {
self.num_inserts += 1;
let p = self.precision as u32;
let idx = (hash >> (64 - p)) as usize;
let w = hash << p;
let max_rho = (64 - p + 1) as u8;
let rho = ((w.leading_zeros() + 1) as u8).min(max_rho);
if rho > self.registers[idx] {
self.registers[idx] = rho;
}
}
fn raw_estimate(&self) -> f64 {
let m = self.registers.len() as f64;
let sum: f64 = self
.registers
.iter()
.map(|&r| math::exp2(-(r as f64)))
.sum();
let alpha = self.alpha_m();
alpha * m * m / sum
}
fn alpha_m(&self) -> f64 {
let m = self.registers.len();
match m {
16 => 0.673,
32 => 0.697,
64 => 0.709,
_ => 0.7213 / (1.0 + 1.079 / m as f64),
}
}
fn count_zeros(&self) -> usize {
self.registers.iter().filter(|&&r| r == 0).count()
}
fn linear_counting(&self, zeros: usize) -> f64 {
let m = self.registers.len() as f64;
m * math::ln(m / zeros as f64)
}
fn bias_correction(&self, raw: f64) -> f64 {
let m = self.registers.len() as f64;
let threshold = 2.5 * m;
if raw <= threshold {
let zeros = self.count_zeros();
if zeros > 0 {
let lc = self.linear_counting(zeros);
if lc <= threshold {
return lc;
}
}
}
raw
}
}
impl Sketch for HyperLogLog {
type Item = [u8];
fn update(&mut self, item: &[u8]) {
self.insert_bytes(item);
}
fn merge(&mut self, other: &Self) -> Result<(), MergeError> {
if self.precision != other.precision {
return Err(MergeError::IncompatibleConfig {
expected: format!("precision={}", self.precision),
found: format!("precision={}", other.precision),
});
}
for (a, &b) in self.registers.iter_mut().zip(other.registers.iter()) {
*a = (*a).max(b);
}
self.num_inserts += other.num_inserts;
Ok(())
}
fn clear(&mut self) {
self.registers.fill(0);
self.num_inserts = 0;
}
fn size_bytes(&self) -> usize {
self.registers.len() + core::mem::size_of::<Self>()
}
fn count(&self) -> u64 {
self.num_inserts
}
}
impl CardinalitySketch for HyperLogLog {
fn estimate(&self) -> f64 {
if self.num_inserts == 0 {
return 0.0;
}
let raw = self.raw_estimate();
self.bias_correction(raw)
}
fn error_bounds(&self, confidence: f64) -> ErrorBounds {
let estimate = self.estimate();
let rse = self.relative_error();
let z = match confidence {
c if c >= 0.99 => 2.576,
c if c >= 0.95 => 1.96,
c if c >= 0.90 => 1.645,
c if c >= 0.80 => 1.282,
_ => 1.0,
};
let margin = z * rse * estimate;
ErrorBounds::new(
(estimate - margin).max(0.0),
estimate,
estimate + margin,
confidence,
)
}
fn relative_error(&self) -> f64 {
let m = self.registers.len() as f64;
1.04 / math::sqrt(m)
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for HyperLogLog {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("HyperLogLog", 3)?;
state.serialize_field("precision", &self.precision)?;
state.serialize_field("registers", &self.registers)?;
state.serialize_field("num_inserts", &self.num_inserts)?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for HyperLogLog {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct HllData {
precision: u8,
registers: Vec<u8>,
num_inserts: u64,
}
let data = HllData::deserialize(deserializer)?;
if !(4..=18).contains(&data.precision) {
return Err(serde::de::Error::custom(
"precision must be between 4 and 18",
));
}
let expected_len = 1usize << data.precision;
if data.registers.len() != expected_len {
return Err(serde::de::Error::custom(format!(
"invalid register length: expected {}, got {}",
expected_len,
data.registers.len()
)));
}
Ok(HyperLogLog {
precision: data.precision,
registers: data.registers,
num_inserts: data.num_inserts,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
let mut hll = HyperLogLog::new(12);
for i in 0..10000 {
hll.insert(&format!("item_{}", i));
}
let estimate = hll.estimate();
assert!(estimate > 9000.0 && estimate < 11000.0);
}
#[test]
fn test_empty() {
let hll = HyperLogLog::new(12);
assert_eq!(hll.estimate(), 0.0);
}
#[test]
fn test_duplicates() {
let mut hll = HyperLogLog::new(12);
for _ in 0..10000 {
hll.insert("same_item");
}
let estimate = hll.estimate();
assert!(estimate >= 0.5 && estimate <= 2.0);
}
#[test]
fn test_merge() {
let mut hll1 = HyperLogLog::new(12);
let mut hll2 = HyperLogLog::new(12);
for i in 0..5000 {
hll1.insert(&format!("a_{}", i));
}
for i in 0..5000 {
hll2.insert(&format!("b_{}", i));
}
let est1 = hll1.estimate();
let est2 = hll2.estimate();
hll1.merge(&hll2).unwrap();
let merged_est = hll1.estimate();
assert!(merged_est > est1);
assert!(merged_est > est2);
assert!(merged_est > 9000.0 && merged_est < 11000.0);
}
#[test]
fn test_merge_incompatible() {
let mut hll1 = HyperLogLog::new(12);
let hll2 = HyperLogLog::new(14);
assert!(hll1.merge(&hll2).is_err());
}
#[test]
fn test_precision() {
let hll = HyperLogLog::new(14);
assert_eq!(hll.precision(), 14);
assert_eq!(hll.num_registers(), 16384);
}
#[test]
fn test_error_bounds() {
let mut hll = HyperLogLog::new(14);
for i in 0..100000 {
hll.insert(&format!("item_{}", i));
}
let bounds = hll.error_bounds(0.95);
assert!(bounds.lower < bounds.estimate);
assert!(bounds.estimate < bounds.upper);
assert!(bounds.lower < 110000.0);
assert!(bounds.upper > 90000.0);
}
#[test]
fn test_small_cardinalities() {
let mut hll = HyperLogLog::new(12);
for i in 0..100 {
hll.insert(&format!("item_{}", i));
}
let estimate = hll.estimate();
assert!(estimate > 80.0 && estimate < 120.0);
}
#[test]
fn test_clear() {
let mut hll = HyperLogLog::new(12);
for i in 0..1000 {
hll.insert(&format!("item_{}", i));
}
assert!(hll.estimate() > 0.0);
hll.clear();
assert_eq!(hll.estimate(), 0.0);
assert_eq!(hll.count(), 0);
}
#[test]
fn test_with_error() {
let hll = HyperLogLog::with_error(0.01); assert!(hll.precision() >= 13); }
#[test]
fn test_rho_edge_case_all_zeros() {
let mut hll = HyperLogLog::new(14);
let hash_with_zero_suffix =
0b1111111111111100_0000000000000000_0000000000000000_0000000000000000u64;
hll.insert_hash(hash_with_zero_suffix);
let max_rho = 64 - 14 + 1;
let idx = (hash_with_zero_suffix >> (64 - 14)) as usize;
assert!(
hll.registers[idx] <= max_rho as u8,
"rho {} exceeds max valid rho {}",
hll.registers[idx],
max_rho
);
let estimate = hll.estimate();
assert!(estimate >= 0.5 && estimate <= 5.0, "estimate={}", estimate);
}
#[test]
fn test_rho_various_precisions() {
for precision in [4u8, 10, 14, 18] {
let mut hll = HyperLogLog::new(precision);
let hash = ((1u64 << precision) - 1) << (64 - precision); hll.insert_hash(hash);
let max_rho = (64 - precision + 1) as u8;
let idx = (hash >> (64 - precision)) as usize;
assert!(
hll.registers[idx] <= max_rho,
"precision={}: rho {} exceeds max {}",
precision,
hll.registers[idx],
max_rho
);
}
}
}