use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use dashmap::DashMap;
use serde_json::json;
use crate::config::ServeConfig;
struct Bucket {
count: u32,
window_start: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimitError {
pub retry_after: u64,
pub limit: u32,
pub reset_at: u64,
}
pub struct RateLimitStore {
read_store: DashMap<String, Bucket>,
write_store: DashMap<String, Bucket>,
registration_store: DashMap<String, Bucket>,
read_limit: u32,
write_limit: u32,
window_secs: u64,
registration_limit: u32,
}
const REGISTRATION_WINDOW_SECS: u64 = 60;
impl RateLimitStore {
pub fn new(config: &ServeConfig) -> Arc<Self> {
Arc::new(Self {
read_store: DashMap::new(),
write_store: DashMap::new(),
registration_store: DashMap::new(),
read_limit: config.rate_limit_read,
write_limit: config.rate_limit_write,
window_secs: config.rate_limit_window,
registration_limit: config.registration_limit,
})
}
pub fn check_read(&self, ip: &str) -> Result<(), RateLimitError> {
check_bucket(&self.read_store, ip, self.read_limit, self.window_secs)
}
pub fn check_write(&self, token: &str) -> Result<(), RateLimitError> {
check_bucket(&self.write_store, token, self.write_limit, self.window_secs)
}
pub fn check_registration(&self, ip: &str) -> Result<(), RateLimitError> {
check_bucket(
&self.registration_store,
ip,
self.registration_limit,
REGISTRATION_WINDOW_SECS,
)
}
pub fn evict_expired(&self) {
let cutoff_secs = self.window_secs * 2;
let registration_cutoff = REGISTRATION_WINDOW_SECS * 2;
self.read_store
.retain(|_, bucket| bucket.window_start.elapsed().as_secs() < cutoff_secs);
self.write_store
.retain(|_, bucket| bucket.window_start.elapsed().as_secs() < cutoff_secs);
self.registration_store
.retain(|_, bucket| bucket.window_start.elapsed().as_secs() < registration_cutoff);
}
}
fn check_bucket(
store: &DashMap<String, Bucket>,
key: &str,
limit: u32,
window_secs: u64,
) -> Result<(), RateLimitError> {
let now = Instant::now();
let mut entry = store.entry(key.to_string()).or_insert_with(|| Bucket {
count: 0,
window_start: now,
});
let elapsed = now.duration_since(entry.window_start).as_secs();
if elapsed >= window_secs {
entry.count = 0;
entry.window_start = now;
}
if entry.count < limit {
entry.count += 1;
Ok(())
} else {
let unix = unix_now();
let window_start_unix = unix.saturating_sub(elapsed);
let reset_at = window_start_unix + window_secs;
let retry_after = reset_at.saturating_sub(unix);
Err(RateLimitError {
retry_after,
limit,
reset_at,
})
}
}
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn extract_client_ip(parts: &Parts) -> String {
let peer_addr = parts
.extensions
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|c| c.0.to_string());
crate::helpers::extract_client_ip(&parts.headers, peer_addr.as_deref())
}
fn extract_bearer_from_headers(headers: &HeaderMap) -> Option<String> {
let auth = headers.get("authorization")?.to_str().ok()?;
auth.strip_prefix("Bearer ").map(|s| s.to_string())
}
pub fn too_many_requests_response(err: &RateLimitError) -> Response {
(
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", err.retry_after.to_string()),
("X-RateLimit-Limit", err.limit.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
("X-RateLimit-Reset", err.reset_at.to_string()),
("Content-Type", "application/json".to_string()),
],
Json(json!({"error": "Too many requests"})).to_string(),
)
.into_response()
}
pub struct ReadRateLimit;
#[async_trait]
impl<S> FromRequestParts<S> for ReadRateLimit
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let store = parts
.extensions
.get::<Arc<RateLimitStore>>()
.cloned()
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Rate limit store missing",
)
.into_response()
})?;
let ip = extract_client_ip(parts);
store
.check_read(&ip)
.map_err(|e| too_many_requests_response(&e))?;
Ok(ReadRateLimit)
}
}
pub struct RegistrationRateLimit;
#[async_trait]
impl<S> FromRequestParts<S> for RegistrationRateLimit
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let store = parts
.extensions
.get::<Arc<RateLimitStore>>()
.cloned()
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Rate limit store missing",
)
.into_response()
})?;
let ip = extract_client_ip(parts);
store
.check_registration(&ip)
.map_err(|e| too_many_requests_response(&e))?;
Ok(RegistrationRateLimit)
}
}
pub struct WriteRateLimit;
#[async_trait]
impl<S> FromRequestParts<S> for WriteRateLimit
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let store = parts
.extensions
.get::<Arc<RateLimitStore>>()
.cloned()
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Rate limit store missing",
)
.into_response()
})?;
if let Some(token) = extract_bearer_from_headers(&parts.headers) {
store
.check_write(&token)
.map_err(|e| too_many_requests_response(&e))?;
}
Ok(WriteRateLimit)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_store(limit: u32, window_secs: u64) -> Arc<RateLimitStore> {
Arc::new(RateLimitStore {
read_store: DashMap::new(),
write_store: DashMap::new(),
registration_store: DashMap::new(),
read_limit: limit,
write_limit: limit,
window_secs,
registration_limit: 5,
})
}
#[test]
fn check_bucket_within_limit() {
let store = make_store(5, 60);
assert!(store.check_read("192.0.2.1").is_ok());
}
#[test]
fn check_bucket_at_limit() {
let store = make_store(3, 60);
assert!(store.check_read("10.0.0.1").is_ok());
assert!(store.check_read("10.0.0.1").is_ok());
assert!(store.check_read("10.0.0.1").is_ok());
}
#[test]
fn check_bucket_over_limit() {
let store = make_store(2, 60);
assert!(store.check_read("10.0.0.2").is_ok());
assert!(store.check_read("10.0.0.2").is_ok());
let result = store.check_read("10.0.0.2");
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.limit, 2);
assert!(err.retry_after <= 60);
}
#[test]
fn window_reset() {
let store = make_store(1, 0);
assert!(store.check_read("10.0.0.3").is_ok());
assert!(store.check_read("10.0.0.3").is_ok());
}
#[test]
fn separate_buckets() {
let store = make_store(1, 60);
assert!(store.check_read("10.0.0.4").is_ok());
assert!(store.check_read("10.0.0.4").is_err());
assert!(store.check_read("10.0.0.5").is_ok());
}
#[test]
fn evict_expired() {
let store = make_store(10, 0);
let _ = store.check_read("evict-a");
let _ = store.check_read("evict-b");
assert_eq!(store.read_store.len(), 2);
store.evict_expired();
assert!(store.read_store.len() <= 2);
}
}