use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::common::NumStdDev;
use crate::error::Error;
use crate::hll::estimator::HipEstimator;
use crate::hll::get_slot;
use crate::hll::get_value;
#[derive(Debug, Clone, PartialEq)]
pub struct Array8 {
lg_config_k: u8,
bytes: Box<[u8]>,
num_zeros: u32,
estimator: HipEstimator,
}
impl Array8 {
pub fn new(lg_config_k: u8) -> Self {
let k = 1 << lg_config_k;
Self {
lg_config_k,
bytes: vec![0u8; k as usize].into_boxed_slice(),
num_zeros: k,
estimator: HipEstimator::new(lg_config_k),
}
}
#[inline]
pub fn get(&self, slot: u32) -> u8 {
self.bytes[slot as usize]
}
#[inline]
fn put(&mut self, slot: u32, value: u8) {
self.bytes[slot as usize] = value;
}
pub fn update(&mut self, coupon: u32) {
let mask = (1 << self.lg_config_k) - 1;
let slot = get_slot(coupon) & mask;
let new_value = get_value(coupon);
let old_value = self.get(slot);
if new_value > old_value {
self.estimator
.update(self.lg_config_k, old_value, new_value);
self.put(slot, new_value);
if old_value == 0 {
self.num_zeros -= 1;
}
}
}
pub fn estimate(&self) -> f64 {
self.estimator.estimate(self.lg_config_k, 0, self.num_zeros)
}
pub fn upper_bound(&self, num_std_dev: NumStdDev) -> f64 {
self.estimator
.upper_bound(self.lg_config_k, 0, self.num_zeros, num_std_dev)
}
pub fn lower_bound(&self, num_std_dev: NumStdDev) -> f64 {
self.estimator
.lower_bound(self.lg_config_k, 0, self.num_zeros, num_std_dev)
}
pub fn set_hip_accum(&mut self, value: f64) {
self.estimator.set_hip_accum(value);
}
pub fn is_empty(&self) -> bool {
self.num_zeros == (1 << self.lg_config_k)
}
pub(super) fn values(&self) -> &[u8] {
&self.bytes
}
pub(super) fn num_registers(&self) -> usize {
1 << self.lg_config_k
}
pub(super) fn hip_accum(&self) -> f64 {
self.estimator.hip_accum()
}
pub(super) fn set_register(&mut self, slot: usize, value: u8) {
self.bytes[slot] = value;
}
pub(super) fn rebuild_estimator_from_registers(&mut self) {
self.rebuild_cached_values();
self.estimator.set_out_of_order(true);
}
pub(super) fn merge_array_same_lgk(&mut self, src: &[u8]) {
assert_eq!(
src.len(),
self.bytes.len(),
"Source and destination must have same lg_k"
);
for (i, &val) in src.iter().enumerate() {
self.bytes[i] = self.bytes[i].max(val);
}
self.rebuild_cached_values();
self.estimator.set_out_of_order(true);
}
pub(super) fn merge_array_with_downsample(&mut self, src: &[u8], src_lg_k: u8) {
assert!(
src_lg_k > self.lg_config_k,
"Source lg_k must be greater than destination lg_k for downsampling"
);
assert_eq!(
src.len(),
1 << src_lg_k,
"Source length must match 2^src_lg_k"
);
let dst_mask = (1 << self.lg_config_k) - 1;
for (src_slot, &val) in src.iter().enumerate() {
let dst_slot = (src_slot as u32 & dst_mask) as usize;
self.bytes[dst_slot] = self.bytes[dst_slot].max(val);
}
self.rebuild_cached_values();
self.estimator.set_out_of_order(true);
}
fn rebuild_cached_values(&mut self) {
self.num_zeros = self.bytes.iter().filter(|&&v| v == 0).count() as u32;
let mut kxq0_sum = 0.0;
let mut kxq1_sum = 0.0;
for &val in self.bytes.iter() {
if val == 0 {
kxq0_sum += 1.0;
} else if val < 32 {
kxq0_sum += 1.0 / (1u64 << val) as f64;
} else {
kxq1_sum += 1.0 / (1u64 << val) as f64;
}
}
self.estimator.set_kxq0(kxq0_sum);
self.estimator.set_kxq1(kxq1_sum);
}
pub fn deserialize(
mut cursor: SketchSlice,
lg_config_k: u8,
compact: bool,
ooo: bool,
) -> Result<Self, Error> {
fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error {
move |_| Error::insufficient_data(tag)
}
let k = 1usize << lg_config_k;
let hip_accum = cursor.read_f64_le().map_err(make_error("hip_accum"))?;
let kxq0 = cursor.read_f64_le().map_err(make_error("kxq0"))?;
let kxq1 = cursor.read_f64_le().map_err(make_error("kxq1"))?;
let num_zeros = cursor.read_u32_le().map_err(make_error("num_zeros"))?;
let _aux_count = cursor.read_u32_le().map_err(make_error("aux_count"))?;
let mut data = vec![0u8; k];
if !compact {
cursor.read_exact(&mut data).map_err(make_error("data"))?;
} else {
cursor.advance(k as u64);
}
let mut estimator = HipEstimator::new(lg_config_k);
estimator.set_hip_accum(hip_accum);
estimator.set_kxq0(kxq0);
estimator.set_kxq1(kxq1);
estimator.set_out_of_order(ooo);
Ok(Self {
lg_config_k,
bytes: data.into_boxed_slice(),
num_zeros,
estimator,
})
}
pub fn serialize(&self, lg_config_k: u8) -> Vec<u8> {
use crate::hll::serialization::*;
let k = 1 << lg_config_k;
let total_size = HLL_PREAMBLE_SIZE + k as usize;
let mut bytes = SketchBytes::with_capacity(total_size);
bytes.write_u8(HLL_PREINTS);
bytes.write_u8(SERIAL_VERSION);
bytes.write_u8(HLL_FAMILY_ID);
bytes.write_u8(lg_config_k);
bytes.write_u8(0);
let mut flags = 0u8;
if self.estimator.is_out_of_order() {
flags |= OUT_OF_ORDER_FLAG_MASK;
}
bytes.write_u8(flags);
bytes.write_u8(0);
bytes.write_u8(encode_mode_byte(CUR_MODE_HLL, TGT_HLL8));
bytes.write_f64_le(self.estimator.hip_accum());
bytes.write_f64_le(self.estimator.kxq0());
bytes.write_f64_le(self.estimator.kxq1());
bytes.write_u32_le(self.num_zeros);
bytes.write_u32_le(0);
bytes.write(&self.bytes);
bytes.into_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hll::coupon;
use crate::hll::pack_coupon;
#[test]
fn test_array8_basic() {
let arr = Array8::new(10);
assert_eq!(arr.get(0), 0);
assert_eq!(arr.get(100), 0);
assert_eq!(arr.get(1023), 0);
}
#[test]
fn test_get_set() {
let mut arr = Array8::new(4);
for slot in 0..16 {
arr.put(slot, (slot * 17) as u8); }
for slot in 0..16 {
assert_eq!(arr.get(slot), (slot * 17) as u8);
}
arr.put(0, 0);
arr.put(1, 127);
arr.put(2, 255);
assert_eq!(arr.get(0), 0);
assert_eq!(arr.get(1), 127);
assert_eq!(arr.get(2), 255);
}
#[test]
fn test_update_basic() {
let mut arr = Array8::new(4);
arr.update(pack_coupon(0, 5));
assert_eq!(arr.get(0), 5);
arr.update(pack_coupon(0, 3));
assert_eq!(arr.get(0), 5);
arr.update(pack_coupon(0, 42));
assert_eq!(arr.get(0), 42);
arr.update(pack_coupon(1, 63));
assert_eq!(arr.get(1), 63);
}
#[test]
fn test_hip_estimator() {
let mut arr = Array8::new(10);
assert_eq!(arr.estimate(), 0.0);
for i in 0..10_000u32 {
let coupon = coupon(i);
arr.update(coupon);
}
let estimate = arr.estimate();
assert!(estimate > 0.0, "Estimate should be positive");
assert!(estimate.is_finite(), "Estimate should be finite");
assert!(estimate > 1_000.0, "Estimate seems too low");
assert!(estimate < 100_000.0, "Estimate seems too high");
}
#[test]
fn test_full_value_range() {
let mut arr = Array8::new(8);
for val in 0..=255u8 {
arr.put(val as u32, val);
}
for val in 0..=255u8 {
assert_eq!(arr.get(val as u32), val);
}
}
#[test]
fn test_high_value_direct() {
let mut arr = Array8::new(6);
let test_values = [16, 32, 64, 128, 200, 255];
for (slot, &value) in test_values.iter().enumerate() {
arr.put(slot as u32, value);
assert_eq!(arr.get(slot as u32), value);
}
for (slot, &value) in test_values.iter().enumerate() {
assert_eq!(arr.get(slot as u32), value);
}
}
#[test]
fn test_kxq_register_split() {
let mut arr = Array8::new(8);
arr.update(pack_coupon(0, 10)); arr.update(pack_coupon(1, 50));
assert!(arr.estimator.kxq0() < 256.0, "kxq0 should have decreased");
assert!(arr.estimator.kxq1() > 0.0, "kxq1 should be positive");
assert!(
arr.estimator.kxq1() < 1e-10,
"kxq1 should be very small (1/2^50 ≈ 8.9e-16)"
);
}
#[test]
fn test_values_access() {
let mut arr = Array8::new(4);
arr.put(0, 10);
arr.put(5, 25);
arr.put(15, 63);
let vals = arr.values();
assert_eq!(vals.len(), 16);
assert_eq!(vals[0], 10);
assert_eq!(vals[5], 25);
assert_eq!(vals[15], 63);
assert_eq!(vals[1], 0); }
#[test]
fn test_merge_array_same_lgk() {
let mut dst = Array8::new(4); let mut src = Array8::new(4);
dst.put(0, 10);
dst.put(1, 20);
dst.put(2, 30);
src.put(1, 15); src.put(2, 35); src.put(3, 40);
dst.merge_array_same_lgk(src.values());
assert_eq!(dst.get(0), 10, "dst[0] unchanged");
assert_eq!(dst.get(1), 20, "dst[1] kept max value");
assert_eq!(dst.get(2), 35, "dst[2] updated to larger value");
assert_eq!(dst.get(3), 40, "dst[3] got new value");
assert!(dst.estimator.is_out_of_order());
assert_eq!(dst.num_zeros, 12);
}
#[test]
fn test_merge_array_with_downsample() {
let mut dst = Array8::new(4); let mut src = Array8::new(5);
dst.put(0, 10);
dst.put(1, 20);
src.put(0, 15); src.put(16, 25); src.put(1, 18); src.put(17, 30);
dst.merge_array_with_downsample(src.values(), 5);
assert_eq!(dst.get(0), 25, "dst[0] = max(10, 15, 25)");
assert_eq!(dst.get(1), 30, "dst[1] = max(20, 18, 30)");
assert!(dst.estimator.is_out_of_order());
}
#[test]
#[should_panic(expected = "Source and destination must have same lg_k")]
fn test_merge_same_lgk_panics_on_size_mismatch() {
let mut dst = Array8::new(4); let src = Array8::new(5);
dst.merge_array_same_lgk(src.values());
}
#[test]
#[should_panic(expected = "Source lg_k must be greater")]
fn test_merge_downsample_panics_if_not_downsampling() {
let mut dst = Array8::new(5); let src = Array8::new(4);
dst.merge_array_with_downsample(src.values(), 4);
}
#[test]
fn test_rebuild_cached_values() {
let mut arr = Array8::new(4);
arr.put(0, 10);
arr.put(1, 20);
arr.put(2, 30);
arr.num_zeros = 999;
arr.rebuild_cached_values();
assert_eq!(arr.num_zeros, 13);
}
#[test]
fn test_merge_preserves_max_semantics() {
let mut dst = Array8::new(4);
let mut src = Array8::new(4);
for i in 0..16 {
dst.put(i, i as u8);
}
for i in 0..16 {
src.put(i, (15 - i) as u8);
}
dst.merge_array_same_lgk(src.values());
for i in 0..16 {
let expected = (i as u8).max((15 - i) as u8);
assert_eq!(
dst.get(i),
expected,
"slot {} should be max({}, {}) = {}",
i,
i,
15 - i,
expected
);
}
}
}