use crate::utils::trad::t;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use tokio::time::interval;
type Store = Arc<Mutex<HashMap<String, (u32, Instant)>>>;
#[derive(Clone)]
pub struct RateLimiter {
store: Store,
pub max_requests: u32,
pub window: Duration,
}
impl RateLimiter {
pub fn new() -> Self {
Self {
store: Arc::new(Mutex::new(HashMap::new())),
max_requests: 60,
window: Duration::from_secs(60),
}
}
#[must_use]
pub fn max_requests(mut self, max: u32) -> Self {
self.max_requests = max;
self
}
#[must_use]
pub fn retry_after(mut self, secs: u64) -> Self {
self.window = Duration::from_secs(secs);
self
}
pub fn spawn_cleanup(&self, period: tokio::time::Duration) {
let store = self.store.clone();
let window = self.window;
tokio::spawn(async move {
let mut ticker = interval(period);
loop {
ticker.tick().await;
let mut guard = match store.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
let now = Instant::now();
guard.retain(|_, (_, start)| now.duration_since(*start) < window);
}
});
}
#[must_use]
pub fn retry_after_secs(&self, key: &str) -> u64 {
let store = match self.store.lock() {
Ok(s) => s,
Err(p) => p.into_inner(),
};
match store.get(key) {
Some((_, start)) => {
let interval = Instant::now().duration_since(*start);
self.window.saturating_sub(interval).as_secs()
}
None => 0,
}
}
#[must_use]
pub fn is_allowed(&self, key: &str) -> bool {
let mut store = match self.store.lock() {
Ok(s) => s,
Err(p) => p.into_inner(),
};
let now = Instant::now();
let entry = store.entry(key.to_string()).or_insert((0, now));
if now.duration_since(entry.1) >= self.window {
*entry = (1, now);
true
} else if entry.0 < self.max_requests {
entry.0 = entry.0.saturating_add(1);
true
} else {
false
}
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
fn extract_ip(req: &Request<Body>) -> String {
req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.map(|s| s.trim().to_string())
.or_else(|| {
req.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
})
.unwrap_or_else(|| "unknown".to_string())
}
pub async fn rate_limit_middleware(
State(limiter): State<Arc<RateLimiter>>,
req: Request<Body>,
next: Next,
) -> Response {
let ip = extract_ip(&req);
if limiter.is_allowed(&ip) {
next.run(req).await
} else {
let retry_after = limiter.retry_after_secs(&ip).to_string();
(
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after)],
t("html.429_text").into_owned(),
)
.into_response()
}
}