use core::hash::{BuildHasher, Hash};
use alloc::vec::Vec;
use crate::{count_min::CountMinSketch, hash::DefaultHashBuilder, Error};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct Entry<T> {
key: T,
count: u64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "CountMinSketch<T, S>: serde::Serialize, T: serde::Serialize",
deserialize = "CountMinSketch<T, S>: serde::Deserialize<'de>, T: serde::Deserialize<'de>"
))
)]
pub struct TopK<T, S = DefaultHashBuilder> {
sketch: CountMinSketch<T, S>,
entries: Vec<Entry<T>>,
k: usize,
}
impl<T> TopK<T, DefaultHashBuilder>
where
T: Hash + Eq,
{
pub fn new(k: usize, epsilon: f64, delta: f64) -> Result<Self, Error> {
Self::with_hasher(k, epsilon, delta, DefaultHashBuilder)
}
}
impl<T, S> TopK<T, S>
where
T: Hash + Eq,
S: BuildHasher,
{
pub fn with_hasher(k: usize, epsilon: f64, delta: f64, hasher: S) -> Result<Self, Error> {
if k == 0 {
return Err(Error::InvalidParameter {
param: "k",
reason: "must be greater than zero",
});
}
let sketch = CountMinSketch::with_hasher(epsilon, delta, hasher)?;
Ok(Self {
sketch,
entries: Vec::with_capacity(k),
k,
})
}
pub fn insert(&mut self, item: T) {
self.sketch.increment(&item);
let estimate = self.sketch.estimate(&item);
if let Some(entry) = self.entries.iter_mut().find(|entry| entry.key == item) {
entry.count = estimate;
return;
}
if self.entries.len() < self.k {
self.entries.push(Entry {
key: item,
count: estimate,
});
return;
}
if let Some((min_index, min_count)) = self
.entries
.iter()
.enumerate()
.map(|(index, entry)| (index, entry.count))
.min_by_key(|&(_, count)| count)
{
if estimate > min_count {
self.entries[min_index] = Entry {
key: item,
count: estimate,
};
}
}
}
#[must_use]
pub fn estimate(&self, item: &T) -> u64 {
self.sketch.estimate(item)
}
#[must_use]
pub fn top(&self) -> Vec<(&T, u64)> {
let mut ranked: Vec<(&T, u64)> = self
.entries
.iter()
.map(|entry| (&entry.key, entry.count))
.collect();
ranked.sort_unstable_by_key(|&(_, count)| core::cmp::Reverse(count));
ranked
}
#[inline]
#[must_use]
pub fn k(&self) -> usize {
self.k
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.sketch.clear();
self.entries.clear();
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_new_rejects_zero_k() {
assert!(matches!(
TopK::<&str>::new(0, 0.01, 0.01),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_new_rejects_bad_sketch_params() {
assert!(matches!(
TopK::<&str>::new(5, 0.0, 0.01),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_tracks_most_frequent() {
let mut top = TopK::new(3, 0.001, 0.001).unwrap();
for _ in 0..100 {
top.insert("apple");
}
for _ in 0..50 {
top.insert("banana");
}
for _ in 0..10 {
top.insert("cherry");
}
for _ in 0..1 {
top.insert("date");
}
let ranked = top.top();
assert_eq!(ranked.len(), 3);
assert_eq!(ranked[0].0, &"apple");
assert_eq!(ranked[1].0, &"banana");
assert_eq!(ranked[2].0, &"cherry");
}
#[test]
fn test_counts_are_estimated() {
let mut top = TopK::new(5, 0.0001, 0.0001).unwrap();
for _ in 0..42 {
top.insert("x");
}
assert!(top.estimate(&"x") >= 42);
assert_eq!(top.estimate(&"never-seen"), 0);
}
#[test]
fn test_len_capped_at_k() {
let mut top = TopK::new(2, 0.01, 0.01).unwrap();
for i in 0..100u32 {
top.insert(i);
}
assert_eq!(top.len(), 2);
assert!(top.len() <= top.k());
}
#[test]
fn test_eviction_replaces_minimum() {
let mut top = TopK::new(2, 0.0001, 0.0001).unwrap();
for _ in 0..10 {
top.insert("a");
}
for _ in 0..5 {
top.insert("b");
}
for _ in 0..8 {
top.insert("c");
}
let ranked = top.top();
let keys: Vec<&str> = ranked.iter().map(|&(key, _)| *key).collect();
assert!(keys.contains(&"a"));
assert!(keys.contains(&"c"));
assert!(!keys.contains(&"b"), "low-frequency item should be evicted");
}
#[test]
fn test_clear() {
let mut top = TopK::new(3, 0.01, 0.01).unwrap();
top.insert("x");
assert!(!top.is_empty());
top.clear();
assert!(top.is_empty());
assert_eq!(top.estimate(&"x"), 0);
}
#[test]
fn test_top_is_sorted_descending() {
let mut top = TopK::new(4, 0.0001, 0.0001).unwrap();
for _ in 0..3 {
top.insert("low");
}
for _ in 0..30 {
top.insert("high");
}
for _ in 0..15 {
top.insert("mid");
}
let ranked = top.top();
for pair in ranked.windows(2) {
assert!(pair[0].1 >= pair[1].1, "top() not sorted descending");
}
}
}