use crate::{
DataInput, DefaultMatrixI32, DefaultMatrixI64, DefaultMatrixI128, DefaultXxHasher, FastPath,
FastPathHasher, FixedMatrix, MatrixFastHash, MatrixStorage, NitroTarget, QuickMatrixI64,
QuickMatrixI128, RegularPath, SketchHasher, Vector2D, hash64_seeded,
};
use rmp_serde::{
decode::Error as RmpDecodeError, encode::Error as RmpEncodeError, from_slice, to_vec_named,
};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use std::ops::Neg;
const DEFAULT_ROW_NUM: usize = 3;
const DEFAULT_COL_NUM: usize = 4096;
const LOWER_32_MASK: u64 = (1u64 << 32) - 1;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound(serialize = "S: Serialize", deserialize = "S: Deserialize<'de>"))]
pub struct Count<
S: MatrixStorage = Vector2D<i32>,
Mode = RegularPath,
H: SketchHasher = DefaultXxHasher,
> {
counts: S,
row: usize,
col: usize,
#[serde(skip)]
_mode: PhantomData<Mode>,
#[serde(skip)]
_hasher: PhantomData<H>,
}
pub trait CountSketchCounter: Copy + std::ops::AddAssign + Neg<Output = Self> + From<i32> {
fn to_f64(self) -> f64;
}
impl CountSketchCounter for i32 {
fn to_f64(self) -> f64 {
self as f64
}
}
impl CountSketchCounter for i64 {
fn to_f64(self) -> f64 {
self as f64
}
}
impl CountSketchCounter for i128 {
fn to_f64(self) -> f64 {
self as f64
}
}
impl Default for Count<Vector2D<i32>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for Count<Vector2D<i32>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for Count<Vector2D<i64>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for Count<Vector2D<i64>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for Count<Vector2D<i128>, RegularPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for Count<Vector2D<i128>, FastPath> {
fn default() -> Self {
Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM)
}
}
impl Default for Count<FixedMatrix, RegularPath> {
fn default() -> Self {
Count::from_storage(FixedMatrix::default())
}
}
impl Default for Count<FixedMatrix, FastPath> {
fn default() -> Self {
Count::from_storage(FixedMatrix::default())
}
}
impl Default for Count<DefaultMatrixI32, RegularPath> {
fn default() -> Self {
Count::from_storage(DefaultMatrixI32::default())
}
}
impl Default for Count<DefaultMatrixI32, FastPath> {
fn default() -> Self {
Count::from_storage(DefaultMatrixI32::default())
}
}
impl Default for Count<DefaultMatrixI64, RegularPath> {
fn default() -> Self {
Count::from_storage(DefaultMatrixI64::default())
}
}
impl Default for Count<DefaultMatrixI64, FastPath> {
fn default() -> Self {
Count::from_storage(DefaultMatrixI64::default())
}
}
impl Default for Count<DefaultMatrixI128, RegularPath> {
fn default() -> Self {
Count::from_storage(DefaultMatrixI128::default())
}
}
impl Default for Count<DefaultMatrixI128, FastPath> {
fn default() -> Self {
Count::from_storage(DefaultMatrixI128::default())
}
}
impl Default for Count<QuickMatrixI64, RegularPath> {
fn default() -> Self {
Count::from_storage(QuickMatrixI64::default())
}
}
impl Default for Count<QuickMatrixI64, FastPath> {
fn default() -> Self {
Count::from_storage(QuickMatrixI64::default())
}
}
impl Default for Count<QuickMatrixI128, RegularPath> {
fn default() -> Self {
Count::from_storage(QuickMatrixI128::default())
}
}
impl Default for Count<QuickMatrixI128, FastPath> {
fn default() -> Self {
Count::from_storage(QuickMatrixI128::default())
}
}
impl<T, M, H: SketchHasher> Count<Vector2D<T>, M, H>
where
T: CountSketchCounter,
{
pub fn with_dimensions(rows: usize, cols: usize) -> Self {
let mut sk = Count {
counts: Vector2D::init(rows, cols),
row: rows,
col: cols,
_mode: PhantomData,
_hasher: PhantomData,
};
sk.counts.fill(T::from(0));
sk
}
}
impl<S, C, Mode, H: SketchHasher> Count<S, Mode, H>
where
S: MatrixStorage<Counter = C>,
C: CountSketchCounter,
{
pub fn from_storage(counts: S) -> Self {
let row = counts.rows();
let col = counts.cols();
Self {
counts,
row,
col,
_mode: PhantomData,
_hasher: PhantomData,
}
}
pub fn rows(&self) -> usize {
self.counts.rows()
}
pub fn cols(&self) -> usize {
self.counts.cols()
}
pub fn merge(&mut self, other: &Self) {
let self_rows = self.counts.rows();
let self_cols = self.counts.cols();
assert_eq!(
(self_rows, self_cols),
(other.counts.rows(), other.counts.cols()),
"dimension mismatch while merging CountMin sketches"
);
for i in 0..self_rows {
for j in 0..self_cols {
self.counts.update_one_counter(
i,
j,
|a, b| *a += b,
other.counts.query_one_counter(i, j),
);
}
}
}
pub fn as_storage(&self) -> &S {
&self.counts
}
pub fn as_storage_mut(&mut self) -> &mut S {
&mut self.counts
}
}
impl<S, C, Mode, H: SketchHasher> Count<S, Mode, H>
where
S: MatrixStorage<Counter = C> + Serialize,
C: CountSketchCounter,
{
pub fn serialize_to_bytes(&self) -> Result<Vec<u8>, RmpEncodeError> {
to_vec_named(self)
}
}
impl<S, C, Mode, H: SketchHasher> Count<S, Mode, H>
where
S: MatrixStorage<Counter = C> + for<'de> Deserialize<'de>,
C: CountSketchCounter,
{
pub fn deserialize_from_bytes(bytes: &[u8]) -> Result<Self, RmpDecodeError> {
from_slice(bytes)
}
}
impl<S, C, H: SketchHasher> Count<S, RegularPath, H>
where
S: MatrixStorage<Counter = C>,
C: CountSketchCounter,
{
pub fn insert(&mut self, value: &DataInput) {
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let hashed = H::hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let bit = ((hashed >> 63) & 1) as i32;
let sign_bit = if bit == 1 { 1 } else { -1 };
let delta = if sign_bit > 0 {
C::from(1)
} else {
-C::from(1)
};
self.counts
.update_one_counter(r, col, |a, b| *a += b, delta);
}
}
pub fn insert_many(&mut self, value: &DataInput, many: C) {
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let hashed = H::hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let bit = ((hashed >> 63) & 1) as i32;
let sign_bit = if bit == 1 { 1 } else { -1 };
let delta = if sign_bit > 0 { many } else { -many };
self.counts
.update_one_counter(r, col, |a, b| *a += b, delta);
}
}
pub fn estimate(&self, value: &DataInput) -> f64 {
let rows = self.counts.rows();
let cols = self.counts.cols();
let mut estimates = Vec::with_capacity(rows);
for r in 0..rows {
let hashed = H::hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let bit = ((hashed >> 63) & 1) as i32;
let sign_bit = if bit == 1 { 1 } else { -1 };
let counter = self.counts.query_one_counter(r, col);
if sign_bit > 0 {
estimates.push(counter.to_f64());
} else {
estimates.push(-counter.to_f64());
}
}
if estimates.is_empty() {
return 0.0;
}
estimates.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mid = estimates.len() / 2;
if estimates.len() % 2 == 1 {
estimates[mid]
} else {
(estimates[mid - 1] + estimates[mid]) / 2.0
}
}
}
impl<S, H: SketchHasher> Count<S, FastPath, H>
where
S: MatrixStorage + crate::FastPathHasher<H>,
S::Counter: CountSketchCounter,
{
#[inline(always)]
pub fn insert(&mut self, value: &DataInput) {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts.fast_insert(
|counter, value, row| {
let sign = hashed_val.sign_for_row(row);
let delta = if sign > 0 { *value } else { -*value };
*counter += delta;
},
S::Counter::from(1),
&hashed_val,
);
}
#[inline(always)]
pub fn insert_many(&mut self, value: &DataInput, many: S::Counter) {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts.fast_insert(
|counter, value, row| {
let sign = hashed_val.sign_for_row(row);
let delta = if sign > 0 { *value } else { -*value };
*counter += delta;
},
many,
&hashed_val,
);
}
#[inline(always)]
pub fn estimate(&self, value: &DataInput) -> f64 {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
self.counts
.fast_query_median(&hashed_val, |val, row, hash| {
let sign = hash.sign_for_row(row);
if sign > 0 {
(*val).to_f64()
} else {
-(*val).to_f64()
}
})
}
#[inline(always)]
pub fn fast_insert_with_hash_value(&mut self, hashed_val: &H::HashType) {
self.counts.fast_insert(
|counter, value, row| {
let sign = hashed_val.sign_for_row(row);
let delta = if sign > 0 { *value } else { -*value };
*counter += delta;
},
S::Counter::from(1),
hashed_val,
);
}
#[inline(always)]
pub fn fast_insert_many_with_hash_value(&mut self, hashed_val: &H::HashType, many: S::Counter) {
self.counts.fast_insert(
|counter, value, row| {
let sign = hashed_val.sign_for_row(row);
let delta = if sign > 0 { *value } else { -*value };
*counter += delta;
},
many,
hashed_val,
);
}
#[inline(always)]
pub fn fast_estimate_with_hash(&self, hashed_val: &H::HashType) -> f64 {
self.counts.fast_query_median(hashed_val, |val, row, hash| {
let sign = hash.sign_for_row(row);
if sign > 0 {
(*val).to_f64()
} else {
-(*val).to_f64()
}
})
}
}
impl<M, H: SketchHasher> Count<Vector2D<i32>, M, H> {
pub fn debug(&self) {
for row in 0..self.counts.rows() {
println!("row {}: {:?}", row, &self.counts.row_slice(row));
}
}
}
impl<H: SketchHasher> Count<Vector2D<i32>, FastPath, H> {
pub fn enable_nitro(&mut self, sampling_rate: f64) {
self.counts.enable_nitro(sampling_rate);
}
#[inline(always)]
pub fn fast_insert_nitro(&mut self, value: &DataInput) {
let rows = self.counts.rows();
let delta = self.counts.nitro().delta;
if self.counts.nitro().to_skip >= rows {
self.counts.reduce_nitro_skip(rows);
} else {
let hashed = H::hash128_seeded(0, value);
let mut r = self.counts.nitro().to_skip;
loop {
let bit = (hashed >> (127 - r)) & 1;
let sign = (bit << 1) as i32 - 1;
self.counts
.update_by_row(r, hashed, |a, b| *a += b, sign * (delta as i32));
self.counts.nitro_mut().draw_geometric();
if r + self.counts.nitro_mut().to_skip + 1 >= rows {
break;
}
r += self.counts.nitro_mut().to_skip + 1;
}
let temp = self.counts.get_nitro_skip();
self.counts.update_nitro_skip((r + temp + 1) - rows);
}
}
}
impl<H: SketchHasher> NitroTarget for Count<Vector2D<i32>, FastPath, H> {
#[inline(always)]
fn rows(&self) -> usize {
self.counts.rows()
}
#[inline(always)]
fn update_row(&mut self, row: usize, hashed: u128, delta: u64) {
let bit = (hashed >> (127 - row)) & 1;
let sign = (bit << 1) as i32 - 1;
self.counts
.update_by_row(row, hashed, |a, b| *a += b, sign * (delta as i32));
}
}
use crate::octo_delta::{COUNT_PROMASK, CountDelta};
impl<S: MatrixStorage<Counter = i32>, H: SketchHasher> Count<S, RegularPath, H> {
#[inline(always)]
pub fn insert_emit_delta(&mut self, value: &DataInput, emit: &mut impl FnMut(CountDelta)) {
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let hashed = hash64_seeded(r, value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let sign: i32 = if ((hashed >> 63) & 1) == 1 { 1 } else { -1 };
self.counts.increment_by_row(r, col, sign);
let current = self.counts.query_one_counter(r, col);
if current.unsigned_abs() >= COUNT_PROMASK as u32 {
emit(CountDelta {
row: r as u16,
col: col as u16,
value: current as i8,
});
self.counts.update_one_counter(r, col, |c, _| *c = 0, ());
}
}
}
}
impl<S, H: SketchHasher> Count<S, FastPath, H>
where
S: MatrixStorage<Counter = i32> + FastPathHasher<H>,
{
#[inline(always)]
pub fn insert_emit_delta(&mut self, value: &DataInput, emit: &mut impl FnMut(CountDelta)) {
let hashed_val = <S as FastPathHasher<H>>::hash_for_matrix(&self.counts, value);
let rows = self.counts.rows();
let cols = self.counts.cols();
for r in 0..rows {
let col = hashed_val.col_for_row(r, cols);
let sign = hashed_val.sign_for_row(r);
self.counts.increment_by_row(r, col, sign);
let current = self.counts.query_one_counter(r, col);
if current.unsigned_abs() >= COUNT_PROMASK as u32 {
emit(CountDelta {
row: r as u16,
col: col as u16,
value: current as i8,
});
self.counts.update_one_counter(r, col, |c, _| *c = 0, ());
}
}
}
}
impl<S: MatrixStorage, Mode, H: SketchHasher> Count<S, Mode, H>
where
S::Counter: Copy + std::ops::AddAssign + From<i32>,
{
pub fn apply_delta(&mut self, delta: CountDelta) {
self.counts.increment_by_row(
delta.row as usize,
delta.col as usize,
S::Counter::from(delta.value as i32),
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{
all_counter_zero_i32, counter_index, sample_uniform_f64, sample_zipf_u64,
};
use crate::{DataInput, hash64_seeded};
use std::collections::HashMap;
#[test]
fn count_child_insert_emits_at_threshold() {
let mut child = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let key = DataInput::U64(99);
let mut deltas: Vec<CountDelta> = Vec::new();
for _ in 0..200 {
child.insert_emit_delta(&key, &mut |d| deltas.push(d));
}
assert!(
deltas.len() >= 3,
"expected at least one promoted delta per row"
);
}
fn counter_sign(row: usize, key: &DataInput) -> i32 {
let hash = hash64_seeded(row, key);
if (hash >> 63) & 1 == 1 { 1 } else { -1 }
}
fn run_zipf_stream(
rows: usize,
cols: usize,
domain: usize,
exponent: f64,
samples: usize,
seed: u64,
) -> (Count, HashMap<u64, i32>) {
let mut truth = HashMap::<u64, i32>::new();
let mut sketch = Count::<Vector2D<i32>, RegularPath>::with_dimensions(rows, cols);
for value in sample_zipf_u64(domain, exponent, samples, seed) {
let key = DataInput::U64(value);
sketch.insert(&key);
*truth.entry(value).or_insert(0) += 1;
}
(sketch, truth)
}
fn run_zipf_stream_fast(
rows: usize,
cols: usize,
domain: usize,
exponent: f64,
samples: usize,
seed: u64,
) -> (Count<Vector2D<i32>, FastPath>, HashMap<u64, u64>) {
let mut truth = HashMap::<u64, u64>::new();
let mut sketch = Count::<Vector2D<i32>, FastPath>::with_dimensions(rows, cols);
for value in sample_zipf_u64(domain, exponent, samples, seed) {
let key = DataInput::U64(value);
sketch.insert(&key);
*truth.entry(value).or_insert(0) += 1;
}
(sketch, truth)
}
fn run_uniform_stream(
rows: usize,
cols: usize,
min: f64,
max: f64,
samples: usize,
seed: u64,
) -> (Count, HashMap<u64, u64>) {
let mut truth = HashMap::<u64, u64>::new();
let mut sketch = Count::<Vector2D<i32>, RegularPath>::with_dimensions(rows, cols);
for value in sample_uniform_f64(min, max, samples, seed) {
let key = DataInput::F64(value);
sketch.insert(&key);
*truth.entry(value.to_bits()).or_insert(0) += 1;
}
(sketch, truth)
}
fn run_uniform_stream_fast(
rows: usize,
cols: usize,
min: f64,
max: f64,
samples: usize,
seed: u64,
) -> (Count<Vector2D<i32>, FastPath>, HashMap<u64, u64>) {
let mut truth = HashMap::<u64, u64>::new();
let mut sketch = Count::<Vector2D<i32>, FastPath>::with_dimensions(rows, cols);
for value in sample_uniform_f64(min, max, samples, seed) {
let key = DataInput::F64(value);
sketch.insert(&key);
*truth.entry(value.to_bits()).or_insert(0) += 1;
}
(sketch, truth)
}
#[test]
fn default_initializes_expected_dimensions() {
let cs = Count::<Vector2D<i32>, RegularPath>::default();
assert_eq!(cs.rows(), 3);
assert_eq!(cs.cols(), 4096);
all_counter_zero_i32(cs.as_storage());
}
#[test]
fn with_dimensions_uses_custom_sizes() {
let cs = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 17);
assert_eq!(cs.rows(), 3);
assert_eq!(cs.cols(), 17);
let storage = cs.as_storage();
for row in 0..cs.rows() {
assert!(
storage.row_slice(row).iter().all(|&value| value == 0),
"expected row {} to be zero-initialized, got {:?}",
row,
storage.row_slice(row)
);
}
}
#[test]
fn insert_updates_signed_counters_per_row() {
let mut sketch = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let key = DataInput::Str("alpha");
sketch.insert(&key);
for row in 0..sketch.rows() {
let idx = counter_index(row, &key, sketch.cols());
let expected = counter_sign(row, &key);
assert_eq!(
sketch.counts.query_one_counter(row, idx),
expected,
"row {row} counter mismatch"
);
}
}
#[test]
fn fast_insert_produces_consistent_estimates() {
let mut fast = Count::<Vector2D<i32>, FastPath>::with_dimensions(4, 128);
let keys = vec![
DataInput::Str("alpha"),
DataInput::Str("beta"),
DataInput::Str("gamma"),
DataInput::Str("delta"),
DataInput::Str("epsilon"),
];
for key in &keys {
fast.insert(key);
}
for key in &keys {
let estimate = fast.estimate(key);
assert!(
(estimate - 1.0).abs() < f64::EPSILON,
"fast estimate for key {key:?} should be 1.0, got {estimate}"
);
}
}
#[test]
fn insert_produces_consistent_estimates() {
let mut sketch = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let keys = vec![
DataInput::Str("alpha"),
DataInput::Str("beta"),
DataInput::Str("gamma"),
DataInput::Str("delta"),
DataInput::Str("epsilon"),
];
for key in &keys {
sketch.insert(key);
}
for key in &keys {
let estimate = sketch.estimate(key);
assert!(
(estimate - 1.0).abs() < f64::EPSILON,
"estimate for key {key:?} should be 1.0, got {estimate}"
);
}
}
#[test]
fn estimate_recovers_frequency_for_repeated_key() {
let mut sketch = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 64);
let key = DataInput::Str("theta");
let repeats = 37;
for _ in 0..repeats {
sketch.insert(&key);
}
let estimate = sketch.estimate(&key);
assert!(
(estimate - repeats as f64).abs() < f64::EPSILON,
"expected estimate {repeats}, got {estimate}"
);
}
#[test]
fn fast_path_recovers_repeated_insertions() {
let mut sketch = Count::<Vector2D<i32>, FastPath>::with_dimensions(4, 256);
let keys = vec![
DataInput::Str("alpha"),
DataInput::Str("beta"),
DataInput::Str("gamma"),
DataInput::Str("delta"),
DataInput::Str("epsilon"),
];
for _ in 0..5 {
for key in &keys {
sketch.insert(key);
}
}
for key in &keys {
let estimate = sketch.estimate(key);
assert!(
(estimate - 5.0).abs() < f64::EPSILON,
"fast estimate for key {key:?} should be 5.0, got {estimate}"
);
}
}
#[test]
fn merge_adds_counters_element_wise() {
let mut left = Count::<Vector2D<i32>, RegularPath>::with_dimensions(2, 32);
let mut right = Count::<Vector2D<i32>, RegularPath>::with_dimensions(2, 32);
let key = DataInput::Str("delta");
left.insert(&key);
right.insert(&key);
right.insert(&key);
let left_indices: Vec<_> = (0..left.rows())
.map(|row| counter_index(row, &key, left.cols()))
.collect();
left.merge(&right);
for (row, idx) in left_indices.into_iter().enumerate() {
let expected = counter_sign(row, &key) * 3;
assert_eq!(left.as_storage().query_one_counter(row, idx), expected);
}
}
#[test]
#[should_panic(expected = "dimension mismatch while merging CountMin sketches")]
fn merge_requires_matching_dimensions() {
let mut left = Count::<Vector2D<i32>, RegularPath>::with_dimensions(2, 32);
let right = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 32);
left.merge(&right);
}
#[test]
fn zipf_stream_stays_within_twenty_percent_for_most_keys() {
let (sketch, truth) = run_zipf_stream(5, 8192, 8192, 1.1, 200_000, 0x5eed_c0de);
let mut within_tolerance = 0usize;
for (&value, &count) in &truth {
let estimate = sketch.estimate(&DataInput::U64(value));
let rel_error = ((estimate - count as f64).abs()) / (count as f64);
if rel_error < 0.20 {
within_tolerance += 1;
}
}
let total = truth.len();
let accuracy = within_tolerance as f64 / total as f64;
assert!(
accuracy >= 0.70,
"Only {:.2}% of keys within tolerance ({} of {}); expected at least 70%",
accuracy * 100.0,
within_tolerance,
total
);
}
#[test]
fn cs_regular_path_correctness() {
let mut sk = Count::<Vector2D<i32>, RegularPath>::default();
for i in 0..10 {
sk.insert(&DataInput::I32(i));
}
let storage = sk.as_storage();
let rows = storage.rows();
let cols = storage.cols();
let mut expected_once = vec![0_i32; rows * cols];
for i in 0..10 {
let value = DataInput::I32(i);
for r in 0..rows {
let hashed = hash64_seeded(r, &value);
let col = ((hashed & LOWER_32_MASK) as usize) % cols;
let bit = ((hashed >> 63) & 1) as i32;
let sign_bit = -(1 - 2 * bit);
let idx = r * cols + col;
expected_once[idx] += sign_bit;
}
}
assert_eq!(storage.as_slice(), expected_once.as_slice());
for i in 0..10 {
sk.insert(&DataInput::I32(i));
}
let expected_twice: Vec<i32> = expected_once.iter().map(|v| v * 2).collect();
assert_eq!(sk.as_storage().as_slice(), expected_twice.as_slice());
for i in 0..10 {
let estimate = sk.estimate(&DataInput::I32(i));
assert!(
(estimate - 2.0).abs() < f64::EPSILON,
"estimate for {i} should be 2.0, but get {estimate}"
);
}
}
#[test]
fn cs_fast_path_correctness() {
let mut sk = Count::<Vector2D<i32>, FastPath>::default();
for i in 0..10 {
sk.insert(&DataInput::I32(i));
}
let storage = sk.as_storage();
let rows = storage.rows();
let cols = storage.cols();
let mask_bits = storage.get_mask_bits();
let mask = (1u128 << mask_bits) - 1;
let mut expected_once = vec![0_i32; rows * cols];
for i in 0..10 {
let value = DataInput::I32(i);
let hash = <Vector2D<i32> as FastPathHasher<DefaultXxHasher>>::hash_for_matrix(
storage, &value,
);
for row in 0..rows {
let hashed = hash.row_hash(row, mask_bits, mask);
let col = (hashed % cols as u128) as usize;
let idx = row * cols + col;
expected_once[idx] += hash.sign_for_row(row);
}
}
assert_eq!(storage.as_slice(), expected_once.as_slice());
}
#[test]
fn cs_error_bound_zipf() {
let (sk, truth) = run_zipf_stream(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
8192,
1.1,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est - (*truth.get(key).unwrap() as f64)).abs() < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
let (sk, truth) = run_zipf_stream_fast(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
8192,
1.1,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est - (*truth.get(key).unwrap() as f64)).abs() < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
}
#[test]
fn cs_error_bound_uniform() {
let (sk, truth) = run_uniform_stream(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
100.0,
1000.0,
200_000,
0x5eed_c0de,
);
let epsilon = (std::f64::consts::E / DEFAULT_COL_NUM as f64).sqrt();
let l2_norm = truth
.values()
.map(|&c| (c as f64).powi(2))
.sum::<f64>()
.sqrt();
let error_bound = epsilon * l2_norm;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est - (*truth.get(key).unwrap() as f64)).abs() < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
let (sk, truth) = run_uniform_stream_fast(
DEFAULT_ROW_NUM,
DEFAULT_COL_NUM,
100.0,
1000.0,
200_000,
0x5eed_c0de,
);
let epsilon = std::f64::consts::E / DEFAULT_COL_NUM as f64;
let delta = 1.0 / std::f64::consts::E.powi(DEFAULT_ROW_NUM as i32);
let error_bound = epsilon * 200_000_f64;
let keys = truth.keys();
let correct_lower_bound = keys.len() as f64 * (1.0 - delta);
let mut within_count = 0;
for key in keys {
let est = sk.estimate(&DataInput::U64(*key));
if (est - (*truth.get(key).unwrap() as f64)).abs() < error_bound {
within_count += 1;
}
}
assert!(
within_count as f64 > correct_lower_bound,
"in-bound items number {within_count} not greater than expected amount {correct_lower_bound}"
);
}
#[test]
fn count_sketch_round_trip_serialization() {
let mut sketch = Count::<Vector2D<i32>, RegularPath>::with_dimensions(3, 8);
sketch.insert(&DataInput::U64(42));
sketch.insert(&DataInput::U64(7));
let encoded = sketch.serialize_to_bytes().expect("serialize Count");
assert!(!encoded.is_empty());
let data_copied = encoded.clone();
let decoded = Count::<Vector2D<i32>, RegularPath>::deserialize_from_bytes(&data_copied)
.expect("deserialize Count");
assert_eq!(sketch.rows(), decoded.rows());
assert_eq!(sketch.cols(), decoded.cols());
assert_eq!(
sketch.as_storage().as_slice(),
decoded.as_storage().as_slice()
);
}
}