use serde::{Deserialize, Serialize};
use crate::message_pack_format::{Error as MsgPackError, MessagePackCodec};
use crate::sketches::countminsketch::CountMin;
use crate::{DataInput, FastPath, Vector2D};
pub type SketchlibCms = CountMin<Vector2D<f64>, FastPath>;
pub fn new_sketchlib_cms(row_num: usize, col_num: usize) -> SketchlibCms {
SketchlibCms::with_dimensions(row_num, col_num)
}
pub fn sketchlib_cms_from_matrix(
row_num: usize,
col_num: usize,
sketch: &[Vec<f64>],
) -> SketchlibCms {
let matrix = Vector2D::from_fn(row_num, col_num, |r, c| {
sketch
.get(r)
.and_then(|row| row.get(c))
.copied()
.unwrap_or(0.0)
});
SketchlibCms::from_storage(matrix)
}
pub fn matrix_from_sketchlib_cms(inner: &SketchlibCms) -> Vec<Vec<f64>> {
let storage: &Vector2D<f64> = inner.as_storage();
let rows = storage.rows();
let cols = storage.cols();
let mut sketch = vec![vec![0.0; cols]; rows];
for (r, row) in sketch.iter_mut().enumerate().take(rows) {
for (c, cell) in row.iter_mut().enumerate().take(cols) {
if let Some(v) = storage.get(r, c) {
*cell = *v;
}
}
}
sketch
}
pub fn sketchlib_cms_update(inner: &mut SketchlibCms, key: &str, value: f64) {
if value <= 0.0 {
return;
}
inner.insert_many(&DataInput::String(key.to_owned()), value);
}
pub fn sketchlib_cms_query(inner: &SketchlibCms, key: &str) -> f64 {
inner.estimate(&DataInput::String(key.to_owned()))
}
#[derive(Debug, Clone, Default)]
pub struct CountMinSketchDelta {
pub rows: u32,
pub cols: u32,
pub cells: Vec<(u32, u32, i64)>,
pub l1: Vec<f64>,
pub l2: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct CountMinSketch {
pub rows: usize,
pub cols: usize,
pub(crate) backend: SketchlibCms,
}
impl CountMinSketch {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
backend: new_sketchlib_cms(rows, cols),
}
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn sketch(&self) -> Vec<Vec<f64>> {
matrix_from_sketchlib_cms(&self.backend)
}
pub fn from_legacy_matrix(sketch: Vec<Vec<f64>>, rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
backend: sketchlib_cms_from_matrix(rows, cols, &sketch),
}
}
pub fn update(&mut self, key: &str, value: f64) {
if value <= 0.0 || self.rows == 0 || self.cols == 0 {
return;
}
self.backend
.insert_many(&DataInput::String(key.to_owned()), value);
}
pub fn estimate(&self, key: &str) -> f64 {
if self.rows == 0 || self.cols == 0 {
return 0.0;
}
self.backend.estimate(&DataInput::String(key.to_owned()))
}
pub fn merge(
&mut self,
other: &CountMinSketch,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if self.rows != other.rows || self.cols != other.cols {
return Err(format!(
"CountMinSketch dimension mismatch: self={}x{}, other={}x{}",
self.rows, self.cols, other.rows, other.cols
)
.into());
}
self.backend.merge(&other.backend);
Ok(())
}
pub fn merge_refs(
accumulators: &[&Self],
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
if accumulators.is_empty() {
return Err("No accumulators to merge".into());
}
let rows = accumulators[0].rows;
let cols = accumulators[0].cols;
for acc in accumulators {
if acc.rows != rows || acc.cols != cols {
return Err(
"Cannot merge CountMinSketch accumulators with different dimensions".into(),
);
}
}
let mut merged = CountMinSketch::new(rows, cols);
for acc in accumulators {
merged.backend.merge(&acc.backend);
}
Ok(merged)
}
pub fn apply_delta(
&mut self,
delta: &CountMinSketchDelta,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
for (row, col, _) in &delta.cells {
let r = *row as usize;
let c = *col as usize;
if r >= self.rows || c >= self.cols {
return Err(format!(
"CountMinSketchDelta cell ({r},{c}) out of range (matrix={}x{})",
self.rows, self.cols
)
.into());
}
}
let mut matrix = self.sketch();
for (row, col, d_count) in &delta.cells {
matrix[*row as usize][*col as usize] += *d_count as f64;
}
self.backend = sketchlib_cms_from_matrix(self.rows, self.cols, &matrix);
Ok(())
}
pub fn aggregate_count(
depth: usize,
width: usize,
keys: &[&str],
values: &[f64],
) -> Option<Vec<u8>> {
if keys.is_empty() {
return None;
}
let mut sketch = Self::new(depth, width);
for (key, &value) in keys.iter().zip(values.iter()) {
sketch.update(key, value);
}
sketch.to_msgpack().ok()
}
pub fn aggregate_sum(
depth: usize,
width: usize,
keys: &[&str],
values: &[f64],
) -> Option<Vec<u8>> {
Self::aggregate_count(depth, width, keys, values)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CountMinSketchWire {
pub sketch: Vec<Vec<f64>>,
#[serde(rename = "row_num")]
pub rows: usize,
#[serde(rename = "col_num")]
pub cols: usize,
}
impl MessagePackCodec for CountMinSketch {
fn to_msgpack(&self) -> Result<Vec<u8>, MsgPackError> {
let wire = CountMinSketchWire {
sketch: self.sketch(),
rows: self.rows,
cols: self.cols,
};
Ok(rmp_serde::to_vec(&wire)?)
}
fn from_msgpack(bytes: &[u8]) -> Result<Self, MsgPackError> {
let wire: CountMinSketchWire = rmp_serde::from_slice(bytes)?;
let backend = sketchlib_cms_from_matrix(wire.rows, wire.cols, &wire.sketch);
Ok(Self {
rows: wire.rows,
cols: wire.cols,
backend,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_min_sketch_creation() {
let cms = CountMinSketch::new(4, 1000);
assert_eq!(cms.rows, 4);
assert_eq!(cms.cols, 1000);
let sketch = cms.sketch();
assert_eq!(sketch.len(), 4);
assert_eq!(sketch[0].len(), 1000);
for row in &sketch {
for &value in row {
assert_eq!(value, 0.0);
}
}
}
#[test]
fn test_count_min_sketch_update() {
let mut cms = CountMinSketch::new(2, 10);
cms.update("key1", 1.0);
let result = cms.estimate("key1");
assert!(result >= 1.0);
}
#[test]
fn test_count_min_sketch_query_empty() {
let cms = CountMinSketch::new(2, 10);
assert_eq!(cms.estimate("anything"), 0.0);
}
#[test]
fn test_count_min_sketch_merge() {
let mut sketch1 = vec![vec![0.0; 3]; 2];
sketch1[0][0] = 5.0;
sketch1[1][2] = 10.0;
let mut cms1 = CountMinSketch::from_legacy_matrix(sketch1, 2, 3);
let mut sketch2 = vec![vec![0.0; 3]; 2];
sketch2[0][0] = 3.0;
sketch2[0][1] = 7.0;
let cms2 = CountMinSketch::from_legacy_matrix(sketch2, 2, 3);
cms1.merge(&cms2).unwrap();
let merged_sketch = cms1.sketch();
assert_eq!(merged_sketch[0][0], 8.0); assert_eq!(merged_sketch[0][1], 7.0); assert_eq!(merged_sketch[1][2], 10.0); }
#[test]
fn test_count_min_sketch_merge_dimension_mismatch() {
let mut cms1 = CountMinSketch::new(2, 3);
let cms2 = CountMinSketch::new(3, 3);
assert!(cms1.merge(&cms2).is_err());
}
#[test]
fn test_count_min_sketch_msgpack_round_trip() {
let mut cms = CountMinSketch::new(4, 256);
cms.update("apple", 5.0);
cms.update("banana", 3.0);
cms.update("apple", 2.0);
let bytes = cms.to_msgpack().unwrap();
let deserialized = CountMinSketch::from_msgpack(&bytes).unwrap();
assert_eq!(deserialized.rows, 4);
assert_eq!(deserialized.cols, 256);
assert!(deserialized.estimate("apple") >= 7.0);
assert!(deserialized.estimate("banana") >= 3.0);
}
#[test]
fn test_aggregate_count() {
let keys = ["a", "b", "a"];
let values = [1.0, 2.0, 3.0];
let bytes = CountMinSketch::aggregate_count(4, 100, &keys, &values).unwrap();
let cms = CountMinSketch::from_msgpack(&bytes).unwrap();
assert!(cms.estimate("a") >= 4.0);
assert!(cms.estimate("b") >= 2.0);
}
#[test]
fn test_aggregate_count_empty() {
assert!(CountMinSketch::aggregate_count(4, 100, &[], &[]).is_none());
}
#[test]
fn test_apply_delta_additive() {
let mut cms = CountMinSketch::from_legacy_matrix(
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
2,
3,
);
let delta = CountMinSketchDelta {
rows: 2,
cols: 3,
cells: vec![(0, 0, 10), (1, 2, 100)],
l1: vec![],
l2: vec![],
};
cms.apply_delta(&delta).unwrap();
assert_eq!(
cms.sketch(),
vec![vec![11.0, 2.0, 3.0], vec![4.0, 5.0, 106.0]]
);
}
#[test]
fn test_apply_delta_matches_full_merge() {
let base = CountMinSketch::from_legacy_matrix(vec![vec![1.0, 2.0], vec![3.0, 4.0]], 2, 2);
let addition =
CountMinSketch::from_legacy_matrix(vec![vec![10.0, 0.0], vec![0.0, 20.0]], 2, 2);
let mut via_merge = base.clone();
via_merge.merge(&addition).unwrap();
let delta = CountMinSketchDelta {
rows: 2,
cols: 2,
cells: vec![(0, 0, 10), (1, 1, 20)],
l1: vec![],
l2: vec![],
};
let mut via_delta = base;
via_delta.apply_delta(&delta).unwrap();
assert_eq!(via_delta.sketch(), via_merge.sketch());
}
#[test]
fn test_apply_delta_out_of_range() {
let mut cms = CountMinSketch::new(2, 3);
let delta = CountMinSketchDelta {
rows: 2,
cols: 3,
cells: vec![(5, 0, 1)],
l1: vec![],
l2: vec![],
};
assert!(cms.apply_delta(&delta).is_err());
}
#[test]
fn test_update_then_envelope_matches_sketchlib_go_bytes() {
use crate::proto::sketchlib::{
CountMinState, CounterType, SketchEnvelope, sketch_envelope::SketchState,
};
use prost::Message;
let rows = 4usize;
let cols = 2048usize;
let mut sk = CountMinSketch::new(rows, cols);
for i in 0..50u64 {
let key = format!("flow-{}", i % 10);
sk.update(&key, 1.0);
}
let matrix = sk.sketch();
let mut counts_int: Vec<i64> = Vec::with_capacity(rows * cols);
let mut l1: Vec<f64> = Vec::with_capacity(rows);
let mut l2: Vec<f64> = Vec::with_capacity(rows);
for row in matrix.iter().take(rows) {
let mut row_l1 = 0.0f64;
let mut row_l2 = 0.0f64;
for &cell in row.iter().take(cols) {
counts_int.push(cell as i64);
row_l1 += cell;
row_l2 += cell * cell;
}
l1.push(row_l1);
l2.push(row_l2);
}
let state = CountMinState {
rows: rows as u32,
cols: cols as u32,
counter_type: CounterType::Int64 as i32,
counts_int,
counts_float: Vec::new(),
sum_counts: Vec::new(),
sum2_counts: Vec::new(),
l1,
l2,
};
let envelope = SketchEnvelope {
format_version: 1,
producer: None,
hash_spec: None,
sketch_state: Some(SketchState::CountMin(state)),
};
let mut got = Vec::with_capacity(envelope.encoded_len());
envelope.encode(&mut got).expect("prost encode");
const GOLDEN_HEX: &str = include_str!("../../sketches/testdata/cms_envelope_golden.hex");
let want = decode_hex_cms(GOLDEN_HEX);
assert_eq!(
got.len(),
want.len(),
"CMS envelope length differs: got {} bytes, want {} bytes",
got.len(),
want.len(),
);
assert_eq!(
got, want,
"CMS envelope bytes diverge from sketchlib-go golden"
);
}
fn decode_hex_cms(s: &str) -> Vec<u8> {
let s = s.trim();
s.as_bytes()
.chunks(2)
.map(|pair| {
let high = hex_nibble_cms(pair[0]);
let low = hex_nibble_cms(pair[1]);
(high << 4) | low
})
.collect()
}
fn hex_nibble_cms(c: u8) -> u8 {
match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => panic!("non-hex byte {}", c as char),
}
}
}