Skip to main content

api_gateway/middleware/
auth.rs

1use axum::http::Method;
2use axum::response::IntoResponse;
3use std::{collections::HashMap, sync::Arc};
4
5use crate::middleware::common;
6
7use authn_resolver_sdk::{AuthNResolverClient, AuthNResolverError};
8use modkit::api::Problem;
9use modkit_security::SecurityContext;
10
11/// Route matcher for a specific HTTP method (authenticated routes).
12#[derive(Clone)]
13pub struct RouteMatcher {
14    matcher: matchit::Router<()>,
15}
16
17impl RouteMatcher {
18    fn new() -> Self {
19        Self {
20            matcher: matchit::Router::new(),
21        }
22    }
23
24    fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
25        self.matcher.insert(path, ())
26    }
27
28    fn find(&self, path: &str) -> bool {
29        self.matcher.at(path).is_ok()
30    }
31}
32
33/// Public route matcher for explicitly public routes
34#[derive(Clone)]
35pub struct PublicRouteMatcher {
36    matcher: matchit::Router<()>,
37}
38
39impl PublicRouteMatcher {
40    fn new() -> Self {
41        Self {
42            matcher: matchit::Router::new(),
43        }
44    }
45
46    fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
47        self.matcher.insert(path, ())
48    }
49
50    fn find(&self, path: &str) -> bool {
51        self.matcher.at(path).is_ok()
52    }
53}
54
55/// Convert Axum path syntax `:param` to matchit syntax `{param}`
56///
57/// Axum uses `:id` for path parameters, but matchit 0.8 uses `{id}`.
58/// This function converts between the two syntaxes.
59fn convert_axum_path_to_matchit(path: &str) -> String {
60    // Simple regex-free approach: find :word and replace with {word}
61    let mut result = String::with_capacity(path.len());
62    let mut chars = path.chars().peekable();
63
64    while let Some(ch) = chars.next() {
65        if ch == ':' {
66            // Start of a parameter - collect the parameter name
67            result.push('{');
68            while matches!(chars.peek(), Some(c) if c.is_alphanumeric() || *c == '_') {
69                if let Some(c) = chars.next() {
70                    result.push(c);
71                }
72            }
73            result.push('}');
74        } else {
75            result.push(ch);
76        }
77    }
78
79    result
80}
81
82/// Whether a route requires authentication.
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum AuthRequirement {
85    /// No authentication required (public route).
86    None,
87    /// Authentication required.
88    Required,
89}
90
91/// Gateway-specific route policy implementation
92#[derive(Clone)]
93pub struct GatewayRoutePolicy {
94    route_matchers: Arc<HashMap<Method, RouteMatcher>>,
95    public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
96    require_auth_by_default: bool,
97}
98
99impl GatewayRoutePolicy {
100    #[must_use]
101    pub fn new(
102        route_matchers: Arc<HashMap<Method, RouteMatcher>>,
103        public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
104        require_auth_by_default: bool,
105    ) -> Self {
106        Self {
107            route_matchers,
108            public_matchers,
109            require_auth_by_default,
110        }
111    }
112
113    /// Resolve the authentication requirement for a given (method, path).
114    #[must_use]
115    pub fn resolve(&self, method: &Method, path: &str) -> AuthRequirement {
116        // Check if route is explicitly authenticated
117        let is_authenticated = self
118            .route_matchers
119            .get(method)
120            .is_some_and(|matcher| matcher.find(path));
121
122        // Check if route is explicitly public using pattern matching
123        let is_public = self
124            .public_matchers
125            .get(method)
126            .is_some_and(|matcher| matcher.find(path));
127
128        // Public routes should not be forced to auth by default
129        if is_public {
130            return AuthRequirement::None;
131        }
132
133        if is_authenticated {
134            return AuthRequirement::Required;
135        }
136
137        if self.require_auth_by_default {
138            AuthRequirement::Required
139        } else {
140            AuthRequirement::None
141        }
142    }
143}
144
145/// Shared state for the authentication middleware.
146#[derive(Clone)]
147pub struct AuthState {
148    pub authn_client: Arc<dyn AuthNResolverClient>,
149    pub route_policy: GatewayRoutePolicy,
150}
151
152/// Helper to build `GatewayRoutePolicy` from operation requirements.
153///
154/// # Errors
155///
156/// Returns an error if a route pattern cannot be inserted into the matcher.
157#[allow(clippy::implicit_hasher)]
158pub fn build_route_policy(
159    cfg: &crate::config::ApiGatewayConfig,
160    authenticated_routes: std::collections::HashSet<(Method, String)>,
161    public_routes: std::collections::HashSet<(Method, String)>,
162) -> Result<GatewayRoutePolicy, anyhow::Error> {
163    // Build route matchers per HTTP method (authenticated routes)
164    let mut route_matchers_map: HashMap<Method, RouteMatcher> = HashMap::new();
165
166    for (method, path) in authenticated_routes {
167        let matcher = route_matchers_map
168            .entry(method)
169            .or_insert_with(RouteMatcher::new);
170        // Convert Axum path syntax (:param) to matchit syntax ({param})
171        let matchit_path = convert_axum_path_to_matchit(&path);
172        matcher
173            .insert(&matchit_path)
174            .map_err(|e| anyhow::anyhow!("Failed to insert route pattern '{path}': {e}"))?;
175    }
176
177    // Build public matchers per HTTP method
178    let mut public_matchers_map: HashMap<Method, PublicRouteMatcher> = HashMap::new();
179
180    for (method, path) in public_routes {
181        let matcher = public_matchers_map
182            .entry(method)
183            .or_insert_with(PublicRouteMatcher::new);
184        // Convert Axum path syntax (:param) to matchit syntax ({param})
185        let matchit_path = convert_axum_path_to_matchit(&path);
186        matcher
187            .insert(&matchit_path)
188            .map_err(|e| anyhow::anyhow!("Failed to insert public route pattern '{path}': {e}"))?;
189    }
190
191    Ok(GatewayRoutePolicy::new(
192        Arc::new(route_matchers_map),
193        Arc::new(public_matchers_map),
194        cfg.require_auth_by_default,
195    ))
196}
197
198/// Authentication middleware that uses the `AuthN` Resolver to validate bearer tokens.
199///
200/// For each request:
201/// 1. Skips CORS preflight requests
202/// 2. Resolves the route's auth requirement via `GatewayRoutePolicy`
203/// 3. For public routes: inserts anonymous `SecurityContext`
204/// 4. For required routes: extracts bearer token, calls `AuthN` Resolver, inserts `SecurityContext`
205pub async fn authn_middleware(
206    axum::extract::State(state): axum::extract::State<AuthState>,
207    mut req: axum::extract::Request,
208    next: axum::middleware::Next,
209) -> axum::response::Response {
210    // Skip CORS preflight — insert anonymous SecurityContext so downstream
211    // handlers that extract Extension<SecurityContext> don't panic.
212    if is_preflight_request(req.method(), req.headers()) {
213        req.extensions_mut().insert(SecurityContext::anonymous());
214        return next.run(req).await;
215    }
216
217    let path = req
218        .extensions()
219        .get::<axum::extract::MatchedPath>()
220        .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
221
222    let path = common::resolve_path(&req, path.as_str());
223
224    let requirement = state.route_policy.resolve(req.method(), path.as_str());
225
226    match requirement {
227        AuthRequirement::None => {
228            req.extensions_mut().insert(SecurityContext::anonymous());
229            next.run(req).await
230        }
231        AuthRequirement::Required => {
232            let Some(token) = extract_bearer_token(req.headers()) else {
233                return Problem::new(
234                    axum::http::StatusCode::UNAUTHORIZED,
235                    "Unauthorized",
236                    "Missing or invalid Authorization header",
237                )
238                .into_response();
239            };
240
241            match state.authn_client.authenticate(token).await {
242                Ok(result) => {
243                    req.extensions_mut().insert(result.security_context);
244                    next.run(req).await
245                }
246                Err(err) => authn_error_to_response(&err),
247            }
248        }
249    }
250}
251
252/// Convert `AuthNResolverError` to an RFC-9457 Problem Details response.
253fn authn_error_to_response(err: &AuthNResolverError) -> axum::response::Response {
254    log_authn_error(err);
255    let (status, title, detail) = match err {
256        AuthNResolverError::Unauthorized(_) => (
257            axum::http::StatusCode::UNAUTHORIZED,
258            "Unauthorized",
259            "Authentication failed",
260        ),
261        AuthNResolverError::NoPluginAvailable | AuthNResolverError::ServiceUnavailable(_) => (
262            axum::http::StatusCode::SERVICE_UNAVAILABLE,
263            "Service Unavailable",
264            "Authentication service unavailable",
265        ),
266        AuthNResolverError::TokenAcquisitionFailed(_) | AuthNResolverError::Internal(_) => (
267            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
268            "Internal Server Error",
269            "Internal authentication error",
270        ),
271    };
272    Problem::new(status, title, detail).into_response()
273}
274
275/// Log authentication errors at appropriate levels.
276///
277/// Cognitive complexity is inflated by tracing macro expansion.
278#[allow(clippy::cognitive_complexity)]
279fn log_authn_error(err: &AuthNResolverError) {
280    match err {
281        AuthNResolverError::Unauthorized(msg) => tracing::debug!("AuthN rejected: {msg}"),
282        AuthNResolverError::NoPluginAvailable => tracing::error!("No AuthN plugin available"),
283        AuthNResolverError::ServiceUnavailable(msg) => {
284            tracing::error!("AuthN service unavailable: {msg}");
285        }
286        AuthNResolverError::TokenAcquisitionFailed(msg) => {
287            tracing::error!("AuthN token acquisition failed: {msg}");
288        }
289        AuthNResolverError::Internal(msg) => tracing::error!("AuthN internal error: {msg}"),
290    }
291}
292
293/// Extract Bearer token from Authorization header
294fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
295    headers
296        .get(axum::http::header::AUTHORIZATION)
297        .and_then(|v| v.to_str().ok())
298        .and_then(|s| s.strip_prefix("Bearer ").map(str::trim))
299}
300
301/// Check if this is a CORS preflight request
302///
303/// Preflight requests are OPTIONS requests with:
304/// - Origin header present
305/// - Access-Control-Request-Method header present
306fn is_preflight_request(method: &Method, headers: &axum::http::HeaderMap) -> bool {
307    method == Method::OPTIONS
308        && headers.contains_key(axum::http::header::ORIGIN)
309        && headers.contains_key(axum::http::header::ACCESS_CONTROL_REQUEST_METHOD)
310}
311
312#[cfg(test)]
313#[cfg_attr(coverage_nightly, coverage(off))]
314mod tests {
315    use super::*;
316    use axum::http::Method;
317
318    /// Helper to build `GatewayRoutePolicy` with given matchers
319    fn build_test_policy(
320        route_matchers: HashMap<Method, RouteMatcher>,
321        public_matchers: HashMap<Method, PublicRouteMatcher>,
322        require_auth_by_default: bool,
323    ) -> GatewayRoutePolicy {
324        GatewayRoutePolicy::new(
325            Arc::new(route_matchers),
326            Arc::new(public_matchers),
327            require_auth_by_default,
328        )
329    }
330
331    #[test]
332    fn test_convert_axum_path_to_matchit() {
333        assert_eq!(convert_axum_path_to_matchit("/users/:id"), "/users/{id}");
334        assert_eq!(
335            convert_axum_path_to_matchit("/posts/:post_id/comments/:comment_id"),
336            "/posts/{post_id}/comments/{comment_id}"
337        );
338        assert_eq!(convert_axum_path_to_matchit("/health"), "/health"); // No params
339        assert_eq!(
340            convert_axum_path_to_matchit("/api/v1/:resource/:id/status"),
341            "/api/v1/{resource}/{id}/status"
342        );
343    }
344
345    #[test]
346    fn test_matchit_router_with_params() {
347        // matchit 0.8 uses {param} syntax for path parameters (NOT :param)
348        let mut router = matchit::Router::new();
349        router.insert("/users/{id}", "user_route").unwrap();
350
351        let result = router.at("/users/42");
352        assert!(
353            result.is_ok(),
354            "matchit should match /users/{{id}} against /users/42"
355        );
356        assert_eq!(*result.unwrap().value, "user_route");
357    }
358
359    #[test]
360    fn explicit_public_route_with_path_params_returns_none() {
361        let mut public_matchers = HashMap::new();
362        let mut matcher = PublicRouteMatcher::new();
363        // matchit 0.8 uses {param} syntax (Axum uses :param, so conversion needed in production)
364        matcher.insert("/users/{id}").unwrap();
365
366        public_matchers.insert(Method::GET, matcher);
367
368        let policy = build_test_policy(HashMap::new(), public_matchers, true);
369
370        // Path parameters should match concrete values
371        let result = policy.resolve(&Method::GET, "/users/42");
372        assert_eq!(result, AuthRequirement::None);
373    }
374
375    #[test]
376    fn explicit_public_route_exact_match_returns_none() {
377        let mut public_matchers = HashMap::new();
378        let mut matcher = PublicRouteMatcher::new();
379        matcher.insert("/health").unwrap();
380        public_matchers.insert(Method::GET, matcher);
381
382        let policy = build_test_policy(HashMap::new(), public_matchers, true);
383
384        let result = policy.resolve(&Method::GET, "/health");
385        assert_eq!(result, AuthRequirement::None);
386    }
387
388    #[test]
389    fn explicit_authenticated_route_returns_required() {
390        let mut route_matchers = HashMap::new();
391        let mut matcher = RouteMatcher::new();
392        matcher.insert("/admin/metrics").unwrap();
393        route_matchers.insert(Method::GET, matcher);
394
395        let policy = build_test_policy(route_matchers, HashMap::new(), false);
396
397        let result = policy.resolve(&Method::GET, "/admin/metrics");
398        assert_eq!(result, AuthRequirement::Required);
399    }
400
401    #[test]
402    fn route_without_requirement_with_require_auth_by_default_returns_required() {
403        let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
404
405        let result = policy.resolve(&Method::GET, "/profile");
406        assert_eq!(result, AuthRequirement::Required);
407    }
408
409    #[test]
410    fn route_without_requirement_without_require_auth_by_default_returns_none() {
411        let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
412
413        let result = policy.resolve(&Method::GET, "/profile");
414        assert_eq!(result, AuthRequirement::None);
415    }
416
417    #[test]
418    fn unknown_route_with_require_auth_by_default_true_returns_required() {
419        let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
420
421        let result = policy.resolve(&Method::POST, "/unknown");
422        assert_eq!(result, AuthRequirement::Required);
423    }
424
425    #[test]
426    fn unknown_route_with_require_auth_by_default_false_returns_none() {
427        let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
428
429        let result = policy.resolve(&Method::POST, "/unknown");
430        assert_eq!(result, AuthRequirement::None);
431    }
432
433    #[test]
434    fn public_route_overrides_require_auth_by_default() {
435        let mut public_matchers = HashMap::new();
436        let mut matcher = PublicRouteMatcher::new();
437        matcher.insert("/public").unwrap();
438        public_matchers.insert(Method::GET, matcher);
439
440        let policy = build_test_policy(HashMap::new(), public_matchers, true);
441
442        let result = policy.resolve(&Method::GET, "/public");
443        assert_eq!(result, AuthRequirement::None);
444    }
445
446    #[test]
447    fn authenticated_route_has_priority_over_default() {
448        let mut route_matchers = HashMap::new();
449        let mut matcher = RouteMatcher::new();
450        // matchit 0.8 uses {param} syntax
451        matcher.insert("/users/{id}").unwrap();
452        route_matchers.insert(Method::GET, matcher);
453
454        let policy = build_test_policy(route_matchers, HashMap::new(), false);
455
456        let result = policy.resolve(&Method::GET, "/users/123");
457        assert_eq!(result, AuthRequirement::Required);
458    }
459
460    #[test]
461    fn explicit_public_overrides_wildcard_authenticated_fallback() {
462        // When a gateway registers a wildcard authenticated 404 the fallback
463        // like `/{*rest}` (used to convert anonymous 404s to 401s),
464        // grabs the public routes too, causing 401 on them
465        let mut public_matchers = HashMap::new();
466        let mut public_matcher = PublicRouteMatcher::new();
467        public_matcher.insert("/v1/auth/config").unwrap();
468        public_matchers.insert(Method::GET, public_matcher);
469
470        let mut route_matchers = HashMap::new();
471        let mut auth_matcher = RouteMatcher::new();
472        auth_matcher.insert("/{*rest}").unwrap();
473        route_matchers.insert(Method::GET, auth_matcher);
474
475        let policy = build_test_policy(route_matchers, public_matchers, true);
476
477        assert_eq!(
478            policy.resolve(&Method::GET, "/v1/auth/config"),
479            AuthRequirement::None,
480            "explicit public must win over wildcard authenticated fallback"
481        );
482        // Sanity: a path that only matches the wildcard fallback still requires auth.
483        assert_eq!(
484            policy.resolve(&Method::GET, "/some/other/path"),
485            AuthRequirement::Required,
486            "wildcard authenticated still applies to non-public paths"
487        );
488    }
489
490    #[test]
491    fn different_methods_resolve_independently() {
492        let mut route_matchers = HashMap::new();
493
494        // GET /users is authenticated
495        let mut get_matcher = RouteMatcher::new();
496        get_matcher.insert("/user-management/v1/users").unwrap();
497        route_matchers.insert(Method::GET, get_matcher);
498
499        // POST /users is not in matchers
500        let policy = build_test_policy(route_matchers, HashMap::new(), false);
501
502        // GET should be authenticated
503        let get_result = policy.resolve(&Method::GET, "/user-management/v1/users");
504        assert_eq!(get_result, AuthRequirement::Required);
505
506        // POST should be public (no requirement, require_auth_by_default=false)
507        let post_result = policy.resolve(&Method::POST, "/user-management/v1/users");
508        assert_eq!(post_result, AuthRequirement::None);
509    }
510}