use async_trait::async_trait;
use dashmap::DashMap;
use http::Extensions;
use rand::Rng;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next, Result as MiddlewareResult};
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Instant;
use tokio::time::sleep;
use crate::builder::RateLimitBuilder;
use crate::error::RateLimitError;
use crate::gcra::GcraState;
use crate::types::{Route, RouteKey, ThrottleBehavior};
#[derive(Debug, Clone)]
pub struct RateLimitMiddleware {
pub(crate) routes: Arc<Vec<Route>>,
pub(crate) state: Arc<DashMap<RouteKey, GcraState>>,
pub(crate) start_instant: Instant,
}
impl RateLimitMiddleware {
#[must_use]
pub fn builder() -> RateLimitBuilder {
RateLimitBuilder::new()
}
#[inline]
pub(crate) fn now_nanos(&self) -> u64 {
self.start_instant
.elapsed()
.as_nanos()
.min(u64::MAX as u128) as u64
}
pub fn cleanup(&self) {
let now = self.now_nanos();
self.state.retain(|key, gcra_state| {
if key.route_index >= self.routes.len() {
return false;
}
let route = &self.routes[key.route_index];
if key.limit_index >= route.limits.len() {
return false;
}
let limit = &route.limits[key.limit_index];
let window_nanos = limit.window.as_nanos() as u64;
let tat = gcra_state.tat(Ordering::Acquire);
tat > now.saturating_sub(window_nanos.saturating_mul(2))
});
}
#[must_use]
pub fn state_count(&self) -> usize {
self.state.len()
}
async fn check_and_apply_limits(&self, req: &Request) -> Result<(), RateLimitError> {
'outer: loop {
let now = self.now_nanos();
for (route_index, route) in self.routes.iter().enumerate() {
if !route.matches(req) {
continue;
}
for (limit_index, limit) in route.limits.iter().enumerate() {
let key = RouteKey {
route_index,
limit_index,
};
let emission_interval_nanos = limit.emission_interval().as_nanos() as u64;
let limit_nanos = limit.window.as_nanos() as u64;
let state = self.state.entry(key).or_insert_with(GcraState::new);
match state.try_acquire(now, emission_interval_nanos, limit_nanos) {
Ok(()) => {}
Err(wait_duration) => {
match route.on_limit {
ThrottleBehavior::Delay => {
drop(state);
let jitter_max_nanos = wait_duration.as_nanos() as u64 / 2;
let jitter_nanos = if jitter_max_nanos > 0 {
rand::rng().random_range(0..=jitter_max_nanos)
} else {
0
};
let sleep_duration = wait_duration
+ std::time::Duration::from_nanos(jitter_nanos);
sleep(sleep_duration).await;
continue 'outer;
}
ThrottleBehavior::Error => {
return Err(RateLimitError::RateLimited(wait_duration));
}
}
}
}
}
}
break Ok(());
}
}
}
#[async_trait]
impl Middleware for RateLimitMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> MiddlewareResult<Response> {
self.check_and_apply_limits(&req).await?;
next.run(req, extensions).await
}
}
impl Default for RateLimitMiddleware {
fn default() -> Self {
Self::builder().build()
}
}