use num_traits::{One, Zero};
use std::borrow::Borrow;
use std::collections::{hash_map::Iter, HashMap};
use std::hash::Hash;
use std::iter;
use std::ops::{Add, AddAssign, BitAnd, BitOr, Deref, DerefMut, Index, IndexMut, Sub, SubAssign};
type CounterMap<T, N> = HashMap<T, N>;
#[derive(Clone, PartialEq, Eq, Debug, Default)]
pub struct Counter<T: Hash + Eq, N = usize> {
map: CounterMap<T, N>,
zero: N,
}
impl<T, N> Counter<T, N>
where
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
pub fn new() -> Counter<T, N> {
Counter {
map: HashMap::new(),
zero: N::zero(),
}
}
pub fn init<I>(iterable: I) -> Counter<T, N>
where
I: IntoIterator<Item = T>,
{
let mut counter = Counter::new();
counter.update(iterable);
counter
}
pub fn update<I>(&mut self, iterable: I)
where
I: IntoIterator<Item = T>,
{
for item in iterable.into_iter() {
let entry = self.map.entry(item).or_insert_with(N::zero);
*entry += N::one();
}
}
pub fn into_map(self) -> HashMap<T, N> {
self.map
}
pub fn subtract<I>(&mut self, iterable: I)
where
I: IntoIterator<Item = T>,
{
for item in iterable.into_iter() {
let mut remove = false;
if let Some(entry) = self.map.get_mut(&item) {
if *entry > N::zero() {
*entry -= N::one();
}
remove = *entry == N::zero();
}
if remove {
self.map.remove(&item);
}
}
}
}
impl<T, N> Counter<T, N>
where
T: Hash + Eq + Clone,
N: Clone + Ord,
{
pub fn most_common(&self) -> Vec<(T, N)> {
use std::cmp::Ordering;
self.most_common_tiebreaker(|ref _a, ref _b| Ordering::Equal)
}
pub fn most_common_tiebreaker<F>(&self, tiebreaker: F) -> Vec<(T, N)>
where
F: Fn(&T, &T) -> ::std::cmp::Ordering,
{
use std::cmp::Ordering;
let mut items = self
.map
.iter()
.map(|(key, count)| (key.clone(), count.clone()))
.collect::<Vec<_>>();
items.sort_by(|&(ref a_item, ref a_count), &(ref b_item, ref b_count)| {
match b_count.cmp(&a_count) {
Ordering::Equal => tiebreaker(&a_item, &b_item),
unequal => unequal,
}
});
items
}
}
impl<T, N> Counter<T, N>
where
T: Hash + Eq + Clone + Ord,
N: Clone + Ord,
{
pub fn most_common_ordered(&self) -> Vec<(T, N)> {
self.most_common_tiebreaker(|ref a, ref b| a.cmp(&b))
}
}
impl<T, N> AddAssign for Counter<T, N>
where
T: Clone + Hash + Eq,
N: Clone + Zero + AddAssign,
{
fn add_assign(&mut self, rhs: Self) {
for (key, value) in rhs.map.iter() {
let entry = self.map.entry(key.clone()).or_insert_with(N::zero);
*entry += value.clone();
}
}
}
impl<T, N> Add for Counter<T, N>
where
T: Clone + Hash + Eq,
N: Clone + PartialOrd + PartialEq + AddAssign + Zero,
{
type Output = Counter<T, N>;
fn add(mut self, rhs: Counter<T, N>) -> Self::Output {
self += rhs;
self
}
}
impl<T, N> SubAssign for Counter<T, N>
where
T: Hash + Eq,
N: Clone + PartialOrd + PartialEq + SubAssign + Zero,
{
fn sub_assign(&mut self, rhs: Self) {
for (key, value) in rhs.map.iter() {
let mut remove = false;
if let Some(entry) = self.map.get_mut(key) {
if *entry >= *value {
*entry -= value.clone();
} else {
remove = true;
}
if *entry == N::zero() {
remove = true;
}
}
if remove {
self.map.remove(key);
}
}
}
}
impl<T, N> Sub for Counter<T, N>
where
T: Hash + Eq,
N: Clone + PartialOrd + PartialEq + SubAssign + Zero,
{
type Output = Counter<T, N>;
fn sub(mut self, rhs: Counter<T, N>) -> Self::Output {
self -= rhs;
self
}
}
impl<T, N> BitAnd for Counter<T, N>
where
T: Clone + Hash + Eq,
N: Clone + Ord + AddAssign + SubAssign + Zero + One,
{
type Output = Counter<T, N>;
fn bitand(self, rhs: Counter<T, N>) -> Self::Output {
use std::cmp::min;
use std::collections::HashSet;
let self_keys = self.map.keys().collect::<HashSet<_>>();
let other_keys = rhs.map.keys().collect::<HashSet<_>>();
let both_keys = self_keys.intersection(&other_keys);
let mut counter = Counter::new();
for key in both_keys {
counter.map.insert(
(*key).clone(),
min(self.map.get(*key).unwrap(), rhs.map.get(*key).unwrap()).clone(),
);
}
counter
}
}
impl<T, N> BitOr for Counter<T, N>
where
T: Clone + Hash + Eq,
N: Clone + Ord + Zero,
{
type Output = Counter<T, N>;
fn bitor(mut self, rhs: Counter<T, N>) -> Self::Output {
use std::cmp::max;
for (key, value) in rhs.map.iter() {
let entry = self.map.entry(key.clone()).or_insert_with(N::zero);
*entry = max(&*entry, value).clone();
}
self
}
}
impl<T, N> Deref for Counter<T, N>
where
T: Hash + Eq,
N: Clone,
{
type Target = CounterMap<T, N>;
fn deref(&self) -> &CounterMap<T, N> {
&self.map
}
}
impl<T, N> DerefMut for Counter<T, N>
where
T: Hash + Eq,
N: Clone,
{
fn deref_mut(&mut self) -> &mut CounterMap<T, N> {
&mut self.map
}
}
impl<'a, T, N> IntoIterator for &'a Counter<T, N>
where
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
type Item = (&'a T, &'a N);
type IntoIter = Iter<'a, T, N>;
fn into_iter(self) -> Iter<'a, T, N> {
self.map.iter()
}
}
impl<T, Q, N> Index<&'_ Q> for Counter<T, N>
where
T: Hash + Eq + Borrow<Q>,
Q: Hash + Eq,
N: Zero,
{
type Output = N;
fn index(&self, key: &'_ Q) -> &N {
self.map.get(key).unwrap_or(&self.zero)
}
}
impl<T, Q, N> IndexMut<&'_ Q> for Counter<T, N>
where
T: Hash + Eq + Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T>,
N: Zero,
{
fn index_mut(&mut self, key: &'_ Q) -> &mut N {
self.map.entry(key.to_owned()).or_insert_with(N::zero)
}
}
impl<I, T, N> AddAssign<I> for Counter<T, N>
where
I: IntoIterator<Item = T>,
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
fn add_assign(&mut self, rhs: I) {
self.update(rhs);
}
}
impl<I, T, N> Add<I> for Counter<T, N>
where
I: IntoIterator<Item = T>,
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
type Output = Self;
fn add(mut self, rhs: I) -> Self::Output {
self.update(rhs);
self
}
}
impl<I, T, N> SubAssign<I> for Counter<T, N>
where
I: IntoIterator<Item = T>,
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
fn sub_assign(&mut self, rhs: I) {
self.subtract(rhs);
}
}
impl<I, T, N> Sub<I> for Counter<T, N>
where
I: IntoIterator<Item = T>,
T: Clone + Hash + Eq,
N: Clone + PartialOrd + AddAssign + SubAssign + Zero + One,
{
type Output = Self;
fn sub(mut self, rhs: I) -> Self::Output {
self.subtract(rhs);
self
}
}
impl<T, N> iter::FromIterator<T> for Counter<T, N>
where
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Counter::<T, N>::init(iter)
}
}
impl<T, N> iter::FromIterator<(T, N)> for Counter<T, N>
where
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
fn from_iter<I: IntoIterator<Item = (T, N)>>(iter: I) -> Self {
let mut cnt = Counter::new();
for (item, item_count) in iter.into_iter() {
let entry = cnt.map.entry(item).or_insert_with(N::zero);
*entry += item_count;
}
cnt
}
}
impl<T, N> Extend<T> for Counter<T, N>
where
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
self.update(iter);
}
}
impl<T, N> Extend<(T, N)> for Counter<T, N>
where
T: Hash + Eq,
N: PartialOrd + AddAssign + SubAssign + Zero + One,
{
fn extend<I: IntoIterator<Item = (T, N)>>(&mut self, iter: I) {
for (item, item_count) in iter.into_iter() {
let entry = self.map.entry(item).or_insert_with(N::zero);
*entry += item_count;
}
}
}
impl<'a, T: 'a, N: 'a> Extend<(&'a T, &'a N)> for Counter<T, N>
where
T: Hash + Eq + Copy,
N: PartialOrd + AddAssign + SubAssign + Zero + One + Copy,
{
fn extend<I: IntoIterator<Item = (&'a T, &'a N)>>(&mut self, iter: I) {
for (item, item_count) in iter.into_iter() {
let entry = self.map.entry(*item).or_insert_with(N::zero);
*entry += *item_count;
}
}
}
#[cfg(test)]
mod tests {
use maplit::hashmap;
use super::*;
use std::collections::HashMap;
#[test]
fn test_creation() {
let _: Counter<usize> = Counter::new();
let initializer = &[1];
let counter = Counter::init(initializer);
let mut expected = HashMap::new();
static ONE: usize = 1;
expected.insert(&ONE, 1);
assert!(counter.map == expected);
}
#[test]
fn test_update() {
let mut counter = Counter::init("abbccc".chars());
let expected = hashmap! {
'a' => 1,
'b' => 2,
'c' => 3,
};
assert!(counter.map == expected);
counter.update("aeeeee".chars());
let expected = hashmap! {
'a' => 2,
'b' => 2,
'c' => 3,
'e' => 5,
};
assert!(counter.map == expected);
}
#[test]
fn test_add_update_iterable() {
let mut counter = Counter::init("abbccc".chars());
let expected = hashmap! {
'a' => 1,
'b' => 2,
'c' => 3,
};
assert!(counter.map == expected);
counter += "aeeeee".chars();
let expected = hashmap! {
'a' => 2,
'b' => 2,
'c' => 3,
'e' => 5,
};
assert!(counter.map == expected);
}
#[test]
fn test_add_update_counter() {
let mut counter = Counter::init("abbccc".chars());
let expected = hashmap! {
'a' => 1,
'b' => 2,
'c' => 3,
};
assert!(counter.map == expected);
let other = Counter::init("aeeeee".chars());
counter += other;
let expected = hashmap! {
'a' => 2,
'b' => 2,
'c' => 3,
'e' => 5,
};
assert!(counter.map == expected);
}
#[test]
fn test_subtract() {
let mut counter = Counter::init("abbccc".chars());
counter.subtract("bbccddd".chars());
let expected = hashmap! {
'a' => 1,
'c' => 1,
};
assert!(counter.map == expected);
}
#[test]
fn test_sub_update_iterable() {
let mut counter = Counter::init("abbccc".chars());
counter -= "bbccddd".chars();
let expected = hashmap! {
'a' => 1,
'c' => 1,
};
assert!(counter.map == expected);
}
#[test]
fn test_sub_update_counter() {
let mut counter = Counter::init("abbccc".chars());
let other = Counter::init("bbccddd".chars());
counter -= other;
let expected = hashmap! {
'a' => 1,
'c' => 1,
};
assert!(counter.map == expected);
}
#[test]
fn test_composite_add_sub() {
let mut counts = Counter::<_>::init(
"able babble table babble rabble table able fable scrabble".split_whitespace(),
);
counts += "cain and abel fable table cable".split_whitespace();
let other_counts = Counter::init("scrabble cabbie fable babble".split_whitespace());
let _diff = counts - other_counts;
}
#[test]
fn test_most_common() {
let counter = Counter::init("abbccc".chars());
let by_common = counter.most_common();
let expected = vec![('c', 3), ('b', 2), ('a', 1)];
assert!(by_common == expected);
}
#[test]
fn test_most_common_tiebreaker() {
let counter = Counter::init("eaddbbccc".chars());
let by_common = counter.most_common_tiebreaker(|&a, &b| a.cmp(&b));
let expected = vec![('c', 3), ('b', 2), ('d', 2), ('a', 1), ('e', 1)];
assert!(by_common == expected);
}
#[test]
fn test_most_common_tiebreaker_reversed() {
let counter = Counter::init("eaddbbccc".chars());
let by_common = counter.most_common_tiebreaker(|&a, &b| b.cmp(&a));
let expected = vec![('c', 3), ('d', 2), ('b', 2), ('e', 1), ('a', 1)];
assert!(by_common == expected);
}
#[test]
fn test_most_common_ordered() {
let counter = Counter::init("eaddbbccc".chars());
let by_common = counter.most_common_ordered();
let expected = vec![('c', 3), ('b', 2), ('d', 2), ('a', 1), ('e', 1)];
assert!(by_common == expected);
}
#[test]
fn test_add() {
let d = Counter::<_>::init("abbccc".chars());
let e = Counter::<_>::init("bccddd".chars());
let out = d + e;
let expected = Counter::init("abbbcccccddd".chars());
assert!(out == expected);
}
#[test]
fn test_sub() {
let d = Counter::<_>::init("abbccc".chars());
let e = Counter::<_>::init("bccddd".chars());
let out = d - e;
let expected = Counter::init("abc".chars());
assert!(out == expected);
}
#[test]
fn test_intersection() {
let d = Counter::<_>::init("abbccc".chars());
let e = Counter::<_>::init("bccddd".chars());
let out = d & e;
let expected = Counter::init("bcc".chars());
assert!(out == expected);
}
#[test]
fn test_union() {
let d = Counter::<_>::init("abbccc".chars());
let e = Counter::<_>::init("bccddd".chars());
let out = d | e;
let expected = Counter::init("abbcccddd".chars());
assert!(out == expected);
}
#[test]
fn test_delete_key_from_backing_map() {
let mut counter = Counter::<_>::init("aa-bb-cc".chars());
counter.remove(&'-');
assert!(counter == Counter::init("aabbcc".chars()));
}
#[test]
fn test_from_iter_simple() {
let counter = "abbccc".chars().collect::<Counter<_>>();
let expected = hashmap! {
'a' => 1,
'b' => 2,
'c' => 3,
};
assert!(counter.map == expected);
}
#[test]
fn test_from_iter_tuple() {
let items = [('a', 1), ('b', 2), ('c', 3)];
let counter = items.iter().cloned().collect::<Counter<_>>();
let expected: HashMap<char, usize> = items.iter().cloned().collect();
assert_eq!(counter.map, expected);
}
#[test]
fn test_from_iter_tuple_with_duplicates() {
let items = [('a', 1), ('b', 2), ('c', 3)];
let counter = items
.iter()
.cycle()
.take(items.len() * 2)
.cloned()
.collect::<Counter<_>>();
let expected: HashMap<char, usize> = items.iter().map(|(c, n)| (*c, n * 2)).collect();
assert_eq!(counter.map, expected);
}
#[test]
fn test_extend_simple() {
let mut counter = "abbccc".chars().collect::<Counter<_>>();
counter.extend("bccddd".chars());
let expected = hashmap! {
'a' => 1,
'b' => 3,
'c' => 5,
'd' => 3,
};
assert!(counter.map == expected);
}
#[test]
fn test_extend_tuple() {
let mut counter = "bccddd".chars().collect::<Counter<_>>();
let items = [('a', 1), ('b', 2), ('c', 3)];
counter.extend(items.iter().cloned());
let expected = hashmap! {
'a' => 1,
'b' => 3,
'c' => 5,
'd' => 3,
};
assert_eq!(counter.map, expected);
}
#[test]
fn test_extend_tuple_with_duplicates() {
let mut counter = "ccc".chars().collect::<Counter<_>>();
let items = [('a', 1), ('b', 2), ('c', 3)];
counter.extend(items.iter().cycle().take(items.len() * 2 - 1).cloned());
let expected: HashMap<char, usize> = items.iter().map(|(c, n)| (*c, n * 2)).collect();
assert_eq!(counter.map, expected);
}
#[test]
fn test_count_minimal_type() {
#[derive(Debug, Hash, PartialEq, Eq)]
struct Inty {
i: usize,
}
impl Inty {
pub fn new(i: usize) -> Inty {
Inty { i }
}
}
let intys = vec![
Inty::new(8),
Inty::new(0),
Inty::new(0),
Inty::new(8),
Inty::new(6),
Inty::new(7),
Inty::new(5),
Inty::new(3),
Inty::new(0),
Inty::new(9),
];
let inty_counts = Counter::init(intys);
assert!(inty_counts.map.get(&Inty { i: 8 }) == Some(&2));
assert!(inty_counts.map.get(&Inty { i: 0 }) == Some(&3));
assert!(inty_counts.map.get(&Inty { i: 6 }) == Some(&1));
}
#[test]
fn test_collect() {
let counter: Counter<_> = "abbccc".chars().collect();
let expected = hashmap! {
'a' => 1,
'b' => 2,
'c' => 3,
};
assert!(counter.map == expected);
}
#[test]
fn test_non_usize_count() {
let counter: Counter<_, i8> = "abbccc".chars().collect();
let expected = hashmap! {
'a' => 1,
'b' => 2,
'c' => 3,
};
assert!(counter.map == expected);
}
}