use std::{
fmt::Debug,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use async_trait::async_trait;
use crossbeam_utils::Backoff;
use tokio::time::timeout;
use crate::algorithms::{RateLimitAlgorithm, RequestSample};
type RequestCount = u64;
#[derive(Debug)]
pub struct Token {
start_time: Instant,
}
#[async_trait]
pub trait RateLimiter: Debug + Sync {
async fn acquire(&self) -> Token;
async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
}
#[derive(Debug)]
pub struct DefaultRateLimiter<T> {
algorithm: T,
tokens: Arc<AtomicU64>,
last_refill_nanos: Arc<AtomicU64>,
requests_per_second: Arc<AtomicU64>,
bucket_capacity: RequestCount,
refill_interval_nanos: Arc<AtomicU64>,
}
#[derive(Debug, Clone, Copy)]
pub struct RateLimiterState {
requests_per_second: RequestCount,
available_tokens: RequestCount,
bucket_capacity: RequestCount,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RequestOutcome {
Success,
Overload,
ClientError,
}
impl<T> DefaultRateLimiter<T>
where
T: RateLimitAlgorithm,
{
pub fn new(algorithm: T) -> Self {
let initial_rps = algorithm.requests_per_second();
let bucket_capacity = initial_rps;
assert!(initial_rps >= 1);
let now_nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
Self {
algorithm,
tokens: Arc::new(AtomicU64::new(bucket_capacity)),
last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
bucket_capacity,
refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
}
}
#[inline]
fn refill_tokens(&self) {
let current_tokens = self.tokens.load(Ordering::Relaxed);
if current_tokens >= self.bucket_capacity {
return; }
let now_nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
let elapsed_nanos = now_nanos.saturating_sub(last_refill);
let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
if elapsed_nanos >= refill_interval {
let tokens_to_add = elapsed_nanos / refill_interval;
if tokens_to_add > 0 {
let _ = self.last_refill_nanos.compare_exchange_weak(
last_refill,
now_nanos,
Ordering::Release,
Ordering::Relaxed,
);
self.tokens
.fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
let new_tokens = (current + tokens_to_add).min(self.bucket_capacity);
if new_tokens > current {
Some(new_tokens)
} else {
None
}
})
.ok();
}
}
}
pub fn state(&self) -> RateLimiterState {
self.refill_tokens();
RateLimiterState {
requests_per_second: self.algorithm.requests_per_second(),
available_tokens: self.tokens.load(Ordering::Acquire),
bucket_capacity: self.bucket_capacity,
}
}
}
#[async_trait]
impl<T> RateLimiter for DefaultRateLimiter<T>
where
T: RateLimitAlgorithm + Sync + Debug,
{
async fn acquire(&self) -> Token {
let backoff = Backoff::new();
loop {
if self.tokens
.fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
if current > 0 {
Some(current - 1)
} else {
None
}
}).is_ok()
{
return Token {
start_time: Instant::now(),
};
}
self.refill_tokens();
if self.tokens
.fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
if current > 0 {
Some(current - 1)
} else {
None
}
}).is_ok()
{
return Token {
start_time: Instant::now(),
};
}
if backoff.is_completed() {
tokio::task::yield_now().await;
backoff.reset();
} else {
backoff.spin();
}
}
}
async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
timeout(duration, self.acquire()).await.ok()
}
async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
let response_time = token.start_time.elapsed();
if let Some(outcome) = outcome {
let current_rps = self.requests_per_second.load(Ordering::Relaxed);
let sample = RequestSample::new(response_time, current_rps, outcome);
let new_rps = self.algorithm.update(sample).await;
self.requests_per_second.store(new_rps, Ordering::Relaxed);
if new_rps != current_rps && new_rps > 0 {
self.refill_interval_nanos
.store(1_000_000_000 / new_rps, Ordering::Relaxed);
}
}
}
}
impl RateLimiterState {
pub fn requests_per_second(&self) -> RequestCount {
self.requests_per_second
}
pub fn available_tokens(&self) -> RequestCount {
self.available_tokens
}
pub fn bucket_capacity(&self) -> RequestCount {
self.bucket_capacity
}
}
#[cfg(test)]
mod tests {
use crate::{
algorithms::Fixed,
limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
};
use std::time::Duration;
#[tokio::test]
async fn rate_limiter_allows_requests_within_limit() {
let limiter = DefaultRateLimiter::new(Fixed::new(10));
let token = limiter.acquire().await;
limiter.release(token, Some(RequestOutcome::Success)).await;
}
#[tokio::test]
async fn rate_limiter_waits_for_tokens() {
use std::sync::Arc;
let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
let token1 = limiter.acquire().await;
let limiter_clone = Arc::clone(&limiter);
let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
tokio::time::sleep(Duration::from_millis(10)).await;
limiter.release(token1, Some(RequestOutcome::Success)).await;
let token2 = acquire_task.await.unwrap();
limiter.release(token2, Some(RequestOutcome::Success)).await;
}
}