use std::hash::{Hash, Hasher};
use packed_simd::{f64x8, u8x16};
use seahash::SeaHasher;
#[cfg(feature = "serde_support")]
use serde::{de::Deserializer, Deserialize, Serialize, Serializer};
#[cfg(feature = "serde_support")]
use crate::serde::{serialize_registers, CompressedRegistersVisitor};
use crate::{M, P};
#[derive(Debug, Clone)]
pub struct HyperLogLog {
pub registers: Box<[u8; M]>,
}
impl HyperLogLog {
#[inline(always)]
fn get_alpha() -> f64 {
match M {
16 => 0.673,
32 => 0.697,
64 => 0.709,
_ => 0.7213 / (1.0 + 1.079 / M as f64),
}
}
pub fn new() -> Self {
Self {
registers: Box::new([0; M]),
}
}
#[inline(always)]
pub fn add<T: Hash>(&mut self, item: T) {
let mut hasher = SeaHasher::new();
item.hash(&mut hasher);
let hashed_value = hasher.finish() as usize;
let j = hashed_value & (M - 1);
let w = hashed_value >> P;
let rho = w.leading_zeros() as u8 + 1;
self.registers[j] = std::cmp::max(self.registers[j], rho);
}
#[inline(always)]
pub fn estimate(&self) -> f64 {
let len = self.registers.len();
let simd_iteration_count = len / 8;
let mut z = f64x8::splat(0.0);
for i in 0..simd_iteration_count {
z += f64x8::new(
2f64.powi(-i32::from(self.registers[i * 8])),
2f64.powi(-i32::from(self.registers[i * 8 + 1])),
2f64.powi(-i32::from(self.registers[i * 8 + 2])),
2f64.powi(-i32::from(self.registers[i * 8 + 3])),
2f64.powi(-i32::from(self.registers[i * 8 + 4])),
2f64.powi(-i32::from(self.registers[i * 8 + 5])),
2f64.powi(-i32::from(self.registers[i * 8 + 6])),
2f64.powi(-i32::from(self.registers[i * 8 + 7])),
);
}
let rem = len % 8;
if rem != 0 {
let mut remainder = f64x8::splat(0.0);
for i in 0..rem {
remainder += f64x8::splat(
2f64.powi(-i32::from(self.registers[simd_iteration_count * 8 + i])),
);
}
z += remainder;
}
let raw_estimate = Self::get_alpha() * (M * M) as f64 / z.sum();
let num_zeros = self.registers.iter().filter(|&&val| val == 0).count();
if num_zeros > 0 {
return M as f64 * (M as f64 / num_zeros as f64).ln();
}
raw_estimate
}
#[inline(always)]
pub fn merge(&mut self, other: &HyperLogLog) {
const CHUNKS: usize = M / 16;
unsafe {
let self_regs =
std::slice::from_raw_parts_mut(self.registers.as_mut_ptr() as *mut u8x16, CHUNKS);
let other_regs =
std::slice::from_raw_parts(other.registers.as_ptr() as *const u8x16, CHUNKS);
for i in 0..CHUNKS {
self_regs[i] = self_regs[i].max(other_regs[i]);
}
}
for i in (CHUNKS * 16)..M {
self.registers[i] = std::cmp::max(self.registers[i], other.registers[i]);
}
}
}
impl Default for HyperLogLog {
fn default() -> Self {
Self::new()
}
}
impl From<[u8; M]> for HyperLogLog {
fn from(registers: [u8; M]) -> Self {
let r = Box::new(registers);
HyperLogLog { registers: r }
}
}
#[cfg(feature = "serde_support")]
impl Serialize for HyperLogLog {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_registers(&self.registers, serializer)
}
}
#[cfg(feature = "serde_support")]
impl<'de> Deserialize<'de> for HyperLogLog {
fn deserialize<D>(deserializer: D) -> Result<HyperLogLog, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(CompressedRegistersVisitor::new())
}
}
#[cfg(test)]
mod tests {
use crate::HyperLogLog;
use std::collections::HashSet;
#[test]
fn add_and_estimate_unique_elements() {
let mut hll = HyperLogLog::new();
for i in 0..10_000 {
hll.add(&i);
}
let count = dbg!(hll.estimate());
assert!((count - 10_000 as f64).abs() < 10_000 as f64 * 0.05); }
#[test]
fn estimate_with_duplicates() {
let mut hll = HyperLogLog::new();
for _ in 0..100 {
for i in 0..10_000 {
hll.add(&i);
}
}
let count = dbg!(hll.estimate());
assert!((count - 10_000 as f64).abs() < 10_000 as f64 * 0.05); }
#[test]
fn compare_with_hashset() {
let mut hll = HyperLogLog::new();
let mut set = HashSet::new();
for i in 0..10_000 {
let item = format!("item_{}", i);
hll.add(&item);
set.insert(item);
}
let hll_count = dbg!(hll.estimate());
let set_count = dbg!(set.len() as f64);
assert!((hll_count - set_count).abs() < set_count * 0.05); }
#[test]
fn test_merge() {
let mut hll1 = HyperLogLog::new();
hll1.add(1);
hll1.add(2);
let mut hll2 = HyperLogLog::new();
hll2.add(3);
hll2.add(4);
hll1.merge(&hll2);
assert_eq!(hll1.estimate().round() as u32, 4);
}
}