use crate::common::input::HHItem;
use crate::common::{CommonHeap, KeepSmallest};
use crate::{DataInput, HeapItem, hash_item64_seeded, hash64_seeded, input_to_owned};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct HHHeap {
heap: CommonHeap<HHItem, KeepSmallest>,
positions: HashMap<u64, Vec<(HeapItem, usize)>>,
k: usize,
}
impl HHHeap {
pub fn new(k: usize) -> Self {
HHHeap {
heap: CommonHeap::new_min(k),
positions: HashMap::with_capacity(k),
k,
}
}
pub fn find(&self, key: &DataInput) -> Option<usize> {
let slot = self.slot_for_input(key);
self.positions.get(&slot).and_then(|bucket| {
bucket
.iter()
.find_map(|(value, idx)| if value == key { Some(*idx) } else { None })
})
}
pub fn find_heap_item(&self, key: &HeapItem) -> Option<usize> {
let slot = self.slot_for_item(key);
self.positions.get(&slot).and_then(|bucket| {
bucket
.iter()
.find_map(|(value, idx)| if value == key { Some(*idx) } else { None })
})
}
pub fn update(&mut self, key: &DataInput, count: i64) -> bool {
if let Some(idx) = self.find(key) {
self.heap[idx].count = count;
self.heap.update_at(idx);
self.refresh_positions();
return true;
}
if !self.should_accept_new(count) {
return true;
}
let owned = input_to_owned(key);
self.heap.push(HHItem::create_item(owned, count));
self.refresh_positions();
true
}
pub fn update_heap_item(&mut self, key: &HeapItem, count: i64) -> bool {
if let Some(idx) = self.find_heap_item(key) {
self.heap[idx].count = count;
self.heap.update_at(idx);
self.refresh_positions();
return true;
}
if !self.should_accept_new(count) {
return true;
}
self.heap.push(HHItem::create_item(key.to_owned(), count));
self.refresh_positions();
true
}
pub fn heap(&self) -> &[HHItem] {
self.heap.as_slice()
}
pub fn print_heap(&self) {
println!("======== Beginning of Heap ========");
for item in self.heap.iter() {
item.print_item();
}
println!("============ Heap Ends ============");
}
pub fn clear(&mut self) {
self.heap.clear();
self.positions.clear();
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub fn from_heap(other: &HHHeap) -> Self {
other.clone()
}
pub fn capacity(&self) -> usize {
self.k
}
#[inline]
fn should_accept_new(&self, count: i64) -> bool {
if self.heap.len() < self.k {
return true;
}
self.heap
.peek()
.map(|min_item| count > min_item.count)
.unwrap_or(true)
}
fn refresh_positions(&mut self) {
self.positions.clear();
for (idx, item) in self.heap.iter().enumerate() {
let slot = self.slot_for_item(&item.key);
self.positions
.entry(slot)
.or_default()
.push((item.key.clone(), idx));
}
}
#[inline]
fn slot_for_input(&self, key: &DataInput) -> u64 {
hash64_seeded(0, key)
}
#[inline]
fn slot_for_item(&self, key: &HeapItem) -> u64 {
hash_item64_seeded(0, key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
CommonHeap, CommonHeapOrder, DataInput, HeapItem, KeepLargest, KeepSmallest,
common::input::HHItem,
};
fn heap_item_from_str(value: &str) -> HeapItem {
HeapItem::String(value.to_string())
}
#[test]
fn heap_retains_top_k_items_by_count() {
let mut heap = HHHeap::new(3);
for i in 1..=5 {
let key = format!("key-{i}");
let key_item = heap_item_from_str(&key);
heap.update_heap_item(&key_item, i as i64);
}
assert_eq!(heap.heap.len(), 3);
let mut counts: Vec<i64> = heap.heap.iter().map(|item| item.count).collect();
counts.sort_unstable();
assert_eq!(counts, vec![3, 4, 5]);
}
#[test]
fn update_count_increments_existing_entry() {
let mut heap = HHHeap::new(4);
let key_item = heap_item_from_str("alpha");
let mut count = 0;
for _ in 0..3 {
count += 1;
heap.update_heap_item(&key_item, count);
}
let idx = heap.find_heap_item(&key_item).expect("alpha present");
assert_eq!(heap.heap[idx].count, 3);
}
#[test]
fn clean_resets_heap_state() {
let mut heap = HHHeap::new(2);
let key_a = heap_item_from_str("a");
let key_b = heap_item_from_str("b");
heap.update_heap_item(&key_a, 5);
heap.update_heap_item(&key_b, 6);
assert_eq!(heap.heap.len(), 2);
heap.clear();
assert!(heap.heap.is_empty());
}
#[test]
fn test_min_heap_basic() {
let mut heap = CommonHeap::<i32, KeepSmallest>::new_min(5);
heap.push(5);
heap.push(3);
heap.push(7);
heap.push(1);
assert_eq!(heap.peek(), Some(&1));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(5));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_max_heap_basic() {
let mut heap = CommonHeap::<i32, KeepLargest>::new_max(5);
heap.push(5);
heap.push(3);
heap.push(7);
heap.push(1);
assert_eq!(heap.peek(), Some(&7));
assert_eq!(heap.pop(), Some(7));
assert_eq!(heap.pop(), Some(5));
assert_eq!(heap.pop(), Some(3));
assert_eq!(heap.pop(), Some(1));
assert_eq!(heap.pop(), None);
}
#[test]
fn test_bounded_heap_capacity() {
let mut heap = CommonHeap::<i32, KeepSmallest>::new_min(3);
heap.push(5);
heap.push(3);
heap.push(7);
assert_eq!(heap.len(), 3);
heap.push(1);
assert_eq!(heap.len(), 3);
heap.push(10);
assert_eq!(heap.len(), 3);
let mut vals: Vec<i32> = vec![];
while let Some(v) = heap.pop() {
vals.push(v);
}
vals.sort();
assert_eq!(vals, vec![5, 7, 10]);
}
#[test]
fn test_update_at() {
let mut heap = CommonHeap::<i32, KeepSmallest>::new_min(5);
heap.push(10);
heap.push(20);
heap.push(5);
heap[1] = 3;
heap.update_at(1);
assert_eq!(heap.peek(), Some(&3));
}
#[test]
fn test_custom_struct_with_ord() {
let mut heap = CommonHeap::<HHItem, KeepSmallest>::new_min(3);
heap.push(HHItem::new(DataInput::String("five".to_owned()), 5));
heap.push(HHItem::new(DataInput::String("three".to_owned()), 3));
heap.push(HHItem::new(DataInput::String("seven".to_owned()), 7));
assert_eq!(heap.peek().map(|item| item.count), Some(3));
}
#[test]
fn test_topk_use_case() {
let mut heap = CommonHeap::<HHItem, KeepSmallest>::new_min(3);
for i in 1..=5 {
heap.push(HHItem::new(
DataInput::String(format!("key-{i}").to_owned()),
i,
));
}
assert_eq!(heap.len(), 3);
let mut counts: Vec<i64> = heap.iter().map(|item| item.count).collect();
counts.sort_unstable();
assert_eq!(counts, vec![3, 4, 5]);
let found = heap
.iter()
.find(|item| item.key == HeapItem::String("key-4".to_owned()));
assert!(found.is_some());
assert_eq!(found.unwrap().count, 4);
}
#[test]
fn test_heap_size() {
use std::mem::size_of;
let vec_size = size_of::<Vec<u64>>();
let heap_min_size = size_of::<CommonHeap<u64, KeepSmallest>>();
let heap_max_size = size_of::<CommonHeap<u64, KeepLargest>>();
println!("Vec<u64> size: {vec_size}");
println!("Heap<u64, MinHeap> size: {heap_min_size}");
println!("Heap<u64, MaxHeap> size: {heap_max_size}");
assert_eq!(heap_min_size, vec_size + size_of::<usize>());
assert_eq!(heap_max_size, vec_size + size_of::<usize>());
}
#[test]
fn test_topk_with_custom_comparator() {
#[derive(Clone)]
struct CompareByCount;
impl CommonHeapOrder<HHItem> for CompareByCount {
fn should_swap(&self, parent: &HHItem, child: &HHItem) -> bool {
child.count < parent.count
}
fn should_replace_root(&self, root: &HHItem, new_value: &HHItem) -> bool {
new_value.count > root.count
}
}
let mut heap = CommonHeap::<HHItem, CompareByCount>::with_capacity(3, CompareByCount);
heap.push(HHItem::new(DataInput::String("a".to_owned()), 5));
heap.push(HHItem::new(DataInput::String("b".to_owned()), 3));
heap.push(HHItem::new(DataInput::String("c".to_owned()), 7));
heap.push(HHItem::new(DataInput::String("d".to_owned()), 1)); heap.push(HHItem::new(DataInput::String("e".to_owned()), 10));
assert_eq!(heap.len(), 3);
let min_count = heap.peek().map(|item| item.count);
assert_eq!(min_count, Some(5)); }
#[test]
fn test_exact_topk_heap_replacement() {
let mut heap = CommonHeap::<HHItem, KeepSmallest>::new_min(3);
let find_and_update =
|heap: &mut CommonHeap<HHItem, KeepSmallest>, key: &str, count: i64| {
let idx_opt = heap
.iter()
.position(|item| item.key == HeapItem::String(key.to_owned()));
if let Some(idx) = idx_opt {
heap[idx].count = count;
heap.update_at(idx);
} else {
heap.push(HHItem::new(DataInput::Str(key), count));
}
};
for i in 1..=5 {
let key = format!("key-{i}");
find_and_update(&mut heap, &key, i);
}
assert_eq!(heap.len(), 3);
let mut counts: Vec<i64> = heap.iter().map(|item| item.count).collect();
counts.sort_unstable();
assert_eq!(counts, vec![3, 4, 5]);
let found = heap
.iter()
.find(|item| item.key == HeapItem::String("key-4".to_owned()));
assert!(found.is_some());
assert_eq!(found.unwrap().count, 4);
heap.clear();
assert!(heap.is_empty());
}
}