use core::array::from_fn;
use core::hash::{Hash, Hasher};
#[cfg(feature = "std")]
use crate::StdClock;
use crate::{Clock, RateLimiter, Snapshot};
pub struct ShardedRateLimiter<C: Clock + Clone, const N: usize> {
shards: [RateLimiter<C>; N],
}
impl<C: Clock + Clone, const N: usize> ShardedRateLimiter<C, N> {
pub fn with_clock(capacity: u64, refill_per_sec: u64, clock: C) -> Self {
Self {
shards: from_fn(|_| RateLimiter::with_clock(capacity, refill_per_sec, clock.clone())),
}
}
#[inline]
pub const fn shard_count(&self) -> usize {
N
}
#[inline]
pub fn allow_by_hash(&self, hash: u64) -> bool {
self.shard_for_hash(hash).is_some_and(RateLimiter::allow)
}
#[inline]
pub fn allow_n_by_hash(&self, hash: u64, n: u64) -> bool {
self.shard_for_hash(hash)
.is_some_and(|limiter| limiter.allow_n(n))
}
#[inline]
pub fn remaining_by_hash(&self, hash: u64) -> u64 {
self.shard_for_hash(hash).map_or(0, RateLimiter::remaining)
}
#[inline]
pub fn snapshot_by_hash(&self, hash: u64) -> Option<Snapshot> {
self.shard_for_hash(hash).map(RateLimiter::snapshot)
}
#[inline]
pub fn shard(&self, index: usize) -> Option<&RateLimiter<C>> {
self.shards.get(index)
}
#[inline]
pub fn shard_for_hash(&self, hash: u64) -> Option<&RateLimiter<C>> {
self.shard(self.index_for_hash(hash)?)
}
#[inline]
fn index_for_hash(&self, hash: u64) -> Option<usize> {
if N == 0 {
return None;
}
let idx = if N.is_power_of_two() {
(hash as usize) & (N - 1)
} else {
(hash as usize) % N
};
Some(idx)
}
#[inline]
pub fn hash_key_with<H: Hasher + Default, K: Hash>(&self, key: &K) -> u64 {
let mut hasher = H::default();
key.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(feature = "std")]
impl<const N: usize> ShardedRateLimiter<StdClock, N> {
#[inline]
pub fn new(capacity: u64, refill_per_sec: u64) -> Self {
Self::with_clock(capacity, refill_per_sec, StdClock)
}
#[inline]
pub fn allow_by_key<K: Hash>(&self, key: &K) -> bool {
self.allow_by_hash(self.hash_key_with::<std::collections::hash_map::DefaultHasher, K>(key))
}
#[inline]
pub fn allow_n_by_key<K: Hash>(&self, key: &K, n: u64) -> bool {
self.allow_n_by_hash(
self.hash_key_with::<std::collections::hash_map::DefaultHasher, K>(key),
n,
)
}
#[inline]
pub fn remaining_by_key<K: Hash>(&self, key: &K) -> u64 {
self.remaining_by_hash(
self.hash_key_with::<std::collections::hash_map::DefaultHasher, K>(key),
)
}
#[inline]
pub fn snapshot_by_key<K: Hash>(&self, key: &K) -> Option<Snapshot> {
self.snapshot_by_hash(
self.hash_key_with::<std::collections::hash_map::DefaultHasher, K>(key),
)
}
}