use crate::{DoubleHasher, HashIter, SipHasherBuilder};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::hash::BuildHasher;
use std::hash::Hash;
use std::marker::PhantomData;
pub trait CountStrategy {
fn get_estimate(items: i64, rows: usize, cols: usize, iter: ItemValueIter<'_>) -> i64;
}
pub struct CountMinStrategy;
impl CountStrategy for CountMinStrategy {
fn get_estimate(_items: i64, _rows: usize, _cols: usize, iter: ItemValueIter<'_>) -> i64 {
iter.min()
.expect("Expected `CountMinSketch` to be non-empty.")
}
}
pub struct CountMeanStrategy;
impl CountStrategy for CountMeanStrategy {
fn get_estimate(_items: i64, rows: usize, _cols: usize, iter: ItemValueIter<'_>) -> i64 {
(iter.sum::<i64>() as f64 / rows as f64).round() as i64
}
}
pub struct CountMedianBiasStrategy;
impl CountStrategy for CountMedianBiasStrategy {
fn get_estimate(items: i64, rows: usize, cols: usize, iter: ItemValueIter<'_>) -> i64 {
let min_count = CountMinStrategy::get_estimate(items, rows, cols, iter.clone());
let mut items_with_bias: Vec<i64> = iter
.map(|value| value - ((items - value) as f64 / (cols - 1) as f64).ceil() as i64)
.collect();
items_with_bias.sort();
let median_count = items_with_bias[(items_with_bias.len() - 1) / 2];
i64::min(min_count, median_count)
}
}
#[cfg_attr(
feature = "serde",
derive(Deserialize, Serialize),
serde(crate = "serde_crate")
)]
pub struct CountMinSketch<T, U, B = SipHasherBuilder> {
rows: usize,
cols: usize,
items: i64,
grid: Vec<i64>,
hasher: DoubleHasher<U, B>,
_marker: PhantomData<(T, U)>,
}
impl<T, U> CountMinSketch<T, U> {
pub fn new(rows: usize, cols: usize) -> Self {
CountMinSketch {
rows,
cols,
items: 0,
grid: vec![0; rows * cols],
hasher: DoubleHasher::new(),
_marker: PhantomData,
}
}
pub fn from_error(epsilon: f64, delta: f64) -> Self {
let rows = (1.0 / delta).ln().ceil() as usize;
let cols = ((1.0_f64).exp() / epsilon).ceil() as usize;
CountMinSketch {
rows,
cols,
items: 0,
grid: vec![0; rows * cols],
hasher: DoubleHasher::new(),
_marker: PhantomData,
}
}
}
impl<T, U, B> CountMinSketch<T, U, B>
where
T: CountStrategy,
B: BuildHasher,
{
pub fn with_hashers(rows: usize, cols: usize, hash_builders: [B; 2]) -> Self {
CountMinSketch {
rows,
cols,
items: 0,
grid: vec![0; rows * cols],
hasher: DoubleHasher::with_hashers(hash_builders),
_marker: PhantomData,
}
}
pub fn from_error_with_hashers(epsilon: f64, delta: f64, hash_builders: [B; 2]) -> Self {
let rows = (1.0 / delta).ln().ceil() as usize;
let cols = ((1.0_f64).exp() / epsilon).ceil() as usize;
CountMinSketch {
rows,
cols,
items: 0,
grid: vec![0; rows * cols],
hasher: DoubleHasher::with_hashers(hash_builders),
_marker: PhantomData,
}
}
pub fn insert<V>(&mut self, item: &V, value: i64)
where
U: Borrow<V>,
V: Hash + ?Sized,
{
self.items += value;
for (row, hash) in self.hasher.hash(item).take(self.rows).enumerate() {
let offset = hash % self.cols as u64;
self.grid[row * self.cols + offset as usize] += value;
}
}
pub fn remove<V>(&mut self, item: &V, value: i64)
where
U: Borrow<V>,
V: Hash + ?Sized,
{
self.insert(item, -value);
}
pub fn count<V>(&self, item: &V) -> i64
where
U: Borrow<V>,
V: Hash + ?Sized,
{
let iter = ItemValueIter {
row: 0,
rows: self.rows,
cols: self.cols,
hash_iter: self.hasher.hash(item),
grid: &self.grid,
};
T::get_estimate(self.items, self.rows, self.cols, iter)
}
pub fn clear(&mut self) {
for value in &mut self.grid {
*value = 0
}
self.items = 0;
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn confidence(&self) -> f64 {
1.0_f64.exp() / self.cols as f64
}
pub fn error(&self) -> f64 {
1.0_f64 / (self.rows as f64).exp()
}
pub fn hashers(&self) -> &[B; 2] {
&self.hasher.hashers()
}
}
#[derive(Clone)]
pub struct ItemValueIter<'a> {
row: usize,
rows: usize,
cols: usize,
grid: &'a Vec<i64>,
hash_iter: HashIter,
}
impl<'a> Iterator for ItemValueIter<'a> {
type Item = i64;
fn next(&mut self) -> Option<Self::Item> {
if self.row == self.rows {
return None;
}
self.hash_iter.next().map(|hash| {
let offset = (hash % self.cols as u64) + (self.row * self.cols) as u64;
self.row += 1;
self.grid[offset as usize]
})
}
}
#[cfg(test)]
mod tests {
macro_rules! count_min_sketch_tests {
($($name:ident: $strategy:ident,)*) => {
$(
mod $name {
use super::super::{CountMinSketch, $strategy};
use crate::SipHasherBuilder;
#[test]
fn test_new() {
let cms = CountMinSketch::<$strategy, String, SipHasherBuilder>::new(3, 28);
assert_eq!(cms.cols(), 28);
assert_eq!(cms.rows(), 3);
assert!(cms.confidence() <= 0.1);
assert!(cms.error() <= 0.05);
}
#[test]
fn test_from_error() {
let cms = CountMinSketch::<$strategy, String, SipHasherBuilder>::from_error(0.1, 0.05);
assert_eq!(cms.cols(), 28);
assert_eq!(cms.rows(), 3);
assert!(cms.confidence() <= 0.1);
assert!(cms.error() <= 0.05);
}
#[test]
fn test_insert() {
let mut cms = CountMinSketch::<$strategy, String, SipHasherBuilder>::from_error(0.1, 0.05);
cms.insert("foo", 3);
assert_eq!(cms.count("foo"), 3);
}
#[test]
fn test_remove() {
let mut cms = CountMinSketch::<$strategy, String, SipHasherBuilder>::from_error(0.1, 0.05);
cms.insert("foo", 3);
cms.remove("foo", 3);
assert_eq!(cms.count("foo"), 0);
}
#[test]
fn test_clear() {
let mut cms = CountMinSketch::<$strategy, String, SipHasherBuilder>::from_error(0.1, 0.05);
cms.insert("foo", 3);
cms.clear();
assert_eq!(cms.count("foo"), 0);
}
#[cfg(feature = "serde")]
#[test]
fn test_ser_de() {
let mut cms = CountMinSketch::<$strategy, String, SipHasherBuilder>::from_error(0.1, 0.05);
cms.insert("foo", 3);
let serialized_cms = bincode::serialize(&cms).unwrap();
let de_cms: CountMinSketch<$strategy, String> = bincode::deserialize(&serialized_cms).unwrap();
assert_eq!(cms.count("foo"), de_cms.count("foo"));
assert_eq!(cms.rows(), de_cms.rows());
assert_eq!(cms.cols(), de_cms.cols());
assert_eq!(cms.items, de_cms.items);
assert_eq!(cms.grid, de_cms.grid);
assert_eq!(cms.hashers(), de_cms.hashers());
}
}
)*
}
}
count_min_sketch_tests!(
count_min_strategy: CountMinStrategy,
count_mean_strategy: CountMeanStrategy,
count_median_bias_strategy: CountMedianBiasStrategy,
);
}