#![no_std]
#[cfg(any(feature = "std", test))]
extern crate std;
use core::fmt::{self, Debug, Formatter};
use core::sync::atomic::{AtomicU64, Ordering};
use core::time::Duration;
use thiserror::Error;
pub trait Clock {
fn elapsed(&self) -> Duration;
}
#[cfg(feature = "std")]
pub struct StdClock(std::time::Instant);
#[cfg(feature = "std")]
impl StdClock {
pub fn new() -> Self {
Self(std::time::Instant::now())
}
}
#[cfg(feature = "std")]
impl Default for StdClock {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "std")]
impl Clock for StdClock {
fn elapsed(&self) -> Duration {
self.0.elapsed()
}
}
const TOKEN_SCALE: u64 = 1_000_000;
#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Error {
#[error("initial available tokens cannot exceed max tokens")]
AvailableTokensTooHigh,
#[error("max tokens must be at least 1")]
MaxTokensTooLow,
#[error("period must be greater than zero")]
PeriodTooShort,
}
#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TryWaitError {
#[error("insufficient tokens; retry after {0:?}")]
Insufficient(Duration),
#[error("requested tokens exceed bucket capacity")]
ExceedsCapacity,
}
#[must_use]
#[cfg(feature = "std")]
pub struct Ratelimiter<C: Clock = StdClock> {
rate: AtomicU64,
period_ns: AtomicU64,
max_tokens: AtomicU64,
tokens: AtomicU64,
dropped: AtomicU64,
last_refill_ns: AtomicU64,
clock: C,
}
#[must_use]
#[cfg(not(feature = "std"))]
pub struct Ratelimiter<C: Clock> {
rate: AtomicU64,
period_ns: AtomicU64,
max_tokens: AtomicU64,
tokens: AtomicU64,
dropped: AtomicU64,
last_refill_ns: AtomicU64,
clock: C,
}
const DEFAULT_PERIOD_NS: u64 = 1_000_000_000;
#[inline]
fn wait_ns_for_deficit(deficit: u64, rate: u64, period_ns: u64) -> u64 {
let denom = (rate as u128).saturating_mul(TOKEN_SCALE as u128).max(1);
((deficit as u128).saturating_mul(period_ns as u128) / denom)
.max(1)
.min(u64::MAX as u128) as u64
}
#[cfg(feature = "std")]
impl Ratelimiter<StdClock> {
pub fn new(rate: u64) -> Self {
Self::with_clock(rate, StdClock::new())
}
pub fn builder(rate: u64) -> Builder<StdClock> {
Builder::with_clock(rate, StdClock::new())
}
}
impl<C> Ratelimiter<C>
where
C: Clock,
{
pub fn with_clock(rate: u64, clock: C) -> Self {
Self {
rate: AtomicU64::new(rate),
period_ns: AtomicU64::new(DEFAULT_PERIOD_NS),
max_tokens: AtomicU64::new(if rate == 0 { u64::MAX } else { rate }),
tokens: AtomicU64::new(0),
dropped: AtomicU64::new(0),
last_refill_ns: AtomicU64::new(0),
clock,
}
}
pub fn rate(&self) -> u64 {
self.rate.load(Ordering::Relaxed)
}
pub fn set_rate(&self, rate: u64) {
if rate == 0 {
self.max_tokens.store(u64::MAX, Ordering::Release);
} else if self.max_tokens.load(Ordering::Acquire) == u64::MAX {
self.max_tokens.store(rate, Ordering::Release);
}
self.rate.store(rate, Ordering::Release);
}
pub fn max_tokens(&self) -> u64 {
self.max_tokens.load(Ordering::Relaxed)
}
pub fn set_max_tokens(&self, tokens: u64) {
self.max_tokens.store(tokens, Ordering::Release);
let max_scaled = tokens.saturating_mul(TOKEN_SCALE);
loop {
let current = self.tokens.load(Ordering::Acquire);
if current <= max_scaled {
break;
}
if self
.tokens
.compare_exchange(current, max_scaled, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
core::hint::spin_loop();
}
}
pub fn period(&self) -> Duration {
Duration::from_nanos(self.period_ns.load(Ordering::Relaxed))
}
pub fn set_period(&self, period: Duration) {
let ns = period.as_nanos().min(u64::MAX as u128) as u64;
self.period_ns.store(ns.max(1), Ordering::Release);
}
pub fn set_rate_per(&self, rate: u64, period: Duration) {
let ns = period.as_nanos().min(u64::MAX as u128) as u64;
self.period_ns.store(ns.max(1), Ordering::Release);
self.set_rate(rate);
}
pub fn available(&self) -> u64 {
self.tokens.load(Ordering::Relaxed) / TOKEN_SCALE
}
pub fn dropped(&self) -> u64 {
self.dropped.load(Ordering::Relaxed) / TOKEN_SCALE
}
fn refill(&self) {
let rate = self.rate.load(Ordering::Relaxed);
if rate == 0 {
return;
}
let now_ns = self.clock.elapsed().as_nanos() as u64;
let last_ns = self.last_refill_ns.load(Ordering::Relaxed);
let elapsed_ns = now_ns.saturating_sub(last_ns);
if elapsed_ns < 1_000 {
return;
}
let period_ns = self.period_ns.load(Ordering::Relaxed).max(1);
let new_tokens = ((rate as u128)
.saturating_mul(elapsed_ns as u128)
.saturating_mul(TOKEN_SCALE as u128)
/ period_ns as u128)
.min(u64::MAX as u128) as u64;
if new_tokens == 0 {
return;
}
if self
.last_refill_ns
.compare_exchange(last_ns, now_ns, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
return;
}
let max_scaled = self
.max_tokens
.load(Ordering::Acquire)
.saturating_mul(TOKEN_SCALE);
loop {
let current = self.tokens.load(Ordering::Acquire);
let new_total = current.saturating_add(new_tokens).min(max_scaled);
if new_total <= current {
self.dropped.fetch_add(new_tokens, Ordering::Relaxed);
break;
}
if self
.tokens
.compare_exchange_weak(current, new_total, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let added = new_total - current;
if added < new_tokens {
self.dropped
.fetch_add(new_tokens - added, Ordering::Relaxed);
}
break;
}
core::hint::spin_loop();
}
}
pub fn try_wait(&self) -> Result<(), TryWaitError> {
self.try_wait_n(1)
}
pub fn try_wait_n(&self, n: u64) -> Result<(), TryWaitError> {
let rate = self.rate.load(Ordering::Relaxed);
if rate == 0 {
return Ok(());
}
if n == 0 {
return Ok(());
}
if n > self.max_tokens.load(Ordering::Relaxed) {
return Err(TryWaitError::ExceedsCapacity);
}
self.refill();
let period_ns = self.period_ns.load(Ordering::Relaxed).max(1);
let cost = n.saturating_mul(TOKEN_SCALE);
loop {
let current = self.tokens.load(Ordering::Acquire);
if current < cost {
let deficit = cost - current;
let wait_ns = wait_ns_for_deficit(deficit, rate, period_ns);
return Err(TryWaitError::Insufficient(Duration::from_nanos(wait_ns)));
}
if self
.tokens
.compare_exchange_weak(current, current - cost, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return Ok(());
}
core::hint::spin_loop();
}
}
}
const _: () = {
#[allow(dead_code)]
fn assert_send_sync<T: Send + Sync>() {}
fn _check<C: Clock + Send + Sync>() {
assert_send_sync::<Ratelimiter<C>>();
}
};
impl<C> Debug for Ratelimiter<C>
where
C: Clock,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Ratelimiter")
.field("rate", &self.rate.load(Ordering::Relaxed))
.field("period", &self.period())
.field("max_tokens", &self.max_tokens.load(Ordering::Relaxed))
.field("available", &self.available())
.finish()
}
}
#[derive(Debug, Clone, Copy)]
#[must_use = "call .build() to construct the Ratelimiter"]
#[cfg(feature = "std")]
pub struct Builder<C = StdClock> {
rate: u64,
period: Duration,
max_tokens: Option<u64>,
initial_available: u64,
clock: C,
}
#[derive(Debug, Clone, Copy)]
#[must_use = "call .build() to construct the Ratelimiter"]
#[cfg(not(feature = "std"))]
pub struct Builder<C> {
rate: u64,
period: Duration,
max_tokens: Option<u64>,
initial_available: u64,
clock: C,
}
impl<C> Builder<C> {
pub fn with_clock(rate: u64, clock: C) -> Self {
Self {
rate,
period: Duration::from_nanos(DEFAULT_PERIOD_NS),
max_tokens: None,
initial_available: 0,
clock,
}
}
pub fn period(mut self, period: Duration) -> Self {
self.period = period;
self
}
pub fn max_tokens(mut self, tokens: u64) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn initial_available(mut self, tokens: u64) -> Self {
self.initial_available = tokens;
self
}
pub fn build(self) -> Result<Ratelimiter<C>, Error>
where
C: Clock,
{
let period_ns = self.period.as_nanos();
if period_ns == 0 {
return Err(Error::PeriodTooShort);
}
let period_ns = period_ns.min(u64::MAX as u128) as u64;
let max_tokens =
self.max_tokens
.unwrap_or(if self.rate == 0 { u64::MAX } else { self.rate });
if max_tokens == 0 && self.rate != 0 {
return Err(Error::MaxTokensTooLow);
}
if self.initial_available > max_tokens {
return Err(Error::AvailableTokensTooHigh);
}
Ok(Ratelimiter {
rate: AtomicU64::new(self.rate),
period_ns: AtomicU64::new(period_ns),
max_tokens: AtomicU64::new(max_tokens),
tokens: AtomicU64::new(self.initial_available.saturating_mul(TOKEN_SCALE)),
dropped: AtomicU64::new(0),
last_refill_ns: AtomicU64::new(0),
clock: self.clock,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::sync::atomic::AtomicU64;
use core::time::Duration;
use std::sync::Arc;
#[derive(Clone, Debug)]
struct TestClock {
elapsed_ns: Arc<AtomicU64>,
}
impl TestClock {
fn new() -> Self {
Self {
elapsed_ns: Arc::new(AtomicU64::new(0)),
}
}
fn advance(&self, duration: Duration) {
let elapsed_ns = duration.as_nanos().min(u64::MAX as u128) as u64;
self.elapsed_ns.fetch_add(elapsed_ns, Ordering::Relaxed);
}
}
impl Clock for TestClock {
fn elapsed(&self) -> Duration {
Duration::from_nanos(self.elapsed_ns.load(Ordering::Relaxed))
}
}
#[test]
fn unlimited() {
let rl = Ratelimiter::with_clock(0, TestClock::new());
for _ in 0..1000 {
assert!(rl.try_wait().is_ok());
}
}
#[test]
fn basic_rate() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock)
.initial_available(10)
.build()
.unwrap();
for _ in 0..10 {
assert!(rl.try_wait().is_ok());
}
assert!(rl.try_wait().is_err());
}
#[test]
fn refill_over_time() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(1000, clock.clone());
clock.advance(Duration::from_millis(100));
let mut count = 0;
while rl.try_wait().is_ok() {
count += 1;
}
assert!(count >= 50, "expected >= 50, got {count}");
assert!(count <= 200, "expected <= 200, got {count}");
}
#[test]
fn burst_capacity() {
let clock = TestClock::new();
let rl = Builder::with_clock(100, clock)
.max_tokens(10)
.initial_available(10)
.build()
.unwrap();
for _ in 0..10 {
assert!(rl.try_wait().is_ok());
}
assert!(rl.try_wait().is_err());
}
#[test]
fn idle_does_not_exceed_capacity() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock.clone())
.max_tokens(10)
.build()
.unwrap();
clock.advance(Duration::from_millis(100));
let mut count = 0;
while rl.try_wait().is_ok() {
count += 1;
}
assert!(count <= 10, "expected <= 10, got {count}");
}
#[test]
fn set_rate() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(100, clock.clone());
clock.advance(Duration::from_millis(50));
rl.set_rate(1000);
clock.advance(Duration::from_millis(50));
let mut count = 0;
while rl.try_wait().is_ok() {
count += 1;
}
assert!(count >= 30, "expected >= 30, got {count}");
}
#[test]
fn set_max_tokens_clamps_down() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock)
.max_tokens(100)
.initial_available(100)
.build()
.unwrap();
assert_eq!(rl.available(), 100);
rl.set_max_tokens(10);
assert!(rl.available() <= 10);
}
#[test]
fn try_wait_returns_duration_hint() {
let rl = Ratelimiter::with_clock(1000, TestClock::new());
let err = rl.try_wait().unwrap_err();
assert_eq!(err, TryWaitError::Insufficient(Duration::from_micros(1000)));
}
#[test]
fn builder_error_available_too_high() {
let clock = TestClock::new();
let result = Builder::with_clock(100, clock)
.max_tokens(10)
.initial_available(20)
.build();
assert!(matches!(result, Err(Error::AvailableTokensTooHigh)));
}
#[test]
fn dropped_tokens() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock.clone())
.max_tokens(10)
.build()
.unwrap();
clock.advance(Duration::from_millis(100));
let _ = rl.try_wait();
assert!(rl.dropped() > 0, "expected dropped > 0");
}
#[test]
fn wait_loop() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(10_000, clock.clone());
let mut count = 0;
while clock.elapsed() < Duration::from_millis(100) {
match rl.try_wait() {
Ok(()) => count += 1,
Err(TryWaitError::Insufficient(wait)) => clock.advance(wait),
Err(e) => panic!("unexpected error: {e}"),
}
}
assert!(count >= 500, "expected >= 500, got {count}");
assert!(count <= 2000, "expected <= 2000, got {count}");
}
#[test]
fn high_rate() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(1_000_000_000_000, clock.clone()); clock.advance(Duration::from_millis(10));
assert!(rl.try_wait().is_ok());
}
#[test]
fn try_wait_hint_at_high_rate() {
let rl = Ratelimiter::with_clock(10_000_000_000, TestClock::new()); let err = rl.try_wait().unwrap_err();
let TryWaitError::Insufficient(wait) = err else {
panic!("expected Insufficient, got {err:?}");
};
assert!(wait >= Duration::from_nanos(1));
}
#[test]
fn unlimited_then_set_rate() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(0, clock.clone());
assert!(rl.try_wait().is_ok());
rl.set_rate(1000);
clock.advance(Duration::from_millis(50));
assert!(rl.try_wait().is_ok()); }
#[test]
fn set_rate_to_zero_and_back() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(1000, clock.clone());
rl.set_rate(0);
assert_eq!(rl.max_tokens(), u64::MAX);
for _ in 0..100 {
assert!(rl.try_wait().is_ok());
}
rl.set_rate(500);
assert_eq!(rl.max_tokens(), 500);
clock.advance(Duration::from_millis(50));
assert!(rl.try_wait().is_ok());
}
#[test]
fn builder_error_max_tokens_zero() {
let clock = TestClock::new();
let result = Builder::with_clock(100, clock).max_tokens(0).build();
assert!(matches!(result, Err(Error::MaxTokensTooLow)));
}
#[test]
fn max_tokens_zero() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(1000, clock.clone());
rl.set_max_tokens(0);
clock.advance(Duration::from_millis(10));
assert_eq!(rl.try_wait(), Err(TryWaitError::ExceedsCapacity));
rl.set_max_tokens(1000);
clock.advance(Duration::from_millis(10));
assert!(rl.try_wait().is_ok());
}
#[cfg(feature = "std")]
#[test]
fn std_convenience_apis() {
let rl = Ratelimiter::new(1000);
assert_eq!(rl.rate(), 1000);
let rl = Ratelimiter::builder(1000)
.max_tokens(100)
.initial_available(50)
.build()
.unwrap();
assert_eq!(rl.max_tokens(), 100);
assert_eq!(rl.available(), 50);
let clock = StdClock::new();
let rl = Ratelimiter::with_clock(1000, clock);
assert_eq!(rl.rate(), 1000);
}
#[cfg(feature = "std")]
#[test]
fn type_default_clock() {
let rl: Ratelimiter = Ratelimiter::new(1000);
assert_eq!(rl.rate(), 1000);
let b: Builder = Ratelimiter::builder(1000);
let rl = b.max_tokens(10).build().unwrap();
assert_eq!(rl.max_tokens(), 10);
}
#[cfg(feature = "std")]
#[test]
fn multithread() {
use std::sync::Arc;
use std::vec::Vec;
let rl = Arc::new(
Ratelimiter::builder(10_000)
.max_tokens(10_000)
.build()
.unwrap(),
);
let duration = Duration::from_millis(200);
let handles: Vec<_> = (0..4)
.map(|_| {
let rl = rl.clone();
std::thread::spawn(move || {
let start = std::time::Instant::now();
let mut count = 0u64;
while start.elapsed() < duration {
if rl.try_wait().is_ok() {
count += 1;
}
}
count
})
})
.collect();
let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total >= 1000, "expected >= 1000, got {total}");
assert!(total <= 4000, "expected <= 4000, got {total}");
}
#[test]
fn try_wait_n_basic() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock)
.initial_available(10)
.build()
.unwrap();
assert!(rl.try_wait_n(5).is_ok());
assert!(rl.try_wait_n(5).is_ok());
assert!(matches!(
rl.try_wait_n(1),
Err(TryWaitError::Insufficient(_))
));
}
#[test]
fn try_wait_n_zero_is_noop() {
let rl = Ratelimiter::with_clock(1000, TestClock::new());
assert!(rl.try_wait_n(0).is_ok());
assert_eq!(rl.available(), 0);
}
#[test]
fn try_wait_n_unlimited() {
let rl = Ratelimiter::with_clock(0, TestClock::new());
assert!(rl.try_wait_n(1_000_000).is_ok());
}
#[test]
fn try_wait_n_does_not_partially_consume() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock)
.initial_available(5)
.build()
.unwrap();
assert!(matches!(
rl.try_wait_n(10),
Err(TryWaitError::Insufficient(_))
));
for _ in 0..5 {
assert!(rl.try_wait().is_ok());
}
}
#[test]
fn try_wait_n_exceeds_capacity() {
let clock = TestClock::new();
let rl = Builder::with_clock(1000, clock)
.max_tokens(10)
.build()
.unwrap();
assert_eq!(rl.try_wait_n(100), Err(TryWaitError::ExceedsCapacity));
}
#[test]
fn sub_hz_refill() {
let clock = TestClock::new();
let rl = Builder::with_clock(1, clock.clone())
.period(Duration::from_secs(60))
.build()
.unwrap();
clock.advance(Duration::from_secs(30));
assert!(rl.try_wait().is_err());
clock.advance(Duration::from_secs(30));
assert!(rl.try_wait().is_ok());
assert!(rl.try_wait().is_err());
}
#[test]
fn sub_hz_wait_hint() {
let rl = Builder::with_clock(1, TestClock::new())
.period(Duration::from_secs(60))
.build()
.unwrap();
let err = rl.try_wait().unwrap_err();
assert_eq!(err, TryWaitError::Insufficient(Duration::from_secs(60)));
}
#[test]
fn set_period_changes_rate() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(1, clock.clone());
assert_eq!(rl.period(), Duration::from_secs(1));
rl.set_period(Duration::from_secs(10));
assert_eq!(rl.period(), Duration::from_secs(10));
clock.advance(Duration::from_secs(5));
assert!(rl.try_wait().is_err());
clock.advance(Duration::from_secs(5));
assert!(rl.try_wait().is_ok());
}
#[test]
fn builder_error_period_zero() {
let result = Builder::with_clock(1, TestClock::new())
.period(Duration::ZERO)
.build();
assert!(matches!(result, Err(Error::PeriodTooShort)));
}
#[test]
fn set_rate_per_updates_both() {
let clock = TestClock::new();
let rl = Ratelimiter::with_clock(1000, clock.clone());
rl.set_rate_per(5, Duration::from_secs(3600));
assert_eq!(rl.rate(), 5);
assert_eq!(rl.period(), Duration::from_secs(3600));
clock.advance(Duration::from_secs(3600));
for _ in 0..5 {
assert!(rl.try_wait().is_ok());
}
assert!(rl.try_wait().is_err());
}
}