use std::hash::Hash;
use std::hash::Hasher;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::assert::ensure_preamble_longs_in;
use crate::codec::assert::ensure_serial_version_is;
use crate::codec::assert::insufficient_data;
use crate::codec::family::Family;
use crate::countmin::CountMinValue;
use crate::countmin::UnsignedCountMinValue;
use crate::countmin::serialization::FLAGS_IS_EMPTY;
use crate::countmin::serialization::LONG_SIZE_BYTES;
use crate::countmin::serialization::PREAMBLE_LONGS_SHORT;
use crate::countmin::serialization::SERIAL_VERSION;
use crate::error::Error;
use crate::hash::DEFAULT_UPDATE_SEED;
use crate::hash::MurmurHash3X64128;
use crate::hash::compute_seed_hash;
const MAX_TABLE_ENTRIES: usize = 1 << 30;
#[derive(Debug, Clone, PartialEq)]
pub struct CountMinSketch<T: CountMinValue> {
num_hashes: u8,
num_buckets: u32,
seed: u64,
seed_hash: u16,
total_weight: T,
counts: Vec<T>,
hash_seeds: Vec<u64>,
}
impl<T: CountMinValue> CountMinSketch<T> {
pub fn new(num_hashes: u8, num_buckets: u32) -> Self {
Self::with_seed(num_hashes, num_buckets, DEFAULT_UPDATE_SEED)
}
pub fn with_seed(num_hashes: u8, num_buckets: u32, seed: u64) -> Self {
let entries = entries_for_config(num_hashes, num_buckets);
Self::make(num_hashes, num_buckets, seed, entries)
}
pub fn num_hashes(&self) -> u8 {
self.num_hashes
}
pub fn num_buckets(&self) -> u32 {
self.num_buckets
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn total_weight(&self) -> T {
self.total_weight
}
pub fn relative_error(&self) -> f64 {
std::f64::consts::E / self.num_buckets as f64
}
pub fn is_empty(&self) -> bool {
self.total_weight == T::ZERO
}
pub fn suggest_num_buckets(relative_error: f64) -> u32 {
assert!(relative_error >= 0.0, "relative_error must be at least 0");
(std::f64::consts::E / relative_error).ceil() as u32
}
pub fn suggest_num_hashes(confidence: f64) -> u8 {
assert!(
(0.0..=1.0).contains(&confidence),
"confidence must be between 0 and 1.0 (inclusive)"
);
if confidence == 1.0 {
return 127;
}
let hashes = (1.0 / (1.0 - confidence)).ln().ceil();
hashes.min(127.0) as u8
}
pub fn update<I: Hash>(&mut self, item: I) {
self.update_with_weight(item, T::ONE);
}
pub fn update_with_weight<I: Hash>(&mut self, item: I, weight: T) {
if weight == T::ZERO {
return;
}
let abs_weight = weight.abs();
self.total_weight = self.total_weight.add(abs_weight);
let num_buckets = self.num_buckets as usize;
for (row, seed) in self.hash_seeds.iter().enumerate() {
let bucket = self.bucket_index(&item, *seed);
let index = row * num_buckets + bucket;
self.counts[index] = self.counts[index].add(weight);
}
}
pub fn estimate<I: Hash>(&self, item: I) -> T {
let num_buckets = self.num_buckets as usize;
let mut min = T::MAX;
for (row, seed) in self.hash_seeds.iter().enumerate() {
let bucket = self.bucket_index(&item, *seed);
let index = row * num_buckets + bucket;
let value = self.counts[index];
if value < min {
min = value;
}
}
min
}
pub fn lower_bound<I: Hash>(&self, item: I) -> T {
self.estimate(item)
}
pub fn upper_bound<I: Hash>(&self, item: I) -> T {
let estimate = self.estimate(item);
let error = T::from_f64(self.relative_error() * self.total_weight.to_f64());
estimate.add(error)
}
pub fn merge(&mut self, other: &CountMinSketch<T>) {
if std::ptr::eq(self, other) {
panic!("Cannot merge a sketch with itself.");
}
assert_eq!(self.num_hashes, other.num_hashes);
assert_eq!(self.num_buckets, other.num_buckets);
assert_eq!(self.seed, other.seed);
assert_eq!(self.counts.len(), other.counts.len());
let counts_len = self.counts.len();
for i in 0..counts_len {
self.counts[i] = self.counts[i].add(other.counts[i]);
}
self.total_weight = self.total_weight.add(other.total_weight);
}
pub fn serialize(&self) -> Vec<u8> {
let header_size = PREAMBLE_LONGS_SHORT as usize * LONG_SIZE_BYTES;
let value_size = LONG_SIZE_BYTES;
let payload_size = if self.is_empty() {
0
} else {
value_size + (self.counts.len() * value_size)
};
let mut bytes = SketchBytes::with_capacity(header_size + payload_size);
bytes.write_u8(PREAMBLE_LONGS_SHORT);
bytes.write_u8(SERIAL_VERSION);
bytes.write_u8(Family::COUNTMIN.id);
bytes.write_u8(if self.is_empty() { FLAGS_IS_EMPTY } else { 0 });
bytes.write_u32_le(0);
bytes.write_u32_le(self.num_buckets);
bytes.write_u8(self.num_hashes);
debug_assert_eq!(self.seed_hash, compute_seed_hash(self.seed));
bytes.write_u16_le(self.seed_hash);
bytes.write_u8(0);
if self.is_empty() {
return bytes.into_bytes();
}
bytes.write(&self.total_weight.to_bytes());
for count in &self.counts {
bytes.write(&count.to_bytes());
}
bytes.into_bytes()
}
pub fn deserialize(bytes: &[u8]) -> Result<Self, Error> {
Self::deserialize_with_seed(bytes, DEFAULT_UPDATE_SEED)
}
pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self, Error> {
fn read_value<T: CountMinValue>(
cursor: &mut SketchSlice<'_>,
tag: &'static str,
) -> Result<T, Error> {
let mut bs = [0u8; 8];
cursor.read_exact(&mut bs).map_err(insufficient_data(tag))?;
T::try_from_bytes(bs)
}
let mut cursor = SketchSlice::new(bytes);
let preamble_longs = cursor
.read_u8()
.map_err(insufficient_data("preamble_longs"))?;
let serial_version = cursor
.read_u8()
.map_err(insufficient_data("serial_version"))?;
let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?;
let flags = cursor.read_u8().map_err(insufficient_data("flags"))?;
cursor
.read_u32_le()
.map_err(insufficient_data("<unused>"))?;
Family::COUNTMIN.validate_id(family_id)?;
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
ensure_preamble_longs_in(&[PREAMBLE_LONGS_SHORT], preamble_longs)?;
let num_buckets = cursor
.read_u32_le()
.map_err(insufficient_data("num_buckets"))?;
let num_hashes = cursor.read_u8().map_err(insufficient_data("num_hashes"))?;
let seed_hash = cursor
.read_u16_le()
.map_err(insufficient_data("seed_hash"))?;
cursor.read_u8().map_err(insufficient_data("unused8"))?;
let expected_seed_hash = compute_seed_hash(seed);
if seed_hash != expected_seed_hash {
return Err(Error::deserial(format!(
"incompatible seed hash: expected {expected_seed_hash}, got {seed_hash}",
)));
}
let entries = entries_for_config_checked(num_hashes, num_buckets)?;
let mut sketch = Self::make(num_hashes, num_buckets, seed, entries);
if (flags & FLAGS_IS_EMPTY) != 0 {
return Ok(sketch);
}
sketch.total_weight = read_value(&mut cursor, "total_weight")?;
for count in &mut sketch.counts {
*count = read_value(&mut cursor, "counts")?;
}
Ok(sketch)
}
fn make(num_hashes: u8, num_buckets: u32, seed: u64, entries: usize) -> Self {
let counts = vec![T::ZERO; entries];
let seed_hash = compute_seed_hash(seed);
let hash_seeds = make_hash_seeds(seed, num_hashes);
CountMinSketch {
num_hashes,
num_buckets,
seed,
seed_hash,
total_weight: T::ZERO,
counts,
hash_seeds,
}
}
fn bucket_index<I: Hash>(&self, item: &I, seed: u64) -> usize {
let mut hasher = MurmurHash3X64128::with_seed(seed);
item.hash(&mut hasher);
let (h1, _) = hasher.finish128();
(h1 % self.num_buckets as u64) as usize
}
}
impl<T: UnsignedCountMinValue> CountMinSketch<T> {
pub fn halve(&mut self) {
for c in &mut self.counts {
*c = c.halve()
}
self.total_weight = self.total_weight.halve();
}
pub fn decay(&mut self, decay: f64) {
assert!(decay > 0.0 && decay <= 1.0, "decay must be within (0, 1]");
for c in &mut self.counts {
*c = c.decay(decay)
}
self.total_weight = self.total_weight.decay(decay);
}
}
fn entries_for_config(num_hashes: u8, num_buckets: u32) -> usize {
assert!(num_hashes > 0, "num_hashes must be at least 1");
assert!(num_buckets >= 3, "num_buckets must be at least 3");
let entries = (num_hashes as usize)
.checked_mul(num_buckets as usize)
.expect("num_hashes * num_buckets overflows usize");
assert!(
entries < MAX_TABLE_ENTRIES,
"num_hashes * num_buckets must be < {}",
MAX_TABLE_ENTRIES
);
entries
}
fn entries_for_config_checked(num_hashes: u8, num_buckets: u32) -> Result<usize, Error> {
if num_hashes == 0 {
return Err(Error::deserial("num_hashes must be at least 1"));
}
if num_buckets < 3 {
return Err(Error::deserial("num_buckets must be at least 3"));
}
let entries = (num_hashes as usize)
.checked_mul(num_buckets as usize)
.ok_or_else(|| Error::deserial("num_hashes * num_buckets overflows usize"))?;
if entries >= MAX_TABLE_ENTRIES {
return Err(Error::deserial(format!(
"num_hashes * num_buckets must be < {MAX_TABLE_ENTRIES}",
)));
}
Ok(entries)
}
fn make_hash_seeds(seed: u64, num_hashes: u8) -> Vec<u64> {
let mut seeds = Vec::with_capacity(num_hashes as usize);
for i in 0..num_hashes {
let mut hasher = MurmurHash3X64128::with_seed(seed);
hasher.write(&u64::from(i).to_le_bytes());
let (h1, _) = hasher.finish128();
seeds.push(h1);
}
seeds
}