use serde::{Deserialize, Serialize};
use std::{
borrow::Borrow, cmp::max, convert::TryFrom, fmt, hash::{Hash, Hasher}, marker::PhantomData, ops
};
use twox_hash::XxHash;
use super::f64_to_usize;
use crate::traits::{Intersect, IntersectPlusUnionIsPlus, New, UnionAssign};
#[derive(Serialize, Deserialize)]
#[serde(bound(
serialize = "C: Serialize, <C as New>::Config: Serialize",
deserialize = "C: Deserialize<'de>, <C as New>::Config: Deserialize<'de>"
))]
pub struct CountMinSketch<K: ?Sized, C: New> {
counters: Vec<Vec<C>>,
offsets: Vec<usize>, mask: usize,
k_num: usize,
config: <C as New>::Config,
marker: PhantomData<fn(K)>,
}
impl<K: ?Sized, C> CountMinSketch<K, C>
where
K: Hash,
C: New + for<'a> UnionAssign<&'a C> + Intersect,
{
pub fn new(probability: f64, tolerance: f64, config: C::Config) -> Self {
let width = Self::optimal_width(tolerance);
let k_num = Self::optimal_k_num(probability);
let counters: Vec<Vec<C>> = (0..k_num)
.map(|_| (0..width).map(|_| C::new(&config)).collect())
.collect();
let offsets = vec![0; k_num];
Self {
counters,
offsets,
mask: Self::mask(width),
k_num,
config,
marker: PhantomData,
}
}
pub fn push<Q: ?Sized, V: ?Sized>(&mut self, key: &Q, value: &V) -> C
where
Q: Hash,
K: Borrow<Q>,
C: for<'a> ops::AddAssign<&'a V> + IntersectPlusUnionIsPlus,
{
if !<C as IntersectPlusUnionIsPlus>::VAL {
let offsets = self.offsets(key);
self.offsets
.iter_mut()
.zip(offsets)
.for_each(|(offset, offset_new)| {
*offset = offset_new;
});
let mut lowest = C::intersect(
self.offsets
.iter()
.enumerate()
.map(|(k_i, &offset)| &self.counters[k_i][offset]),
)
.unwrap();
lowest += value;
self.counters
.iter_mut()
.zip(self.offsets.iter())
.for_each(|(counters, &offset)| {
counters[offset].union_assign(&lowest);
});
lowest
} else {
let offsets = self.offsets(key);
C::intersect(
self.counters
.iter_mut()
.zip(offsets)
.map(|(counters, offset)| {
counters[offset] += value;
&counters[offset]
}),
)
.unwrap()
}
}
pub fn union_assign<Q: ?Sized>(&mut self, key: &Q, value: &C)
where
Q: Hash,
K: Borrow<Q>,
{
let offsets = self.offsets(key);
self.counters
.iter_mut()
.zip(offsets)
.for_each(|(counters, offset)| {
counters[offset].union_assign(value);
})
}
pub fn get<Q: ?Sized>(&self, key: &Q) -> C
where
Q: Hash,
K: Borrow<Q>,
{
C::intersect(
self.counters
.iter()
.zip(self.offsets(key))
.map(|(counters, offset)| &counters[offset]),
)
.unwrap()
}
pub fn clear(&mut self) {
let config = &self.config;
self.counters
.iter_mut()
.flat_map(|x| x.iter_mut())
.for_each(|counter| {
*counter = C::new(config);
})
}
fn optimal_width(tolerance: f64) -> usize {
let e = tolerance;
let width = f64_to_usize((2.0 / e).round());
max(2, width)
.checked_next_power_of_two()
.expect("Width would be way too large")
}
fn mask(width: usize) -> usize {
assert!(width > 1);
assert_eq!(width & (width - 1), 0);
width - 1
}
fn optimal_k_num(probability: f64) -> usize {
max(
1,
f64_to_usize(((1.0 - probability).ln() / 0.5_f64.ln()).floor()),
)
}
fn offsets<Q: ?Sized>(&self, key: &Q) -> impl Iterator<Item = usize>
where
Q: Hash,
K: Borrow<Q>,
{
let mask = self.mask;
hashes(key).map(move |hash| usize::try_from(hash & u64::try_from(mask).unwrap()).unwrap())
}
}
fn hashes<Q: ?Sized>(key: &Q) -> impl Iterator<Item = u64>
where
Q: Hash,
{
#[allow(missing_copy_implementations, missing_debug_implementations)]
struct X(XxHash);
impl Iterator for X {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
let ret = self.0.finish();
self.0.write(&[123]);
Some(ret)
}
}
let mut hasher = XxHash::default();
key.hash(&mut hasher);
X(hasher)
}
impl<K: ?Sized, C: New + Clone> Clone for CountMinSketch<K, C> {
fn clone(&self) -> Self {
Self {
counters: self.counters.clone(),
offsets: vec![0; self.offsets.len()],
mask: self.mask,
k_num: self.k_num,
config: self.config.clone(),
marker: PhantomData,
}
}
}
impl<K: ?Sized, C: New> fmt::Debug for CountMinSketch<K, C> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("CountMinSketch")
.finish()
}
}
#[cfg(test)]
mod tests {
type CountMinSketch8<K> = super::CountMinSketch<K, u8>;
type CountMinSketch16<K> = super::CountMinSketch<K, u16>;
type CountMinSketch64<K> = super::CountMinSketch<K, u64>;
#[ignore] #[test]
#[should_panic]
fn test_overflow() {
let mut cms = CountMinSketch8::<&str>::new(0.95, 10.0 / 100.0, ());
for _ in 0..300 {
let _ = cms.push("key", &1);
}
}
#[test]
fn test_increment() {
let mut cms = CountMinSketch16::<&str>::new(0.95, 10.0 / 100.0, ());
for _ in 0..300 {
let _ = cms.push("key", &1);
}
assert_eq!(cms.get("key"), 300);
}
#[test]
fn test_increment_multi() {
let mut cms = CountMinSketch64::<u64>::new(0.99, 2.0 / 100.0, ());
for i in 0..1_000_000 {
let _ = cms.push(&(i % 100), &1);
}
for key in 0..100 {
assert!(cms.get(&key) >= 9_000);
}
}
}