use axum::{
body::Body,
extract::{ConnectInfo, FromRequestParts, Request, State},
http::{request::Parts, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use super::audit::{log_auth_failure, log_auth_success, log_rate_limit};
use super::authn::Authenticator;
use super::error::AuthError;
use super::principal::Principal;
use super::rate_limit::{RateLimitConfig, RateLimitInfo, RateLimiter};
#[derive(Clone)]
pub struct AuthState {
pub authenticator: Arc<dyn Authenticator>,
pub rate_limiter: Arc<RateLimiter>,
}
impl AuthState {
pub fn new(authenticator: Arc<dyn Authenticator>) -> Self {
Self {
authenticator,
rate_limiter: Arc::new(RateLimiter::new(RateLimitConfig::disabled())),
}
}
pub fn with_rate_limiter(
authenticator: Arc<dyn Authenticator>,
rate_limiter: Arc<RateLimiter>,
) -> Self {
Self {
authenticator,
rate_limiter,
}
}
}
pub async fn auth_middleware(
State(auth_state): State<AuthState>,
mut request: Request<Body>,
next: Next,
) -> Response {
let trust_proxy = auth_state.rate_limiter.trust_proxy_headers();
let client_ip = extract_client_ip(&request, trust_proxy);
if let Some(ip) = client_ip {
if let Err(exceeded) = auth_state.rate_limiter.check_ip_limit(&ip) {
log_rate_limit(&ip.to_string(), None, "ip");
return exceeded.into_response();
}
}
let headers = request.headers().clone();
match auth_state.authenticator.authenticate(&headers).await {
Ok(principal) => {
if let Some(ip) = client_ip {
auth_state.rate_limiter.record_auth_success(&ip);
}
let auth_method = match principal.auth_method() {
super::principal::AuthMethod::ApiKey => "api_key",
super::principal::AuthMethod::Bearer => "jwt",
super::principal::AuthMethod::None => "anonymous",
_ => "other",
};
log_auth_success(
principal.id(),
principal.tenant_id(),
client_ip.as_ref().map(|ip| ip.to_string()).as_deref(),
auth_method,
);
if principal.is_expired() {
return AuthError::TokenExpired.into_response();
}
if let Err(exceeded) = auth_state
.rate_limiter
.check_tenant_limit(principal.tenant_id())
{
log_rate_limit(
&client_ip
.map(|ip| ip.to_string())
.unwrap_or_else(|| "unknown".to_string()),
Some(principal.tenant_id()),
"tenant",
);
return add_rate_limit_headers(exceeded.into_response(), None);
}
request.extensions_mut().insert(principal);
let mut response = next.run(request).await;
if let Some(ip) = client_ip {
if let Ok(info) = auth_state.rate_limiter.check_ip_limit(&ip) {
response = add_rate_limit_headers(response, Some(info));
}
}
response
}
Err(e) => {
if let Some(ip) = client_ip {
if matches!(
e,
AuthError::InvalidCredentials(_)
| AuthError::ApiKeyNotFound
| AuthError::ApiKeyDisabled
) {
auth_state.rate_limiter.record_auth_failure(&ip);
log_auth_failure(Some(&ip.to_string()), &e.to_string());
}
}
e.into_response()
}
}
}
pub async fn require_auth_middleware(
State(auth_state): State<AuthState>,
mut request: Request<Body>,
next: Next,
) -> Response {
let trust_proxy = auth_state.rate_limiter.trust_proxy_headers();
let client_ip = extract_client_ip(&request, trust_proxy);
if let Some(ip) = client_ip {
if let Err(exceeded) = auth_state.rate_limiter.check_ip_limit(&ip) {
log_rate_limit(&ip.to_string(), None, "ip");
return exceeded.into_response();
}
}
let headers = request.headers().clone();
match auth_state.authenticator.authenticate(&headers).await {
Ok(principal) => {
if let Some(ip) = client_ip {
auth_state.rate_limiter.record_auth_success(&ip);
}
let auth_method = match principal.auth_method() {
super::principal::AuthMethod::ApiKey => "api_key",
super::principal::AuthMethod::Bearer => "jwt",
super::principal::AuthMethod::None => "anonymous",
_ => "other",
};
log_auth_success(
principal.id(),
principal.tenant_id(),
client_ip.as_ref().map(|ip| ip.to_string()).as_deref(),
auth_method,
);
if principal.is_expired() {
return AuthError::TokenExpired.into_response();
}
if principal.is_anonymous() {
return AuthError::Unauthenticated.into_response();
}
if let Err(exceeded) = auth_state
.rate_limiter
.check_tenant_limit(principal.tenant_id())
{
log_rate_limit(
&client_ip
.map(|ip| ip.to_string())
.unwrap_or_else(|| "unknown".to_string()),
Some(principal.tenant_id()),
"tenant",
);
return exceeded.into_response();
}
request.extensions_mut().insert(principal);
let mut response = next.run(request).await;
if let Some(ip) = client_ip {
if let Ok(info) = auth_state.rate_limiter.check_ip_limit(&ip) {
response = add_rate_limit_headers(response, Some(info));
}
}
response
}
Err(e) => {
if let Some(ip) = client_ip {
if matches!(
e,
AuthError::InvalidCredentials(_)
| AuthError::ApiKeyNotFound
| AuthError::ApiKeyDisabled
) {
auth_state.rate_limiter.record_auth_failure(&ip);
log_auth_failure(Some(&ip.to_string()), &e.to_string());
}
}
e.into_response()
}
}
}
fn extract_client_ip(request: &Request<Body>, trust_proxy_headers: bool) -> Option<IpAddr> {
if trust_proxy_headers {
if let Some(xff) = request.headers().get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
if let Some(first_ip) = xff_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
if let Some(real_ip) = request.headers().get("x-real-ip") {
if let Ok(ip_str) = real_ip.to_str() {
if let Ok(ip) = ip_str.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip())
}
fn add_rate_limit_headers(mut response: Response, info: Option<RateLimitInfo>) -> Response {
if let Some(info) = info {
for (name, value) in info.headers() {
response.headers_mut().insert(name, value);
}
}
response
}
#[derive(Debug, Clone)]
pub struct AuthenticatedPrincipal(pub Principal);
impl<S> FromRequestParts<S> for AuthenticatedPrincipal
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Principal>()
.cloned()
.map(AuthenticatedPrincipal)
.ok_or((StatusCode::UNAUTHORIZED, "Authentication required"))
}
}
impl std::ops::Deref for AuthenticatedPrincipal {
type Target = Principal;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct OptionalPrincipal(pub Option<Principal>);
impl<S> FromRequestParts<S> for OptionalPrincipal
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(OptionalPrincipal(
parts.extensions.get::<Principal>().cloned(),
))
}
}
impl std::ops::Deref for OptionalPrincipal {
type Target = Option<Principal>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::authn::{AllowAllAuthenticator, DenyAllAuthenticator};
use axum::{
body::Body,
http::{Request, StatusCode},
routing::get,
Router,
};
use tower::ServiceExt;
async fn test_handler(AuthenticatedPrincipal(principal): AuthenticatedPrincipal) -> String {
format!("Hello, {}!", principal.name())
}
fn create_test_app(authenticator: Arc<dyn Authenticator>) -> Router {
let auth_state = AuthState::new(authenticator);
Router::new()
.route("/test", get(test_handler))
.layer(axum::middleware::from_fn_with_state(
auth_state.clone(),
auth_middleware,
))
.with_state(auth_state)
}
#[tokio::test]
async fn test_auth_middleware_allows_authenticated() {
let app = create_test_app(Arc::new(AllowAllAuthenticator));
let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_middleware_denies_unauthenticated() {
let app = create_test_app(Arc::new(DenyAllAuthenticator));
let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
}