use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use thiserror::Error;
const TOKEN_SCALE: u64 = 1_000_000;
#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
#[error("initial available tokens cannot exceed max tokens")]
AvailableTokensTooHigh,
#[error("max tokens must be at least 1")]
MaxTokensTooLow,
}
#[must_use]
pub struct Ratelimiter {
rate: AtomicU64,
max_tokens: AtomicU64,
tokens: AtomicU64,
dropped: AtomicU64,
last_refill_ns: AtomicU64,
start: Instant,
}
impl Ratelimiter {
pub fn new(rate: u64) -> Self {
Self {
rate: AtomicU64::new(rate),
max_tokens: AtomicU64::new(if rate == 0 { u64::MAX } else { rate }),
tokens: AtomicU64::new(0),
dropped: AtomicU64::new(0),
last_refill_ns: AtomicU64::new(0),
start: Instant::now(),
}
}
pub fn builder(rate: u64) -> Builder {
Builder::new(rate)
}
pub fn rate(&self) -> u64 {
self.rate.load(Ordering::Relaxed)
}
pub fn set_rate(&self, rate: u64) {
if rate == 0 {
self.max_tokens.store(u64::MAX, Ordering::Release);
} else if self.max_tokens.load(Ordering::Acquire) == u64::MAX {
self.max_tokens.store(rate, Ordering::Release);
}
self.rate.store(rate, Ordering::Release);
}
pub fn max_tokens(&self) -> u64 {
self.max_tokens.load(Ordering::Relaxed)
}
pub fn set_max_tokens(&self, tokens: u64) {
self.max_tokens.store(tokens, Ordering::Release);
let max_scaled = tokens.saturating_mul(TOKEN_SCALE);
loop {
let current = self.tokens.load(Ordering::Acquire);
if current <= max_scaled {
break;
}
if self
.tokens
.compare_exchange(current, max_scaled, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
std::hint::spin_loop();
}
}
pub fn available(&self) -> u64 {
self.tokens.load(Ordering::Relaxed) / TOKEN_SCALE
}
pub fn dropped(&self) -> u64 {
self.dropped.load(Ordering::Relaxed) / TOKEN_SCALE
}
fn refill(&self) {
let rate = self.rate.load(Ordering::Relaxed);
if rate == 0 {
return;
}
let now_ns = self.start.elapsed().as_nanos() as u64;
let last_ns = self.last_refill_ns.load(Ordering::Relaxed);
let elapsed_ns = now_ns.saturating_sub(last_ns);
if elapsed_ns < 1_000 {
return;
}
let new_tokens = (rate as u128 * elapsed_ns as u128 / 1_000).min(u64::MAX as u128) as u64;
if new_tokens == 0 {
return;
}
if self
.last_refill_ns
.compare_exchange(last_ns, now_ns, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
return;
}
let max_scaled = self
.max_tokens
.load(Ordering::Acquire)
.saturating_mul(TOKEN_SCALE);
loop {
let current = self.tokens.load(Ordering::Acquire);
let new_total = current.saturating_add(new_tokens).min(max_scaled);
if new_total <= current {
self.dropped.fetch_add(new_tokens, Ordering::Relaxed);
break;
}
if self
.tokens
.compare_exchange_weak(current, new_total, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let added = new_total - current;
if added < new_tokens {
self.dropped
.fetch_add(new_tokens - added, Ordering::Relaxed);
}
break;
}
std::hint::spin_loop();
}
}
pub fn try_wait(&self) -> Result<(), std::time::Duration> {
let rate = self.rate.load(Ordering::Relaxed);
if rate == 0 {
return Ok(());
}
self.refill();
let cost = TOKEN_SCALE;
loop {
let current = self.tokens.load(Ordering::Acquire);
if current < cost {
let deficit = cost - current;
let wait_ns = (deficit as u128 * 1_000 / rate as u128).max(1) as u64;
return Err(std::time::Duration::from_nanos(wait_ns));
}
if self
.tokens
.compare_exchange_weak(current, current - cost, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return Ok(());
}
std::hint::spin_loop();
}
}
}
const _: () = {
#[allow(dead_code)]
fn assert_send_sync<T: Send + Sync>() {}
fn _check() {
assert_send_sync::<Ratelimiter>();
}
};
impl std::fmt::Debug for Ratelimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Ratelimiter")
.field("rate", &self.rate.load(Ordering::Relaxed))
.field("max_tokens", &self.max_tokens.load(Ordering::Relaxed))
.field("available", &self.available())
.finish()
}
}
#[derive(Debug, Clone, Copy)]
#[must_use = "call .build() to construct the Ratelimiter"]
pub struct Builder {
rate: u64,
max_tokens: Option<u64>,
initial_available: u64,
}
impl Builder {
fn new(rate: u64) -> Self {
Self {
rate,
max_tokens: None,
initial_available: 0,
}
}
pub fn max_tokens(mut self, tokens: u64) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn initial_available(mut self, tokens: u64) -> Self {
self.initial_available = tokens;
self
}
pub fn build(self) -> Result<Ratelimiter, Error> {
let max_tokens =
self.max_tokens
.unwrap_or(if self.rate == 0 { u64::MAX } else { self.rate });
if max_tokens == 0 && self.rate != 0 {
return Err(Error::MaxTokensTooLow);
}
if self.initial_available > max_tokens {
return Err(Error::AvailableTokensTooHigh);
}
Ok(Ratelimiter {
rate: AtomicU64::new(self.rate),
max_tokens: AtomicU64::new(max_tokens),
tokens: AtomicU64::new(self.initial_available.saturating_mul(TOKEN_SCALE)),
dropped: AtomicU64::new(0),
last_refill_ns: AtomicU64::new(0),
start: Instant::now(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn unlimited() {
let rl = Ratelimiter::new(0);
for _ in 0..1000 {
assert!(rl.try_wait().is_ok());
}
}
#[test]
fn basic_rate() {
let rl = Ratelimiter::builder(1000)
.initial_available(10)
.build()
.unwrap();
for _ in 0..10 {
assert!(rl.try_wait().is_ok());
}
assert!(rl.try_wait().is_err());
}
#[test]
fn refill_over_time() {
let rl = Ratelimiter::new(1000);
std::thread::sleep(Duration::from_millis(100));
let mut count = 0;
while rl.try_wait().is_ok() {
count += 1;
}
assert!(count >= 50, "expected >= 50, got {count}");
assert!(count <= 200, "expected <= 200, got {count}");
}
#[test]
fn burst_capacity() {
let rl = Ratelimiter::builder(100)
.max_tokens(10)
.initial_available(10)
.build()
.unwrap();
for _ in 0..10 {
assert!(rl.try_wait().is_ok());
}
assert!(rl.try_wait().is_err());
}
#[test]
fn idle_does_not_exceed_capacity() {
let rl = Ratelimiter::builder(1000).max_tokens(10).build().unwrap();
std::thread::sleep(Duration::from_millis(100));
let mut count = 0;
while rl.try_wait().is_ok() {
count += 1;
}
assert!(count <= 10, "expected <= 10, got {count}");
}
#[test]
fn set_rate() {
let rl = Ratelimiter::new(100);
std::thread::sleep(Duration::from_millis(50));
rl.set_rate(1000);
std::thread::sleep(Duration::from_millis(50));
let mut count = 0;
while rl.try_wait().is_ok() {
count += 1;
}
assert!(count >= 30, "expected >= 30, got {count}");
}
#[test]
fn set_max_tokens_clamps_down() {
let rl = Ratelimiter::builder(1000)
.max_tokens(100)
.initial_available(100)
.build()
.unwrap();
assert_eq!(rl.available(), 100);
rl.set_max_tokens(10);
assert!(rl.available() <= 10);
}
#[test]
fn try_wait_returns_duration_hint() {
let rl = Ratelimiter::new(1000);
let err = rl.try_wait().unwrap_err();
assert_eq!(err, Duration::from_micros(1000));
}
#[test]
fn builder_error_available_too_high() {
let result = Ratelimiter::builder(100)
.max_tokens(10)
.initial_available(20)
.build();
assert!(matches!(result, Err(Error::AvailableTokensTooHigh)));
}
#[test]
fn dropped_tokens() {
let rl = Ratelimiter::builder(1000).max_tokens(10).build().unwrap();
std::thread::sleep(Duration::from_millis(100));
let _ = rl.try_wait();
assert!(rl.dropped() > 0, "expected dropped > 0");
}
#[test]
fn wait_loop() {
let rl = Ratelimiter::new(10_000);
let start = std::time::Instant::now();
let mut count = 0;
while start.elapsed() < Duration::from_millis(100) {
match rl.try_wait() {
Ok(()) => count += 1,
Err(wait) => std::thread::sleep(wait),
}
}
assert!(count >= 500, "expected >= 500, got {count}");
assert!(count <= 2000, "expected <= 2000, got {count}");
}
#[test]
fn multithread() {
use std::sync::Arc;
let rl = Arc::new(
Ratelimiter::builder(10_000)
.max_tokens(10_000)
.build()
.unwrap(),
);
let duration = Duration::from_millis(200);
let handles: Vec<_> = (0..4)
.map(|_| {
let rl = rl.clone();
std::thread::spawn(move || {
let start = std::time::Instant::now();
let mut count = 0u64;
while start.elapsed() < duration {
if rl.try_wait().is_ok() {
count += 1;
}
}
count
})
})
.collect();
let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total >= 1000, "expected >= 1000, got {total}");
assert!(total <= 4000, "expected <= 4000, got {total}");
}
#[test]
fn high_rate() {
let rl = Ratelimiter::new(1_000_000_000_000); std::thread::sleep(Duration::from_millis(10));
assert!(rl.try_wait().is_ok());
}
#[test]
fn try_wait_hint_at_high_rate() {
let rl = Ratelimiter::new(10_000_000_000); let err = rl.try_wait().unwrap_err();
assert!(err >= Duration::from_nanos(1));
}
#[test]
fn unlimited_then_set_rate() {
let rl = Ratelimiter::new(0);
assert!(rl.try_wait().is_ok());
rl.set_rate(1000);
std::thread::sleep(Duration::from_millis(50));
assert!(rl.try_wait().is_ok()); }
#[test]
fn set_rate_to_zero_and_back() {
let rl = Ratelimiter::new(1000);
rl.set_rate(0);
assert_eq!(rl.max_tokens(), u64::MAX);
for _ in 0..100 {
assert!(rl.try_wait().is_ok());
}
rl.set_rate(500);
assert_eq!(rl.max_tokens(), 500);
std::thread::sleep(std::time::Duration::from_millis(50));
assert!(rl.try_wait().is_ok());
}
#[test]
fn builder_error_max_tokens_zero() {
let result = Ratelimiter::builder(100).max_tokens(0).build();
assert!(matches!(result, Err(Error::MaxTokensTooLow)));
}
#[test]
fn max_tokens_zero() {
let rl = Ratelimiter::new(1000);
rl.set_max_tokens(0);
std::thread::sleep(Duration::from_millis(10));
assert!(rl.try_wait().is_err());
rl.set_max_tokens(1000);
std::thread::sleep(Duration::from_millis(10));
assert!(rl.try_wait().is_ok());
}
}