use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
use axum::Json;
use axum::extract::State;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use dashmap::DashMap;
use crate::api::{AppState, ErrorDetail, ErrorResponse};
use crate::auth::validate_bearer_header;
const MAX_AUTH_FAILURES: u32 = 5;
const LOCKOUT_DURATION_SECS: u64 = 300; const PRUNE_INTERVAL_SECS: u64 = 60;
struct AuthFailureEntry {
count: u32,
first_failure: Instant,
last_failure: Instant,
}
pub struct AuthRateLimiter {
failures: DashMap<IpAddr, AuthFailureEntry>,
last_prune: std::sync::Mutex<Instant>,
}
impl AuthRateLimiter {
pub fn new() -> Self {
Self {
failures: DashMap::new(),
last_prune: std::sync::Mutex::new(Instant::now()),
}
}
pub fn is_locked_out(&self, ip: &IpAddr) -> bool {
if let Some(entry) = self.failures.get(ip) {
if entry.count >= MAX_AUTH_FAILURES {
let elapsed = entry.last_failure.elapsed().as_secs();
if elapsed < LOCKOUT_DURATION_SECS {
return true;
}
}
}
false
}
pub fn record_failure(&self, ip: IpAddr) {
let now = Instant::now();
self.failures
.entry(ip)
.and_modify(|entry| {
if entry.first_failure.elapsed().as_secs() > LOCKOUT_DURATION_SECS {
entry.count = 1;
entry.first_failure = now;
} else {
entry.count += 1;
}
entry.last_failure = now;
})
.or_insert(AuthFailureEntry {
count: 1,
first_failure: now,
last_failure: now,
});
self.maybe_prune();
}
pub fn record_success(&self, ip: &IpAddr) {
self.failures.remove(ip);
}
fn maybe_prune(&self) {
let mut last = self.last_prune.lock().unwrap();
if last.elapsed().as_secs() < PRUNE_INTERVAL_SECS {
return;
}
*last = Instant::now();
drop(last);
self.failures
.retain(|_, entry| entry.last_failure.elapsed().as_secs() < LOCKOUT_DURATION_SECS);
}
}
impl Default for AuthRateLimiter {
fn default() -> Self {
Self::new()
}
}
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
request: axum::extract::Request,
next: Next,
) -> Response {
if !state.auth_config.enabled {
return next.run(request).await;
}
let client_ip: Option<IpAddr> = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.and_then(|s| s.trim().parse().ok())
.or_else(|| {
request
.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip())
});
if let Some(ip) = client_ip {
if state.auth_rate_limiter.is_locked_out(&ip) {
return rate_limited_response();
}
}
let auth_header = request
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match validate_bearer_header(auth_header, &state.auth_config.tokens) {
Ok(()) => {
if let Some(ip) = client_ip {
state.auth_rate_limiter.record_success(&ip);
}
next.run(request).await
}
Err(err) => {
if let Some(ip) = client_ip {
state.auth_rate_limiter.record_failure(ip);
}
unauthorized_response(err.message())
}
}
}
fn unauthorized_response(message: &str) -> Response {
let body = ErrorResponse {
ok: false,
error: ErrorDetail {
code: "UNAUTHORIZED",
message: message.to_string(),
details: None,
},
};
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
}
fn rate_limited_response() -> Response {
let body = ErrorResponse {
ok: false,
error: ErrorDetail {
code: "RATE_LIMITED",
message: "Too many failed authentication attempts. Try again later.".to_string(),
details: None,
},
};
(StatusCode::TOO_MANY_REQUESTS, Json(body)).into_response()
}