use std::fmt;
use std::num::NonZeroUsize;
use std::time::Duration;
use clock_lib::{Clock, Monotonic, SystemClock};
use crate::algo::AlgoState;
use crate::algorithm::Algorithm;
use crate::decision::Decision;
use crate::eviction::Eviction;
use crate::key::Key;
use crate::quota::Quota;
use crate::store::Store;
pub(crate) fn default_shard_count() -> usize {
let parallelism = std::thread::available_parallelism()
.map(NonZeroUsize::get)
.unwrap_or(1);
(parallelism.saturating_mul(4))
.next_power_of_two()
.clamp(1, 4096)
}
pub trait Limiter {
fn check_n(&self, key: impl Into<Key>, n: u32) -> Decision;
fn check(&self, key: impl Into<Key>) -> Decision {
self.check_n(key, 1)
}
}
pub struct RateLimiter<C: Clock + Clone = SystemClock> {
algorithm: Algorithm,
quota: Quota,
clock: C,
epoch: Monotonic,
shards: usize,
eviction: Eviction,
store: Store<C>,
reads_clock: bool,
}
impl RateLimiter<SystemClock> {
#[must_use]
pub fn per_second(limit: u32) -> Self {
Self::with_quota(Quota::per_second(limit))
}
#[must_use]
pub fn per_minute(limit: u32) -> Self {
Self::with_quota(Quota::per_minute(limit))
}
#[must_use]
pub fn with_quota(quota: Quota) -> Self {
Self::build(
Algorithm::default(),
quota,
SystemClock::new(),
default_shard_count(),
Eviction::default(),
)
}
pub fn builder() -> crate::builder::Builder<SystemClock> {
crate::builder::Builder::new()
}
}
impl<C: Clock + Clone> RateLimiter<C> {
pub(crate) fn build(
algorithm: Algorithm,
quota: Quota,
clock: C,
shards: usize,
eviction: Eviction,
) -> Self {
let epoch = clock.now();
let store = Store::new(shards, eviction);
let reads_clock = algorithm != Algorithm::TokenBucket || eviction.idle_ttl().is_some();
Self {
algorithm,
quota,
clock,
epoch,
shards,
eviction,
store,
reads_clock,
}
}
#[must_use]
pub fn with_clock<C2: Clock + Clone>(self, clock: C2) -> RateLimiter<C2> {
RateLimiter::build(
self.algorithm,
self.quota,
clock,
self.shards,
self.eviction,
)
}
#[must_use]
pub fn with_algorithm(self, algorithm: Algorithm) -> Self {
Self::build(
algorithm,
self.quota,
self.clock,
self.shards,
self.eviction,
)
}
#[must_use]
pub fn with_shards(self, shards: usize) -> Self {
Self::build(
self.algorithm,
self.quota,
self.clock,
shards,
self.eviction,
)
}
#[must_use]
pub fn with_eviction(self, eviction: Eviction) -> Self {
Self::build(
self.algorithm,
self.quota,
self.clock,
self.shards,
eviction,
)
}
#[inline]
pub fn check(&self, key: impl Into<Key>) -> Decision {
self.check_inner(key.into(), 1)
}
#[inline]
pub fn check_n(&self, key: impl Into<Key>, n: u32) -> Decision {
self.check_inner(key.into(), n)
}
#[must_use]
pub fn quota(&self) -> Quota {
self.quota
}
#[must_use]
pub const fn algorithm(&self) -> Algorithm {
self.algorithm
}
#[must_use]
pub const fn eviction(&self) -> Eviction {
self.eviction
}
#[must_use]
pub fn shards(&self) -> usize {
self.store.shard_count()
}
#[must_use]
pub fn tracked_keys(&self) -> usize {
self.store.len()
}
#[inline]
fn check_inner(&self, key: Key, n: u32) -> Decision {
let now = if self.reads_clock {
self.now()
} else {
Duration::ZERO
};
self.store.check(key, n, now, || self.new_state(now))
}
fn new_state(&self, now: Duration) -> AlgoState<C> {
AlgoState::new(self.algorithm, &self.quota, self.clock.clone(), now)
}
fn now(&self) -> Duration {
self.clock.now().saturating_duration_since(self.epoch)
}
}
impl<C: Clock + Clone> Limiter for RateLimiter<C> {
fn check_n(&self, key: impl Into<Key>, n: u32) -> Decision {
self.check_inner(key.into(), n)
}
}
impl<C: Clock + Clone> fmt::Debug for RateLimiter<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RateLimiter")
.field("algorithm", &self.algorithm())
.field("quota", &self.quota)
.field("shards", &self.shards())
.field("eviction", &self.eviction)
.field("tracked_keys", &self.store.len())
.finish()
}
}
#[cfg(all(test, not(loom)))]
mod tests {
#![allow(clippy::unwrap_used)]
use std::sync::Arc;
use std::time::Duration;
use clock_lib::ManualClock;
use super::{Limiter, RateLimiter};
use crate::algorithm::Algorithm;
use crate::decision::Decision;
use crate::eviction::Eviction;
use crate::quota::Quota;
fn manual() -> (Arc<ManualClock>, RateLimiter<Arc<ManualClock>>) {
let clock = Arc::new(ManualClock::new());
let limiter = RateLimiter::per_second(5).with_clock(Arc::clone(&clock));
(clock, limiter)
}
#[test]
fn test_fresh_key_is_admitted() {
let limiter = RateLimiter::per_second(1);
assert_eq!(limiter.check("user:1"), Decision::Allow);
}
#[test]
fn test_quota_is_exhausted_then_refills_on_advance() {
let (clock, limiter) = manual();
for _ in 0..5 {
assert_eq!(limiter.check("k"), Decision::Allow);
}
let decision = limiter.check("k");
assert!(decision.is_deny());
assert!(decision.retry_after().is_some());
clock.advance(Duration::from_secs(1));
assert_eq!(limiter.check("k"), Decision::Allow);
}
#[test]
fn test_keys_are_independent() {
let (_clock, limiter) = manual();
for _ in 0..5 {
assert!(limiter.check("a").is_allow());
}
assert!(limiter.check("a").is_deny());
assert!(limiter.check("b").is_allow());
}
#[test]
fn test_check_n_takes_multiple_units_atomically() {
let (_clock, limiter) = manual();
assert_eq!(limiter.check_n("batch", 3), Decision::Allow);
assert_eq!(limiter.check_n("batch", 2), Decision::Allow);
assert!(limiter.check_n("batch", 1).is_deny());
}
#[test]
fn test_check_n_zero_always_admits() {
let (_clock, limiter) = manual();
for _ in 0..5 {
assert!(limiter.check("k").is_allow());
}
assert_eq!(limiter.check_n("k", 0), Decision::Allow);
}
#[test]
fn test_request_larger_than_quota_can_never_succeed() {
let (clock, limiter) = manual();
let decision = limiter.check_n("k", 6); assert_eq!(
decision,
Decision::Deny {
retry_after: Duration::MAX
}
);
clock.advance(Duration::from_secs(10));
assert_eq!(limiter.check_n("k", 6).retry_after(), Some(Duration::MAX));
}
#[test]
fn test_zero_limit_denies_everything() {
let limiter = RateLimiter::with_quota(Quota::per_second(0));
assert!(limiter.check("k").is_deny());
}
#[test]
fn test_partial_refill_admits_proportionally() {
let clock = Arc::new(ManualClock::new());
let limiter = RateLimiter::per_second(10).with_clock(Arc::clone(&clock));
for _ in 0..10 {
assert!(limiter.check("k").is_allow());
}
assert!(limiter.check("k").is_deny());
clock.advance(Duration::from_millis(300));
assert!(limiter.check("k").is_allow());
assert!(limiter.check("k").is_allow());
assert!(limiter.check("k").is_allow());
assert!(limiter.check("k").is_deny());
}
#[test]
fn test_tracked_keys_counts_distinct_keys() {
let (_clock, limiter) = manual();
assert_eq!(limiter.tracked_keys(), 0);
let _ = limiter.check("a");
let _ = limiter.check("b");
let _ = limiter.check("a");
assert_eq!(limiter.tracked_keys(), 2);
}
#[test]
fn test_introspection_reports_token_bucket() {
let limiter = RateLimiter::per_second(1);
assert_eq!(limiter.algorithm(), Algorithm::TokenBucket);
}
#[test]
fn test_with_shards_rounds_to_power_of_two() {
let limiter = RateLimiter::per_second(1).with_shards(5);
assert_eq!(limiter.shards(), 8);
}
#[test]
fn test_with_eviction_is_reported() {
let limiter = RateLimiter::per_second(1).with_eviction(Eviction::capacity(10));
assert_eq!(limiter.eviction().max_keys(), Some(10));
}
#[test]
fn test_unique_key_flood_is_bounded_by_capacity() {
let limiter = RateLimiter::per_second(1)
.with_shards(8)
.with_eviction(Eviction::capacity(100));
for k in 0..50_000u64 {
let _ = limiter.check(k);
}
let bound = 100usize.div_ceil(8).max(1) * 8;
assert!(
limiter.tracked_keys() <= bound,
"flood grew to {} keys, bound {bound}",
limiter.tracked_keys()
);
}
#[test]
fn test_limiter_trait_generic() {
fn count_admitted<L: Limiter>(limiter: &L, key: &str, attempts: u32) -> u32 {
(0..attempts)
.filter(|_| limiter.check(key).is_allow())
.count() as u32
}
let limiter = RateLimiter::per_second(3);
assert_eq!(count_admitted(&limiter, "k", 10), 3);
}
#[test]
fn test_debug_does_not_leak_keys() {
let (_clock, limiter) = manual();
let _ = limiter.check("secret-token-do-not-print");
let rendered = format!("{limiter:?}");
assert!(!rendered.contains("secret-token"));
assert!(rendered.contains("RateLimiter"));
assert!(rendered.contains("tracked_keys"));
}
}