use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::num::NonZeroUsize;
use std::sync::Mutex;
pub(crate) const DEFAULT_SHARDS: usize = 16;
pub(crate) const MIN_PER_SHARD: usize = 16;
pub(crate) fn shard_count(capacity: NonZeroUsize) -> usize {
let cap = capacity.get();
if cap < MIN_PER_SHARD.saturating_mul(2) {
return 1;
}
let max_by_cap = cap / MIN_PER_SHARD;
let mut shards = 1usize;
while shards.saturating_mul(2) <= max_by_cap && shards.saturating_mul(2) <= DEFAULT_SHARDS {
shards = shards.saturating_mul(2);
}
shards
}
pub(crate) fn per_shard_capacity(capacity: NonZeroUsize, num_shards: usize) -> NonZeroUsize {
let per = (capacity.get() / num_shards).max(1);
NonZeroUsize::new(per).unwrap_or(NonZeroUsize::MIN)
}
pub(crate) struct Sharded<T> {
shards: Box<[Mutex<T>]>,
shard_mask: usize,
}
impl<T> Sharded<T> {
pub(crate) fn from_factory<F>(num_shards: usize, mut factory: F) -> Self
where
F: FnMut(usize) -> T,
{
debug_assert!(num_shards.is_power_of_two() && num_shards >= 1);
let shards: Box<[Mutex<T>]> = (0..num_shards)
.map(|i| Mutex::new(factory(i)))
.collect::<Vec<_>>()
.into_boxed_slice();
let shard_mask = num_shards - 1;
Self { shards, shard_mask }
}
pub(crate) fn shard_for<K: Hash + ?Sized>(&self, key: &K) -> &Mutex<T> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let idx = (hasher.finish() as usize) & self.shard_mask;
&self.shards[idx]
}
pub(crate) fn iter(&self) -> impl Iterator<Item = &Mutex<T>> {
self.shards.iter()
}
}