use super::DefaultEstimator;
use super::hyper_log_log::apply_correction;
use crate::traits::{
EstimationLogic, MergeEstimationLogic, SliceEstimationLogic, assert_backend_len,
};
use std::borrow::Borrow;
use std::hash::*;
use std::{fmt, marker::PhantomData};
#[derive(Debug, PartialEq)]
pub struct HyperLogLog8<T, H, const BETA: bool = true> {
build_hasher: H,
num_regs_minus_1: u64,
log2_num_regs: u32,
num_regs: usize,
alpha_m_m: f64,
_marker: PhantomData<T>,
}
impl<T, H: Clone, const BETA: bool> Clone for HyperLogLog8<T, H, BETA> {
fn clone(&self) -> Self {
Self {
build_hasher: self.build_hasher.clone(),
num_regs_minus_1: self.num_regs_minus_1,
log2_num_regs: self.log2_num_regs,
num_regs: self.num_regs,
alpha_m_m: self.alpha_m_m,
_marker: PhantomData,
}
}
}
impl<T, H: Clone, const BETA: bool> HyperLogLog8<T, H, BETA> {
pub fn log2_num_regs(&self) -> u32 {
self.log2_num_regs
}
}
impl<T: Hash, H: BuildHasher + Clone, const BETA: bool> SliceEstimationLogic<u8>
for HyperLogLog8<T, H, BETA>
{
#[inline(always)]
fn backend_len(&self) -> usize {
self.num_regs
}
}
impl<T: Hash, H: BuildHasher + Clone, const BETA: bool> EstimationLogic
for HyperLogLog8<T, H, BETA>
{
type Item = T;
type Backend = [u8];
type Estimator<'a>
= DefaultEstimator<Self, &'a Self, Box<[u8]>>
where
T: 'a,
H: 'a;
fn new_estimator(&self) -> Self::Estimator<'_> {
Self::Estimator::new(self, vec![0u8; self.num_regs].into_boxed_slice())
}
fn add(&self, backend: &mut Self::Backend, element: impl Borrow<T>) {
assert_backend_len!(self, backend);
let hash = self.build_hasher.hash_one(element.borrow());
let register = (hash & self.num_regs_minus_1) as usize;
let r = hash.rotate_right(self.log2_num_regs).trailing_zeros();
debug_assert!(register < self.num_regs);
backend[register] = backend[register].max(r as u8 + 1);
}
fn estimate(&self, backend: &[u8]) -> f64 {
assert_backend_len!(self, backend);
let mut harmonic_mean = 0.0;
let mut zeroes = 0usize;
for &value in backend {
if value == 0 {
zeroes += 1;
}
harmonic_mean += f64::from_bits((1023 - value as u64) << 52);
}
apply_correction::<BETA>(
harmonic_mean,
zeroes,
self.num_regs,
self.log2_num_regs,
self.alpha_m_m,
)
}
#[inline(always)]
fn clear(&self, backend: &mut [u8]) {
backend.fill(0);
}
#[inline(always)]
fn set(&self, dst: &mut [u8], src: &[u8]) {
debug_assert_eq!(dst.len(), src.len());
dst.copy_from_slice(src);
}
}
impl<T: Hash, H: BuildHasher + Clone, const BETA: bool> MergeEstimationLogic
for HyperLogLog8<T, H, BETA>
{
type Helper = ();
fn new_helper(&self) -> Self::Helper {}
#[inline(always)]
fn merge_with_helper(&self, dst: &mut [u8], src: &[u8], _helper: &mut Self::Helper) {
debug_assert_eq!(dst.len(), src.len());
merge_max_bytes(dst, src);
}
}
impl<T, H, const BETA: bool> fmt::Display for HyperLogLog8<T, H, BETA> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"HyperLogLog8 with relative standard deviation: {}% ({} registers/estimator, 8 bits/register, {} bytes/estimator)",
100.0 * super::HyperLogLog::rel_std(self.log2_num_regs),
self.num_regs,
self.num_regs,
)
}
}
#[derive(Debug, Clone)]
pub struct HyperLogLog8Builder<H, const BETA: bool = true> {
build_hasher: H,
log2_num_regs: u32,
}
impl HyperLogLog8Builder<BuildHasherDefault<DefaultHasher>> {
pub const fn new() -> Self {
Self {
build_hasher: BuildHasherDefault::new(),
log2_num_regs: 4,
}
}
}
impl Default for HyperLogLog8Builder<BuildHasherDefault<DefaultHasher>> {
fn default() -> Self {
Self::new()
}
}
impl<H, const BETA: bool> HyperLogLog8Builder<H, BETA> {
pub fn rsd(self, rsd: f64) -> Self {
self.log2_num_regs(super::HyperLogLog::log2_num_of_registers(rsd))
}
pub const fn log2_num_regs(mut self, log2_num_regs: u32) -> Self {
assert!(
log2_num_regs >= 4,
"the logarithm of the number of registers per estimator should be at least 4"
);
self.log2_num_regs = log2_num_regs;
self
}
pub fn beta<const BETA2: bool>(self) -> HyperLogLog8Builder<H, BETA2> {
HyperLogLog8Builder {
build_hasher: self.build_hasher,
log2_num_regs: self.log2_num_regs,
}
}
pub fn build_hasher<H2>(self, build_hasher: H2) -> HyperLogLog8Builder<H2, BETA> {
HyperLogLog8Builder {
log2_num_regs: self.log2_num_regs,
build_hasher,
}
}
pub fn build<T>(self) -> HyperLogLog8<T, H, BETA> {
let log2_num_regs = self.log2_num_regs;
let num_regs = 1usize << log2_num_regs;
let alpha = match log2_num_regs {
4 => 0.673,
5 => 0.697,
6 => 0.709,
_ => 0.7213 / (1.0 + 1.079 / num_regs as f64),
};
HyperLogLog8 {
num_regs,
num_regs_minus_1: (num_regs - 1) as u64,
log2_num_regs,
alpha_m_m: alpha * (num_regs as f64).powi(2),
build_hasher: self.build_hasher,
_marker: PhantomData,
}
}
}
#[allow(dead_code)]
fn merge_max_bytes_scalar(dst: &mut [u8], src: &[u8]) {
for (d, &s) in dst.iter_mut().zip(src.iter()) {
*d = (*d).max(s);
}
}
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn merge_max_bytes_sse2(dst: &mut [u8], src: &[u8]) {
debug_assert_eq!(dst.len(), src.len());
debug_assert!(dst.len() % 16 == 0);
let n = dst.len();
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
let mut i = 0;
while i < n {
unsafe {
let a = _mm_loadu_si128(dst_ptr.add(i) as *const __m128i);
let b = _mm_loadu_si128(src_ptr.add(i) as *const __m128i);
let max = _mm_max_epu8(a, b);
_mm_storeu_si128(dst_ptr.add(i) as *mut __m128i, max);
}
i += 16;
}
}
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn merge_max_bytes_neon(dst: &mut [u8], src: &[u8]) {
debug_assert_eq!(dst.len(), src.len());
debug_assert!(dst.len() % 16 == 0);
let n = dst.len();
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
let mut i = 0;
while i < n {
unsafe {
let a = vld1q_u8(dst_ptr.add(i));
let b = vld1q_u8(src_ptr.add(i));
let max = vmaxq_u8(a, b);
vst1q_u8(dst_ptr.add(i), max);
}
i += 16;
}
}
#[cfg(target_arch = "x86_64")]
fn merge_max_bytes(dst: &mut [u8], src: &[u8]) {
unsafe { merge_max_bytes_sse2(dst, src) }
}
#[cfg(target_arch = "x86")]
fn merge_max_bytes(dst: &mut [u8], src: &[u8]) {
if is_x86_feature_detected!("sse2") {
unsafe { merge_max_bytes_sse2(dst, src) }
} else {
merge_max_bytes_scalar(dst, src)
}
}
#[cfg(target_arch = "aarch64")]
fn merge_max_bytes(dst: &mut [u8], src: &[u8]) {
unsafe { merge_max_bytes_neon(dst, src) }
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
fn merge_max_bytes(dst: &mut [u8], src: &[u8]) {
merge_max_bytes_scalar(dst, src)
}