use crate::traits::{FrequencySketch, HeavyHitters, MergeError, Sketch};
use core::hash::Hash;
use std::collections::HashMap;
use std::vec::Vec;
#[derive(Clone, Debug)]
struct Counter<T> {
item: T,
count: u64,
error: u64,
}
impl<T> Counter<T> {
fn new(item: T, count: u64, error: u64) -> Self {
Self { item, count, error }
}
}
#[derive(Clone, Debug)]
pub struct SpaceSaving<T: Hash + Eq + Clone + core::fmt::Debug> {
capacity: usize,
item_to_index: HashMap<T, usize>,
counters: Vec<Counter<T>>,
total_count: u64,
num_updates: u64,
}
impl<T: Hash + Eq + Clone + core::fmt::Debug> SpaceSaving<T> {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be positive");
Self {
capacity,
item_to_index: HashMap::with_capacity(capacity),
counters: Vec::with_capacity(capacity),
total_count: 0,
num_updates: 0,
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn num_tracked(&self) -> usize {
self.counters.len()
}
pub fn total_count(&self) -> u64 {
self.total_count
}
pub fn add(&mut self, item: T) {
self.add_count(item, 1);
}
pub fn add_count(&mut self, item: T, count: u64) {
self.num_updates += 1;
self.total_count += count;
if let Some(&idx) = self.item_to_index.get(&item) {
self.counters[idx].count += count;
return;
}
if self.counters.len() < self.capacity {
let idx = self.counters.len();
self.counters.push(Counter::new(item.clone(), count, 0));
self.item_to_index.insert(item, idx);
} else {
let min_idx = self.find_min_index();
let min_count = self.counters[min_idx].count;
let old_item = self.counters[min_idx].item.clone();
self.item_to_index.remove(&old_item);
self.counters[min_idx] = Counter::new(item.clone(), min_count + count, min_count);
self.item_to_index.insert(item, min_idx);
}
}
fn find_min_index(&self) -> usize {
self.counters
.iter()
.enumerate()
.min_by_key(|(_, c)| c.count)
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn estimate(&self, item: &T) -> u64 {
self.item_to_index
.get(item)
.map(|&idx| self.counters[idx].count)
.unwrap_or(0)
}
pub fn error(&self, item: &T) -> u64 {
self.item_to_index
.get(item)
.map(|&idx| self.counters[idx].error)
.unwrap_or(0)
}
pub fn guaranteed_count(&self, item: &T) -> u64 {
self.item_to_index
.get(item)
.map(|&idx| {
let c = &self.counters[idx];
c.count.saturating_sub(c.error)
})
.unwrap_or(0)
}
pub fn contains(&self, item: &T) -> bool {
self.item_to_index.contains_key(item)
}
}
impl<T: Hash + Eq + Clone + core::fmt::Debug> Sketch for SpaceSaving<T> {
type Item = T;
fn update(&mut self, item: &T) {
self.add(item.clone());
}
fn merge(&mut self, _other: &Self) -> Result<(), MergeError> {
Err(MergeError::IncompatibleConfig {
expected: "Space-Saving does not support merge".into(),
found: "merge attempted".into(),
})
}
fn clear(&mut self) {
self.item_to_index.clear();
self.counters.clear();
self.total_count = 0;
self.num_updates = 0;
}
fn size_bytes(&self) -> usize {
core::mem::size_of::<Self>() + self.counters.capacity() * core::mem::size_of::<Counter<T>>()
}
fn count(&self) -> u64 {
self.num_updates
}
}
impl<T: Hash + Eq + Clone + core::fmt::Debug> FrequencySketch for SpaceSaving<T> {
fn estimate_frequency(&self, item: &T) -> u64 {
self.estimate(item)
}
}
impl<T: Hash + Eq + Clone + core::fmt::Debug> HeavyHitters for SpaceSaving<T> {
fn heavy_hitters(&self, threshold: f64) -> Vec<(T, u64)> {
let min_count = (threshold * self.total_count as f64) as u64;
self.counters
.iter()
.filter(|c| c.count >= min_count)
.map(|c| (c.item.clone(), c.count))
.collect()
}
fn top_k(&self, k: usize) -> Vec<(T, u64)> {
let mut items: Vec<_> = self
.counters
.iter()
.map(|c| (c.item.clone(), c.count))
.collect();
items.sort_by(|a, b| b.1.cmp(&a.1));
items.truncate(k);
items
}
}
#[cfg(feature = "serde")]
impl<T: Hash + Eq + Clone + core::fmt::Debug + serde::Serialize> serde::Serialize
for SpaceSaving<T>
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let items: Vec<_> = self
.counters
.iter()
.map(|c| (&c.item, c.count, c.error))
.collect();
let mut state = serializer.serialize_struct("SpaceSaving", 3)?;
state.serialize_field("capacity", &self.capacity)?;
state.serialize_field("total_count", &self.total_count)?;
state.serialize_field("items", &items)?;
state.end()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
let mut ss = SpaceSaving::<String>::new(10);
ss.add("apple".to_string());
ss.add("apple".to_string());
ss.add("banana".to_string());
assert!(ss.estimate(&"apple".to_string()) >= 2);
assert!(ss.estimate(&"banana".to_string()) >= 1);
}
#[test]
fn test_empty() {
let ss = SpaceSaving::<String>::new(10);
assert_eq!(ss.estimate(&"anything".to_string()), 0);
assert_eq!(ss.total_count(), 0);
}
#[test]
fn test_top_k() {
let mut ss = SpaceSaving::<&str>::new(10);
for _ in 0..100 {
ss.add("apple");
}
for _ in 0..50 {
ss.add("banana");
}
for _ in 0..25 {
ss.add("cherry");
}
let top = ss.top_k(2);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, "apple");
assert_eq!(top[1].0, "banana");
}
#[test]
fn test_heavy_hitters() {
let mut ss = SpaceSaving::<&str>::new(10);
for _ in 0..100 {
ss.add("apple");
}
for _ in 0..10 {
ss.add("banana");
}
for _ in 0..1 {
ss.add("cherry");
}
let heavy = ss.heavy_hitters(0.05);
assert!(heavy.iter().any(|(item, _)| *item == "apple"));
assert!(heavy.iter().any(|(item, _)| *item == "banana"));
}
#[test]
fn test_replacement() {
let mut ss = SpaceSaving::<i32>::new(3);
ss.add(1);
ss.add(2);
ss.add(3);
assert_eq!(ss.num_tracked(), 3);
ss.add(4);
assert_eq!(ss.num_tracked(), 3);
}
#[test]
fn test_contains() {
let mut ss = SpaceSaving::<&str>::new(10);
ss.add("apple");
assert!(ss.contains(&"apple"));
assert!(!ss.contains(&"banana"));
}
#[test]
fn test_merge_not_supported() {
let mut ss1 = SpaceSaving::<&str>::new(10);
let ss2 = SpaceSaving::<&str>::new(10);
assert!(ss1.merge(&ss2).is_err());
}
#[test]
fn test_guaranteed_count() {
let mut ss = SpaceSaving::<&str>::new(3);
ss.add("a");
ss.add("b");
ss.add("c");
ss.add("d");
for item in ["a", "b", "c", "d"] {
let est = ss.estimate(&item);
let guar = ss.guaranteed_count(&item);
assert!(
guar <= est,
"guaranteed {} > estimate {} for {}",
guar,
est,
item
);
}
}
#[test]
fn test_zipf_distribution() {
let mut ss = SpaceSaving::<i32>::new(10);
for rank in 1..=100 {
let count = 1000 / rank;
for _ in 0..count {
ss.add(rank);
}
}
let top = ss.top_k(5);
assert!(!top.is_empty());
assert!(ss.contains(&1));
}
#[test]
fn test_clear() {
let mut ss = SpaceSaving::<&str>::new(10);
ss.add("apple");
ss.add("banana");
ss.clear();
assert_eq!(ss.num_tracked(), 0);
assert_eq!(ss.total_count(), 0);
assert!(!ss.contains(&"apple"));
}
}