use crate::error::CacheError;
use rand::{RngExt, SeedableRng, rngs::StdRng};
use std::{
fmt::{Debug, Formatter},
sync::atomic::{AtomicU8, Ordering},
time::{SystemTime, UNIX_EPOCH},
};
const DEPTH: usize = 4;
pub(crate) struct CountMinRow(Vec<AtomicU8>);
impl CountMinRow {
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn new(width: u64) -> Self {
let mut v = Vec::with_capacity(width as usize);
for _ in 0..width as usize {
v.push(AtomicU8::new(0));
}
Self(v)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn get(&self, i: u64) -> u8 {
let byte = self.0[(i / 2) as usize].load(Ordering::Relaxed);
(byte >> ((i & 1) * 4)) & 0x0f
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn increment(&self, i: u64) {
let idx = (i / 2) as usize;
let shift = ((i & 1) * 4) as u8;
let cell = &self.0[idx];
let mut cur = cell.load(Ordering::Relaxed);
loop {
let counter = (cur >> shift) & 0x0f;
if counter == 0x0f {
return;
}
let new = cur + (1 << shift);
match cell.compare_exchange_weak(cur, new, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return,
Err(actual) => cur = actual,
}
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn reset(&self) {
for cell in &self.0 {
let mut cur = cell.load(Ordering::Relaxed);
loop {
let new = (cur >> 1) & 0x77;
match cell.compare_exchange_weak(cur, new, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => break,
Err(actual) => cur = actual,
}
}
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn clear(&self) {
for cell in &self.0 {
cell.store(0, Ordering::Relaxed);
}
}
}
impl Debug for CountMinRow {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
let mut s = String::new();
for i in 0..(self.0.len() * 2) {
let byte = self.0[i / 2].load(Ordering::Relaxed);
s.push_str(&format!("{:02} ", (byte >> ((i & 1) * 4)) & 0x0f));
}
write!(f, "{}", s)
}
}
pub(crate) struct CountMinSketch {
rows: [CountMinRow; DEPTH],
seeds: [u64; DEPTH],
mask: u64,
}
impl CountMinSketch {
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn new(ctrs: u64) -> Result<Self, CacheError> {
if ctrs < 1 {
return Err(CacheError::InvalidCountMinWidth(ctrs));
}
let ctrs = ctrs.next_power_of_two();
let hctrs = ctrs / 2;
let mut source = StdRng::seed_from_u64(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
let seeds: Vec<u64> = { (0..DEPTH).map(|_| source.random::<u64>()).collect() };
let this = Self {
rows: [
CountMinRow::new(hctrs),
CountMinRow::new(hctrs),
CountMinRow::new(hctrs),
CountMinRow::new(hctrs),
],
seeds: [seeds[0], seeds[1], seeds[2], seeds[3]],
mask: ctrs - 1,
};
Ok(this)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn increment(&self, hashed: u64) {
let mask = self.mask;
(0..DEPTH).for_each(|i| {
let seed = self.seeds[i];
self.rows[i].increment((hashed ^ seed) & mask);
})
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn estimate(&self, hashed: u64) -> i64 {
let mask = self.mask;
let mut min = 255u8;
(0..DEPTH).for_each(|i| {
let seed = self.seeds[i];
let val = self.rows[i].get((hashed ^ seed) & mask);
if val < min {
min = val;
}
});
min as i64
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn reset(&self) {
self.rows.iter().for_each(|row| row.reset())
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn clear(&self) {
self.rows.iter().for_each(|row| row.clear())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_count_min_row() {
let cmr = CountMinRow::new(8);
cmr.increment(0);
assert_eq!(cmr.0[0].load(Ordering::Relaxed), 0x01);
assert_eq!(cmr.get(0), 1);
assert_eq!(cmr.get(1), 0);
cmr.increment(1);
assert_eq!(cmr.0[0].load(Ordering::Relaxed), 0x11);
assert_eq!(cmr.get(0), 1);
assert_eq!(cmr.get(1), 1);
(0..14).for_each(|_| cmr.increment(1));
assert_eq!(cmr.0[0].load(Ordering::Relaxed), 0xf1);
assert_eq!(cmr.get(1), 15);
assert_eq!(cmr.get(0), 1);
(0..3).for_each(|_| {
cmr.increment(1);
assert_eq!(cmr.0[0].load(Ordering::Relaxed), 0xf1);
});
cmr.reset();
assert_eq!(cmr.0[0].load(Ordering::Relaxed), 0x70);
}
#[test]
fn test_count_min_sketch() {
let s = CountMinSketch::new(5).unwrap();
assert_eq!(7u64, s.mask);
}
#[test]
fn test_count_min_sketch_increment() {
let s = CountMinSketch::new(16).unwrap();
s.increment(1);
s.increment(5);
s.increment(9);
for i in 0..DEPTH {
if format!("{:?}", s.rows[i]) != format!("{:?}", s.rows[0]) {
break;
}
assert_ne!(i, DEPTH - 1);
}
}
#[test]
fn test_count_min_sketch_estimate() {
let s = CountMinSketch::new(16).unwrap();
s.increment(1);
s.increment(1);
assert_eq!(s.estimate(1), 2);
assert_eq!(s.estimate(0), 0);
}
#[test]
fn test_count_min_sketch_reset() {
let s = CountMinSketch::new(16).unwrap();
s.increment(1);
s.increment(1);
s.increment(1);
s.increment(1);
s.reset();
assert_eq!(s.estimate(1), 2);
}
#[test]
fn test_count_min_sketch_clear() {
let s = CountMinSketch::new(16).unwrap();
(0..16).for_each(|i| s.increment(i));
s.clear();
(0..16).for_each(|i| assert_eq!(s.estimate(i), 0));
}
#[test]
fn test_count_min_row_concurrent_increment_independent_keys() {
use std::{sync::Arc, thread};
let cmr = Arc::new(CountMinRow::new(16));
let mut handles = Vec::new();
for k in 0..16u64 {
let cmr = cmr.clone();
handles.push(thread::spawn(move || {
for _ in 0..1000 {
cmr.increment(k);
}
}));
}
for h in handles {
h.join().unwrap();
}
for k in 0..16u64 {
assert_eq!(cmr.get(k), 15, "key {} should be saturated", k);
}
}
#[test]
fn test_count_min_row_concurrent_increment_same_key() {
use std::{sync::Arc, thread};
let cmr = Arc::new(CountMinRow::new(8));
let mut handles = Vec::new();
for _ in 0..16 {
let cmr = cmr.clone();
handles.push(thread::spawn(move || {
for _ in 0..50 {
cmr.increment(3);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(cmr.get(3), 15);
assert_eq!(cmr.get(2), 0);
}
#[test]
fn test_count_min_row_concurrent_increment_and_reset() {
use std::{
sync::{Arc, atomic::AtomicBool},
thread,
};
let cmr = Arc::new(CountMinRow::new(16));
let stop = Arc::new(AtomicBool::new(false));
let cmr_inc = cmr.clone();
let stop_inc = stop.clone();
let inc = thread::spawn(move || {
while !stop_inc.load(Ordering::Relaxed) {
cmr_inc.increment(5);
}
});
let cmr_rst = cmr.clone();
let stop_rst = stop.clone();
let rst = thread::spawn(move || {
while !stop_rst.load(Ordering::Relaxed) {
cmr_rst.reset();
}
});
thread::sleep(std::time::Duration::from_millis(50));
stop.store(true, Ordering::Relaxed);
inc.join().unwrap();
rst.join().unwrap();
for i in 0..32 {
let v = cmr.get(i);
assert!(v <= 15, "nibble {} overflowed to {}", i, v);
if i != 5 {
if i == 4 {
assert_eq!(v, 0, "key 4 should never be touched");
}
}
}
}
}