use std::hash::{DefaultHasher, Hash, Hasher};
use crate::math;
use crate::{DeserializeError, FORMAT_VERSION, MAGIC, MergeError};
use crate::{MAX_P, MIN_P, T};
const D: u32 = 20;
const REGISTER_MASK: u32 = (1u32 << 28) - 1; const HEADER_LEN: usize = 8;
const BYTES_PER_PAIR: usize = 7;
#[derive(Clone, Debug)]
pub struct ExaLogLog {
p: u32,
storage: Box<[u8]>,
martingale: f64,
mu: f64,
martingale_invalid: bool,
}
impl ExaLogLog {
pub fn new(p: u32) -> Self {
assert!(
(MIN_P..=MAX_P).contains(&p),
"precision p={p} out of range [{MIN_P}, {MAX_P}]"
);
let m = 1usize << p;
let storage_bytes = (m / 2) * BYTES_PER_PAIR;
Self {
p,
storage: vec![0u8; storage_bytes].into_boxed_slice(),
martingale: 0.0,
mu: 1.0,
martingale_invalid: false,
}
}
pub fn precision(&self) -> u32 {
self.p
}
pub fn num_registers(&self) -> usize {
1 << self.p
}
pub fn register_bytes(&self) -> usize {
self.storage.len()
}
pub fn d_parameter() -> u32 {
D
}
#[inline]
pub fn get_register(&self, i: usize) -> u32 {
debug_assert!(i < self.num_registers());
let chunk_off = (i >> 1) * BYTES_PER_PAIR;
if i & 1 == 0 {
let bytes = [
self.storage[chunk_off],
self.storage[chunk_off + 1],
self.storage[chunk_off + 2],
self.storage[chunk_off + 3],
];
u32::from_le_bytes(bytes) & REGISTER_MASK
} else {
let bytes = [
self.storage[chunk_off + 3],
self.storage[chunk_off + 4],
self.storage[chunk_off + 5],
self.storage[chunk_off + 6],
];
u32::from_le_bytes(bytes) >> 4
}
}
#[inline]
fn set_register(&mut self, i: usize, value: u32) {
debug_assert!(i < self.num_registers());
debug_assert!(
value <= REGISTER_MASK,
"register value {value:#x} exceeds 28 bits"
);
let chunk_off = (i >> 1) * BYTES_PER_PAIR;
let v = value & REGISTER_MASK;
if i & 1 == 0 {
self.storage[chunk_off] = v as u8;
self.storage[chunk_off + 1] = (v >> 8) as u8;
self.storage[chunk_off + 2] = (v >> 16) as u8;
let high4 = self.storage[chunk_off + 3] & 0xF0;
self.storage[chunk_off + 3] = high4 | ((v >> 24) as u8 & 0x0F);
} else {
let low4 = self.storage[chunk_off + 3] & 0x0F;
self.storage[chunk_off + 3] = low4 | ((v << 4) as u8 & 0xF0);
self.storage[chunk_off + 4] = (v >> 4) as u8;
self.storage[chunk_off + 5] = (v >> 12) as u8;
self.storage[chunk_off + 6] = (v >> 20) as u8;
}
}
pub fn iter_registers(&self) -> impl Iterator<Item = u32> + '_ {
(0..self.num_registers()).map(move |i| self.get_register(i))
}
pub fn add_hash(&mut self, hash: u64) {
let (i, k) = math::hash_to_register_k(hash, self.p);
let r = self.get_register(i);
let new_r = math::apply_insert(r, k, D);
if new_r != r {
if !self.martingale_invalid {
self.martingale += 1.0 / self.mu;
self.mu -= math::h(r, self.p, D) - math::h(new_r, self.p, D);
if self.mu < 1e-300 {
self.mu = 1e-300;
}
}
self.set_register(i, new_r);
}
}
pub fn add<H: Hash + ?Sized>(&mut self, item: &H) {
let mut hasher = DefaultHasher::new();
item.hash(&mut hasher);
self.add_hash(hasher.finish());
}
pub fn estimate(&self) -> f64 {
self.estimate_ml()
}
pub fn estimate_ml(&self) -> f64 {
let (alpha, beta) = math::compute_alpha_beta(self.iter_registers(), self.p, D);
math::solve_ml(alpha, &beta, self.p)
}
pub fn estimate_martingale(&self) -> Option<f64> {
if self.martingale_invalid {
None
} else {
Some(self.martingale)
}
}
pub fn merge(&mut self, other: &Self) -> Result<(), MergeError> {
if self.p != other.p {
return Err(MergeError::PrecisionMismatch {
lhs: self.p,
rhs: other.p,
});
}
for i in 0..self.num_registers() {
let merged = math::merge_register(self.get_register(i), other.get_register(i), D);
self.set_register(i, merged);
}
self.martingale_invalid = true;
self.martingale = f64::NAN;
self.mu = f64::NAN;
Ok(())
}
pub fn clear(&mut self) {
for b in self.storage.iter_mut() {
*b = 0;
}
self.martingale = 0.0;
self.mu = 1.0;
self.martingale_invalid = false;
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(HEADER_LEN + self.storage.len());
out.extend_from_slice(&MAGIC);
out.push(FORMAT_VERSION);
out.push(T as u8);
out.push(D as u8);
out.push(self.p as u8);
out.extend_from_slice(&self.storage);
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, DeserializeError> {
if bytes.len() < HEADER_LEN {
return Err(DeserializeError::TooShort {
got: bytes.len(),
need: HEADER_LEN,
});
}
if bytes[0..4] != MAGIC {
return Err(DeserializeError::BadMagic);
}
if bytes[4] != FORMAT_VERSION {
return Err(DeserializeError::UnsupportedVersion(bytes[4]));
}
let t = bytes[5];
let d = bytes[6];
if u32::from(t) != T || u32::from(d) != D {
return Err(DeserializeError::ParameterMismatch { t, d });
}
let p = bytes[7];
if !(MIN_P..=MAX_P).contains(&u32::from(p)) {
return Err(DeserializeError::InvalidPrecision(p));
}
let m = 1usize << p;
let expected_len = HEADER_LEN + (m / 2) * BYTES_PER_PAIR;
if bytes.len() != expected_len {
return Err(DeserializeError::LengthMismatch {
got: bytes.len(),
expected: expected_len,
});
}
let storage: Box<[u8]> = bytes[HEADER_LEN..].to_vec().into_boxed_slice();
Ok(Self {
p: u32::from(p),
storage,
martingale: f64::NAN,
mu: f64::NAN,
martingale_invalid: true,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::math::h;
fn splitmix64(mut x: u64) -> u64 {
x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
x ^ (x >> 31)
}
#[test]
fn pack_unpack_roundtrip_for_each_register() {
let p = 6;
let m = 1usize << p;
let mut s = ExaLogLog::new(p);
for i in 0..m {
let v = ((0xA5A5A5u32.wrapping_mul(i as u32 + 1)) ^ 0x0DEADBu32) & REGISTER_MASK;
s.set_register(i, v);
}
for i in 0..m {
let v = ((0xA5A5A5u32.wrapping_mul(i as u32 + 1)) ^ 0x0DEADBu32) & REGISTER_MASK;
assert_eq!(
s.get_register(i),
v,
"register {i} round-trip failed (m={m})"
);
}
}
#[test]
fn writing_register_does_not_disturb_neighbors() {
let p = 8;
let mut s = ExaLogLog::new(p);
for i in 0..s.num_registers() {
s.set_register(i, REGISTER_MASK); }
for i in (0..s.num_registers()).step_by(2) {
s.set_register(i, 0);
}
for i in 0..s.num_registers() {
let expected = if i % 2 == 0 { 0 } else { REGISTER_MASK };
assert_eq!(s.get_register(i), expected, "register {i}");
}
}
#[test]
fn empty_sketch_estimates_zero() {
let s = ExaLogLog::new(12);
assert_eq!(s.estimate(), 0.0);
}
#[test]
fn idempotent_inserts() {
let mut s = ExaLogLog::new(12);
for _ in 0..1000 {
s.add_hash(0xDEAD_BEEF_CAFE_BABE);
}
let changed = s.iter_registers().filter(|&r| r != 0).count();
assert_eq!(changed, 1);
}
#[test]
fn h_strictly_decreases_on_real_state_change() {
let p = 10;
let mut s = ExaLogLog::new(p);
for i in 0..200_000u64 {
let regs_before: Vec<u32> = s.iter_registers().collect();
s.add_hash(splitmix64(i));
let regs_after: Vec<u32> = s.iter_registers().collect();
for (j, (&old_r, &new_r)) in regs_before.iter().zip(regs_after.iter()).enumerate() {
if old_r != new_r {
let h_old = h(old_r, p, D);
let h_new = h(new_r, p, D);
assert!(h_new < h_old, "register {j}: h {h_old} → {h_new}");
}
}
}
}
#[test]
fn ml_estimate_within_error_bounds() {
let p = 12;
for &n in &[100u64, 1_000, 10_000, 100_000, 1_000_000] {
let mut s = ExaLogLog::new(p);
for i in 0..n {
s.add_hash(splitmix64(i));
}
let est = s.estimate_ml();
let rel_err = (est - n as f64).abs() / n as f64;
assert!(rel_err < 0.05, "n={n}: est={est}, rel_err={rel_err}");
}
}
#[test]
fn ml_and_martingale_agree() {
let p = 12;
let n = 50_000u64;
let mut s = ExaLogLog::new(p);
for i in 0..n {
s.add_hash(splitmix64(i));
}
let mart = s.estimate_martingale().unwrap();
let ml = s.estimate_ml();
let rel_diff = (mart - ml).abs() / n as f64;
assert!(rel_diff < 0.02);
}
#[test]
fn merge_disjoint_recovers_union() {
let p = 12;
let mut a = ExaLogLog::new(p);
let mut b = ExaLogLog::new(p);
let mut combined = ExaLogLog::new(p);
for i in 0..50_000u64 {
a.add_hash(splitmix64(i));
combined.add_hash(splitmix64(i));
}
for i in 50_000..100_000u64 {
b.add_hash(splitmix64(i));
combined.add_hash(splitmix64(i));
}
a.merge(&b).unwrap();
for i in 0..a.num_registers() {
assert_eq!(
a.get_register(i),
combined.get_register(i),
"register {i} differs"
);
}
assert_eq!(a.estimate_martingale(), None);
let est = a.estimate();
let rel_err = (est - 100_000.0).abs() / 100_000.0;
assert!(rel_err < 0.05);
}
#[test]
fn serialize_roundtrip() {
let p = 12;
let mut s = ExaLogLog::new(p);
for i in 0..50_000u64 {
s.add_hash(splitmix64(i));
}
let est = s.estimate_ml();
let bytes = s.to_bytes();
let expected_size = 8 + (1 << p) / 2 * 7;
assert_eq!(bytes.len(), expected_size);
let restored = ExaLogLog::from_bytes(&bytes).unwrap();
for i in 0..restored.num_registers() {
assert_eq!(restored.get_register(i), s.get_register(i));
}
assert!((restored.estimate_ml() - est).abs() < 1e-6);
}
#[test]
fn memory_is_43_percent_smaller_than_hll_6bit() {
for p in [8u32, 10, 12, 14] {
let m = 1usize << p;
let s = ExaLogLog::new(p);
assert_eq!(s.register_bytes(), m * 7 / 2);
}
}
}