use rmp_serde::{
decode::Error as RmpDecodeError, encode::Error as RmpEncodeError, from_slice, to_vec_named,
};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::octo_delta::{CM_PROMASK, CmDelta};
use crate::{
DataInput, DefaultMatrixI32, DefaultMatrixI64, DefaultMatrixI128, DefaultXxHasher, FastPath,
FastPathHasher, FixedMatrix, MatrixFastHash, MatrixStorage, NitroTarget, QuickMatrixI64,
QuickMatrixI128, RegularPath, SketchHasher, Vector2D, hash64_seeded,
};
const DEFAULT_ROW_NUM: usize = 3;
const DEFAULT_COL_NUM: usize = 4096;
pub const QUICKSTART_ROW_NUM: usize = 5;
pub const QUICKSTART_COL_NUM: usize = 2048;
const LOWER_32_MASK: u64 = (1u64 << 32) - 1;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound(serialize = "S: Serialize", deserialize = "S: Deserialize<'de>"))]
pub struct CountMin<
S: MatrixStorage = Vector2D<i32>,
Mode = RegularPath,
H: SketchHasher = DefaultXxHasher,
> {
counts: S,
row: usize,
col: usize,
#[serde(skip)]
_mode: PhantomData<Mode>,
#[serde(skip)]
_hasher: PhantomData<H>,
}
impl Default for CountMin<Vector2D<i32>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<i32>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<i64>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<i64>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<i128>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<i128>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<f64>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<Vector2D<f64>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for CountMin<FixedMatrix, RegularPath> {
fn default() -> Self {
CountMin::from_storage(FixedMatrix::default())
}
}
impl Default for CountMin<FixedMatrix, FastPath> {
fn default() -> Self {
CountMin::from_storage(FixedMatrix::default())
}
}
impl Default for CountMin<DefaultMatrixI32, RegularPath> {
fn default() -> Self {
CountMin::from_storage(DefaultMatrixI32::default())
}
}
impl Default for CountMin<DefaultMatrixI32, FastPath> {
fn default() -> Self {
CountMin::from_storage(DefaultMatrixI32::default())
}
}
impl Default for CountMin<QuickMatrixI64, RegularPath> {
fn default() -> Self {
CountMin::from_storage(QuickMatrixI64::default())
}
}
impl Default for CountMin<QuickMatrixI64, FastPath> {
fn default() -> Self {
CountMin::from_storage(QuickMatrixI64::default())
}
}
impl Default for CountMin<QuickMatrixI128, RegularPath> {
fn default() -> Self {
CountMin::from_storage(QuickMatrixI128::default())
}
}
impl Default for CountMin<QuickMatrixI128, FastPath> {
fn default() -> Self {
CountMin::from_storage(QuickMatrixI128::default())
}
}
impl Default for CountMin<DefaultMatrixI64, RegularPath> {
fn default() -> Self {
CountMin::from_storage(DefaultMatrixI64::default())
}
}
impl Default for CountMin<DefaultMatrixI64, FastPath> {
fn default() -> Self {
CountMin::from_storage(DefaultMatrixI64::default())
}
}
impl Default for CountMin<DefaultMatrixI128, RegularPath> {
fn default() -> Self {
CountMin::from_storage(DefaultMatrixI128::default())
}
}
impl Default for CountMin<DefaultMatrixI128, FastPath> {
fn default() -> Self {
CountMin::from_storage(DefaultMatrixI128::default())
}
}
impl<T, M, H: SketchHasher> CountMin<Vector2D<T>, M, H>
where
T: Copy + Default + std::ops::AddAssign,
{
pub fn with_dimensions(rows: usize, cols: usize) -> Self {
let mut sk = CountMin {
counts: Vector2D::init(rows, cols),
row: rows,
col: cols,
_mode: PhantomData,
_hasher: PhantomData,
};
sk.counts.fill(T::default());
sk
}
}
impl<S: MatrixStorage, Mode, H: SketchHasher> CountMin<S, Mode, H> {
pub fn from_storage(counts: S) -> Self {
let row = counts.rows();
let col = counts.cols();
Self {
counts,
row,
col,
_mode: PhantomData,
_hasher: PhantomData,
}
}
#[inline(always)]
pub fn rows(&self) -> usize {
self.counts.rows()
}
#[inline(always)]
pub fn cols(&self) -> usize {
self.counts.cols()
}
pub fn as_storage(&self) -> &S {
&self.counts
}
pub fn as_storage_mut(&mut self) -> &mut S {
&mut self.counts
}
pub fn merge(&mut self, other: &Self) {
let self_rows = self.counts.rows();
let self_cols = self.counts.cols();
assert_eq!(
(self_rows, self_cols),
(other.counts.rows(), other.counts.cols()),
"dimension mismatch while merging CountMin sketches"
);
for i in 0..self_rows {
for j in 0..self_cols {
let value = other.counts.query_one_counter(i, j);
self.counts.increment_by_row(i, j, value);
}
}
}
}
impl<S: MatrixStorage + Serialize, Mode, H: SketchHasher> CountMin<S, Mode, H> {
pub fn serialize_to_bytes(&self) -> Result<Vec<u8>, RmpEncodeError> {
to_vec_named(self)
}
}
impl<S: MatrixStorage + for<'de> Deserialize<'de>, Mode, H: SketchHasher> CountMin<S, Mode, H> {
pub fn deserialize_from_bytes(bytes: &[u8]) -> Result<Self, RmpDecodeError> {
from_slice(bytes)
}
}
impl<S: MatrixStorage, H: SketchHasher> CountMin<S, RegularPath, H>
where
S::Counter: Copy + PartialOrd + From<i32> + std::ops::AddAssign,
{
#[inline(always)]
pub fn insert(&mut self, value: &DataInput) {
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let hashed = H::hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
self.counts.increment_by_row(r, col, S::Counter::from(1));
}
}
#[inline(always)]
pub fn insert_many(&mut self, value: &DataInput, many: S::Counter) {
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let hashed = H::hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
self.counts.increment_by_row(r, col, many);
}
}
#[inline(always)]
pub fn bulk_insert(&mut self, values: &[DataInput]) {
for value in values {
self.insert(value);
}
}
#[inline(always)]
pub fn bulk_insert_many(&mut self, values: &[(DataInput, S::Counter)]) {
for (value, many) in values {
self.insert_many(value, *many);
}
}
#[inline(always)]
pub fn estimate(&self, value: &DataInput) -> S::Counter {
let rows = self.counts.rows();
let cols = self.counts.cols();
let mut min = S::Counter::from(i32::MAX);
for r in 0..rows {
let hashed = H::hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let v = self.counts.query_one_counter(r, col);
if v < min {
min = v;
}
}
min
}
}
pub type CountMinF64<H = DefaultXxHasher> = CountMin<Vector2D<f64>, RegularPath, H>;
impl<S: MatrixStorage<Counter = i32>, H: SketchHasher> CountMin<S, RegularPath, H> {
#[inline(always)]
pub fn insert_emit_delta(&mut self, value: &DataInput, emit: &mut impl FnMut(CmDelta)) {
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let hashed = hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
self.counts.increment_by_row(r, col, 1);
let current = self.counts.query_one_counter(r, col);
if current % CM_PROMASK as i32 == 0 {
emit(CmDelta {
row: r as u16,
col: col as u16,
value: CM_PROMASK,
});
}
}
}
}
impl<S, H: SketchHasher> CountMin<S, FastPath, H>
where
S: MatrixStorage<Counter = i32> + FastPathHasher<H>,
{
#[inline(always)]
pub fn insert_emit_delta(&mut self, value: &DataInput, emit: &mut impl FnMut(CmDelta)) {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let col = hashed_val.col_for_row(r, cols);
self.counts.increment_by_row(r, col, 1);
let current = self.counts.query_one_counter(r, col);
if current % CM_PROMASK as i32 == 0 {
emit(CmDelta {
row: r as u16,
col: col as u16,
value: CM_PROMASK,
});
}
}
}
}
impl<S: MatrixStorage, Mode, H: SketchHasher> CountMin<S, Mode, H>
where
S::Counter: Copy + std::ops::AddAssign + From<i32>,
{
pub fn apply_delta(&mut self, delta: CmDelta) {
self.counts.increment_by_row(
delta.row as usize,
delta.col as usize,
S::Counter::from(delta.value as i32),
);
}
}
impl<S, H: SketchHasher> CountMin<S, FastPath, H>
where
S: MatrixStorage + crate::FastPathHasher<H>,
S::Counter: Copy + PartialOrd + From<i32> + std::ops::AddAssign,
{
#[inline(always)]
pub fn insert(&mut self, value: &DataInput) {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts
.fast_insert(|a, b, _| *a += *b, S::Counter::from(1), &hashed_val);
}
#[inline(always)]
pub fn insert_many(&mut self, value: &DataInput, many: S::Counter) {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts
.fast_insert(|a, b, _| *a += *b, many, &hashed_val);
}
#[inline(always)]
pub fn bulk_insert(&mut self, values: &[DataInput]) {
for value in values {
self.insert(value);
}
}
#[inline(always)]
pub fn bulk_insert_many(&mut self, values: &[(DataInput, S::Counter)]) {
for (value, many) in values {
self.insert_many(value, *many);
}
}
#[inline(always)]
pub fn estimate(&self, value: &DataInput) -> S::Counter {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts.fast_query_min(&hashed_val, |val, _, _| *val)
}
}
impl<S, H: SketchHasher> CountMin<S, FastPath, H>
where
S: MatrixStorage,
S::Counter: Copy + PartialOrd + From<i32> + std::ops::AddAssign,
{
#[inline(always)]
pub fn fast_insert_with_hash_value(&mut self, hashed_val: &H::HashType) {
self.counts
.fast_insert(|a, b, _| *a += *b, S::Counter::from(1), hashed_val);
}
#[inline(always)]
pub fn fast_insert_many_with_hash_value(&mut self, hashed_val: &H::HashType, many: S::Counter) {
self.counts
.fast_insert(|a, b, _| *a += *b, many, hashed_val);
}
#[inline(always)]
pub fn bulk_insert_with_hashes(&mut self, hashes: &[H::HashType]) {
for hashed_val in hashes {
self.fast_insert_with_hash_value(hashed_val);
}
}
#[inline(always)]
pub fn bulk_insert_many_with_hashes(&mut self, hashes: &[(H::HashType, S::Counter)]) {
for (hashed_val, many) in hashes {
self.fast_insert_many_with_hash_value(hashed_val, *many);
}
}
#[inline(always)]
pub fn fast_estimate_with_hash(&self, hashed_val: &H::HashType) -> S::Counter {
self.counts.fast_query_min(hashed_val, |val, _, _| *val)
}
}
impl<H: SketchHasher> CountMin<Vector2D<i32>, FastPath, H> {
pub fn enable_nitro(&mut self, sampling_rate: f64) {
self.counts.enable_nitro(sampling_rate);
}
pub fn disable_nitro(&mut self) {
self.counts.disable_nitro();
}
#[inline(always)]
pub fn fast_insert_nitro(&mut self, value: &DataInput) {
let rows = self.counts.rows();
let delta = self.counts.nitro().delta as i32;
if self.counts.nitro().to_skip >= rows {
self.counts.reduce_nitro_skip(rows);
} else {
let hashed = H::hash128_seeded(0, value);
let r = self.counts.nitro().to_skip;
self.counts.update_by_row(r, hashed, |a, b| *a += b, delta);
self.counts.nitro_mut().draw_geometric();
let temp = self.counts.get_nitro_skip();
self.counts.update_nitro_skip((r + temp + 1) - rows);
}
}
pub fn nitro_estimate(&self, value: &DataInput) -> f64 {
let hashed_val = <Vector2D<i32> as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts
.fast_query_median(&hashed_val, |val, _, _| (*val) as f64)
}
}
impl<H: SketchHasher> NitroTarget for CountMin<Vector2D<i32>, FastPath, H> {
#[inline(always)]
fn rows(&self) -> usize {
self.counts.rows()
}
#[inline(always)]
fn update_row(&mut self, row: usize, hashed: u128, delta: u64) {
self.counts
.update_by_row(row, hashed, |a, b| *a += b, delta as i32);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{
all_counter_zero_i32, counter_index, sample_uniform_f64, sample_zipf_u64,
};
use crate::{DataInput, hash64_seeded};
use core::f64;
use std::collections::HashMap;
#[test]
fn countmin_insert_emit_delta_emits_at_threshold_and_resets_period() {
let mut sketch = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let key = DataInput::U64(42);
let mut deltas: Vec<CmDelta> = Vec::new();
for _ in 0..(CM_PROMASK - 1) {
sketch.insert_emit_delta(&key, &mut |d| deltas.push(d));
}
assert!(
deltas.is_empty(),
"regular CMS worker path should not emit before threshold"
);
sketch.insert_emit_delta(&key, &mut |d| deltas.push(d));
assert_eq!(
deltas.len(),
3,
"should emit one delta per row at threshold"
);
assert!(deltas.iter().all(|d| d.value == CM_PROMASK));
for _ in 0..(CM_PROMASK - 1) {
sketch.insert_emit_delta(&key, &mut |d| deltas.push(d));
}
assert_eq!(deltas.len(), 3, "no second emission before next threshold");
sketch.insert_emit_delta(&key, &mut |d| deltas.push(d));
assert_eq!(deltas.len(), 6, "should emit again on next threshold");
}
#[test]
fn countmin_apply_delta_increments_parent_counter() {
let mut parent = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let delta = CmDelta {
row: 1,
col: 5,
value: CM_PROMASK,
};
parent.apply_delta(delta);
assert_eq!(
parent.as_storage().query_one_counter(1, 5),
CM_PROMASK as i32
);
}
fn run_zipf_stream(
rows: usize,
cols: usize,
domain: usize,
exponent: f64,
samples: usize,
seed: u64,
) -> (CountMin<Vector2D<i32>, RegularPath>, HashMap<u64, i32>) {
let mut truth = HashMap::<u64, i32>::new();
let mut sketch = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(rows, cols);
for value in sample_zipf_u64(domain, exponent, samples, seed) {
let key = DataInput::U64(value);
sketch.insert(&key);
*truth.entry(value).or_insert(0) += 1;
}
(sketch, truth)
}
fn run_zipf_stream_fast(
rows: usize,
cols: usize,
domain: usize,
exponent: f64,
samples: usize,
seed: u64,
) -> (CountMin<Vector2D<i32>, FastPath>, HashMap<u64, i32>) {
let mut truth = HashMap::<u64, i32>::new();
let mut sketch = CountMin::<Vector2D<i32>, FastPath>::with_dimensions(rows, cols);
for value in sample_zipf_u64(domain, exponent, samples, seed) {
let key = DataInput::U64(value);
sketch.insert(&key);
*truth.entry(value).or_insert(0) += 1;
}
(sketch, truth)
}
fn run_uniform_stream(
rows: usize,
cols: usize,
min: f64,
max: f64,
samples: usize,
seed: u64,
) -> (CountMin<Vector2D<i32>, RegularPath>, HashMap<u64, i32>) {
let mut truth = HashMap::<u64, i32>::new();
let mut sketch = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(rows, cols);
for value in sample_uniform_f64(min, max, samples, seed) {
let key = DataInput::F64(value);
sketch.insert(&key);
*truth.entry(value.to_bits()).or_insert(0) += 1;
}
(sketch, truth)
}
fn run_uniform_stream_fast(
rows: usize,
cols: usize,
min: f64,
max: f64,
samples: usize,
seed: u64,
) -> (CountMin<Vector2D<i32>, FastPath>, HashMap<u64, i32>) {
let mut truth = HashMap::<u64, i32>::new();
let mut sketch = CountMin::<Vector2D<i32>, FastPath>::with_dimensions(rows, cols);
for value in sample_uniform_f64(min, max, samples, seed) {
let key = DataInput::F64(value);
sketch.insert(&key);
*truth.entry(value.to_bits()).or_insert(0) += 1;
}
(sketch, truth)
}
#[test]
fn dimension_test() {
let cm = CountMin::<Vector2D<i32>, RegularPath>::default();
assert_eq!(cm.rows(), 3);
assert_eq!(cm.cols(), 4096);
let storage = cm.as_storage();
all_counter_zero_i32(storage);
let cm_customize = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(3, 17);
assert_eq!(cm_customize.rows(), 3);
assert_eq!(cm_customize.cols(), 17);
let storage_customize = cm_customize.as_storage();
all_counter_zero_i32(storage_customize);
}
#[test]
fn fast_insert_same_estimate() {
let mut slow = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let mut fast = CountMin::<Vector2D<i32>, FastPath>::with_dimensions(3, 64);
let keys = vec![
DataInput::Str("alpha"),
DataInput::Str("beta"),
DataInput::Str("gamma"),
DataInput::Str("delta"),
DataInput::Str("epsilon"),
];
for key in &keys {
slow.insert(key);
fast.insert(key);
}
for key in &keys {
assert_eq!(
slow.estimate(key),
fast.estimate(key),
"fast path should match standard insert for key {key:?}"
);
}
}
#[test]
fn merge_adds_counters_element_wise() {
let mut left = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(2, 32);
let mut right = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(2, 32);
let key = DataInput::Str("delta");
left.insert(&key);
right.insert(&key);
right.insert(&key);
let left_indices: Vec<_> = (0..left.rows())
.map(|row| counter_index(row, &key, left.cols()))
.collect();
left.merge(&right);
for (row, idx) in left_indices.into_iter().enumerate() {
assert_eq!(left.as_storage().query_one_counter(row, idx), 3);
}
}
#[test]
#[should_panic(expected = "dimension mismatch while merging CountMin sketches")]
fn merge_requires_matching_dimensions() {
let mut left = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(2, 32);
let right = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(3, 32);
left.merge(&right);
}
#[test]
fn cm_regular_path_correctness() {
let mut sk = CountMin::<Vector2D<i32>, RegularPath>::default();
for i in 0..10 {
sk.insert(&DataInput::I32(i));
}
let storage = sk.as_storage();
let rows = storage.rows();
let cols = storage.cols();
let mut expected_once = vec![0_i32; rows * cols];
for i in 0..10 {
let value = DataInput::I32(i);
for r in 0..rows {
let hashed = hash64_seeded(r, &value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let idx = r * cols + col;
expected_once[idx] += 1;
}
}
assert_eq!(storage.as_slice(), expected_once.as_slice());
for i in 0..10 {
sk.insert(&DataInput::I32(i));
}
let expected_twice: Vec<i32> = expected_once.iter().map(|v| v * 2).collect();
assert_eq!(sk.as_storage().as_slice(), expected_twice.as_slice());
for i in 0..10 {
assert_eq!(
sk.estimate(&DataInput::I32(i)),
2,
"estimate for {i} should be 2, but get {}",
sk.estimate(&DataInput::I32(i))
)
}
}
#[test]
fn cm_fast_path_correctness() {
let mut sk = CountMin::<Vector2D<i32>, FastPath>::default();
for i in 0..10 {
sk.insert(&DataInput::I32(i));
}
let storage = sk.as_storage();
let rows = storage.rows();
let cols = storage.cols();
let mask_bits = storage.get_mask_bits();
let mask = (1u64 << mask_bits) - 1;
let mut expected_once = vec![0_i32; rows * cols];
for i in 0..10 {
let value = DataInput::I32(i);
let hash = hash64_seeded(0, &value);
for row in 0..rows {
let hashed = (hash >> (mask_bits as usize * row)) & mask;
let col = (hashed as usize) % cols;
let idx = row * cols + col;
expected_once[idx] += 1;
}
}
assert_eq!(storage.as_slice(), expected_once.as_slice());
}
#[test]
fn cm_error_bound_zipf() {
let (sk, truth) = run_zipf_stream(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
8192,
1.1,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est.abs_diff(*truth.get(key).unwrap()) as f64) < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
let (sk, truth) = run_zipf_stream_fast(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
8192,
1.1,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est.abs_diff(*truth.get(key).unwrap()) as f64) < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
}
#[test]
fn cm_error_bound_uniform() {
let (sk, truth) = run_uniform_stream(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
100.0,
1000.0,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est.abs_diff(*truth.get(key).unwrap()) as f64) < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
let (sk, truth) = run_uniform_stream_fast(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
100.0,
1000.0,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est.abs_diff(*truth.get(key).unwrap()) as f64) < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
}
#[test]
fn count_min_round_trip_serialization() {
let mut sketch = CountMin::<Vector2D<i32>, RegularPath>::with_dimensions(3, 8);
sketch.insert(&DataInput::U64(42));
sketch.insert(&DataInput::U64(7));
let encoded = sketch.serialize_to_bytes().expect("serialize CountMin");
assert!(!encoded.is_empty());
let data_copied = encoded.clone();
let decoded = CountMin::<Vector2D<i32>, RegularPath>::deserialize_from_bytes(&data_copied)
.expect("deserialize CountMin");
assert_eq!(sketch.rows(), decoded.rows());
assert_eq!(sketch.cols(), decoded.cols());
assert_eq!(
sketch.as_storage().as_slice(),
decoded.as_storage().as_slice()
);
}
}