use axum::http::Method;
use axum::response::IntoResponse;
use std::{collections::HashMap, sync::Arc};
use crate::middleware::common;
use authn_resolver_sdk::{AuthNResolverClient, AuthNResolverError};
use modkit::api::Problem;
use modkit_security::SecurityContext;
#[derive(Clone)]
pub struct RouteMatcher {
matcher: matchit::Router<()>,
}
impl RouteMatcher {
fn new() -> Self {
Self {
matcher: matchit::Router::new(),
}
}
fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
self.matcher.insert(path, ())
}
fn find(&self, path: &str) -> bool {
self.matcher.at(path).is_ok()
}
}
#[derive(Clone)]
pub struct PublicRouteMatcher {
matcher: matchit::Router<()>,
}
impl PublicRouteMatcher {
fn new() -> Self {
Self {
matcher: matchit::Router::new(),
}
}
fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
self.matcher.insert(path, ())
}
fn find(&self, path: &str) -> bool {
self.matcher.at(path).is_ok()
}
}
fn convert_axum_path_to_matchit(path: &str) -> String {
let mut result = String::with_capacity(path.len());
let mut chars = path.chars().peekable();
while let Some(ch) = chars.next() {
if ch == ':' {
result.push('{');
while matches!(chars.peek(), Some(c) if c.is_alphanumeric() || *c == '_') {
if let Some(c) = chars.next() {
result.push(c);
}
}
result.push('}');
} else {
result.push(ch);
}
}
result
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthRequirement {
None,
Required,
}
#[derive(Clone)]
pub struct GatewayRoutePolicy {
route_matchers: Arc<HashMap<Method, RouteMatcher>>,
public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
require_auth_by_default: bool,
}
impl GatewayRoutePolicy {
#[must_use]
pub fn new(
route_matchers: Arc<HashMap<Method, RouteMatcher>>,
public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
require_auth_by_default: bool,
) -> Self {
Self {
route_matchers,
public_matchers,
require_auth_by_default,
}
}
#[must_use]
pub fn resolve(&self, method: &Method, path: &str) -> AuthRequirement {
let is_authenticated = self
.route_matchers
.get(method)
.is_some_and(|matcher| matcher.find(path));
let is_public = self
.public_matchers
.get(method)
.is_some_and(|matcher| matcher.find(path));
let needs_authn = is_authenticated || (self.require_auth_by_default && !is_public);
if needs_authn {
AuthRequirement::Required
} else {
AuthRequirement::None
}
}
}
#[derive(Clone)]
pub struct AuthState {
pub authn_client: Arc<dyn AuthNResolverClient>,
pub route_policy: GatewayRoutePolicy,
}
#[allow(clippy::implicit_hasher)]
pub fn build_route_policy(
cfg: &crate::config::ApiGatewayConfig,
authenticated_routes: std::collections::HashSet<(Method, String)>,
public_routes: std::collections::HashSet<(Method, String)>,
) -> Result<GatewayRoutePolicy, anyhow::Error> {
let mut route_matchers_map: HashMap<Method, RouteMatcher> = HashMap::new();
for (method, path) in authenticated_routes {
let matcher = route_matchers_map
.entry(method)
.or_insert_with(RouteMatcher::new);
let matchit_path = convert_axum_path_to_matchit(&path);
matcher
.insert(&matchit_path)
.map_err(|e| anyhow::anyhow!("Failed to insert route pattern '{path}': {e}"))?;
}
let mut public_matchers_map: HashMap<Method, PublicRouteMatcher> = HashMap::new();
for (method, path) in public_routes {
let matcher = public_matchers_map
.entry(method)
.or_insert_with(PublicRouteMatcher::new);
let matchit_path = convert_axum_path_to_matchit(&path);
matcher
.insert(&matchit_path)
.map_err(|e| anyhow::anyhow!("Failed to insert public route pattern '{path}': {e}"))?;
}
Ok(GatewayRoutePolicy::new(
Arc::new(route_matchers_map),
Arc::new(public_matchers_map),
cfg.require_auth_by_default,
))
}
pub async fn authn_middleware(
axum::extract::State(state): axum::extract::State<AuthState>,
mut req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
if is_preflight_request(req.method(), req.headers()) {
req.extensions_mut().insert(SecurityContext::anonymous());
return next.run(req).await;
}
let path = req
.extensions()
.get::<axum::extract::MatchedPath>()
.map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
let path = common::resolve_path(&req, path.as_str());
let requirement = state.route_policy.resolve(req.method(), path.as_str());
match requirement {
AuthRequirement::None => {
req.extensions_mut().insert(SecurityContext::anonymous());
next.run(req).await
}
AuthRequirement::Required => {
let Some(token) = extract_bearer_token(req.headers()) else {
return Problem::new(
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized",
"Missing or invalid Authorization header",
)
.into_response();
};
match state.authn_client.authenticate(token).await {
Ok(result) => {
req.extensions_mut().insert(result.security_context);
next.run(req).await
}
Err(err) => authn_error_to_response(&err),
}
}
}
}
fn authn_error_to_response(err: &AuthNResolverError) -> axum::response::Response {
log_authn_error(err);
let (status, title, detail) = match err {
AuthNResolverError::Unauthorized(_) => (
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized",
"Authentication failed",
),
AuthNResolverError::NoPluginAvailable | AuthNResolverError::ServiceUnavailable(_) => (
axum::http::StatusCode::SERVICE_UNAVAILABLE,
"Service Unavailable",
"Authentication service unavailable",
),
AuthNResolverError::TokenAcquisitionFailed(_) | AuthNResolverError::Internal(_) => (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
"Internal authentication error",
),
};
Problem::new(status, title, detail).into_response()
}
#[allow(clippy::cognitive_complexity)]
fn log_authn_error(err: &AuthNResolverError) {
match err {
AuthNResolverError::Unauthorized(msg) => tracing::debug!("AuthN rejected: {msg}"),
AuthNResolverError::NoPluginAvailable => tracing::error!("No AuthN plugin available"),
AuthNResolverError::ServiceUnavailable(msg) => {
tracing::error!("AuthN service unavailable: {msg}");
}
AuthNResolverError::TokenAcquisitionFailed(msg) => {
tracing::error!("AuthN token acquisition failed: {msg}");
}
AuthNResolverError::Internal(msg) => tracing::error!("AuthN internal error: {msg}"),
}
}
fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer ").map(str::trim))
}
fn is_preflight_request(method: &Method, headers: &axum::http::HeaderMap) -> bool {
method == Method::OPTIONS
&& headers.contains_key(axum::http::header::ORIGIN)
&& headers.contains_key(axum::http::header::ACCESS_CONTROL_REQUEST_METHOD)
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use axum::http::Method;
fn build_test_policy(
route_matchers: HashMap<Method, RouteMatcher>,
public_matchers: HashMap<Method, PublicRouteMatcher>,
require_auth_by_default: bool,
) -> GatewayRoutePolicy {
GatewayRoutePolicy::new(
Arc::new(route_matchers),
Arc::new(public_matchers),
require_auth_by_default,
)
}
#[test]
fn test_convert_axum_path_to_matchit() {
assert_eq!(convert_axum_path_to_matchit("/users/:id"), "/users/{id}");
assert_eq!(
convert_axum_path_to_matchit("/posts/:post_id/comments/:comment_id"),
"/posts/{post_id}/comments/{comment_id}"
);
assert_eq!(convert_axum_path_to_matchit("/health"), "/health"); assert_eq!(
convert_axum_path_to_matchit("/api/v1/:resource/:id/status"),
"/api/v1/{resource}/{id}/status"
);
}
#[test]
fn test_matchit_router_with_params() {
let mut router = matchit::Router::new();
router.insert("/users/{id}", "user_route").unwrap();
let result = router.at("/users/42");
assert!(
result.is_ok(),
"matchit should match /users/{{id}} against /users/42"
);
assert_eq!(*result.unwrap().value, "user_route");
}
#[test]
fn explicit_public_route_with_path_params_returns_none() {
let mut public_matchers = HashMap::new();
let mut matcher = PublicRouteMatcher::new();
matcher.insert("/users/{id}").unwrap();
public_matchers.insert(Method::GET, matcher);
let policy = build_test_policy(HashMap::new(), public_matchers, true);
let result = policy.resolve(&Method::GET, "/users/42");
assert_eq!(result, AuthRequirement::None);
}
#[test]
fn explicit_public_route_exact_match_returns_none() {
let mut public_matchers = HashMap::new();
let mut matcher = PublicRouteMatcher::new();
matcher.insert("/health").unwrap();
public_matchers.insert(Method::GET, matcher);
let policy = build_test_policy(HashMap::new(), public_matchers, true);
let result = policy.resolve(&Method::GET, "/health");
assert_eq!(result, AuthRequirement::None);
}
#[test]
fn explicit_authenticated_route_returns_required() {
let mut route_matchers = HashMap::new();
let mut matcher = RouteMatcher::new();
matcher.insert("/admin/metrics").unwrap();
route_matchers.insert(Method::GET, matcher);
let policy = build_test_policy(route_matchers, HashMap::new(), false);
let result = policy.resolve(&Method::GET, "/admin/metrics");
assert_eq!(result, AuthRequirement::Required);
}
#[test]
fn route_without_requirement_with_require_auth_by_default_returns_required() {
let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
let result = policy.resolve(&Method::GET, "/profile");
assert_eq!(result, AuthRequirement::Required);
}
#[test]
fn route_without_requirement_without_require_auth_by_default_returns_none() {
let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
let result = policy.resolve(&Method::GET, "/profile");
assert_eq!(result, AuthRequirement::None);
}
#[test]
fn unknown_route_with_require_auth_by_default_true_returns_required() {
let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
let result = policy.resolve(&Method::POST, "/unknown");
assert_eq!(result, AuthRequirement::Required);
}
#[test]
fn unknown_route_with_require_auth_by_default_false_returns_none() {
let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
let result = policy.resolve(&Method::POST, "/unknown");
assert_eq!(result, AuthRequirement::None);
}
#[test]
fn public_route_overrides_require_auth_by_default() {
let mut public_matchers = HashMap::new();
let mut matcher = PublicRouteMatcher::new();
matcher.insert("/public").unwrap();
public_matchers.insert(Method::GET, matcher);
let policy = build_test_policy(HashMap::new(), public_matchers, true);
let result = policy.resolve(&Method::GET, "/public");
assert_eq!(result, AuthRequirement::None);
}
#[test]
fn authenticated_route_has_priority_over_default() {
let mut route_matchers = HashMap::new();
let mut matcher = RouteMatcher::new();
matcher.insert("/users/{id}").unwrap();
route_matchers.insert(Method::GET, matcher);
let policy = build_test_policy(route_matchers, HashMap::new(), false);
let result = policy.resolve(&Method::GET, "/users/123");
assert_eq!(result, AuthRequirement::Required);
}
#[test]
fn different_methods_resolve_independently() {
let mut route_matchers = HashMap::new();
let mut get_matcher = RouteMatcher::new();
get_matcher.insert("/user-management/v1/users").unwrap();
route_matchers.insert(Method::GET, get_matcher);
let policy = build_test_policy(route_matchers, HashMap::new(), false);
let get_result = policy.resolve(&Method::GET, "/user-management/v1/users");
assert_eq!(get_result, AuthRequirement::Required);
let post_result = policy.resolve(&Method::POST, "/user-management/v1/users");
assert_eq!(post_result, AuthRequirement::None);
}
}