use axum::extract::State;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::Response;
use std::net::IpAddr;
use std::sync::Arc;
use crate::api::utils::error_response;
use crate::auth::v4::PresignedParams;
use crate::AppState;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthMode {
Passthrough,
Authenticated,
Misconfigured,
}
pub fn detect_auth_mode(access_key: &str, secret_key: &str) -> AuthMode {
match (access_key.is_empty(), secret_key.is_empty()) {
(true, true) => AuthMode::Passthrough,
(false, false) => AuthMode::Authenticated,
_ => AuthMode::Misconfigured,
}
}
const EXEMPT_PATHS: &[&str] = &["/health", "/ready", "/metrics", "/openapi.json"];
const MAX_AUTH_FAILURES: u32 = 10;
const AUTH_FAILURE_WINDOW_SECS: u64 = 60;
fn extract_client_ip(request: &axum::extract::Request) -> IpAddr {
if let Some(forwarded_for) = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
if let Some(first_ip) = forwarded_for.split(',').next() {
if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
return ip;
}
}
}
if let Some(addr) = request
.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
{
return addr.0.ip();
}
IpAddr::from([127, 0, 0, 1])
}
fn check_and_record_failure(state: &AppState, ip: IpAddr) -> bool {
let mut counts = state
.auth_failure_counts
.lock()
.unwrap_or_else(|e| e.into_inner());
let now = std::time::Instant::now();
let entry = counts.entry(ip).or_insert_with(|| (0u32, now));
if now.duration_since(entry.1).as_secs() >= AUTH_FAILURE_WINDOW_SECS {
*entry = (0u32, now);
}
entry.0 = entry.0.saturating_add(1);
entry.0 > MAX_AUTH_FAILURES
}
fn clear_failure_counter(state: &AppState, ip: IpAddr) {
let mut counts = state
.auth_failure_counts
.lock()
.unwrap_or_else(|e| e.into_inner());
counts.remove(&ip);
}
fn too_many_requests_response(_uri_path: &str) -> Response {
use axum::body::Body;
let xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\
<Error>\
<Code>TooManyRequests</Code>\
<Message>Too many authentication failures</Message>\
</Error>";
Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Content-Type", "application/xml")
.header("Retry-After", "60")
.header("x-amz-request-id", uuid::Uuid::new_v4().to_string())
.body(Body::from(xml))
.unwrap_or_else(|_| {
let mut resp = axum::response::Response::new(Body::empty());
*resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
resp
})
}
pub async fn sigv4_auth_layer(
State(state): State<AppState>,
request: axum::extract::Request,
next: Next,
) -> Result<Response, Response> {
let path = request.uri().path().to_owned();
if EXEMPT_PATHS.iter().any(|&p| path == p) {
return Ok(next.run(request).await);
}
{
let mode = detect_auth_mode(&state.config.access_key, &state.config.secret_key);
match mode {
AuthMode::Passthrough => {
tracing::warn!(
"Auth middleware is in PASSTHROUGH mode (no credentials configured). \
All requests pass without authentication. Do not use in production."
);
}
AuthMode::Misconfigured => {
tracing::error!(
"Auth misconfiguration: only one of access_key/secret_key is set. \
Server should refuse to start."
);
}
AuthMode::Authenticated => {}
}
}
let verifier = match &state.verifier {
Some(v) => Arc::clone(v),
None => return Ok(next.run(request).await),
};
let client_ip = extract_client_ip(&request);
let method = request.method().as_str().to_owned();
let uri_path = path;
let query_string = request.uri().query().unwrap_or("").to_owned();
let headers_vec: Vec<(String, String)> = request
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.as_str().to_lowercase(), v.to_owned()))
})
.collect();
let payload_hash = headers_vec
.iter()
.find(|(name, _)| name == "x-amz-content-sha256")
.map(|(_, v)| v.as_str())
.unwrap_or("UNSIGNED-PAYLOAD")
.to_owned();
if PresignedParams::is_presigned_request(&query_string) {
match verifier.verify_presigned_request(&method, &uri_path, &query_string, &headers_vec) {
Ok(()) => {
clear_failure_counter(&state, client_ip);
return Ok(next.run(request).await);
}
Err(e) => {
tracing::warn!(
"Presigned URL auth failed for {} {}: {}",
method,
uri_path,
e
);
if check_and_record_failure(&state, client_ip) {
return Err(too_many_requests_response(&uri_path));
}
return Err(error_response(
StatusCode::FORBIDDEN,
"AccessDenied",
"Access Denied",
&uri_path,
));
}
}
}
let auth_header_value = headers_vec
.iter()
.find(|(name, _)| name == "authorization")
.map(|(_, v)| v.clone());
if let Some(ref auth_header) = auth_header_value {
if auth_header.starts_with("AWS4-HMAC-SHA256") {
match verifier.verify_request(
&method,
&uri_path,
&query_string,
&headers_vec,
&payload_hash,
Some(auth_header.as_str()),
) {
Ok(()) => {
clear_failure_counter(&state, client_ip);
return Ok(next.run(request).await);
}
Err(e) => {
tracing::warn!(
"SigV4 header auth failed for {} {}: {}",
method,
uri_path,
e
);
if check_and_record_failure(&state, client_ip) {
return Err(too_many_requests_response(&uri_path));
}
return Err(error_response(
StatusCode::FORBIDDEN,
"AccessDenied",
"Access Denied",
&uri_path,
));
}
}
}
}
tracing::warn!(
"Request {} {} missing SigV4 auth (verifier enabled)",
method,
uri_path
);
if check_and_record_failure(&state, client_ip) {
return Err(too_many_requests_response(&uri_path));
}
Err(error_response(
StatusCode::FORBIDDEN,
"AccessDenied",
"Access Denied",
&uri_path,
))
}