use std::collections::hash_map::DefaultHasher;
use std::f64;
use std::fmt;
use std::hash::{BuildHasher, BuildHasherDefault, Hash};
use std::marker::PhantomData;
use num_traits::{CheckedAdd, One, Unsigned, Zero};
use crate::hash_utils::HashIterBuilder;
#[derive(Clone)]
pub struct CountMinSketch<T, C = usize, B = BuildHasherDefault<DefaultHasher>>
where
T: Hash + ?Sized,
C: CheckedAdd + Clone + One + Ord + Unsigned + Zero,
B: BuildHasher + Clone + Eq,
{
table: Vec<C>,
w: usize,
d: usize,
builder: HashIterBuilder<B>,
phantom: PhantomData<fn() -> T>,
}
impl<T, C> CountMinSketch<T, C>
where
T: Hash + ?Sized,
C: CheckedAdd + Clone + One + Ord + Unsigned + Zero,
{
pub fn with_params(w: usize, d: usize) -> Self {
let bh = BuildHasherDefault::<DefaultHasher>::default();
Self::with_params_and_hasher(w, d, bh)
}
pub fn with_point_query_properties(epsilon: f64, delta: f64) -> Self {
let bh = BuildHasherDefault::<DefaultHasher>::default();
Self::with_point_query_properties_and_hasher(epsilon, delta, bh)
}
}
impl<T, C, B> CountMinSketch<T, C, B>
where
T: Hash + ?Sized,
C: CheckedAdd + Clone + One + Ord + Unsigned + Zero,
B: BuildHasher + Clone + Eq,
{
pub fn with_params_and_hasher(w: usize, d: usize, buildhasher: B) -> Self {
let table = vec![C::zero(); w.checked_mul(d).unwrap()];
Self {
table,
w,
d,
builder: HashIterBuilder::new(w, d, buildhasher),
phantom: PhantomData,
}
}
pub fn with_point_query_properties_and_hasher(
epsilon: f64,
delta: f64,
buildhasher: B,
) -> Self {
assert!(epsilon > 0., "epsilon must be greater than 0");
assert!(
(delta > 0.) & (delta < 1.),
"delta ({}) must be greater than 0 and smaller than 1",
delta
);
let w = (f64::consts::E / epsilon).ceil() as usize;
let d = (1. / delta).ln().ceil() as usize;
Self::with_params_and_hasher(w, d, buildhasher)
}
pub fn w(&self) -> usize {
self.w
}
pub fn d(&self) -> usize {
self.d
}
pub fn buildhasher(&self) -> &B {
self.builder.buildhasher()
}
pub fn is_empty(&self) -> bool {
self.table.iter().all(|x| x.is_zero())
}
pub fn add(&mut self, obj: &T) -> C {
self.add_n(obj, &C::one())
}
pub fn add_n(&mut self, obj: &T, n: &C) -> C {
let mut result = C::zero();
for (i, pos) in self.builder.iter_for(obj).enumerate() {
let x = i * self.w + pos;
let current = self.table[x].clone();
result = if i == 0 {
current.clone()
} else {
result.min(current.clone())
};
self.table[x] = current.checked_add(n).unwrap();
}
result.checked_add(n).unwrap()
}
pub fn query_point(&self, obj: &T) -> C {
self.builder
.iter_for(obj)
.enumerate()
.map(|(i, pos)| i * self.w + pos)
.map(|x| self.table[x].clone())
.min()
.unwrap()
}
pub fn merge(&mut self, other: &Self) {
assert_eq!(
self.d, other.d,
"number of rows (d) must be equal (left={}, right={})",
self.d, other.d
);
assert_eq!(
self.w, other.w,
"number of columns (w) must be equal (left={}, right={})",
self.w, other.w
);
assert!(
self.buildhasher() == other.buildhasher(),
"buildhasher must be equal"
);
self.table = self
.table
.iter()
.zip(other.table.iter())
.map(|x| x.0.checked_add(x.1).unwrap())
.collect();
}
pub fn clear(&mut self) {
self.table = vec![C::zero(); self.w.checked_mul(self.d).unwrap()];
}
}
impl<T> fmt::Debug for CountMinSketch<T>
where
T: Hash + ?Sized,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CountMinSketch {{ w: {}, d: {} }}", self.w, self.d)
}
}
impl<T> Extend<T> for CountMinSketch<T>
where
T: Hash,
{
fn extend<S: IntoIterator<Item = T>>(&mut self, iter: S) {
for elem in iter {
self.add(&elem);
}
}
}
#[cfg(test)]
mod tests {
use super::CountMinSketch;
use crate::{
hash_utils::BuildHasherSeeded,
test_util::{assert_send, NotSend},
};
#[test]
fn getter() {
let cms = CountMinSketch::<u64>::with_params(10, 20);
assert_eq!(cms.w(), 10);
assert_eq!(cms.d(), 20);
cms.buildhasher();
}
#[test]
fn properties() {
let cms = CountMinSketch::<u64>::with_point_query_properties(0.01, 0.1);
assert_eq!(cms.w(), 272);
assert_eq!(cms.d(), 3);
}
#[test]
#[should_panic(expected = "epsilon must be greater than 0")]
fn properties_panics_epsilon0() {
CountMinSketch::<u64>::with_point_query_properties(0., 0.1);
}
#[test]
#[should_panic(expected = "delta (0) must be greater than 0 and smaller than 1")]
fn properties_panics_delta0() {
CountMinSketch::<u64>::with_point_query_properties(0.01, 0.);
}
#[test]
#[should_panic(expected = "delta (1) must be greater than 0 and smaller than 1")]
fn properties_panics_delta1() {
CountMinSketch::<u64>::with_point_query_properties(0.01, 1.);
}
#[test]
fn empty() {
let cms = CountMinSketch::<u64>::with_params(10, 10);
assert_eq!(cms.query_point(&1u64), 0);
assert!(cms.is_empty());
}
#[test]
fn add_1() {
let mut cms = CountMinSketch::<u64>::with_params(10, 10);
assert_eq!(cms.add(&1), 1);
assert_eq!(cms.query_point(&1), 1);
assert_eq!(cms.query_point(&2), 0);
}
#[test]
fn add_2() {
let mut cms = CountMinSketch::<u64>::with_params(10, 10);
assert_eq!(cms.add(&1), 1);
assert_eq!(cms.add(&1), 2);
assert_eq!(cms.query_point(&1), 2);
assert_eq!(cms.query_point(&2), 0);
}
#[test]
fn add_2_1a() {
let mut cms = CountMinSketch::<u64>::with_params(10, 10);
assert_eq!(cms.add(&1), 1);
assert_eq!(cms.add(&2), 1);
assert_eq!(cms.add(&1), 2);
assert_eq!(cms.query_point(&1), 2);
assert_eq!(cms.query_point(&2), 1);
assert_eq!(cms.query_point(&3), 0);
}
#[test]
fn add_2_1b() {
let mut cms = CountMinSketch::<u64>::with_params(10, 10);
assert_eq!(cms.add_n(&1, &2), 2);
assert_eq!(cms.add(&2), 1);
assert_eq!(cms.query_point(&1), 2);
assert_eq!(cms.query_point(&2), 1);
assert_eq!(cms.query_point(&3), 0);
}
#[test]
fn merge() {
let mut cms1 = CountMinSketch::<u64>::with_params(10, 10);
let mut cms2 = CountMinSketch::<u64>::with_params(10, 10);
cms1.add_n(&1, &1);
cms1.add_n(&2, &2);
assert_eq!(cms1.query_point(&1), 1);
assert_eq!(cms1.query_point(&2), 2);
assert_eq!(cms1.query_point(&3), 0);
assert_eq!(cms1.query_point(&4), 0);
cms2.add_n(&2, &20);
cms2.add_n(&3, &30);
assert_eq!(cms2.query_point(&1), 0);
assert_eq!(cms2.query_point(&2), 20);
assert_eq!(cms2.query_point(&3), 30);
assert_eq!(cms2.query_point(&4), 0);
cms1.merge(&cms2);
assert_eq!(cms1.query_point(&1), 1);
assert_eq!(cms1.query_point(&2), 22);
assert_eq!(cms1.query_point(&3), 30);
assert_eq!(cms1.query_point(&4), 0);
}
#[test]
#[should_panic(expected = "number of columns (w) must be equal (left=10, right=20)")]
fn merge_panics_w() {
let mut cms1 = CountMinSketch::<u64>::with_params(10, 10);
let cms2 = CountMinSketch::<u64>::with_params(20, 10);
cms1.merge(&cms2);
}
#[test]
#[should_panic(expected = "number of rows (d) must be equal (left=10, right=20)")]
fn merge_panics_d() {
let mut cms1 = CountMinSketch::<u64>::with_params(10, 10);
let cms2 = CountMinSketch::<u64>::with_params(10, 20);
cms1.merge(&cms2);
}
#[test]
#[should_panic(expected = "buildhasher must be equal")]
fn merge_panics_buildhasher() {
let mut cms1 = CountMinSketch::<u64, usize, BuildHasherSeeded>::with_params_and_hasher(
10,
10,
BuildHasherSeeded::new(0),
);
let cms2 = CountMinSketch::<u64, usize, BuildHasherSeeded>::with_params_and_hasher(
10,
10,
BuildHasherSeeded::new(1),
);
cms1.merge(&cms2);
}
#[test]
fn clear() {
let mut cms = CountMinSketch::<u64>::with_params(10, 10);
cms.add(&1);
assert_eq!(cms.query_point(&1), 1);
assert!(!cms.is_empty());
cms.clear();
assert_eq!(cms.query_point(&1), 0);
assert!(cms.is_empty());
}
#[test]
fn clone() {
let mut cms1 = CountMinSketch::<u64>::with_params(10, 10);
cms1.add(&1);
assert_eq!(cms1.query_point(&1), 1);
assert_eq!(cms1.query_point(&2), 0);
let cms2 = cms1.clone();
assert_eq!(cms2.query_point(&1), 1);
assert_eq!(cms2.query_point(&2), 0);
cms1.add(&1);
assert_eq!(cms1.query_point(&1), 2);
assert_eq!(cms1.query_point(&2), 0);
assert_eq!(cms2.query_point(&1), 1);
assert_eq!(cms2.query_point(&2), 0);
}
#[test]
fn debug() {
let cms = CountMinSketch::<u64>::with_params(10, 20);
assert_eq!(format!("{:?}", cms), "CountMinSketch { w: 10, d: 20 }");
}
#[test]
fn extend() {
let mut cms = CountMinSketch::<u64>::with_params(10, 10);
cms.extend(vec![1, 1]);
assert_eq!(cms.query_point(&1), 2);
assert_eq!(cms.query_point(&2), 0);
}
#[test]
fn add_unsized() {
let mut cms = CountMinSketch::<str, usize>::with_params(10, 10);
assert_eq!(cms.add("test"), 1);
assert_eq!(cms.query_point("test"), 1);
assert_eq!(cms.query_point("foo"), 0);
}
#[test]
fn send() {
let cms = CountMinSketch::<NotSend>::with_params(10, 10);
assert_send(&cms);
}
}