use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use axum::{
Json,
extract::{ConnectInfo, Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use super::ServerState;
use super::handlers::ErrorResponse;
struct WindowState {
window_start: Instant,
count: u32,
}
#[derive(Default)]
pub struct AuthRateLimiter {
windows: Mutex<HashMap<IpAddr, WindowState>>,
}
impl AuthRateLimiter {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
fn check(&self, ip: IpAddr, limit: u32) -> bool {
self.check_at(ip, limit, Instant::now())
}
fn check_at(&self, ip: IpAddr, limit: u32, now: Instant) -> bool {
let window = Duration::from_secs(60);
let mut windows = self.windows.lock().expect("rate limiter mutex poisoned");
windows
.retain(|_, state| now.duration_since(state.window_start) < window * 2);
let entry = windows.entry(ip).or_insert(WindowState {
window_start: now,
count: 0,
});
if now.duration_since(entry.window_start) >= window {
entry.window_start = now;
entry.count = 0;
}
if entry.count >= limit {
return false;
}
entry.count += 1;
true
}
}
pub async fn rate_limit(
State(state): State<Arc<ServerState>>,
ConnectInfo(peer): ConnectInfo<SocketAddr>,
limiter: axum::extract::Extension<Arc<AuthRateLimiter>>,
request: Request,
next: Next,
) -> Response {
let Some(limit) = state.auth_rate_limit_per_minute else {
return next.run(request).await;
};
if !limiter.check(peer.ip(), limit) {
tracing::warn!(
"auth: rate limit exceeded for source {} (limit={}/min)",
peer.ip(),
limit
);
return (
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "rate limit exceeded".to_string(),
}),
)
.into_response();
}
next.run(request).await
}
pub async fn strip_body_for_logs(request: Request, next: Next) -> Response {
let span = tracing::info_span!("auth_request", body = "<redacted>");
let _enter = span.enter();
next.run(request).await
}
#[cfg(test)]
mod tests;