use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use dashmap::DashMap;
use serde_json::json;
use crate::config::ServeConfig;
struct Bucket {
count: u32,
window_start: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimitError {
pub retry_after: u64,
pub limit: u32,
pub reset_at: u64,
}
pub struct RateLimitStore {
read_store: DashMap<String, Bucket>,
write_store: DashMap<String, Bucket>,
registration_store: DashMap<String, Bucket>,
read_limit: u32,
write_limit: u32,
window_secs: u64,
}
const REGISTRATION_LIMIT: u32 = 5;
const REGISTRATION_WINDOW_SECS: u64 = 60;
impl RateLimitStore {
pub fn new(config: &ServeConfig) -> Arc<Self> {
Arc::new(Self {
read_store: DashMap::new(),
write_store: DashMap::new(),
registration_store: DashMap::new(),
read_limit: config.rate_limit_read,
write_limit: config.rate_limit_write,
window_secs: config.rate_limit_window,
})
}
pub fn check_read(&self, ip: &str) -> Result<(), RateLimitError> {
check_bucket(&self.read_store, ip, self.read_limit, self.window_secs)
}
pub fn check_write(&self, token: &str) -> Result<(), RateLimitError> {
check_bucket(&self.write_store, token, self.write_limit, self.window_secs)
}
pub fn check_registration(&self, ip: &str) -> Result<(), RateLimitError> {
check_bucket(&self.registration_store, ip, REGISTRATION_LIMIT, REGISTRATION_WINDOW_SECS)
}
}
fn check_bucket(
store: &DashMap<String, Bucket>,
key: &str,
limit: u32,
window_secs: u64,
) -> Result<(), RateLimitError> {
let now = Instant::now();
let mut entry = store.entry(key.to_string()).or_insert_with(|| Bucket {
count: 0,
window_start: now,
});
let elapsed = now.duration_since(entry.window_start).as_secs();
if elapsed >= window_secs {
entry.count = 0;
entry.window_start = now;
}
if entry.count < limit {
entry.count += 1;
Ok(())
} else {
let unix = unix_now();
let window_start_unix = unix.saturating_sub(elapsed);
let reset_at = window_start_unix + window_secs;
let retry_after = reset_at.saturating_sub(unix);
Err(RateLimitError {
retry_after,
limit,
reset_at,
})
}
}
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn extract_client_ip(parts: &Parts) -> String {
if let Some(xff) = parts.headers.get("x-forwarded-for") {
if let Ok(s) = xff.to_str() {
if let Some(first) = s.split(',').next() {
let ip = first.trim().to_string();
if !ip.is_empty() {
return ip;
}
}
}
}
if let Some(addr) = parts
.extensions
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
{
return addr.0.ip().to_string();
}
"unknown".to_string()
}
fn extract_bearer_from_headers(headers: &HeaderMap) -> Option<String> {
let auth = headers.get("authorization")?.to_str().ok()?;
auth.strip_prefix("Bearer ").map(|s| s.to_string())
}
pub fn too_many_requests_response(err: &RateLimitError) -> Response {
(
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", err.retry_after.to_string()),
("X-RateLimit-Limit", err.limit.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
("X-RateLimit-Reset", err.reset_at.to_string()),
("Content-Type", "application/json".to_string()),
],
Json(json!({"error": "Too many requests"})).to_string(),
)
.into_response()
}
pub struct ReadRateLimit;
#[async_trait]
impl<S> FromRequestParts<S> for ReadRateLimit
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let store = parts
.extensions
.get::<Arc<RateLimitStore>>()
.cloned()
.ok_or_else(|| {
(StatusCode::INTERNAL_SERVER_ERROR, "Rate limit store missing").into_response()
})?;
let ip = extract_client_ip(parts);
store.check_read(&ip).map_err(|e| too_many_requests_response(&e))?;
Ok(ReadRateLimit)
}
}
pub struct RegistrationRateLimit;
#[async_trait]
impl<S> FromRequestParts<S> for RegistrationRateLimit
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let store = parts
.extensions
.get::<Arc<RateLimitStore>>()
.cloned()
.ok_or_else(|| {
(StatusCode::INTERNAL_SERVER_ERROR, "Rate limit store missing").into_response()
})?;
let ip = extract_client_ip(parts);
store.check_registration(&ip).map_err(|e| too_many_requests_response(&e))?;
Ok(RegistrationRateLimit)
}
}
pub struct WriteRateLimit;
#[async_trait]
impl<S> FromRequestParts<S> for WriteRateLimit
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let store = parts
.extensions
.get::<Arc<RateLimitStore>>()
.cloned()
.ok_or_else(|| {
(StatusCode::INTERNAL_SERVER_ERROR, "Rate limit store missing").into_response()
})?;
if let Some(token) = extract_bearer_from_headers(&parts.headers) {
store.check_write(&token).map_err(|e| too_many_requests_response(&e))?;
}
Ok(WriteRateLimit)
}
}