use crate::hash_composition::HashComposer;
use crate::priority_queue::TopKQueue;
use ahash::RandomState;
use rand::rngs::SmallRng;
use rand::{RngCore, SeedableRng};
use std::borrow::Borrow;
use std::clone::Clone;
use std::fmt::Debug;
use std::hash::Hash;
use thiserror::Error;
const DECAY_LOOKUP_SIZE: usize = 1024;
#[derive(Default, Clone, Debug)]
struct Bucket {
fingerprint: u64,
count: u64,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TopKNode<T> {
pub item: T,
pub count: u64,
}
impl<T: Ord> Ord for TopKNode<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.count.cmp(&self.count) }
}
impl<T: Ord> PartialOrd for TopKNode<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[allow(clippy::enum_variant_names)]
#[derive(Error, Debug)]
pub enum HeavyKeeperError {
#[error("Incompatible width: self ({self_width}) != other ({other_width})")]
IncompatibleWidth {
self_width: usize,
other_width: usize,
},
#[error("Incompatible depth: self ({self_depth}) != other ({other_depth})")]
IncompatibleDepth {
self_depth: usize,
other_depth: usize,
},
#[error("Incompatible decay: self ({self_decay}) != other ({other_decay})")]
IncompatibleDecay { self_decay: f64, other_decay: f64 },
#[error("Incompatible top_items: self ({self_items}) != other ({other_items})")]
IncompatibleTopItems {
self_items: usize,
other_items: usize,
},
}
#[derive(Error, Debug)]
pub enum BuilderError {
#[error("Missing required field: {field}")]
MissingField { field: String },
}
pub struct TopK<T: Ord + Clone + Hash> {
top_items: usize,
width: usize,
width_mask: usize,
depth: usize,
decay: f64,
decay_thresholds: Vec<u64>,
buckets: Vec<Vec<Bucket>>,
priority_queue: TopKQueue<T>,
hasher: RandomState,
random: Box<dyn RngCore + Send + Sync>,
}
pub struct Builder<T> {
k: Option<usize>,
width: Option<usize>,
depth: Option<usize>,
decay: Option<f64>,
seed: Option<u64>,
hasher: Option<RandomState>,
rng: Option<Box<dyn RngCore + Send + Sync>>,
_phantom: std::marker::PhantomData<T>,
}
fn precompute_decay_thresholds(decay: f64, num_entries: usize) -> Vec<u64> {
let mut thresholds = Vec::with_capacity(num_entries);
for count in 0..num_entries {
let decay_factor = decay.powf(count as f64);
let threshold = (decay_factor * u64::MAX as f64) as u64;
thresholds.push(threshold);
}
thresholds
}
impl<T: Ord + Clone + Hash> TopK<T> {
pub fn builder() -> Builder<T> {
Builder::new()
}
pub fn new(k: usize, width: usize, depth: usize, decay: f64) -> Self {
let seed = 12345; Self::with_seed(k, width, depth, decay, seed)
}
pub fn with_seed(k: usize, width: usize, depth: usize, decay: f64, seed: u64) -> Self {
let hasher = RandomState::with_seeds(seed, seed, seed, seed);
Self::with_hasher(k, width, depth, decay, hasher)
}
pub fn with_hasher(
k: usize,
width: usize,
depth: usize,
decay: f64,
hasher: RandomState,
) -> Self {
Self::with_components(
k,
width,
depth,
decay,
hasher,
Box::new(SmallRng::seed_from_u64(0)),
)
}
fn with_components(
k: usize,
width: usize,
depth: usize,
decay: f64,
hasher: RandomState,
rng: Box<dyn RngCore + Send + Sync>,
) -> Self {
let mut buckets = Vec::with_capacity(depth);
for _ in 0..depth {
buckets.push(vec![
Bucket {
fingerprint: 0,
count: 0
};
width
]);
}
let width_mask = if width > 1 && width.is_power_of_two() {
width - 1
} else {
0
};
Self {
top_items: k,
width,
width_mask,
depth,
decay,
decay_thresholds: precompute_decay_thresholds(decay, DECAY_LOOKUP_SIZE),
buckets,
priority_queue: TopKQueue::with_capacity_and_hasher(k, hasher.clone()),
hasher,
random: rng,
}
}
pub fn query<Q>(&self, item: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T> + ?Sized,
{
if self.priority_queue.get(item).is_some() {
return true;
}
let mut composer = HashComposer::new(&self.hasher, item);
let mut min_count = u64::MAX;
for i in 0..self.depth {
let bucket_idx = composer.next_bucket(self.width as u64, self.width_mask, i);
let bucket = &self.buckets[i][bucket_idx];
if bucket.fingerprint == composer.fingerprint() {
min_count = min_count.min(bucket.count);
}
}
min_count != u64::MAX
}
pub fn count<Q>(&self, item: &Q) -> u64
where
T: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T> + ?Sized,
{
if let Some(count) = self.priority_queue.get(item) {
return count;
}
let mut composer = HashComposer::new(&self.hasher, item);
let mut min_count = u64::MAX;
for i in 0..self.depth {
let bucket_idx = composer.next_bucket(self.width as u64, self.width_mask, i);
let bucket = &self.buckets[i][bucket_idx];
if bucket.fingerprint == composer.fingerprint() {
min_count = min_count.min(bucket.count);
}
}
if min_count == u64::MAX {
0
} else {
min_count
}
}
#[cfg(test)]
pub fn bucket_count<Q>(&self, item: &Q) -> u64
where
T: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T> + ?Sized,
{
let mut composer = HashComposer::new(&self.hasher, item);
let mut min_count = u64::MAX;
for i in 0..self.depth {
let bucket_idx = composer.next_bucket(self.width as u64, self.width_mask, i);
let bucket = &self.buckets[i][bucket_idx];
if bucket.fingerprint == composer.fingerprint() {
min_count = min_count.min(bucket.count);
}
}
if min_count == u64::MAX {
0
} else {
min_count
}
}
pub fn add<Q>(&mut self, item: &Q, increment: u64)
where
T: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T> + ?Sized,
{
if increment == 0 {
return;
}
let mut composer = HashComposer::new(&self.hasher, item);
let mut max_count: u64 = 0;
for i in 0..self.depth {
let bucket_idx = composer.next_bucket(self.width as u64, self.width_mask, i);
let (matches, empty) = {
let bucket = &self.buckets[i][bucket_idx];
(
bucket.fingerprint == composer.fingerprint(),
bucket.count == 0u64,
)
};
if matches || empty {
let bucket = &mut self.buckets[i][bucket_idx];
bucket.fingerprint = composer.fingerprint();
bucket.count += increment;
max_count = std::cmp::max(max_count, bucket.count);
} else {
let mut remaining_incr = increment;
while remaining_incr > 0 {
let current_count = self.buckets[i][bucket_idx].count;
let decay_threshold = self.decay_threshold(current_count);
let rand = self.random.next_u64();
let bucket = &mut self.buckets[i][bucket_idx];
if rand < decay_threshold {
bucket.count = bucket.count.saturating_sub(1);
if bucket.count == 0 {
bucket.fingerprint = composer.fingerprint();
bucket.count = remaining_incr;
max_count = std::cmp::max(max_count, bucket.count);
break;
}
}
remaining_incr -= 1;
}
}
}
if let Some(current) = self.priority_queue.get(item) {
if max_count > current {
self.priority_queue.update_if_present(item, max_count);
}
return;
}
if self.priority_queue.is_full() && max_count <= self.priority_queue.min_count() {
return;
}
self.priority_queue.upsert(item.to_owned(), max_count);
}
fn decay_threshold(&self, count: u64) -> u64 {
if count < self.decay_thresholds.len() as u64 {
return self.decay_thresholds[count as usize];
}
let tbl = &self.decay_thresholds;
let last = tbl[tbl.len() - 1] as f64 / u64::MAX as f64;
let divisor = (tbl.len() - 1) as u64;
let q = (count / divisor) as f64;
let r = (count % divisor) as usize;
let rem_thr = tbl[r] as f64 / u64::MAX as f64;
((last.powf(q) * rem_thr) * u64::MAX as f64) as u64
}
pub fn list(&self) -> Vec<TopKNode<T>> {
let mut nodes = self
.priority_queue
.iter()
.map(|(item, count)| TopKNode {
item: item.clone(),
count,
})
.collect::<Vec<_>>();
nodes.sort();
nodes
}
pub fn merge(&mut self, other: &Self) -> Result<(), HeavyKeeperError> {
if self.width != other.width {
return Err(HeavyKeeperError::IncompatibleWidth {
self_width: self.width,
other_width: other.width,
});
}
if self.depth != other.depth {
return Err(HeavyKeeperError::IncompatibleDepth {
self_depth: self.depth,
other_depth: other.depth,
});
}
if self.decay != other.decay {
return Err(HeavyKeeperError::IncompatibleDecay {
self_decay: self.decay,
other_decay: other.decay,
});
}
if self.top_items != other.top_items {
return Err(HeavyKeeperError::IncompatibleTopItems {
self_items: self.top_items,
other_items: other.top_items,
});
}
for (self_row, other_row) in self.buckets.iter_mut().zip(other.buckets.iter()) {
for (self_bucket, other_bucket) in self_row.iter_mut().zip(other_row.iter()) {
if self_bucket.fingerprint == other_bucket.fingerprint {
self_bucket.count += other_bucket.count;
} else if self_bucket.count == 0 {
*self_bucket = other_bucket.clone();
}
}
}
for (item, count) in other.priority_queue.iter() {
let self_count = self.priority_queue.get(item).unwrap_or(0);
self.priority_queue.upsert(item.clone(), self_count + count);
}
Ok(())
}
}
impl<T: Ord + Clone + Hash + Debug> TopK<T> {
pub fn debug(&self) {
println!("width: {}", self.width);
println!("depth: {}", self.depth);
println!("decay: {}", self.decay);
println!("decay thresholds: {:?}", self.decay_thresholds);
let mut buckets: Vec<(&Bucket, usize, usize)> = self
.buckets
.iter()
.enumerate()
.flat_map(|(i, row)| {
row.iter()
.enumerate()
.map(move |(j, bucket)| (bucket, i, j))
})
.filter(|(bucket, _, _)| bucket.count != 0)
.collect();
buckets.sort_by(|a, b| b.0.count.cmp(&a.0.count));
for (bucket, i, j) in buckets {
println!("Bucket at row {}, column {}: {:?}", i, j, bucket);
}
println!("priority_queue: ");
let mut nodes = self
.priority_queue
.iter()
.map(|(item, count)| TopKNode {
item: item.clone(),
count,
})
.collect::<Vec<_>>();
nodes.sort();
for node in nodes {
println!("Node - Item: {:?}, Count: {}", node.item, node.count);
}
}
#[cfg(test)]
pub(crate) fn decay_threshold_for_test(&self, count: u64) -> u64 {
self.decay_threshold(count)
}
}
impl<T: Ord + Clone + Hash> Default for Builder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Ord + Clone + Hash> Builder<T> {
pub fn new() -> Self {
Self {
k: None,
width: None,
depth: None,
decay: None,
seed: None,
hasher: None,
rng: None,
_phantom: std::marker::PhantomData,
}
}
pub fn k(mut self, k: usize) -> Self {
self.k = Some(k);
self
}
pub fn width(mut self, width: usize) -> Self {
self.width = Some(width);
self
}
pub fn depth(mut self, depth: usize) -> Self {
self.depth = Some(depth);
self
}
pub fn decay(mut self, decay: f64) -> Self {
self.decay = Some(decay);
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn hasher(mut self, hasher: RandomState) -> Self {
self.hasher = Some(hasher);
self
}
pub fn rng<R: RngCore + Send + Sync + 'static>(mut self, rng: R) -> Self {
self.rng = Some(Box::new(rng));
self
}
pub fn build(self) -> Result<TopK<T>, BuilderError> {
let k = self.k.ok_or_else(|| BuilderError::MissingField {
field: "k".to_string(),
})?;
let width = self.width.ok_or_else(|| BuilderError::MissingField {
field: "width".to_string(),
})?;
let depth = self.depth.ok_or_else(|| BuilderError::MissingField {
field: "depth".to_string(),
})?;
let decay = self.decay.ok_or_else(|| BuilderError::MissingField {
field: "decay".to_string(),
})?;
let hasher = self.hasher.unwrap_or_else(|| {
if let Some(seed) = self.seed {
RandomState::with_seeds(seed, seed, seed, seed)
} else {
RandomState::new()
}
});
let rng = self.rng.unwrap_or_else(|| {
if let Some(seed) = self.seed {
Box::new(SmallRng::seed_from_u64(seed))
} else {
Box::new(SmallRng::seed_from_u64(0))
}
});
Ok(TopK::with_components(k, width, depth, decay, hasher, rng))
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockall::automock;
#[automock]
trait RngCoreTrait {
fn next_u64(&mut self) -> u64;
}
impl RngCore for MockRngCoreTrait {
fn next_u32(&mut self) -> u32 {
RngCoreTrait::next_u64(self) as u32
}
fn next_u64(&mut self) -> u64 {
RngCoreTrait::next_u64(self)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
for chunk in dest.chunks_mut(8) {
let value = RngCoreTrait::next_u64(self);
for (i, byte) in chunk.iter_mut().enumerate() {
*byte = (value >> (i * 8)) as u8;
}
}
}
}
#[test]
fn test_new() {
let k = 10;
let width = 100;
let depth = 5;
let decay = 0.9;
let topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
assert_eq!(topk.width, 100);
assert_eq!(topk.depth, 5);
assert_eq!(topk.decay, 0.9);
assert_eq!(topk.buckets.len(), 5);
assert_eq!(topk.buckets[0].len(), 100);
assert_eq!(topk.priority_queue.len(), 0);
}
#[test]
fn test_query() {
let mut topk: TopK<Vec<u8>> = TopK::new(10, 100, 5, 0.9);
let present = b"hello".to_vec();
let absent = b"world".to_vec();
topk.add(&present, 1);
assert!(topk.query(&present), "Present item should be found");
assert!(!topk.query(&absent), "Absent item should not be found");
}
#[test]
fn test_count() {
let k = 10;
let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let item1 = b"lashin".to_vec();
let item2 = b"ballynamoney".to_vec();
let item3 = "पुष्पं अस्ति।".as_bytes().to_vec();
topk.add(&item1, 8);
assert_eq!(
topk.count(&item1),
8,
"Count should match number of additions"
);
assert_eq!(
topk.count(&item3),
0,
"Non-existent item should have count 0"
);
topk.add(&item2, 1337);
assert_eq!(
topk.count(&item2),
1337,
"Count should match number of additions"
);
}
#[test]
fn test_non_ascii_and_emoji() {
let mut topk: TopK<Vec<u8>> = TopK::new(5, 100, 4, 0.9);
let p = "पुष्पं अस्ति।".as_bytes().to_vec();
let emoji = "🚀🌟".as_bytes().to_vec();
let mixed = "Hello पुष्पं 🚀".as_bytes().to_vec();
topk.add(&p, 1);
topk.add(&emoji, 1);
topk.add(&mixed, 1);
assert!(topk.query(&p), "text should be found");
assert!(topk.query(&emoji), "Emoji should be found");
assert!(topk.query(&mixed), "Mixed content should be found");
assert_eq!(topk.count(&p), 1, "text count should be 1");
assert_eq!(topk.count(&emoji), 1, "Emoji count should be 1");
assert_eq!(topk.count(&mixed), 1, "Mixed content count should be 1");
topk.add(&p, 4);
assert_eq!(topk.count(&p), 5, "text count should be 5");
let items = topk.list();
for node in items {
let text = String::from_utf8_lossy(&node.item);
println!("Item: {}, Count: {}", text, node.count);
}
}
#[test]
fn test_add_single_item() {
let k = 1;
let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let item = b"hello".to_vec();
topk.add(&item, 1);
let nodes = topk.list();
assert_eq!(nodes.len(), 1, "Should have exactly one item");
assert_eq!(nodes[0].count, 1, "Count should be 1");
assert_eq!(nodes[0].item, item, "Item should match");
}
#[test]
fn test_add_overwrite() {
let k = 1;
let width = 1;
let depth = 1;
let decay = 1.0;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
topk.decay_thresholds.iter_mut().for_each(|v| *v = u64::MAX);
let item1 = b"item1".to_vec();
topk.add(&item1, 1000);
let nodes = topk.list();
assert_eq!(nodes.len(), 1, "Should have exactly one item");
assert_eq!(nodes[0].count, 1000, "Invalid count");
assert_eq!(nodes[0].item, item1, "Item should match");
let item2 = b"item2".to_vec();
topk.add(&item2, 3000);
let nodes = topk.list();
assert_eq!(nodes.len(), 1, "Should have exactly one item");
assert_eq!(nodes[0].count, 2001, "Invalid count");
assert_eq!(nodes[0].item, item2, "Item should match");
}
#[test]
fn test_add_duplicate_items() {
let k = 2; let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let item1 = b"hello".to_vec();
let item2 = b"world".to_vec();
topk.add(&item1, 7);
topk.add(&item2, 7);
assert_eq!(topk.priority_queue.len(), k, "Should have exactly k items");
let nodes = topk
.priority_queue
.iter()
.map(|(item, count)| TopKNode {
item: item.clone(),
count,
})
.collect::<Vec<_>>();
assert_eq!(nodes.len(), 2, "Should have exactly two items");
assert_eq!(nodes[0].count, 7, "First item should have count 7");
assert_eq!(nodes[0].item, item1, "First item should match");
assert_eq!(nodes[1].count, 7, "Second item should have count 7");
assert_eq!(nodes[1].item, item2, "Second item should match");
}
#[test]
fn test_add_more_items_than_capacity() {
let k = 2;
let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let items = [
b"hello".to_vec(),
b"world".to_vec(),
b"ballynamoney".to_vec(),
b"lane".to_vec(),
];
for item in &items {
topk.add(item, 1);
}
let nodes = topk.list();
assert_eq!(nodes.len(), 2, "Should maintain capacity limit");
let mut counts = nodes.iter().map(|node| node.count).collect::<Vec<_>>();
counts.sort_unstable();
assert_eq!(counts, vec![1, 1], "All items should have count 1");
}
#[test]
fn test_add_with_different_decay() {
let k = 2;
let width = 100;
let depth = 5;
let decay = 0.5;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let items = [
b"hello".to_vec(),
b"world".to_vec(),
b"ballynamoney".to_vec(),
b"lane".to_vec(),
b"pear tree".to_vec(),
];
for item in &items {
topk.add(item, 1);
}
let nodes = topk.list();
assert_eq!(nodes.len(), 2, "Should maintain capacity limit");
let mut counts = nodes.iter().map(|node| node.count).collect::<Vec<_>>();
counts.sort_unstable();
assert_eq!(counts, vec![1, 1], "All items should have count 1");
}
#[test]
fn test_add_empty_input() {
let k = 2;
let width = 100;
let depth = 5;
let decay = 0.9;
let topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let nodes = topk.list();
assert_eq!(nodes.len(), 0, "Should have no items");
}
#[test]
fn test_add_varied_input() {
let k = 10; let width = 2000; let depth = 20; let decay = 0.98;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let mut items_with_frequencies = Vec::new();
for i in 0..100 {
let item = format!("item{}", i);
let frequency = i + 1;
items_with_frequencies.push((item, frequency));
}
for (item, frequency) in items_with_frequencies.iter() {
let item_bytes = item.as_bytes().to_vec();
for _ in 0..*frequency {
topk.add(&item_bytes, 1);
}
}
assert_eq!(
topk.priority_queue.len(),
k,
"Priority queue should contain exactly k items"
);
let top_items = topk
.priority_queue
.iter()
.map(|(item, count)| TopKNode {
item: std::str::from_utf8(item).unwrap().to_string().into_bytes(),
count,
})
.collect::<Vec<_>>();
let expected_top_items = (90..100)
.map(|i| format!("item{}", i).into_bytes())
.collect::<Vec<_>>();
let mut found = 0;
for expected_item in expected_top_items.iter() {
if top_items.iter().any(|node| &node.item == expected_item) {
found += 1;
} else {
println!(
"Warning: Expected item {} not in top-k",
std::str::from_utf8(expected_item).unwrap()
);
}
}
assert!(
found >= 8,
"At least 8 of the top 10 items should be in top-k"
);
}
#[test]
fn test_large_number_of_duplicates() {
let k = 10;
let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let item = b"test_item".to_vec();
let num_additions = 1000;
topk.add(&item, num_additions);
assert_eq!(
topk.count(&item),
num_additions,
"Count should match number of additions"
);
}
#[test]
fn test_multiple_distinct_items() {
let k = 2; let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let item1 = b"item1".to_vec();
let item2 = b"item2".to_vec();
let num_additions_item1 = 500;
let num_additions_item2 = 499;
topk.add(&item1, num_additions_item1);
topk.add(&item2, num_additions_item2);
assert_eq!(
topk.count(&item1),
num_additions_item1,
"Count should match number of additions for item1"
);
assert_eq!(
topk.count(&item2),
num_additions_item2,
"Count should match number of additions for item2"
);
assert!(topk.query(&item1), "item1 should be in top-k");
assert!(topk.query(&item2), "item2 should be in top-k");
}
#[test]
fn test_insertion_into_empty_buckets() {
let k = 5;
let width = 10;
let depth = 4;
let decay = 0.5;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let item = b"new_flow".to_vec();
topk.add(&item, 1);
let item_hash = topk.hasher.hash_one(&item);
assert!(
topk.buckets.iter().any(|row| row
.iter()
.any(|bucket| bucket.fingerprint == item_hash && bucket.count == 1)),
"Item should be inserted into an empty bucket with count 1"
);
assert!(topk.query(&item), "Item should be in priority queue");
}
#[test]
fn test_add_identical_frequencies() {
let k = 10;
let width = 1000;
let depth = 10;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let frequency = 5;
for i in 0..100 {
let item = format!("item{}", i);
let item_bytes = item.as_bytes().to_vec();
topk.add(&item_bytes, frequency);
}
assert_eq!(
topk.priority_queue.len(),
k,
"Priority queue should contain exactly k items"
);
for node in topk.list() {
assert_eq!(
node.count, frequency,
"All items should have the same frequency"
);
}
}
#[test]
fn test_small_k_value() {
let k = 2;
let width = 1000;
let depth = 10;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
for i in 0..3 {
let item = format!("item{}", i);
let item_bytes = item.as_bytes().to_vec();
topk.add(&item_bytes, i + 1);
}
assert_eq!(
topk.priority_queue.len(),
k,
"Priority queue should contain exactly k items"
);
let top_items = topk
.priority_queue
.iter()
.map(|(item, count)| TopKNode {
item: std::str::from_utf8(item).unwrap().to_string().into_bytes(),
count,
})
.collect::<Vec<_>>();
let expected_top_items = (1..3)
.map(|i| format!("item{}", i).into_bytes())
.collect::<Vec<_>>();
for expected_item in expected_top_items.iter() {
assert!(
top_items.iter().any(|node| &node.item == expected_item),
"Expected item {} to be in top-k",
std::str::from_utf8(expected_item).unwrap()
);
}
}
#[test]
fn test_count_with_sketch() {
let k = 2;
let width = 100;
let depth = 5;
let decay = 0.9;
let mut topk: TopK<Vec<u8>> = TopK::new(k, width, depth, decay);
let items = [
b"item1".to_vec(),
b"item2".to_vec(),
b"item3".to_vec(),
b"item4".to_vec(),
];
topk.add(&items[0], 1);
topk.add(&items[1], 1);
topk.add(&items[2], 2);
topk.add(&items[3], 5);
assert_eq!(topk.count(&items[0]), 1, "Count should be 1");
assert_eq!(topk.count(&items[1]), 1, "Count should be 1");
assert_eq!(topk.count(&items[2]), 2, "Count should be 2");
assert_eq!(topk.count(&items[3]), 5, "Count should be 5");
}
#[test]
fn test_merge_basic() {
let seed = 12345;
let mut hk1 = TopK::with_seed(3, 100, 5, 0.9, seed);
let mut hk2 = TopK::with_seed(3, 100, 5, 0.9, seed);
let items = [b"item1".to_vec(), b"item2".to_vec(), b"item3".to_vec()];
hk1.add(&items[0], 5);
hk1.add(&items[1], 3);
hk2.add(&items[0], 4);
hk2.add(&items[2], 6);
hk1.merge(&hk2).unwrap();
assert_eq!(
hk1.count(&items[0]),
9,
"Count should be sum of both instances"
);
assert_eq!(hk1.count(&items[1]), 3, "Count should be preserved");
assert_eq!(hk1.count(&items[2]), 6, "Count should be preserved");
}
#[test]
fn test_merge_incompatible_width() {
let mut hk1: TopK<Vec<u8>> = TopK::with_seed(3, 100, 5, 0.9, 12345);
let hk2 = TopK::with_seed(3, 50, 5, 0.9, 12345);
match hk1.merge(&hk2) {
Err(HeavyKeeperError::IncompatibleWidth {
self_width,
other_width,
}) => {
assert_eq!(self_width, 100, "Self width should be 100");
assert_eq!(other_width, 50, "Other width should be 50");
}
_ => panic!("Expected Width error"),
}
}
#[test]
fn test_merge_incompatible_depth() {
let mut hk1: TopK<Vec<u8>> = TopK::with_seed(3, 100, 5, 0.9, 12345);
let hk2 = TopK::with_seed(3, 100, 4, 0.9, 12345);
match hk1.merge(&hk2) {
Err(HeavyKeeperError::IncompatibleDepth {
self_depth,
other_depth,
}) => {
assert_eq!(self_depth, 5, "Self depth should be 5");
assert_eq!(other_depth, 4, "Other depth should be 4");
}
_ => panic!("Expected Depth error"),
}
}
#[test]
fn test_merge_with_overlapping_items() {
let seed = 12345;
let mut hk1 = TopK::with_seed(3, 100, 5, 0.9, seed);
let mut hk2 = TopK::with_seed(3, 100, 5, 0.9, seed);
let items = [b"common".to_vec(), b"unique1".to_vec(), b"unique2".to_vec()];
hk1.add(&items[0], 5);
hk2.add(&items[0], 5);
hk1.add(&items[1], 1);
hk2.add(&items[2], 1);
hk1.merge(&hk2).unwrap();
assert_eq!(
hk1.count(&items[0]),
10,
"Common item count should be doubled"
);
assert_eq!(
hk1.count(&items[1]),
1,
"Unique item count should be preserved"
);
assert_eq!(
hk1.count(&items[2]),
1,
"Unique item count should be preserved"
);
}
#[test]
fn test_decay_logic_with_mock_rng() {
let mut mock_rng = MockRngCoreTrait::new();
mock_rng
.expect_next_u64()
.times(1..) .return_const(0u64);
let topk = TopK::<Vec<u8>>::builder()
.k(1)
.width(1)
.depth(1)
.decay(0.9)
.rng(mock_rng)
.build()
.unwrap();
let item1 = b"item1".to_vec();
let item2 = b"item2".to_vec();
let large_count = 9999u64;
let mut topk = topk;
topk.decay_thresholds.iter_mut().for_each(|threshold| {
*threshold = u64::MAX; });
topk.add(&item1, large_count);
assert_eq!(topk.count(&item1), large_count);
let decay_iterations = 1000;
let mut last_count = topk.bucket_count(&item1);
for _ in 0..decay_iterations {
topk.add(&item2, 1);
let new_count = topk.bucket_count(&item1);
if new_count == 0 {
assert!(
!topk.query(&item1),
"item1 should be evicted if count is zero"
);
break;
} else {
assert!(
new_count < last_count,
"Bucket count should decrease with each decay"
);
last_count = new_count;
}
}
}
#[test]
fn test_decay_and_eviction() {
let mut mock_rng = MockRngCoreTrait::new();
mock_rng.expect_next_u64().times(1..).return_const(0u64);
let topk = TopK::<Vec<u8>>::builder()
.k(1)
.width(1)
.depth(1)
.decay(0.9)
.rng(mock_rng)
.build()
.unwrap();
let mut topk = topk;
topk.decay_thresholds.iter_mut().for_each(|threshold| {
*threshold = u64::MAX; });
let item1 = b"item1".to_vec();
let item2 = b"item2".to_vec();
let start_count = 10;
topk.add(&item1, start_count);
assert_eq!(topk.count(&item1), start_count);
let fp1 = crate::hash_composition::HashComposer::new(&topk.hasher, &item1).fingerprint();
let fp2 = crate::hash_composition::HashComposer::new(&topk.hasher, &item2).fingerprint();
println!("item1 fingerprint: {}", fp1);
println!("item2 fingerprint: {}", fp2);
println!("Initial state:");
println!(" item1 count: {}", topk.count(&item1));
println!(" item1 query: {}", topk.query(&item1));
println!(" item2 count: {}", topk.count(&item2));
println!(" item2 query: {}", topk.query(&item2));
let before = topk.bucket_count(&item1);
println!("Before adding item2: item1 bucket count = {}", before);
topk.add(&item2, 1);
let after = topk.bucket_count(&item1);
println!("After adding item2: item1 bucket count = {}", after);
println!("Final state:");
println!(" item1 count: {}", topk.count(&item1));
println!(" item1 query: {}", topk.query(&item1));
println!(" item2 count: {}", topk.count(&item2));
println!(" item2 query: {}", topk.query(&item2));
assert_eq!(
after,
before - 1,
"Bucket count should decrement by 1 after first decay"
);
assert!(
topk.query(&item1),
"Item1 should still be in the bucket after decay"
);
assert_eq!(
topk.bucket_count(&item1),
9,
"Item1 bucket count should be 9 after decay"
);
assert!(!topk.query(&item2), "Item2 should not be in the bucket yet");
assert_eq!(
topk.bucket_count(&item2),
0,
"Item2 bucket count should still be 0"
);
topk.add(&item2, 1);
let final_count = topk.bucket_count(&item1);
println!("After second decay: item1 bucket count = {}", final_count);
assert!(
final_count < 9,
"Item1 bucket count should continue to decrease with more decays"
);
}
#[test]
fn test_builder_missing_fields() {
let result = TopK::<Vec<u8>>::builder()
.width(100)
.depth(5)
.decay(0.9)
.build();
assert!(matches!(result, Err(BuilderError::MissingField { field }) if field == "k"));
let result = TopK::<Vec<u8>>::builder().k(10).depth(5).decay(0.9).build();
assert!(matches!(result, Err(BuilderError::MissingField { field }) if field == "width"));
let result = TopK::<Vec<u8>>::builder()
.k(10)
.width(100)
.decay(0.9)
.build();
assert!(matches!(result, Err(BuilderError::MissingField { field }) if field == "depth"));
let result = TopK::<Vec<u8>>::builder().k(10).width(100).depth(5).build();
assert!(matches!(result, Err(BuilderError::MissingField { field }) if field == "decay"));
}
#[test]
fn test_send_sync_issue() {
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
type Id = String;
type IdTopK = Arc<Mutex<HashMap<Id, TopK<String>>>>;
let topk_map: IdTopK = Arc::new(Mutex::new(HashMap::new()));
thread::spawn(move || {
let mut map = topk_map.lock().unwrap();
let topk = TopK::new(10, 100, 5, 0.9);
map.insert("test".to_string(), topk);
});
}
#[test]
fn test_borrow() {
let mut topk: TopK<String> = TopK::new(10, 100, 5, 0.9);
let item: &str = "foo";
topk.add(item, 1);
assert!(topk.query(item));
assert_eq!(topk.count(item), 1);
let mut topk: TopK<Vec<u8>> = TopK::new(10, 100, 5, 0.9);
let item: &[u8] = b"foo";
topk.add(item, 1);
assert!(topk.query(item));
assert_eq!(topk.count(item), 1);
}
#[test]
fn test_decay_threshold_no_usize_truncation_for_large_count() {
let topk: TopK<Vec<u8>> = TopK::new(10, 100, 5, 0.9);
let huge: u64 = (u32::MAX as u64) + 5000;
let thr = topk.decay_threshold_for_test(huge);
assert!(
thr < u64::MAX / 2,
"expected ~0 threshold for huge count, got {thr}"
);
}
#[test]
fn test_decay_threshold_no_powi_i32_overflow_for_huge_count() {
let topk: TopK<Vec<u8>> = TopK::new(10, 100, 5, 0.9);
let huge: u64 = (i32::MAX as u64) * 2048;
let thr = topk.decay_threshold_for_test(huge);
assert!(
thr < u64::MAX / 2,
"expected ~0 threshold for huge count, got {thr}"
);
}
#[test]
fn test_decay_probability_scaling_fix() {
let mut mock_rng = MockRngCoreTrait::new();
mock_rng
.expect_next_u64()
.times(1..) .return_const(1u64 << 63);
let mut topk = TopK::<Vec<u8>>::builder()
.k(1)
.width(1)
.depth(1)
.decay(1.0)
.rng(mock_rng)
.build()
.unwrap();
let item1 = b"item1".to_vec();
let item2 = b"item2".to_vec();
topk.add(&item1, 1);
assert_eq!(topk.bucket_count(&item1), 1);
assert_eq!(topk.bucket_count(&item2), 0);
topk.add(&item2, 1);
assert_eq!(topk.bucket_count(&item1), 0);
assert_eq!(topk.bucket_count(&item2), 1);
}
}