1use axum::http::Method;
2use axum::response::IntoResponse;
3use std::{collections::HashMap, sync::Arc};
4
5use crate::middleware::common;
6
7use authn_resolver_sdk::{AuthNResolverClient, AuthNResolverError};
8use modkit::api::Problem;
9use modkit_security::SecurityContext;
10
11#[derive(Clone)]
13pub struct RouteMatcher {
14 matcher: matchit::Router<()>,
15}
16
17impl RouteMatcher {
18 fn new() -> Self {
19 Self {
20 matcher: matchit::Router::new(),
21 }
22 }
23
24 fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
25 self.matcher.insert(path, ())
26 }
27
28 fn find(&self, path: &str) -> bool {
29 self.matcher.at(path).is_ok()
30 }
31}
32
33#[derive(Clone)]
35pub struct PublicRouteMatcher {
36 matcher: matchit::Router<()>,
37}
38
39impl PublicRouteMatcher {
40 fn new() -> Self {
41 Self {
42 matcher: matchit::Router::new(),
43 }
44 }
45
46 fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
47 self.matcher.insert(path, ())
48 }
49
50 fn find(&self, path: &str) -> bool {
51 self.matcher.at(path).is_ok()
52 }
53}
54
55fn convert_axum_path_to_matchit(path: &str) -> String {
60 let mut result = String::with_capacity(path.len());
62 let mut chars = path.chars().peekable();
63
64 while let Some(ch) = chars.next() {
65 if ch == ':' {
66 result.push('{');
68 while matches!(chars.peek(), Some(c) if c.is_alphanumeric() || *c == '_') {
69 if let Some(c) = chars.next() {
70 result.push(c);
71 }
72 }
73 result.push('}');
74 } else {
75 result.push(ch);
76 }
77 }
78
79 result
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum AuthRequirement {
85 None,
87 Required,
89}
90
91#[derive(Clone)]
93pub struct GatewayRoutePolicy {
94 route_matchers: Arc<HashMap<Method, RouteMatcher>>,
95 public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
96 require_auth_by_default: bool,
97}
98
99impl GatewayRoutePolicy {
100 #[must_use]
101 pub fn new(
102 route_matchers: Arc<HashMap<Method, RouteMatcher>>,
103 public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
104 require_auth_by_default: bool,
105 ) -> Self {
106 Self {
107 route_matchers,
108 public_matchers,
109 require_auth_by_default,
110 }
111 }
112
113 #[must_use]
115 pub fn resolve(&self, method: &Method, path: &str) -> AuthRequirement {
116 let is_authenticated = self
118 .route_matchers
119 .get(method)
120 .is_some_and(|matcher| matcher.find(path));
121
122 let is_public = self
124 .public_matchers
125 .get(method)
126 .is_some_and(|matcher| matcher.find(path));
127
128 if is_public {
130 return AuthRequirement::None;
131 }
132
133 if is_authenticated {
134 return AuthRequirement::Required;
135 }
136
137 if self.require_auth_by_default {
138 AuthRequirement::Required
139 } else {
140 AuthRequirement::None
141 }
142 }
143}
144
145#[derive(Clone)]
147pub struct AuthState {
148 pub authn_client: Arc<dyn AuthNResolverClient>,
149 pub route_policy: GatewayRoutePolicy,
150}
151
152#[allow(clippy::implicit_hasher)]
158pub fn build_route_policy(
159 cfg: &crate::config::ApiGatewayConfig,
160 authenticated_routes: std::collections::HashSet<(Method, String)>,
161 public_routes: std::collections::HashSet<(Method, String)>,
162) -> Result<GatewayRoutePolicy, anyhow::Error> {
163 let mut route_matchers_map: HashMap<Method, RouteMatcher> = HashMap::new();
165
166 for (method, path) in authenticated_routes {
167 let matcher = route_matchers_map
168 .entry(method)
169 .or_insert_with(RouteMatcher::new);
170 let matchit_path = convert_axum_path_to_matchit(&path);
172 matcher
173 .insert(&matchit_path)
174 .map_err(|e| anyhow::anyhow!("Failed to insert route pattern '{path}': {e}"))?;
175 }
176
177 let mut public_matchers_map: HashMap<Method, PublicRouteMatcher> = HashMap::new();
179
180 for (method, path) in public_routes {
181 let matcher = public_matchers_map
182 .entry(method)
183 .or_insert_with(PublicRouteMatcher::new);
184 let matchit_path = convert_axum_path_to_matchit(&path);
186 matcher
187 .insert(&matchit_path)
188 .map_err(|e| anyhow::anyhow!("Failed to insert public route pattern '{path}': {e}"))?;
189 }
190
191 Ok(GatewayRoutePolicy::new(
192 Arc::new(route_matchers_map),
193 Arc::new(public_matchers_map),
194 cfg.require_auth_by_default,
195 ))
196}
197
198pub async fn authn_middleware(
206 axum::extract::State(state): axum::extract::State<AuthState>,
207 mut req: axum::extract::Request,
208 next: axum::middleware::Next,
209) -> axum::response::Response {
210 if is_preflight_request(req.method(), req.headers()) {
213 req.extensions_mut().insert(SecurityContext::anonymous());
214 return next.run(req).await;
215 }
216
217 let path = req
218 .extensions()
219 .get::<axum::extract::MatchedPath>()
220 .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
221
222 let path = common::resolve_path(&req, path.as_str());
223
224 let requirement = state.route_policy.resolve(req.method(), path.as_str());
225
226 match requirement {
227 AuthRequirement::None => {
228 req.extensions_mut().insert(SecurityContext::anonymous());
229 next.run(req).await
230 }
231 AuthRequirement::Required => {
232 let Some(token) = extract_bearer_token(req.headers()) else {
233 return Problem::new(
234 axum::http::StatusCode::UNAUTHORIZED,
235 "Unauthorized",
236 "Missing or invalid Authorization header",
237 )
238 .into_response();
239 };
240
241 match state.authn_client.authenticate(token).await {
242 Ok(result) => {
243 req.extensions_mut().insert(result.security_context);
244 next.run(req).await
245 }
246 Err(err) => authn_error_to_response(&err),
247 }
248 }
249 }
250}
251
252fn authn_error_to_response(err: &AuthNResolverError) -> axum::response::Response {
254 log_authn_error(err);
255 let (status, title, detail) = match err {
256 AuthNResolverError::Unauthorized(_) => (
257 axum::http::StatusCode::UNAUTHORIZED,
258 "Unauthorized",
259 "Authentication failed",
260 ),
261 AuthNResolverError::NoPluginAvailable | AuthNResolverError::ServiceUnavailable(_) => (
262 axum::http::StatusCode::SERVICE_UNAVAILABLE,
263 "Service Unavailable",
264 "Authentication service unavailable",
265 ),
266 AuthNResolverError::TokenAcquisitionFailed(_) | AuthNResolverError::Internal(_) => (
267 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
268 "Internal Server Error",
269 "Internal authentication error",
270 ),
271 };
272 Problem::new(status, title, detail).into_response()
273}
274
275#[allow(clippy::cognitive_complexity)]
279fn log_authn_error(err: &AuthNResolverError) {
280 match err {
281 AuthNResolverError::Unauthorized(msg) => tracing::debug!("AuthN rejected: {msg}"),
282 AuthNResolverError::NoPluginAvailable => tracing::error!("No AuthN plugin available"),
283 AuthNResolverError::ServiceUnavailable(msg) => {
284 tracing::error!("AuthN service unavailable: {msg}");
285 }
286 AuthNResolverError::TokenAcquisitionFailed(msg) => {
287 tracing::error!("AuthN token acquisition failed: {msg}");
288 }
289 AuthNResolverError::Internal(msg) => tracing::error!("AuthN internal error: {msg}"),
290 }
291}
292
293fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
295 headers
296 .get(axum::http::header::AUTHORIZATION)
297 .and_then(|v| v.to_str().ok())
298 .and_then(|s| s.strip_prefix("Bearer ").map(str::trim))
299}
300
301fn is_preflight_request(method: &Method, headers: &axum::http::HeaderMap) -> bool {
307 method == Method::OPTIONS
308 && headers.contains_key(axum::http::header::ORIGIN)
309 && headers.contains_key(axum::http::header::ACCESS_CONTROL_REQUEST_METHOD)
310}
311
312#[cfg(test)]
313#[cfg_attr(coverage_nightly, coverage(off))]
314mod tests {
315 use super::*;
316 use axum::http::Method;
317
318 fn build_test_policy(
320 route_matchers: HashMap<Method, RouteMatcher>,
321 public_matchers: HashMap<Method, PublicRouteMatcher>,
322 require_auth_by_default: bool,
323 ) -> GatewayRoutePolicy {
324 GatewayRoutePolicy::new(
325 Arc::new(route_matchers),
326 Arc::new(public_matchers),
327 require_auth_by_default,
328 )
329 }
330
331 #[test]
332 fn test_convert_axum_path_to_matchit() {
333 assert_eq!(convert_axum_path_to_matchit("/users/:id"), "/users/{id}");
334 assert_eq!(
335 convert_axum_path_to_matchit("/posts/:post_id/comments/:comment_id"),
336 "/posts/{post_id}/comments/{comment_id}"
337 );
338 assert_eq!(convert_axum_path_to_matchit("/health"), "/health"); assert_eq!(
340 convert_axum_path_to_matchit("/api/v1/:resource/:id/status"),
341 "/api/v1/{resource}/{id}/status"
342 );
343 }
344
345 #[test]
346 fn test_matchit_router_with_params() {
347 let mut router = matchit::Router::new();
349 router.insert("/users/{id}", "user_route").unwrap();
350
351 let result = router.at("/users/42");
352 assert!(
353 result.is_ok(),
354 "matchit should match /users/{{id}} against /users/42"
355 );
356 assert_eq!(*result.unwrap().value, "user_route");
357 }
358
359 #[test]
360 fn explicit_public_route_with_path_params_returns_none() {
361 let mut public_matchers = HashMap::new();
362 let mut matcher = PublicRouteMatcher::new();
363 matcher.insert("/users/{id}").unwrap();
365
366 public_matchers.insert(Method::GET, matcher);
367
368 let policy = build_test_policy(HashMap::new(), public_matchers, true);
369
370 let result = policy.resolve(&Method::GET, "/users/42");
372 assert_eq!(result, AuthRequirement::None);
373 }
374
375 #[test]
376 fn explicit_public_route_exact_match_returns_none() {
377 let mut public_matchers = HashMap::new();
378 let mut matcher = PublicRouteMatcher::new();
379 matcher.insert("/health").unwrap();
380 public_matchers.insert(Method::GET, matcher);
381
382 let policy = build_test_policy(HashMap::new(), public_matchers, true);
383
384 let result = policy.resolve(&Method::GET, "/health");
385 assert_eq!(result, AuthRequirement::None);
386 }
387
388 #[test]
389 fn explicit_authenticated_route_returns_required() {
390 let mut route_matchers = HashMap::new();
391 let mut matcher = RouteMatcher::new();
392 matcher.insert("/admin/metrics").unwrap();
393 route_matchers.insert(Method::GET, matcher);
394
395 let policy = build_test_policy(route_matchers, HashMap::new(), false);
396
397 let result = policy.resolve(&Method::GET, "/admin/metrics");
398 assert_eq!(result, AuthRequirement::Required);
399 }
400
401 #[test]
402 fn route_without_requirement_with_require_auth_by_default_returns_required() {
403 let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
404
405 let result = policy.resolve(&Method::GET, "/profile");
406 assert_eq!(result, AuthRequirement::Required);
407 }
408
409 #[test]
410 fn route_without_requirement_without_require_auth_by_default_returns_none() {
411 let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
412
413 let result = policy.resolve(&Method::GET, "/profile");
414 assert_eq!(result, AuthRequirement::None);
415 }
416
417 #[test]
418 fn unknown_route_with_require_auth_by_default_true_returns_required() {
419 let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
420
421 let result = policy.resolve(&Method::POST, "/unknown");
422 assert_eq!(result, AuthRequirement::Required);
423 }
424
425 #[test]
426 fn unknown_route_with_require_auth_by_default_false_returns_none() {
427 let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
428
429 let result = policy.resolve(&Method::POST, "/unknown");
430 assert_eq!(result, AuthRequirement::None);
431 }
432
433 #[test]
434 fn public_route_overrides_require_auth_by_default() {
435 let mut public_matchers = HashMap::new();
436 let mut matcher = PublicRouteMatcher::new();
437 matcher.insert("/public").unwrap();
438 public_matchers.insert(Method::GET, matcher);
439
440 let policy = build_test_policy(HashMap::new(), public_matchers, true);
441
442 let result = policy.resolve(&Method::GET, "/public");
443 assert_eq!(result, AuthRequirement::None);
444 }
445
446 #[test]
447 fn authenticated_route_has_priority_over_default() {
448 let mut route_matchers = HashMap::new();
449 let mut matcher = RouteMatcher::new();
450 matcher.insert("/users/{id}").unwrap();
452 route_matchers.insert(Method::GET, matcher);
453
454 let policy = build_test_policy(route_matchers, HashMap::new(), false);
455
456 let result = policy.resolve(&Method::GET, "/users/123");
457 assert_eq!(result, AuthRequirement::Required);
458 }
459
460 #[test]
461 fn explicit_public_overrides_wildcard_authenticated_fallback() {
462 let mut public_matchers = HashMap::new();
466 let mut public_matcher = PublicRouteMatcher::new();
467 public_matcher.insert("/v1/auth/config").unwrap();
468 public_matchers.insert(Method::GET, public_matcher);
469
470 let mut route_matchers = HashMap::new();
471 let mut auth_matcher = RouteMatcher::new();
472 auth_matcher.insert("/{*rest}").unwrap();
473 route_matchers.insert(Method::GET, auth_matcher);
474
475 let policy = build_test_policy(route_matchers, public_matchers, true);
476
477 assert_eq!(
478 policy.resolve(&Method::GET, "/v1/auth/config"),
479 AuthRequirement::None,
480 "explicit public must win over wildcard authenticated fallback"
481 );
482 assert_eq!(
484 policy.resolve(&Method::GET, "/some/other/path"),
485 AuthRequirement::Required,
486 "wildcard authenticated still applies to non-public paths"
487 );
488 }
489
490 #[test]
491 fn different_methods_resolve_independently() {
492 let mut route_matchers = HashMap::new();
493
494 let mut get_matcher = RouteMatcher::new();
496 get_matcher.insert("/user-management/v1/users").unwrap();
497 route_matchers.insert(Method::GET, get_matcher);
498
499 let policy = build_test_policy(route_matchers, HashMap::new(), false);
501
502 let get_result = policy.resolve(&Method::GET, "/user-management/v1/users");
504 assert_eq!(get_result, AuthRequirement::Required);
505
506 let post_result = policy.resolve(&Method::POST, "/user-management/v1/users");
508 assert_eq!(post_result, AuthRequirement::None);
509 }
510}