use super::config::RateLimitConfig;
use axum::{
extract::Request,
http::StatusCode,
response::{IntoResponse, Response},
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use tower::{Layer, Service};
#[derive(serde::Serialize)]
struct RateLimitError {
error: String,
message: String,
retry_after: u64,
}
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response {
(
StatusCode::TOO_MANY_REQUESTS,
[("Retry-After", self.retry_after.to_string())],
axum::Json(self),
)
.into_response()
}
}
#[derive(Clone)]
struct RateLimitState {
per_ip: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
global: Arc<Mutex<Vec<Instant>>>,
config: RateLimitConfig,
}
impl RateLimitState {
fn new(config: RateLimitConfig) -> Self {
Self {
per_ip: Arc::new(Mutex::new(HashMap::new())),
global: Arc::new(Mutex::new(Vec::new())),
config,
}
}
fn check_rate_limit(&self, key: Option<&str>) -> Result<(), u64> {
let now = Instant::now();
let window = Duration::from_secs(self.config.window_seconds);
let max_requests = self.config.max_requests as usize;
if self.config.strategy == "per_ip" {
if let Some(ip) = key {
let mut per_ip = self.per_ip.lock().unwrap_or_else(|e| e.into_inner());
let requests = per_ip.entry(ip.to_string()).or_insert_with(Vec::new);
requests.retain(|&req_time| now.duration_since(req_time) < window);
if requests.len() >= max_requests {
let oldest = requests.first().copied().unwrap_or(now);
let wait_time = window.saturating_sub(now.duration_since(oldest));
return Err(wait_time.as_secs());
}
requests.push(now);
}
} else {
let mut global = self.global.lock().unwrap_or_else(|e| e.into_inner());
global.retain(|&req_time| now.duration_since(req_time) < window);
if global.len() >= max_requests {
let oldest = global.first().copied().unwrap_or(now);
let wait_time = window.saturating_sub(now.duration_since(oldest));
return Err(wait_time.as_secs());
}
global.push(now);
}
Ok(())
}
}
#[derive(Clone)]
pub struct RateLimitLayer {
state: RateLimitState,
}
impl RateLimitLayer {
pub fn new(config: RateLimitConfig) -> Self {
Self {
state: RateLimitState::new(config),
}
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
state: self.state.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
state: RateLimitState,
}
impl<S> Service<Request> for RateLimitService<S>
where
S: Service<Request> + Clone + Send + Sync + 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let path = req.uri().path();
if path == "/health" || path.starts_with("/health/") {
let mut svc = self.inner.clone();
return Box::pin(async move {
let response = svc.call(req).await?;
Ok(response.into_response())
});
}
let ip: Option<String> = req
.headers()
.get("x-forwarded-for")
.or_else(|| req.headers().get("x-real-ip"))
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.or_else(|| {
req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|addr| addr.ip().to_string())
});
let key = ip.as_deref();
match self.state.check_rate_limit(key) {
Ok(()) => {
let mut svc = self.inner.clone();
Box::pin(async move {
let response = svc.call(req).await?;
Ok(response.into_response())
})
}
Err(retry_after) => {
let error = RateLimitError {
error: "rate_limit_exceeded".to_string(),
message: format!(
"Rate limit exceeded. Please try again in {} seconds",
retry_after
),
retry_after,
};
Box::pin(async move { Ok(error.into_response()) })
}
}
}
}
pub fn build_rate_limit_layer(config: &RateLimitConfig) -> Option<RateLimitLayer> {
if !config.enabled {
return None;
}
Some(RateLimitLayer::new(config.clone()))
}