#![forbid(unsafe_code)]
use hashbrown::raw::{RawIntoIter, RawTable};
use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard};
use std::borrow::Borrow;
use std::convert::TryInto;
use std::hash::{BuildHasher, Hash};
use std::{fmt, fmt::Debug};
use std::collections::hash_map::RandomState;
const DEFAULT_SHARD_COUNT: usize = 128;
#[inline]
fn equivalent_key<K, V>(k: &K) -> impl Fn(&(K, V)) -> bool + '_
where
K: Eq,
{
move |x| k.eq(x.0.borrow())
}
#[inline]
fn make_hash<K, S>(hash_builder: &S, val: &K) -> u64
where
K: Hash,
S: BuildHasher,
{
hash_builder.hash_one(val)
}
#[inline]
fn make_hasher<K, V, S>(hash_builder: &S) -> impl Fn(&(K, V)) -> u64 + '_
where
K: Hash,
S: BuildHasher,
{
move |val| make_hash::<K, S>(hash_builder, &val.0)
}
pub struct ConcurrentHashMap<K, V, S = RandomState, const N: usize = DEFAULT_SHARD_COUNT> {
hash_builder: S,
shards: [RwLock<Shard<K, V, S>>; N],
}
impl<K, V> ConcurrentHashMap<K, V, RandomState, DEFAULT_SHARD_COUNT> {
#[must_use]
pub fn new() -> ConcurrentHashMap<K, V, RandomState> {
Default::default()
}
#[inline]
#[must_use]
pub fn with_capacity(
capacity: usize,
) -> ConcurrentHashMap<K, V, RandomState, DEFAULT_SHARD_COUNT> {
ConcurrentHashMap::<_, _, _, DEFAULT_SHARD_COUNT>::with_capacity_and_hasher(
capacity,
RandomState::default(),
)
}
}
impl<K, V, S: BuildHasher, const N: usize> ConcurrentHashMap<K, V, S, N> {
#[inline]
pub fn with_hasher(hash_builder: S) -> ConcurrentHashMap<K, V, S, N>
where
S: Clone,
{
ConcurrentHashMap::<_, _, _, N>::with_capacity_and_hasher(0, hash_builder)
}
pub fn with_capacity_and_hasher(
capacity: usize,
hash_builder: S,
) -> ConcurrentHashMap<K, V, S, N>
where
S: Clone,
{
let capacity = (capacity + DEFAULT_SHARD_COUNT - 1) / DEFAULT_SHARD_COUNT;
let shards: Vec<RwLock<Shard<K, V, S>>> =
std::iter::repeat(|| RawTable::with_capacity(capacity))
.map(|f| f())
.take(DEFAULT_SHARD_COUNT)
.map(|inner| {
RwLock::new(Shard {
inner,
hash_builder: hash_builder.clone(),
})
})
.collect::<Vec<_>>();
match shards.try_into() {
Ok(shards) => ConcurrentHashMap {
hash_builder,
shards,
},
Err(_) => panic!("unable to build inner"),
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.shards.first().unwrap().read().inner.capacity()
}
#[inline]
pub fn get<'a>(&'a self, key: &'a K) -> Option<MappedRwLockReadGuard<'_, V>>
where
K: Hash + Eq,
{
let hash = make_hash::<K, _>(&self.hash_builder, key);
let i = hash as usize % N;
let shard = match self.shards.get(i) {
Some(lock) => lock.read(),
None => panic!("index out of bounds"),
};
RwLockReadGuard::try_map(shard, |shard| {
match shard.inner.get(hash, equivalent_key(key)) {
Some((_, v)) => Some(v),
_ => None,
}
})
.ok()
}
#[inline]
pub fn insert(&self, k: K, v: V) -> Option<V>
where
K: Hash + Eq,
{
let hash = make_hash::<K, _>(&self.hash_builder, &k);
let i = hash as usize % N;
let mut shard = match self.shards.get(i) {
Some(lock) => lock.write(),
None => panic!("index out of bounds"),
};
shard.insert(hash, k, v)
}
}
impl<K, V, S, const N: usize> Default for ConcurrentHashMap<K, V, S, N>
where
S: Default + BuildHasher + Clone,
{
#[inline]
fn default() -> ConcurrentHashMap<K, V, S, N> {
if N == 0 {
panic!("number of shards must be > 0")
}
ConcurrentHashMap::<K, V, S, N>::with_hasher(Default::default())
}
}
pub struct IntoIter<K: 'static, V: 'static> {
iter: RawIntoIter<(K, V)>,
shards: Vec<Shard<K, V>>,
}
pub struct IntoValues<K: 'static, V: 'static> {
iter: IntoIter<K, V>,
}
impl<K, V> Iterator for IntoIter<K, V> {
type Item = (K, V);
#[inline]
fn next(&mut self) -> Option<(K, V)> {
match self.iter.next() {
Some(item) => Some(item),
None => match self.shards.pop() {
Some(s) => {
self.iter = s.inner.into_iter();
self.next()
}
None => None,
},
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.iter.size_hint().0, None)
}
}
impl<K, V> Iterator for IntoValues<K, V> {
type Item = V;
#[inline]
fn next(&mut self) -> Option<V> {
self.iter.next().map(|(_, v)| v)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.iter.size_hint().0, None)
}
}
#[derive(Clone)]
pub(crate) struct Shard<K, V, S = RandomState> {
hash_builder: S,
inner: RawTable<(K, V)>,
}
impl<K, V> Debug for Shard<K, V>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("shard").finish()
}
}
#[allow(dead_code)]
impl<K, V, S> Shard<K, V, S>
where
S: BuildHasher,
{
#[inline]
pub(crate) fn len(&self) -> usize {
self.inner.len()
}
#[inline]
pub(crate) fn is_empty(&self) -> bool {
self.inner.len() == 0
}
#[inline]
pub(crate) fn remove(&mut self, hash: u64, key: K) -> Option<V>
where
K: Hash + Eq,
{
#[allow(clippy::manual_map)] match self.inner.remove_entry(hash, equivalent_key(&key)) {
Some((_, v)) => Some(v),
None => None,
}
}
#[inline]
pub(crate) fn get_mut(&mut self, hash: u64, key: &K) -> Option<&mut V>
where
K: Hash + Eq,
{
match self.inner.get_mut(hash, equivalent_key(key)) {
Some(&mut (_, ref mut v)) => Some(v),
None => None,
}
}
#[inline]
pub(crate) fn insert(&mut self, hash: u64, key: K, v: V) -> Option<V>
where
K: Hash + Eq,
{
if let Some((_, item)) = self.inner.get_mut(hash, equivalent_key(&key)) {
Some(std::mem::replace(item, v))
} else {
self.inner
.insert(hash, (key, v), make_hasher::<K, V, S>(&self.hash_builder));
None
}
}
#[inline]
pub(crate) fn get(&self, hash: u64, key: &K) -> Option<&V>
where
K: Hash + Eq,
{
match self.inner.get(hash, equivalent_key(key)) {
Some((_, v)) => Some(v),
None => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn test_insert_values() {
let map = ConcurrentHashMap::new();
{
map.insert("k", "v");
}
assert_eq!(*map.get(&"k").unwrap(), "v");
}
#[test]
fn test_other_deadlock() {
let map_1 = Arc::new(ConcurrentHashMap::<i32, String>::default());
let map_2 = map_1.clone();
for i in 0..1000 {
map_1.insert(i, "foobar".to_string());
}
let _writer = std::thread::spawn(move || loop {
println!("writer iteration");
for i in 0..1000 {
map_1.insert(i, "foobaz".to_string());
}
});
let _reader = std::thread::spawn(move || loop {
println!("reader iteration");
for i in 0..1000 {
let j = i32::min(i + 100, 1000);
let rng: Vec<i32> = (i..j).collect();
let _v: Vec<_> = rng.iter().map(|k| map_2.get(k)).collect();
}
});
std::thread::sleep(Duration::from_secs(10));
}
}