use std::{
hash::{BuildHasher, RandomState},
sync::{Arc, OnceLock},
};
use crossbeam_utils::CachePadded;
use hashbrown::hash_table::Entry;
use crate::{
mapref::{MapRef, MapRefMut},
shard::Shard,
};
struct Inner<K, V, S = RandomState> {
shards: Box<[CachePadded<Shard<K, V>>]>,
hasher: S,
shift: usize,
}
impl<K, V, S> std::ops::Deref for Inner<K, V, S> {
type Target = Box<[CachePadded<Shard<K, V>>]>;
fn deref(&self) -> &Self::Target {
&self.shards
}
}
impl<K, V, S> std::ops::DerefMut for Inner<K, V, S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.shards
}
}
pub struct ShardMap<K, V, S = std::hash::RandomState> {
inner: Arc<Inner<K, V, S>>,
}
impl<K, V, H> Clone for ShardMap<K, V, H> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[inline(always)]
fn calculate_shard_count() -> usize {
(std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two()
}
#[inline(always)]
fn shard_count() -> usize {
static SHARD_COUNT: OnceLock<usize> = OnceLock::new();
*SHARD_COUNT.get_or_init(calculate_shard_count)
}
impl<K, V> ShardMap<K, V, RandomState>
where
K: Eq + std::hash::Hash + 'static,
V: 'static,
{
pub fn new() -> Self {
Self::with_shards(shard_count())
}
pub fn with_shards(shards: usize) -> Self {
Self::with_shards_and_hasher(shards, RandomState::new())
}
pub fn with_capacity(capacity: usize) -> Self {
Self::with_capacity_and_hasher(capacity, RandomState::new())
}
pub fn with_shards_and_capacity(shards: usize, cap: usize) -> Self {
Self::with_shards_and_capacity_and_hasher(shards, cap, RandomState::new())
}
}
fn ptr_size_bits() -> usize {
std::mem::size_of::<*const ()>() * 8
}
impl<K, V, S: BuildHasher> ShardMap<K, V, S>
where
K: Eq + std::hash::Hash + 'static,
V: 'static,
{
pub fn with_hasher(hasher: S) -> Self {
Self::with_shards_and_hasher(shard_count(), hasher)
}
pub fn with_capacity_and_hasher(cap: usize, hasher: S) -> Self {
Self::with_shards_and_capacity_and_hasher(shard_count(), cap, hasher)
}
pub fn with_shards_and_hasher(shards: usize, hasher: S) -> Self {
Self::with_shards_and_capacity_and_hasher(shards, 4, hasher)
}
pub fn with_shards_and_capacity_and_hasher(shards: usize, mut cap: usize, hasher: S) -> Self {
debug_assert!(shards > 1);
debug_assert!(shards.is_power_of_two());
let shift = ptr_size_bits() - (shards.trailing_zeros() as usize);
if cap != 0 {
cap = (cap + (shards - 1)) & !(shards - 1);
}
let shard_capacity = cap / shards;
let shards = std::iter::repeat(())
.take(shards)
.map(|_| CachePadded::new(Shard::with_capacity(shard_capacity)))
.collect();
Self {
inner: Arc::new(Inner {
shards,
shift,
hasher,
}),
}
}
#[inline]
fn shard_for_hash(&self, hash: usize) -> usize {
(hash << 7) >> self.inner.shift
}
#[inline]
fn shard(&self, key: &K) -> (&CachePadded<Shard<K, V>>, u64) {
let hash = self.inner.hasher.hash_one(key);
let shard_idx = self.shard_for_hash(hash as usize);
(unsafe { self.inner.shards.get_unchecked(shard_idx) }, hash)
}
pub async fn insert(&self, key: K, value: V) -> Option<V> {
let (shard, hash) = self.shard(&key);
let mut writer = shard.write().await;
let (old, slot) = match writer.entry(
hash,
|(k, _)| k == &key,
|(k, _)| self.inner.hasher.hash_one(k),
) {
Entry::Occupied(entry) => {
let ((_, old), slot) = entry.remove();
(Some(old), slot)
}
Entry::Vacant(slot) => (None, slot),
};
slot.insert((key, value));
old
}
pub async fn get<'a>(&'a self, key: &'a K) -> Option<MapRef<'a, K, V>> {
let (shard, hash) = self.shard(key);
let reader = shard.read().await;
if let Some((k, v)) = reader.find(hash, |(k, _)| k == key) {
let (k, v) = (k as *const K, v as *const V);
unsafe { Some(MapRef::new(reader, &*k, &*v)) }
} else {
None
}
}
pub async fn get_mut<'a>(&'a self, key: &'a K) -> Option<MapRefMut<'a, K, V>> {
let (shard, hash) = self.shard(key);
let mut writer = shard.write().await;
if let Some((k, v)) = writer.find_mut(hash, |(k, _)| k == key) {
let (k, v) = (k as *const K, v as *mut V);
unsafe { Some(MapRefMut::new(writer, &*k, &mut *v)) }
} else {
None
}
}
pub async fn contains_key(&self, key: &K) -> bool {
let (shard, hash) = self.shard(key);
let reader = shard.read().await;
reader.find(hash, |(k, _)| k == key).is_some()
}
pub async fn remove(&self, key: &K) -> Option<V> {
let (shard, hash) = self.shard(key);
match shard.write().await.find_entry(hash, |(k, _)| k == key) {
Ok(occupied) => {
let ((_, v), _) = occupied.remove();
Some(v)
}
_ => None,
}
}
pub async fn len(&self) -> usize {
let mut sum = 0;
for shard in self.inner.iter() {
sum += shard.read().await.len();
}
sum
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn clear(&self) {
for shard in self.inner.iter() {
shard.write().await.clear();
}
}
}