use std::{
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use dashmap::DashMap;
use subtle::ConstantTimeEq as _;
const ADMIN_AUTH_WINDOW_SECS: u64 = 60;
#[derive(Clone)]
struct FailureRecord {
count: u32,
window_start: u64,
}
#[derive(Clone)]
pub(crate) struct FailureLimiter {
records: Arc<DashMap<String, FailureRecord>>,
max_failures: u32,
}
impl FailureLimiter {
pub(crate) fn new(max_failures: u32) -> Self {
Self {
records: Arc::new(DashMap::new()),
max_failures,
}
}
fn now_secs() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0)
}
pub(crate) fn record_failure(&self, ip: &str) -> bool {
let now = Self::now_secs();
let mut entry = self.records.entry(ip.to_string()).or_insert_with(|| FailureRecord {
count: 0,
window_start: now,
});
if now >= entry.window_start + ADMIN_AUTH_WINDOW_SECS {
entry.count = 1;
entry.window_start = now;
false
} else {
entry.count = entry.count.saturating_add(1);
entry.count >= self.max_failures
}
}
pub(crate) fn is_blocked(&self, ip: &str) -> bool {
let now = Self::now_secs();
if let Some(entry) = self.records.get(ip) {
if now < entry.window_start + ADMIN_AUTH_WINDOW_SECS {
return entry.count >= self.max_failures;
}
}
false
}
pub(crate) fn record_success(&self, ip: &str) {
self.records.remove(ip);
}
#[cfg(test)]
pub(crate) fn failure_count(&self, ip: &str) -> u32 {
self.records.get(ip).map_or(0, |e| e.count)
}
}
#[derive(Clone)]
pub struct BearerAuthState {
pub token: Arc<String>,
failure_limiter: FailureLimiter,
}
impl BearerAuthState {
#[must_use]
pub fn new(token: String) -> Self {
Self::with_max_failures(token, 10)
}
#[must_use]
pub fn with_max_failures(token: String, max_failures: u32) -> Self {
Self {
token: Arc::new(token),
failure_limiter: FailureLimiter::new(max_failures),
}
}
}
pub async fn bearer_auth_middleware(
State(auth_state): State<BearerAuthState>,
request: Request<Body>,
next: Next,
) -> Response {
use std::net::SocketAddr;
use axum::extract::ConnectInfo;
let peer_key = request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
.or_else(|| {
request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|v| v.split(',').next().unwrap_or(v).trim().to_string())
})
.unwrap_or_else(|| "unknown".to_string());
if auth_state.failure_limiter.is_blocked(&peer_key) {
return (StatusCode::TOO_MANY_REQUESTS, "Too many failed auth attempts").into_response();
}
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
None => {
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Missing Authorization header",
)
.into_response();
},
Some(header_value) => {
if !header_value.starts_with("Bearer ") {
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Invalid Authorization header format. Expected: Bearer <token>",
)
.into_response();
}
let token = &header_value[7..];
if !constant_time_compare(token, &auth_state.token) {
if auth_state.failure_limiter.record_failure(&peer_key) {
return (StatusCode::TOO_MANY_REQUESTS, "Too many failed auth attempts")
.into_response();
}
return (StatusCode::FORBIDDEN, "Invalid token").into_response();
}
auth_state.failure_limiter.record_success(&peer_key);
},
}
next.run(request).await
}
#[must_use]
pub fn extract_bearer_token(header_value: &str) -> Option<&str> {
header_value.strip_prefix("Bearer ")
}
pub(crate) fn constant_time_compare(a: &str, b: &str) -> bool {
a.as_bytes().ct_eq(b.as_bytes()).into()
}