use std::hash::{Hash, Hasher};
use packed_simd::{f64x8, u32x2, u8x16};
use seahash::SeaHasher;
#[cfg(feature = "serde_support")]
use crate::serde::{serialize_registers, CompressedRegistersVisitor};
#[cfg(feature = "serde_support")]
use serde::{de::Deserializer, Deserialize, Serialize, Serializer};
use crate::{ALPHA, EMPTY_REGISTERS, M, P};
#[derive(Debug, Clone)]
pub struct HyperLogLogPlusPlus {
pub registers: Box<[u8; M]>,
}
impl HyperLogLogPlusPlus {
#[inline(always)]
pub fn new() -> Self {
Self {
registers: Box::new(unsafe { EMPTY_REGISTERS.clone() }),
}
}
#[inline(always)]
pub fn add<T: Hash>(&mut self, item: T) {
let mut h = SeaHasher::default();
item.hash(&mut h);
let mut hash = h.finish();
for _ in 0..2 {
let vec_hash = u32x2::new(
(hash & (M as u64 - 1)) as u32,
((hash >> 32) & (M as u64 - 1)) as u32,
);
let vec_w = u32x2::new((hash >> P) as u32, (hash >> (32 + P)) as u32);
let vec_rank = vec_w.min_element().leading_zeros() as u8 + 1;
let max_index = vec_hash.extract(0) as usize;
if self.registers[max_index] < vec_rank {
self.registers[max_index] = vec_rank;
}
hash = hash.wrapping_shr(64);
}
}
#[inline(always)]
pub fn estimate(&self) -> f64 {
let mut acc_sum = f64x8::splat(0.0);
let len = self.registers.len();
let simd_iteration_count = len / 8;
for i in 0..simd_iteration_count {
let chunk = self.registers[i * 8..(i + 1) * 8]
.iter()
.map(|&x| x as f64)
.collect::<Vec<f64>>();
let vector = f64x8::from_slice_unaligned(&chunk);
acc_sum += f64x8::splat(2.0).powf(-vector);
}
let rem = len % 8;
if rem > 0 {
let chunk = self.registers[len - rem..]
.iter()
.map(|&x| x as f64)
.collect::<Vec<f64>>();
let vector = f64x8::from_slice_unaligned(&chunk);
acc_sum += f64x8::splat(2.0).powf(-vector);
}
let harmonic_mean: f64 = 1.0 / acc_sum.sum();
let approx_cardinality: f64 = ALPHA * (M * M) as f64 * harmonic_mean;
let zero_reg_count: f64 = self.registers.iter().filter(|&rank| *rank == 0).count() as f64;
if approx_cardinality <= 2.5 * M as f64 && zero_reg_count > 0.0 {
M as f64 * (M as f64 / zero_reg_count).ln()
} else {
approx_cardinality
}
}
#[inline(always)]
pub fn merge(&mut self, other: &HyperLogLogPlusPlus) {
const CHUNKS: usize = M / 16;
unsafe {
let self_regs =
std::slice::from_raw_parts_mut(self.registers.as_mut_ptr() as *mut u8x16, CHUNKS);
let other_regs =
std::slice::from_raw_parts(other.registers.as_ptr() as *const u8x16, CHUNKS);
for i in 0..CHUNKS {
self_regs[i] = self_regs[i].max(other_regs[i]);
}
}
for i in (CHUNKS * 16)..M {
self.registers[i] = std::cmp::max(self.registers[i], other.registers[i]);
}
}
}
impl Default for HyperLogLogPlusPlus {
fn default() -> Self {
Self::new()
}
}
impl From<[u8; M]> for HyperLogLogPlusPlus {
fn from(registers: [u8; M]) -> Self {
HyperLogLogPlusPlus {
registers: Box::new(registers),
}
}
}
#[cfg(feature = "serde_support")]
impl Serialize for HyperLogLogPlusPlus {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_registers(&self.registers, serializer)
}
}
#[cfg(feature = "serde_support")]
impl<'de> Deserialize<'de> for HyperLogLogPlusPlus {
fn deserialize<D>(deserializer: D) -> Result<HyperLogLogPlusPlus, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(CompressedRegistersVisitor::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
use nanorand::Rng;
#[test]
fn test_add_and_estimate() {
let mut hllpp = HyperLogLogPlusPlus::new();
for i in 0..1000 {
hllpp.add(i);
}
let estimate = hllpp.estimate();
assert!(
(950..1050).contains(&(estimate as usize)),
"Estimate out of expected range"
);
}
#[test]
fn test_add_same_value_multiple_times() {
let mut hllpp = HyperLogLogPlusPlus::new();
for _ in 0..1000 {
hllpp.add(500);
}
let estimate = hllpp.estimate();
assert_eq!(
estimate as usize, 1,
"Estimate should be 1 for identical elements"
);
}
#[test]
fn test_empty_estimate() {
let hllpp = HyperLogLogPlusPlus::new();
let estimate = hllpp.estimate();
assert_eq!(estimate, 0.0, "Empty HLL++ should estimate to 0");
}
#[test]
fn test_random_values() {
let mut hllpp = HyperLogLogPlusPlus::new();
let mut rng = nanorand::tls_rng();
let mut unique_values = std::collections::HashSet::new();
for _ in 0..100_000 {
let val = rng.generate_range(0..50_000);
unique_values.insert(val);
hllpp.add(val);
}
let estimate = dbg!(hllpp.estimate());
assert!(
unique_values.len() as f64 * 0.9 <= estimate
&& estimate <= unique_values.len() as f64 * 1.1,
"Estimate out of expected range"
);
}
#[test]
fn test_large_number_of_values() {
let mut hllpp = HyperLogLogPlusPlus::new();
for i in 0..500_000 {
hllpp.add(i);
}
let estimate = hllpp.estimate();
assert!(
(490_000..510_000).contains(&(dbg!(estimate) as usize)),
"Estimate out of expected range"
);
}
#[test]
fn test_merge() {
let mut hll1 = HyperLogLogPlusPlus::new();
hll1.add(1);
hll1.add(2);
let mut hll2 = HyperLogLogPlusPlus::new();
hll2.add(3);
hll2.add(4);
hll1.merge(&hll2);
assert_eq!(hll1.estimate().round() as u32, 4);
}
}