use std::borrow::Cow;
use std::sync::Arc;
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use rusty_gasket::observability::{AuthSummary, LoggingContext, RequestId};
use rusty_gasket::rate_limit::RateLimitSubject;
use rusty_gasket::auth::audit::{
AuditLogger, AuditLoggerHandle, AuthAuditEvent, AuthAuditOutcome, IntoAuditLoggerHandle,
};
use rusty_gasket::auth::chain::AuthChain;
use rusty_gasket::auth::context::{AuthContext, AuthResult, FailedReason};
use rusty_gasket::auth::identity::Identity;
#[non_exhaustive]
pub struct AuthMiddlewareState {
pub(crate) chain: AuthChain,
pub(crate) audit_logger: Option<AuditLoggerHandle>,
}
impl AuthMiddlewareState {
#[must_use]
pub fn new(chain: AuthChain) -> Self {
Self {
chain,
audit_logger: None,
}
}
#[must_use]
pub fn with_audit_logger(mut self, logger: impl IntoAuditLoggerHandle) -> Self {
self.audit_logger = Some(logger.into_audit_logger_handle());
self
}
#[must_use]
pub fn with_audit_logger_handle(mut self, logger: AuditLoggerHandle) -> Self {
self.audit_logger = Some(logger);
self
}
#[must_use]
#[deprecated(note = "use `with_audit_logger` for consistency with other builders")]
pub fn audit_logger(self, logger: Arc<dyn AuditLogger>) -> Self {
self.with_audit_logger_handle(AuditLoggerHandle::shared(logger))
}
#[must_use]
pub const fn chain(&self) -> &AuthChain {
&self.chain
}
}
impl std::fmt::Debug for AuthMiddlewareState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthMiddlewareState")
.field("chain", &self.chain)
.field("has_audit_logger", &self.audit_logger.is_some())
.finish()
}
}
pub async fn auth_middleware(
State(state): State<Arc<AuthMiddlewareState>>,
mut request: Request,
next: Next,
) -> Response {
let headers = request.headers();
let uri = request.uri();
let request_id = request
.extensions()
.get::<RequestId>()
.map(|r| r.as_str().to_owned())
.unwrap_or_default();
let client_ip = extract_client_ip(&request);
let identity = match state.chain.authenticate(headers, uri).await {
Ok(id) => id,
Err(e) => {
return handle_auth_failure(
state.audit_logger.as_ref().map(AuditLoggerHandle::logger),
&request,
request_id,
client_ip,
e,
);
}
};
let identity_ref = identity.as_ref();
audit_success_or_anonymous(
state.audit_logger.as_ref().map(AuditLoggerHandle::logger),
&request_id,
&client_ip,
identity_ref,
);
populate_logging_context(&request, build_success_summary(&client_ip, identity_ref));
if let Some(id) = identity_ref {
request
.extensions_mut()
.insert(RateLimitSubject::new(id.subject()));
}
let auth_result = identity_ref.map_or(AuthResult::Anonymous, |id| AuthResult::Authenticated {
method: id.auth_method(),
});
request.extensions_mut().insert(AuthContext::new(
identity,
client_ip,
request_id,
auth_result,
));
next.run(request).await
}
fn handle_auth_failure(
audit: Option<&dyn AuditLogger>,
request: &Request,
request_id: String,
client_ip: String,
error: rusty_gasket::auth::error::AuthError,
) -> Response {
let reason = error.to_string();
let category = error.category();
if let Some(logger) = audit {
let outcome = match &error {
rusty_gasket::auth::error::AuthError::BackendError(_)
| rusty_gasket::auth::error::AuthError::Configuration(_) => AuthAuditOutcome::Error {
error: reason.clone(),
},
_ => AuthAuditOutcome::Denied {
reason: reason.clone(),
},
};
logger.log_auth_event(&AuthAuditEvent {
request_id: request_id.clone(),
client_ip: client_ip.clone(),
auth_method: None,
subject: None,
outcome,
});
}
populate_logging_context(
request,
AuthSummary::builder()
.client_ip(client_ip.clone())
.user_id(Cow::Borrowed("unknown"))
.auth_method(Cow::Borrowed("unknown"))
.auth_result(format!("failed:{category}"))
.build(),
);
let ctx = AuthContext::new(
None,
client_ip,
request_id,
AuthResult::Failed(FailedReason::new(reason)),
);
let mut response = error.into_response();
response.extensions_mut().insert(ctx);
response
}
fn audit_success_or_anonymous(
audit: Option<&dyn AuditLogger>,
request_id: &str,
client_ip: &str,
identity: Option<&Identity>,
) {
let Some(logger) = audit else { return };
let event = match identity {
Some(id) => AuthAuditEvent {
request_id: request_id.to_owned(),
client_ip: client_ip.to_owned(),
auth_method: Some(id.auth_method().to_owned()),
subject: Some(id.subject().to_owned()),
outcome: AuthAuditOutcome::Success,
},
None => AuthAuditEvent {
request_id: request_id.to_owned(),
client_ip: client_ip.to_owned(),
auth_method: None,
subject: None,
outcome: AuthAuditOutcome::Anonymous,
},
};
logger.log_auth_event(&event);
}
fn build_success_summary(client_ip: &str, identity: Option<&Identity>) -> AuthSummary {
match identity {
Some(id) => {
let subject = id.subject().to_owned();
AuthSummary::builder()
.client_id(subject.clone())
.client_ip(client_ip.to_owned())
.user_id(subject)
.auth_method(Cow::Borrowed(id.auth_method()))
.auth_result(format!("authenticated:{}", id.auth_method()))
.privileged(id.is_privileged())
.build()
}
None => AuthSummary::builder()
.client_ip(Cow::Owned(client_ip.to_owned()))
.user_id(Cow::Borrowed("anonymous"))
.auth_method(Cow::Borrowed("none"))
.auth_result(Cow::Borrowed("anonymous"))
.build(),
}
}
fn populate_logging_context(request: &Request, summary: AuthSummary) {
if let Some(logging_ctx) = request.extensions().get::<LoggingContext>() {
logging_ctx.set(summary);
}
}
fn extract_client_ip(request: &Request) -> String {
if let Some(real_ip) = request.headers().get("x-real-ip")
&& let Ok(val) = real_ip.to_str()
{
let trimmed = val.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
if let Some(forwarded) = request.headers().get("x-forwarded-for")
&& let Ok(val) = forwarded.to_str()
&& let Some(first_ip) = val.split(',').next()
{
let trimmed = first_ip.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
request
.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map_or_else(|| "unknown".to_string(), |ci| ci.0.ip().to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use http::Request as HttpRequest;
#[test]
fn extract_ip_from_x_forwarded_for() {
let req = HttpRequest::builder()
.header("x-forwarded-for", "10.0.0.1, 192.168.1.1")
.body(Body::empty())
.expect("valid request");
let ip = extract_client_ip(&req);
assert_eq!(ip, "10.0.0.1");
}
#[test]
fn extract_ip_from_x_forwarded_for_single() {
let req = HttpRequest::builder()
.header("x-forwarded-for", "203.0.113.50")
.body(Body::empty())
.expect("valid request");
let ip = extract_client_ip(&req);
assert_eq!(ip, "203.0.113.50");
}
#[test]
fn extract_ip_from_x_real_ip() {
let req = HttpRequest::builder()
.header("x-real-ip", "10.0.0.2")
.body(Body::empty())
.expect("valid request");
let ip = extract_client_ip(&req);
assert_eq!(ip, "10.0.0.2");
}
#[test]
fn extract_ip_x_real_ip_takes_priority() {
let req = HttpRequest::builder()
.header("x-forwarded-for", "10.0.0.1")
.header("x-real-ip", "10.0.0.2")
.body(Body::empty())
.expect("valid request");
let ip = extract_client_ip(&req);
assert_eq!(ip, "10.0.0.2");
}
#[test]
fn extract_ip_empty_forwarded_falls_through() {
let req = HttpRequest::builder()
.header("x-forwarded-for", "")
.header("x-real-ip", "10.0.0.2")
.body(Body::empty())
.expect("valid request");
let ip = extract_client_ip(&req);
assert_eq!(ip, "10.0.0.2");
}
#[test]
fn extract_ip_no_headers_returns_unknown() {
let req = HttpRequest::builder()
.body(Body::empty())
.expect("valid request");
let ip = extract_client_ip(&req);
assert_eq!(ip, "unknown");
}
}