use crate::countminsketch::CountMinSketch;
use std::cmp::Ordering;
use std::collections::hash_map::Entry;
use std::collections::{BTreeSet, HashMap};
use std::fmt;
use std::hash::Hash;
use std::rc::Rc;
#[derive(Debug)]
struct TreeEntry<T>
where
T: Eq + Ord,
{
obj: Rc<T>,
n: usize,
}
impl<T> Clone for TreeEntry<T>
where
T: Eq + Ord,
{
fn clone(&self) -> Self {
Self {
obj: Rc::clone(&self.obj),
n: self.n,
}
}
}
impl<T> PartialEq for TreeEntry<T>
where
T: Eq + Ord,
{
fn eq(&self, other: &Self) -> bool {
self.obj == other.obj
}
}
impl<T> Eq for TreeEntry<T> where T: Eq + Ord {}
impl<T> PartialOrd for TreeEntry<T>
where
T: Eq + Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.n.cmp(&other.n) {
Ordering::Greater => Some(Ordering::Greater),
Ordering::Less => Some(Ordering::Less),
Ordering::Equal => Some(self.obj.cmp(&other.obj)),
}
}
}
impl<T> Ord for TreeEntry<T>
where
T: Eq + Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(Clone)]
pub struct CMSHeap<T>
where
T: Clone + Eq + Hash + Ord,
{
cms: CountMinSketch<T>,
obj2count: HashMap<Rc<T>, usize>,
tree: BTreeSet<TreeEntry<T>>,
k: usize,
}
impl<T> CMSHeap<T>
where
T: Clone + Eq + Hash + Ord,
{
pub fn new(k: usize, cms: CountMinSketch<T>) -> Self {
assert!(k > 0, "k must be greater than 0");
Self {
cms,
obj2count: HashMap::new(),
tree: BTreeSet::new(),
k,
}
}
pub fn k(&self) -> usize {
self.k
}
pub fn add(&mut self, obj: T) {
let rc = Rc::new(obj);
let count = self.cms.add(&rc);
let size = self.obj2count.len();
match self.obj2count.entry(Rc::clone(&rc)) {
Entry::Occupied(mut o) => {
let n = o.get_mut();
*n += 1;
let mut entry = TreeEntry {
obj: Rc::clone(&rc),
n: *n - 1,
};
self.tree.remove(&entry);
entry.n += 1;
self.tree.insert(entry);
}
Entry::Vacant(v) => {
if size < self.k {
debug_assert!(count == 1);
v.insert(1);
self.tree.insert(TreeEntry {
obj: Rc::clone(&rc),
n: 1,
});
} else {
let min: TreeEntry<T> = (*self.tree.iter().next().unwrap()).clone();
if count > min.n {
self.tree.remove(&min);
self.tree.insert(TreeEntry {
obj: Rc::clone(&rc),
n: count,
});
self.obj2count.insert(rc, count);
self.obj2count.remove(&min.obj);
}
}
}
}
}
pub fn iter(&self) -> impl '_ + Iterator<Item = T> {
self.tree.iter().map(|x| (*x.obj).clone())
}
pub fn is_empty(&self) -> bool {
self.obj2count.is_empty()
}
pub fn clear(&mut self) {
self.cms.clear();
self.tree.clear();
self.obj2count.clear();
}
}
impl<T> fmt::Debug for CMSHeap<T>
where
T: Clone + Eq + Hash + Ord,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CMSHeap {{ k: {} }}", self.k)
}
}
impl<T> Extend<T> for CMSHeap<T>
where
T: Clone + Eq + Hash + Ord,
{
fn extend<S: IntoIterator<Item = T>>(&mut self, iter: S) {
for elem in iter {
self.add(elem);
}
}
}
#[cfg(test)]
mod tests {
use super::CMSHeap;
use crate::countminsketch::CountMinSketch;
#[test]
#[should_panic(expected = "k must be greater than 0")]
fn new_panics_k0() {
let cms = CountMinSketch::with_params(10, 20);
CMSHeap::<usize>::new(0, cms);
}
#[test]
fn getter() {
let cms = CountMinSketch::with_params(10, 20);
let tk: CMSHeap<usize> = CMSHeap::new(2, cms);
assert_eq!(tk.k(), 2);
}
#[test]
fn add_1() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk = CMSHeap::new(2, cms);
tk.add(1);
assert_eq!(tk.iter().collect::<Vec<u32>>(), vec![1]);
}
#[test]
fn add_2_same() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk = CMSHeap::new(2, cms);
tk.add(1);
tk.add(1);
assert_eq!(tk.iter().collect::<Vec<u32>>(), vec![1]);
}
#[test]
fn add_2_different() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk = CMSHeap::new(2, cms);
tk.add(1);
tk.add(2);
assert_eq!(tk.iter().collect::<Vec<u32>>(), vec![1, 2]);
}
#[test]
fn add_n() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk = CMSHeap::new(2, cms);
for i in 0..5 {
tk.add(i);
}
for _ in 0..100 {
tk.add(99);
}
for _ in 0..100 {
tk.add(100);
}
for i in 0..5 {
tk.add(i);
}
assert_eq!(tk.iter().collect::<Vec<u32>>(), vec![99, 100]);
}
#[test]
fn is_empty() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk = CMSHeap::new(2, cms);
assert!(tk.is_empty());
tk.add(0);
assert!(!tk.is_empty());
}
#[test]
fn clear() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk = CMSHeap::new(2, cms);
tk.add(0);
tk.clear();
assert!(tk.is_empty());
tk.add(1);
assert_eq!(tk.iter().collect::<Vec<u32>>(), vec![1]);
}
#[test]
fn clone() {
let cms = CountMinSketch::with_params(10, 20);
let mut tk1 = CMSHeap::new(2, cms);
tk1.add(0);
let mut tk2 = tk1.clone();
tk2.add(1);
assert_eq!(tk1.iter().collect::<Vec<u32>>(), vec![0]);
assert_eq!(tk2.iter().collect::<Vec<u32>>(), vec![0, 1]);
}
#[test]
fn debug() {
let cms = CountMinSketch::with_params(10, 20);
let tk: CMSHeap<usize> = CMSHeap::new(2, cms);
assert_eq!(format!("{:?}", tk), "CMSHeap { k: 2 }");
}
#[test]
fn extend() {
let cms = CountMinSketch::with_params(10, 10);
let mut tk = CMSHeap::new(2, cms);
tk.extend(vec![0, 1]);
assert_eq!(tk.iter().collect::<Vec<u32>>(), vec![0, 1]);
}
}