use super::exhaust::ExhaustIter;
use ahash::RandomState;
use core::borrow::Borrow;
use core::hash::{BuildHasher, Hash};
use core::mem::{swap, MaybeUninit};
use hashbrown::raw::{Bucket, RawIter, RawIterHash, RawTable};
use hashbrown::TryReserveError;
#[cfg(not(feature = "nightly"))]
use core::convert::identity as likely;
#[cfg(feature = "nightly")]
use core::intrinsics::likely;
#[inline]
fn equivalent_key<Q, K, V>(key: &Q) -> impl Fn(&(K, V)) -> bool + '_
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
move |kv| key == kv.0.borrow()
}
#[inline]
fn bucket_with_key<Q, K, V>(key: &Q) -> impl Fn(&Bucket<(K, V)>) -> bool + '_
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
move |bucket| likely(unsafe { bucket.as_ref() }.0.borrow() == key)
}
#[inline]
fn make_hash<T, S>(hash_builder: &S, value: &T) -> u64
where
T: ?Sized + Hash,
S: BuildHasher,
{
hash_builder.hash_one(value)
}
#[inline]
fn make_hasher<K, V, S>(hash_builder: &S) -> impl Fn(&(K, V)) -> u64 + '_
where
K: Hash,
S: BuildHasher,
{
move |val| make_hash(hash_builder, &val.0)
}
#[derive(Clone)]
pub struct MashMap<K, V, S = RandomState> {
hash_builder: S,
pub(crate) table: RawTable<(K, V)>,
}
impl<K, V> MashMap<K, V, RandomState> {
pub fn new() -> Self {
Self::with_hasher(RandomState::default())
}
pub fn with_capacity(capacity: usize) -> Self {
Self::with_capacity_and_hasher(capacity, RandomState::default())
}
}
impl<K, V> Default for MashMap<K, V, RandomState> {
fn default() -> Self {
Self::new()
}
}
impl<K, V, S> MashMap<K, V, S> {
pub const fn with_hasher(hash_builder: S) -> Self {
Self {
hash_builder,
table: RawTable::new(),
}
}
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
Self {
hash_builder,
table: RawTable::with_capacity(capacity),
}
}
#[inline]
pub const fn hasher(&self) -> &S {
&self.hash_builder
}
#[inline]
pub fn capacity(&self) -> usize {
self.table.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.table.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.table.is_empty()
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &(K, V)> {
unsafe { self.table.iter().map(|bucket| bucket.as_ref()) }
}
#[inline]
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut (K, V)> {
unsafe { self.table.iter().map(|bucket| bucket.as_mut()) }
}
#[inline]
pub fn drain(&mut self) -> impl Iterator<Item = (K, V)> + '_ {
self.table.drain()
}
#[inline]
pub fn clear(&mut self) {
self.table.clear();
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&K, &mut V) -> bool,
{
unsafe {
self.table
.iter()
.filter(|bucket| {
let &mut (ref key, ref mut value) = bucket.as_mut();
!f(key, value)
})
.for_each(|bucket| {
self.table.erase(bucket);
})
}
}
}
impl<K, V, S> MashMap<K, V, S>
where
K: Eq + Hash,
S: BuildHasher,
{
#[inline]
pub fn iter_group_by_key(&self) -> IterGroupByKey<'_, K, V, S> {
IterGroupByKey::new(self)
}
#[inline]
pub fn iter_mut_group_by_key(&mut self) -> IterMutGroupByKey<'_, K, V, S> {
IterMutGroupByKey::new(self)
}
#[inline]
pub fn insert(&mut self, key: K, value: V) {
let hash = make_hash(&self.hash_builder, &key);
self.table
.insert(hash, (key, value), make_hasher(&self.hash_builder));
}
#[inline]
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
self.table.find(hash, equivalent_key(key)).is_some()
}
#[inline]
pub fn get_one<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
self.table
.find(hash, equivalent_key(key))
.map(|bucket| unsafe { bucket.as_ref().1.borrow() })
}
#[inline]
pub fn get_one_mut<Q>(&self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
self.table
.find(hash, equivalent_key(key))
.map(|bucket| unsafe { &mut bucket.as_mut().1 })
}
#[inline]
pub fn remove_one<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
self.table
.remove_entry(hash, equivalent_key(key))
.map(|(_, value)| value)
}
#[inline]
pub(crate) fn get_iter_buckets<'a, Q>(
&'a self,
key: &'a Q,
) -> impl Iterator<Item = Bucket<(K, V)>> + 'a
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
unsafe { self.table.iter_hash(hash).filter(bucket_with_key(key)) }
}
#[inline]
pub fn get_iter<'a, Q>(&'a self, key: &'a Q) -> impl Iterator<Item = &'a V> + 'a
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.get_iter_buckets(key)
.map(|bucket| unsafe { &bucket.as_ref().1 })
}
#[inline]
pub fn get_mut_iter<'a, Q>(&'a mut self, key: &'a Q) -> impl Iterator<Item = &'a mut V> + 'a
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.get_iter_buckets(key)
.map(|bucket| unsafe { &mut bucket.as_mut().1 })
}
pub fn drain_key<'a, Q>(&'a mut self, key: &'a Q) -> impl Iterator<Item = V> + 'a
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
ExhaustIter::new(unsafe {
self.table
.iter_hash(hash)
.filter(bucket_with_key(key))
.map(|bucket| (self.table.remove(bucket).0).1)
})
}
pub fn drain_key_if<'a, Q>(
&'a mut self,
key: &'a Q,
mut predicate: impl FnMut(&'a V) -> bool + 'a,
) -> impl Iterator<Item = V> + 'a
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
ExhaustIter::new(unsafe {
self.table
.iter_hash(hash)
.filter(bucket_with_key(key))
.filter_map(move |bucket| {
predicate(&bucket.as_ref().1).then(|| (self.table.remove(bucket).0).1)
})
})
}
pub fn remove_key_if<'a, Q>(
&'a mut self,
key: &'a Q,
mut predicate: impl FnMut(&'a V) -> bool + 'a,
) -> Option<V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
unsafe {
self.table
.iter_hash(hash)
.filter(bucket_with_key(key))
.find_map(move |bucket| {
predicate(&bucket.as_ref().1).then(|| (self.table.remove(bucket).0).1)
})
}
}
pub fn remove_all<Q>(&mut self, key: &Q)
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = make_hash(&self.hash_builder, key);
unsafe {
self.table
.iter_hash(hash)
.filter(bucket_with_key(key))
.for_each(|bucket| {
self.table.remove(bucket);
})
}
}
#[inline]
pub fn reserve(&mut self, additional: usize) {
self.table
.reserve(additional, make_hasher(&self.hash_builder));
}
#[inline]
pub fn try_reserve(&mut self, additional: usize) -> Result<(), TryReserveError> {
self.table
.try_reserve(additional, make_hasher(&self.hash_builder))
}
#[inline]
pub fn shrink_to_fit(&mut self) {
self.table.shrink_to(0, make_hasher(&self.hash_builder));
}
#[inline]
pub fn shrink_to(&mut self, min_capacity: usize) {
self.table
.shrink_to(min_capacity, make_hasher(&self.hash_builder));
}
}
impl<K, V, S> Extend<(K, V)> for MashMap<K, V, S>
where
K: Eq + Hash,
S: BuildHasher,
{
fn extend<I: IntoIterator<Item = (K, V)>>(&mut self, iter: I) {
let iter = iter.into_iter();
self.reserve(iter.size_hint().0);
iter.for_each(move |(k, v)| {
self.insert(k, v);
});
}
}
impl<K, V, S> MashMap<K, V, S>
where
K: Eq + Hash + Copy,
S: BuildHasher,
{
pub fn insert_iter<I: Iterator<Item = V>>(&mut self, key: K, value_iter: I) {
let hash = make_hash(&self.hash_builder, &key);
self.reserve(value_iter.size_hint().0);
value_iter.for_each(|value| {
self.table
.insert(hash, (key, value), make_hasher(&self.hash_builder));
})
}
}
pub struct IterGroupByKey<'a, K, V, S> {
pub(crate) map: &'a MashMap<K, V, S>,
pub(crate) seen: Vec<bool>,
pub(crate) main_iter: RawIter<(K, V)>,
pub(crate) probe_iter: MaybeUninit<RawIterHash<(K, V)>>,
pub(crate) curr_key: MaybeUninit<&'a K>,
pub(crate) next_bucket: Option<Bucket<(K, V)>>,
}
impl<'a, K, V, S> IterGroupByKey<'a, K, V, S> {
pub fn new(map: &'a MashMap<K, V, S>) -> Self {
Self {
map,
seen: vec![false; map.table.buckets()],
main_iter: unsafe { map.table.iter() },
probe_iter: MaybeUninit::uninit(),
curr_key: MaybeUninit::uninit(),
next_bucket: None,
}
}
}
impl<'a, K, V, S> Iterator for IterGroupByKey<'a, K, V, S>
where
K: Eq + Hash,
S: BuildHasher,
{
type Item = &'a (K, V);
fn next(&mut self) -> Option<Self::Item> {
while self.next_bucket.is_none() {
let mut bucket = self.main_iter.next()?;
let mut index = unsafe { self.map.table.bucket_index(&bucket) };
while self.seen[index] {
bucket = self.main_iter.next()?;
index = unsafe { self.map.table.bucket_index(&bucket) };
}
let key = unsafe { &bucket.as_ref().0 };
let hash = make_hash(self.map.hasher(), key);
self.curr_key = MaybeUninit::new(key);
self.probe_iter = MaybeUninit::new(unsafe { self.map.table.iter_hash(hash) });
self.next_bucket = unsafe {
self.probe_iter
.assume_init_mut()
.find(|bucket| likely(bucket.as_ref().0.borrow() == key))
};
}
let mut next_bucket = unsafe {
self.probe_iter.assume_init_mut().find(|bucket| {
likely(bucket.as_ref().0.borrow() == self.curr_key.assume_init_read())
})
};
swap(&mut next_bucket, &mut self.next_bucket);
let bucket = unsafe { next_bucket.as_mut().unwrap_unchecked() };
let index = unsafe { self.map.table.bucket_index(bucket) };
self.seen[index] = true;
Some(unsafe { bucket.as_ref() })
}
}
pub struct IterMutGroupByKey<'a, K, V, S> {
pub(crate) map: &'a MashMap<K, V, S>,
pub(crate) seen: Vec<bool>,
pub(crate) main_iter: RawIter<(K, V)>,
pub(crate) probe_iter: MaybeUninit<RawIterHash<(K, V)>>,
pub(crate) curr_key: MaybeUninit<&'a K>,
pub(crate) next_bucket: Option<Bucket<(K, V)>>,
}
impl<'a, K, V, S> IterMutGroupByKey<'a, K, V, S> {
pub fn new(map: &'a mut MashMap<K, V, S>) -> Self {
Self {
map,
seen: vec![false; map.table.buckets()],
main_iter: unsafe { map.table.iter() },
probe_iter: MaybeUninit::uninit(),
curr_key: MaybeUninit::uninit(),
next_bucket: None,
}
}
}
impl<'a, K, V, S> Iterator for IterMutGroupByKey<'a, K, V, S>
where
K: Eq + Hash,
S: BuildHasher,
{
type Item = &'a mut (K, V);
fn next(&mut self) -> Option<Self::Item> {
while self.next_bucket.is_none() {
let mut bucket = self.main_iter.next()?;
let mut index = unsafe { self.map.table.bucket_index(&bucket) };
while self.seen[index] {
bucket = self.main_iter.next()?;
index = unsafe { self.map.table.bucket_index(&bucket) };
}
let key = unsafe { &bucket.as_ref().0 };
let hash = make_hash(self.map.hasher(), key);
self.curr_key = MaybeUninit::new(key);
self.probe_iter = MaybeUninit::new(unsafe { self.map.table.iter_hash(hash) });
self.next_bucket = unsafe {
self.probe_iter
.assume_init_mut()
.find(|bucket| likely(bucket.as_ref().0.borrow() == key))
};
}
let mut next_bucket = unsafe {
self.probe_iter.assume_init_mut().find(|bucket| {
likely(bucket.as_ref().0.borrow() == self.curr_key.assume_init_read())
})
};
swap(&mut next_bucket, &mut self.next_bucket);
let bucket = unsafe { next_bucket.as_mut().unwrap_unchecked() };
let index = unsafe { self.map.table.bucket_index(bucket) };
self.seen[index] = true;
Some(unsafe { bucket.as_mut() })
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::collections::HashSet;
#[test]
fn test_map() {
let mut map = MashMap::<usize, usize>::new();
map.insert(1, 10);
map.insert(1, 11);
map.insert(1, 12);
map.insert(2, 20);
map.insert(2, 21);
for val in map.get_mut_iter(&1) {
*val += 1;
}
let mut values_1: Vec<_> = map.get_iter(&1).copied().collect();
let mut values_2: Vec<_> = map.get_iter(&2).copied().collect();
values_1.sort_unstable();
values_2.sort_unstable();
assert_eq!(values_1, vec![11, 12, 13]);
assert_eq!(values_2, vec![20, 21]);
}
#[test]
fn test_drain_key_if() {
let mut map = MashMap::<usize, usize>::new();
map.insert(1, 10);
map.insert(1, 11);
map.insert(1, 12);
map.insert(2, 20);
map.insert(2, 21);
let v: Vec<usize> = map.drain_key_if(&1, |v| *v == 10).collect();
assert_eq!(v, vec![10]);
let mut values_1: Vec<_> = map.get_iter(&1).copied().collect();
let mut values_2: Vec<_> = map.get_iter(&2).copied().collect();
values_1.sort_unstable();
values_2.sort_unstable();
assert_eq!(values_1, vec![11, 12]);
assert_eq!(values_2, vec![20, 21]);
}
#[test]
fn test_drain_key_if_dropped() {
let mut map = MashMap::<usize, usize>::new();
map.insert(1, 10);
map.insert(1, 11);
map.insert(1, 12);
map.insert(2, 20);
map.insert(2, 21);
let _ = map.drain_key_if(&1, |v| *v == 10);
let mut values_1: Vec<_> = map.get_iter(&1).copied().collect();
let mut values_2: Vec<_> = map.get_iter(&2).copied().collect();
values_1.sort_unstable();
values_2.sort_unstable();
assert_eq!(values_1, vec![11, 12]);
assert_eq!(values_2, vec![20, 21]);
}
#[test]
fn test_remove_key_if() {
let mut map = MashMap::<usize, usize>::new();
map.insert(1, 10);
map.insert(1, 11);
map.insert(1, 12);
map.insert(2, 20);
map.insert(2, 21);
assert!(matches!(
map.remove_key_if(&1, |v| *v % 2 == 0),
Some(10) | Some(12)
));
assert!(matches!(
map.remove_key_if(&1, |v| *v % 2 == 0),
Some(10) | Some(12)
));
assert_eq!(map.remove_key_if(&1, |v| *v % 2 == 0), None);
let mut values_1: Vec<_> = map.get_iter(&1).copied().collect();
let mut values_2: Vec<_> = map.get_iter(&2).copied().collect();
values_1.sort_unstable();
values_2.sort_unstable();
assert_eq!(values_1, vec![11]);
assert_eq!(values_2, vec![20, 21]);
}
#[test]
fn test_iter_group_by_key() {
const N: usize = 100_000;
let mut rng = StdRng::seed_from_u64(42);
let mut map = MashMap::<u16, ()>::new();
let mut inserted = HashSet::<u16>::new();
for _ in 0..N {
let x = rng.random::<u16>();
map.insert(x, ());
inserted.insert(x);
}
let mut seen = HashSet::<u16>::new();
for (x, _) in map.iter_group_by_key().dedup() {
assert!(!seen.contains(x));
seen.insert(*x);
}
assert_eq!(inserted, seen);
}
#[test]
fn test_iter_mut_group_by_key() {
const N: usize = 100_000;
let mut rng = StdRng::seed_from_u64(42);
let mut map = MashMap::<u16, ()>::new();
let mut inserted = HashSet::<u16>::new();
for _ in 0..N {
let x = rng.random::<u16>();
map.insert(x, ());
inserted.insert(x);
}
let mut seen = HashSet::<u16>::new();
for (x, _) in map.iter_mut_group_by_key().dedup() {
assert!(!seen.contains(x));
seen.insert(*x);
}
assert_eq!(inserted, seen);
}
}