use alloc::vec::Vec;
use core::hash::Hash;
use std::collections::HashMap;
use crate::error::{RcfError, RcfResult};
pub const DEFAULT_CAPACITY: usize = 128;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct HeavyHitterEntry {
pub estimate: u64,
pub error: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct HeavyHitter<K> {
pub rank: u32,
pub key: K,
pub estimate: u64,
pub error: u64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SpaceSaving<K>
where
K: Hash + Eq + Clone,
{
counts: HashMap<K, HeavyHitterEntry>,
capacity: usize,
total: u64,
}
impl<K> SpaceSaving<K>
where
K: Hash + Eq + Clone,
{
pub fn new(capacity: usize) -> RcfResult<Self> {
if capacity == 0 {
return Err(RcfError::InvalidConfig(
alloc::string::ToString::to_string("SpaceSaving: capacity must be > 0").into(),
));
}
Ok(Self {
counts: HashMap::with_capacity(capacity),
capacity,
total: 0,
})
}
pub fn with_default_capacity() -> RcfResult<Self> {
Self::new(DEFAULT_CAPACITY)
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn len(&self) -> usize {
self.counts.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.counts.is_empty()
}
#[must_use]
pub fn total(&self) -> u64 {
self.total
}
#[must_use]
pub fn error_bound(&self) -> u64 {
if self.capacity == 0 {
return 0;
}
self.total / (self.capacity as u64)
}
#[inline]
pub fn observe(&mut self, key: K) {
self.observe_weighted(key, 1);
}
#[inline]
pub fn observe_weighted(&mut self, key: K, weight: u64) {
if weight == 0 {
return;
}
self.total = self.total.saturating_add(weight);
if let Some(entry) = self.counts.get_mut(&key) {
entry.estimate = entry.estimate.saturating_add(weight);
return;
}
if self.counts.len() < self.capacity {
self.counts.insert(
key,
HeavyHitterEntry {
estimate: weight,
error: 0,
},
);
return;
}
if let Some((min_key, min_entry)) = self.find_min() {
self.counts.remove(&min_key);
let boosted = HeavyHitterEntry {
estimate: min_entry.estimate.saturating_add(weight),
error: min_entry.estimate,
};
self.counts.insert(key, boosted);
}
}
#[must_use]
pub fn estimate(&self, key: &K) -> Option<HeavyHitterEntry> {
self.counts.get(key).copied()
}
#[must_use]
pub fn top_k(&self, n: usize) -> Vec<HeavyHitter<K>> {
let mut entries: Vec<(K, HeavyHitterEntry)> =
self.counts.iter().map(|(k, e)| (k.clone(), *e)).collect();
entries.sort_by_key(|(_, e)| core::cmp::Reverse(e.estimate));
entries.truncate(n);
entries
.into_iter()
.enumerate()
.map(|(idx, (k, e))| HeavyHitter {
rank: u32::try_from(idx).unwrap_or(u32::MAX),
key: k,
estimate: e.estimate,
error: e.error,
})
.collect()
}
pub fn iter(&self) -> impl Iterator<Item = (&K, &HeavyHitterEntry)> {
self.counts.iter()
}
pub fn reset(&mut self) {
self.counts.clear();
self.total = 0;
}
fn find_min(&self) -> Option<(K, HeavyHitterEntry)> {
self.counts
.iter()
.min_by_key(|(_, e)| e.estimate)
.map(|(k, e)| (k.clone(), *e))
}
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
mod tests {
use super::*;
#[test]
fn new_rejects_zero_capacity() {
assert!(SpaceSaving::<u32>::new(0).is_err());
}
#[test]
fn exact_counts_within_capacity() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
for i in 0..5_u32 {
for _ in 0..=u64::from(i) {
ss.observe(i);
}
}
let top = ss.top_k(5);
assert_eq!(top.len(), 5);
for hh in &top {
assert_eq!(hh.error, 0);
}
assert_eq!(top[0].key, 4);
assert_eq!(top[0].estimate, 5);
}
#[test]
fn heavy_hitter_always_retained() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
for _ in 0..1_000 {
ss.observe(0_u32);
}
for i in 1..2_001_u32 {
ss.observe(i);
}
let h = ss
.top_k(8)
.into_iter()
.find(|hh| hh.key == 0)
.expect("heavy hitter retained");
assert!(h.estimate >= 1_000);
}
#[test]
fn error_bound_sandwiches_true_count() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(16).unwrap();
for i in 0..100_u32 {
for _ in 0..10 {
ss.observe(i);
}
}
for hh in ss.top_k(16) {
assert!(hh.estimate >= hh.error);
let lower = hh.estimate - hh.error;
assert!(lower <= 10, "lower={lower}");
assert!(hh.estimate >= 10 || hh.error > 0);
}
}
#[test]
fn estimate_returns_none_for_untracked() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(2).unwrap();
ss.observe(1);
ss.observe(2);
for _ in 0..5 {
ss.observe(3);
}
assert!(ss.estimate(&3).is_some());
assert!(ss.estimate(&100).is_none());
}
#[test]
fn weighted_observe_accumulates() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(4).unwrap();
ss.observe_weighted(7, 1_000);
ss.observe_weighted(7, 500);
let h = ss.estimate(&7).expect("tracked");
assert_eq!(h.estimate, 1_500);
assert_eq!(ss.total(), 1_500);
}
#[test]
fn zero_weight_is_noop() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(4).unwrap();
ss.observe_weighted(1, 0);
assert!(ss.is_empty());
assert_eq!(ss.total(), 0);
}
#[test]
fn error_bound_grows_linearly() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(10).unwrap();
for i in 0..1_000_u32 {
ss.observe(i);
}
assert_eq!(ss.error_bound(), 100);
}
#[test]
fn top_k_ranks_descending() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
for (key, count) in [(1_u32, 100_u64), (2, 50), (3, 25), (4, 10)] {
for _ in 0..count {
ss.observe(key);
}
}
let top = ss.top_k(4);
assert_eq!(top[0].key, 1);
assert_eq!(top[1].key, 2);
assert_eq!(top[2].key, 3);
assert_eq!(top[3].key, 4);
assert_eq!(top[0].rank, 0);
assert_eq!(top[3].rank, 3);
}
#[test]
fn top_k_clamps_to_len() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
ss.observe(1);
assert_eq!(ss.top_k(10).len(), 1);
assert_eq!(ss.top_k(0).len(), 0);
}
#[test]
fn reset_clears_everything() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(4).unwrap();
for i in 0..100_u32 {
ss.observe(i);
}
ss.reset();
assert!(ss.is_empty());
assert_eq!(ss.total(), 0);
assert_eq!(ss.top_k(4).len(), 0);
}
#[test]
fn byte_key_roundtrip() {
let mut ss: SpaceSaving<[u8; 16]> = SpaceSaving::new(4).unwrap();
let k = [0x01_u8; 16];
for _ in 0..10 {
ss.observe(k);
}
assert_eq!(ss.estimate(&k).unwrap().estimate, 10);
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn postcard_roundtrip_preserves_top_k() {
let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
for i in 0..20_u32 {
for _ in 0..=u64::from(i) {
ss.observe(i);
}
}
let bytes = postcard::to_allocvec(&ss).expect("serde ok");
let back: SpaceSaving<u32> = postcard::from_bytes(&bytes).expect("serde ok");
let a = ss.top_k(8);
let b = back.top_k(8);
for (x, y) in a.iter().zip(b.iter()) {
assert_eq!(x.key, y.key);
assert_eq!(x.estimate, y.estimate);
assert_eq!(x.error, y.error);
}
}
}