use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::num::NonZeroUsize;
use std::sync::Mutex;
pub(crate) const DEFAULT_SHARDS: usize = 16;
pub(crate) const MIN_PER_SHARD: usize = 16;
pub(crate) fn shard_count(capacity: NonZeroUsize) -> usize {
let cap = capacity.get();
if cap < MIN_PER_SHARD.saturating_mul(2) {
return 1;
}
let max_by_cap = cap / MIN_PER_SHARD;
let mut shards = 1usize;
while shards.saturating_mul(2) <= max_by_cap && shards.saturating_mul(2) <= DEFAULT_SHARDS {
shards = shards.saturating_mul(2);
}
shards
}
pub(crate) fn per_shard_capacity(capacity: NonZeroUsize, num_shards: usize) -> NonZeroUsize {
let per = (capacity.get() / num_shards).max(1);
NonZeroUsize::new(per).unwrap_or(NonZeroUsize::MIN)
}
pub(crate) struct Sharded<T> {
shards: Box<[Mutex<T>]>,
shard_mask: usize,
}
impl<T> Sharded<T> {
pub(crate) fn from_factory<F>(num_shards: usize, mut factory: F) -> Self
where
F: FnMut(usize) -> T,
{
debug_assert!(num_shards.is_power_of_two() && num_shards >= 1);
let shards: Box<[Mutex<T>]> = (0..num_shards)
.map(|i| Mutex::new(factory(i)))
.collect::<Vec<_>>()
.into_boxed_slice();
let shard_mask = num_shards - 1;
Self { shards, shard_mask }
}
pub(crate) fn shard_for<K: Hash + ?Sized>(&self, key: &K) -> &Mutex<T> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let idx = (hasher.finish() as usize) & self.shard_mask;
&self.shards[idx]
}
pub(crate) fn iter(&self) -> impl Iterator<Item = &Mutex<T>> {
self.shards.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn nz(n: usize) -> NonZeroUsize {
NonZeroUsize::new(n).unwrap_or(NonZeroUsize::MIN)
}
#[test]
fn shard_count_tiny_caches_use_one_shard() {
for cap in [1usize, 2, 3, 4, 8, 16, 24, 31] {
assert_eq!(
shard_count(nz(cap)),
1,
"capacity {cap} should give 1 shard (below MIN_PER_SHARD * 2 = 32)",
);
}
}
#[test]
fn shard_count_scales_with_capacity() {
assert_eq!(shard_count(nz(32)), 2);
assert_eq!(shard_count(nz(48)), 2);
assert_eq!(shard_count(nz(64)), 4);
assert_eq!(shard_count(nz(127)), 4);
assert_eq!(shard_count(nz(128)), 8);
assert_eq!(shard_count(nz(256)), 16);
assert_eq!(shard_count(nz(1_000_000)), 16);
}
#[test]
fn shard_count_is_always_power_of_two() {
for cap in [1usize, 16, 32, 33, 64, 100, 256, 1024, 65_536] {
assert!(
shard_count(nz(cap)).is_power_of_two(),
"shard_count({cap}) must be a power of two",
);
}
}
#[test]
fn per_shard_capacity_distributes_evenly_when_divisible() {
assert_eq!(per_shard_capacity(nz(64), 4).get(), 16);
assert_eq!(per_shard_capacity(nz(256), 16).get(), 16);
assert_eq!(per_shard_capacity(nz(1024), 16).get(), 64);
}
#[test]
fn per_shard_capacity_floors_when_not_divisible() {
assert_eq!(per_shard_capacity(nz(17), 16).get(), 1);
assert_eq!(per_shard_capacity(nz(100), 8).get(), 12);
}
#[test]
fn per_shard_capacity_never_returns_zero() {
assert_eq!(per_shard_capacity(nz(1), 16).get(), 1);
}
#[test]
fn from_factory_creates_requested_number_of_shards() {
let sharded: Sharded<usize> = Sharded::from_factory(4, |i| i * 10);
assert_eq!(sharded.iter().count(), 4);
let values: Vec<usize> = sharded
.iter()
.map(|m| match m.lock() {
Ok(g) => *g,
Err(p) => *p.into_inner(),
})
.collect();
assert_eq!(values, vec![0, 10, 20, 30]);
}
#[test]
fn shard_for_routes_deterministically() {
let sharded: Sharded<usize> = Sharded::from_factory(16, |_| 0);
let key = "hello";
let first = sharded.shard_for(key) as *const _;
for _ in 0..32 {
assert_eq!(sharded.shard_for(key) as *const _, first);
}
}
#[test]
fn shard_for_distributes_keys_across_shards() {
let sharded: Sharded<usize> = Sharded::from_factory(16, |_| 0);
let mut distinct = std::collections::HashSet::new();
for i in 0..1024u32 {
let _ = distinct.insert(sharded.shard_for(&i) as *const _ as usize);
}
assert!(
distinct.len() >= 8,
"expected at least 8 distinct shards across 1024 keys, hit {}",
distinct.len(),
);
}
}