use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub tokens_per_second: f64,
pub bucket_size: u32,
pub initial_tokens: Option<u32>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
tokens_per_second: 1000.0, bucket_size: 100, initial_tokens: None, }
}
}
impl RateLimitConfig {
pub fn permissive() -> Self {
Self {
tokens_per_second: 10_000.0,
bucket_size: 1000,
initial_tokens: None,
}
}
pub fn strict() -> Self {
Self {
tokens_per_second: 100.0,
bucket_size: 10,
initial_tokens: None,
}
}
pub fn unlimited() -> Self {
Self {
tokens_per_second: 1_000_000_000.0, bucket_size: 1_000_000, initial_tokens: None,
}
}
}
pub struct RateLimiter {
tokens: AtomicU64,
last_refill: AtomicU64,
tokens_per_ns: f64,
max_tokens: u64,
start: Instant,
}
impl RateLimiter {
const SCALE: u64 = 1000;
pub fn new(config: RateLimitConfig) -> Self {
let max_tokens = (config.bucket_size as u64) * Self::SCALE;
let initial = config
.initial_tokens
.map(|t| (t as u64) * Self::SCALE)
.unwrap_or(max_tokens);
Self {
tokens: AtomicU64::new(initial),
last_refill: AtomicU64::new(0),
tokens_per_ns: config.tokens_per_second * (Self::SCALE as f64) / 1_000_000_000.0,
max_tokens,
start: Instant::now(),
}
}
#[inline]
pub fn try_acquire(&self) -> bool {
self.try_acquire_n(1)
}
#[inline]
pub fn try_acquire_n(&self, n: u32) -> bool {
let cost = (n as u64) * Self::SCALE;
self.refill();
let mut current = self.tokens.load(Ordering::Relaxed);
loop {
if current < cost {
return false;
}
match self.tokens.compare_exchange_weak(
current,
current - cost,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(actual) => current = actual,
}
}
}
pub fn available(&self) -> u32 {
self.refill();
(self.tokens.load(Ordering::Relaxed) / Self::SCALE) as u32
}
pub fn is_empty(&self) -> bool {
self.available() == 0
}
fn refill(&self) {
let now_ns = self.start.elapsed().as_nanos() as u64;
let last = self.last_refill.load(Ordering::Relaxed);
if now_ns <= last {
return;
}
let elapsed_ns = now_ns - last;
let new_tokens = (elapsed_ns as f64 * self.tokens_per_ns) as u64;
if new_tokens == 0 {
return;
}
if self
.last_refill
.compare_exchange(last, now_ns, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
let mut current = self.tokens.load(Ordering::Relaxed);
loop {
let new_total = (current + new_tokens).min(self.max_tokens);
if new_total == current {
break;
}
match self.tokens.compare_exchange_weak(
current,
new_total,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
}
pub async fn acquire(&self) -> Duration {
self.acquire_n(1).await
}
pub async fn acquire_n(&self, n: u32) -> Duration {
let start = Instant::now();
while !self.try_acquire_n(n) {
let deficit = (n as u64) * Self::SCALE;
let wait_ns = (deficit as f64 / self.tokens_per_ns) as u64;
let wait = Duration::from_nanos(wait_ns.max(1_000_000));
tokio::time::sleep(wait.min(Duration::from_millis(100))).await;
}
start.elapsed()
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::default())
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("available", &self.available())
.field("max_tokens", &(self.max_tokens / Self::SCALE))
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitResult {
Allowed,
Denied {
retry_after: Duration,
},
}
impl RateLimiter {
pub fn check(&self) -> RateLimitResult {
if self.try_acquire() {
RateLimitResult::Allowed
} else {
let deficit = Self::SCALE; let wait_ns = (deficit as f64 / self.tokens_per_ns) as u64;
RateLimitResult::Denied {
retry_after: Duration::from_nanos(wait_ns),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(RateLimitConfig {
tokens_per_second: 10.0,
bucket_size: 5,
initial_tokens: Some(5),
});
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
fn test_rate_limiter_unlimited() {
let limiter = RateLimiter::new(RateLimitConfig::unlimited());
for _ in 0..10000 {
assert!(limiter.try_acquire());
}
}
#[test]
fn test_rate_limiter_available() {
let limiter = RateLimiter::new(RateLimitConfig {
tokens_per_second: 100.0,
bucket_size: 10,
initial_tokens: Some(10),
});
assert_eq!(limiter.available(), 10);
limiter.try_acquire_n(3);
assert_eq!(limiter.available(), 7);
}
#[tokio::test]
async fn test_rate_limiter_refill() {
let limiter = RateLimiter::new(RateLimitConfig {
tokens_per_second: 1000.0, bucket_size: 10,
initial_tokens: Some(0),
});
assert!(!limiter.try_acquire());
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(limiter.try_acquire());
}
#[test]
fn test_rate_limit_result() {
let limiter = RateLimiter::new(RateLimitConfig {
tokens_per_second: 10.0,
bucket_size: 1,
initial_tokens: Some(1),
});
assert_eq!(limiter.check(), RateLimitResult::Allowed);
match limiter.check() {
RateLimitResult::Denied { retry_after } => {
assert!(retry_after > Duration::ZERO);
}
_ => panic!("Expected Denied"),
}
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn bucket_never_exceeds_max(
tokens_per_sec in 1.0f64..10000.0,
bucket_size in 1u32..1000,
initial in 0u32..1000,
) {
let initial = initial.min(bucket_size);
let limiter = RateLimiter::new(RateLimitConfig {
tokens_per_second: tokens_per_sec,
bucket_size,
initial_tokens: Some(initial),
});
prop_assert!(limiter.available() <= bucket_size);
}
#[test]
fn acquire_reduces_available(
bucket_size in 10u32..100,
acquire_count in 1u32..10,
) {
let limiter = RateLimiter::new(RateLimitConfig {
tokens_per_second: 1000.0,
bucket_size,
initial_tokens: Some(bucket_size),
});
let before = limiter.available();
if limiter.try_acquire_n(acquire_count) {
let after = limiter.available();
prop_assert_eq!(after, before - acquire_count);
}
}
#[test]
fn unlimited_always_allows(count in 1u32..10000) {
let limiter = RateLimiter::new(RateLimitConfig::unlimited());
for _ in 0..count {
prop_assert!(limiter.try_acquire());
}
}
#[test]
fn strict_less_than_permissive(_dummy: bool) {
let strict = RateLimitConfig::strict();
let permissive = RateLimitConfig::permissive();
prop_assert!(strict.tokens_per_second < permissive.tokens_per_second);
prop_assert!(strict.bucket_size < permissive.bucket_size);
}
}
}