use std::sync::{Arc, Mutex};
use crate::{Handler, MiddleWareHandler, Next, Request, Response, Result, SilentError, StatusCode};
use async_trait::async_trait;
use http::header::RETRY_AFTER;
struct BucketState {
tokens: f64,
last_refill: std::time::Instant,
}
#[derive(Clone)]
pub struct RateLimiter {
state: Arc<Mutex<BucketState>>,
rate: f64,
capacity: usize,
}
impl RateLimiter {
pub fn new(rate: f64, capacity: usize) -> Self {
Self {
state: Arc::new(Mutex::new(BucketState {
tokens: capacity as f64,
last_refill: std::time::Instant::now(),
})),
rate,
capacity,
}
}
pub fn per_second(rate: f64) -> Self {
Self::new(rate, rate.ceil() as usize)
}
fn try_acquire(&self) -> bool {
let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
let now = std::time::Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
if elapsed > 0.0 {
state.tokens = (state.tokens + elapsed * self.rate).min(self.capacity as f64);
state.last_refill = now;
}
if state.tokens >= 1.0 {
state.tokens -= 1.0;
true
} else {
false
}
}
fn retry_after_secs(&self) -> u64 {
let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
let deficit = 1.0 - state.tokens;
if deficit <= 0.0 {
return 0;
}
(deficit / self.rate).ceil() as u64
}
}
#[async_trait]
impl MiddleWareHandler for RateLimiter {
async fn handle(&self, req: Request, next: &Next) -> Result<Response> {
if self.try_acquire() {
next.call(req).await
} else {
let retry_after = self.retry_after_secs().max(1);
tracing::debug!(retry_after, "rate limit exceeded");
let mut err = SilentError::business_error(
StatusCode::TOO_MANY_REQUESTS,
"Too Many Requests".to_string(),
);
if let SilentError::BusinessError { .. } = &mut err {
}
let mut res = Response::empty();
res.set_status(StatusCode::TOO_MANY_REQUESTS);
res.headers_mut()
.insert(RETRY_AFTER, retry_after.to_string().parse().unwrap());
res.set_body(crate::core::res_body::full("Too Many Requests"));
Ok(res)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_new() {
let rl = RateLimiter::new(10.0, 20);
assert_eq!(rl.rate, 10.0);
assert_eq!(rl.capacity, 20);
}
#[test]
fn test_rate_limiter_per_second() {
let rl = RateLimiter::per_second(5.0);
assert_eq!(rl.rate, 5.0);
assert_eq!(rl.capacity, 5);
}
#[test]
fn test_rate_limiter_per_second_fractional() {
let rl = RateLimiter::per_second(1.5);
assert_eq!(rl.rate, 1.5);
assert_eq!(rl.capacity, 2);
}
#[test]
fn test_rate_limiter_clone() {
let rl1 = RateLimiter::new(10.0, 20);
let rl2 = rl1.clone();
assert_eq!(rl1.rate, rl2.rate);
assert_eq!(rl1.capacity, rl2.capacity);
assert!(Arc::ptr_eq(&rl1.state, &rl2.state));
}
#[test]
fn test_try_acquire_success() {
let rl = RateLimiter::new(10.0, 5);
assert!(rl.try_acquire());
assert!(rl.try_acquire());
assert!(rl.try_acquire());
assert!(rl.try_acquire());
assert!(rl.try_acquire());
}
#[test]
fn test_try_acquire_exhausted() {
let rl = RateLimiter::new(10.0, 2);
assert!(rl.try_acquire());
assert!(rl.try_acquire());
assert!(!rl.try_acquire());
}
#[test]
fn test_try_acquire_refill() {
let rl = RateLimiter::new(1000.0, 1);
assert!(rl.try_acquire());
assert!(!rl.try_acquire());
std::thread::sleep(std::time::Duration::from_millis(5));
assert!(rl.try_acquire());
}
#[test]
fn test_try_acquire_capacity_cap() {
let rl = RateLimiter::new(10000.0, 3);
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(rl.try_acquire());
assert!(rl.try_acquire());
assert!(rl.try_acquire());
assert!(!rl.try_acquire());
}
#[test]
fn test_retry_after_secs_with_tokens() {
let rl = RateLimiter::new(10.0, 5);
assert_eq!(rl.retry_after_secs(), 0);
}
#[test]
fn test_retry_after_secs_exhausted() {
let rl = RateLimiter::new(1.0, 1);
rl.try_acquire(); let retry = rl.retry_after_secs();
assert!(retry >= 1);
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_rate_limiter_allows_request() {
use crate::route::Route;
let rl = RateLimiter::new(100.0, 10);
let route = Route::new("/")
.hook(rl)
.get(|_req: Request| async { Ok("ok") });
let route = Route::new_root().append(route);
let req = Request::empty();
let res: Result<Response> = crate::Handler::call(&route, req).await;
assert!(res.is_ok());
let resp = res.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_rate_limiter_blocks_excess() {
use crate::route::Route;
let rl = RateLimiter::new(0.001, 1); let route = Route::new("/")
.hook(rl)
.get(|_req: Request| async { Ok("ok") });
let route = Route::new_root().append(route);
let req1 = Request::empty();
let res1: Result<Response> = crate::Handler::call(&route, req1).await;
assert!(res1.is_ok());
assert_eq!(res1.unwrap().status(), StatusCode::OK);
let req2 = Request::empty();
let res2: Result<Response> = crate::Handler::call(&route, req2).await;
assert!(res2.is_ok());
let resp2 = res2.unwrap();
assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
assert!(resp2.headers().contains_key(RETRY_AFTER));
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_rate_limiter_shared_across_clones() {
use crate::route::Route;
use std::sync::Arc;
let rl = RateLimiter::new(0.001, 2);
let route = Route::new("/")
.hook(rl)
.get(|_req: Request| async { Ok("ok") });
let route = Arc::new(Route::new_root().append(route));
let req1 = Request::empty();
let res1: Result<Response> = crate::Handler::call(&*route, req1).await;
assert_eq!(res1.unwrap().status(), StatusCode::OK);
let req2 = Request::empty();
let res2: Result<Response> = crate::Handler::call(&*route, req2).await;
assert_eq!(res2.unwrap().status(), StatusCode::OK);
let req3 = Request::empty();
let res3: Result<Response> = crate::Handler::call(&*route, req3).await;
assert_eq!(res3.unwrap().status(), StatusCode::TOO_MANY_REQUESTS);
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_rate_limiter_concurrent() {
use crate::route::Route;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let rl = RateLimiter::new(0.001, 5); let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let route = Route::new("/").hook(rl).get(move |_req: Request| {
let c = counter_clone.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok("ok")
}
});
let route = Arc::new(Route::new_root().append(route));
let mut tasks = Vec::new();
for _ in 0..10 {
let route = Arc::clone(&route);
tasks.push(tokio::spawn(async move {
let req = Request::empty();
let res: Result<Response> = crate::Handler::call(&*route, req).await;
res.unwrap().status()
}));
}
let mut ok_count = 0;
let mut limited_count = 0;
for task in tasks {
match task.await.unwrap() {
StatusCode::OK => ok_count += 1,
StatusCode::TOO_MANY_REQUESTS => limited_count += 1,
_ => panic!("unexpected status"),
}
}
assert_eq!(ok_count, 5);
assert_eq!(limited_count, 5);
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
}