1use axum::http::Method;
2use axum::response::IntoResponse;
3use std::{collections::HashMap, sync::Arc};
4
5use crate::middleware::common;
6
7use authn_resolver_sdk::{AuthNResolverClient, AuthNResolverError};
8use modkit::api::Problem;
9use modkit_security::SecurityContext;
10
11#[derive(Clone)]
13pub struct RouteMatcher {
14 matcher: matchit::Router<()>,
15}
16
17impl RouteMatcher {
18 fn new() -> Self {
19 Self {
20 matcher: matchit::Router::new(),
21 }
22 }
23
24 fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
25 self.matcher.insert(path, ())
26 }
27
28 fn find(&self, path: &str) -> bool {
29 self.matcher.at(path).is_ok()
30 }
31}
32
33#[derive(Clone)]
35pub struct PublicRouteMatcher {
36 matcher: matchit::Router<()>,
37}
38
39impl PublicRouteMatcher {
40 fn new() -> Self {
41 Self {
42 matcher: matchit::Router::new(),
43 }
44 }
45
46 fn insert(&mut self, path: &str) -> Result<(), matchit::InsertError> {
47 self.matcher.insert(path, ())
48 }
49
50 fn find(&self, path: &str) -> bool {
51 self.matcher.at(path).is_ok()
52 }
53}
54
55fn convert_axum_path_to_matchit(path: &str) -> String {
60 let mut result = String::with_capacity(path.len());
62 let mut chars = path.chars().peekable();
63
64 while let Some(ch) = chars.next() {
65 if ch == ':' {
66 result.push('{');
68 while matches!(chars.peek(), Some(c) if c.is_alphanumeric() || *c == '_') {
69 if let Some(c) = chars.next() {
70 result.push(c);
71 }
72 }
73 result.push('}');
74 } else {
75 result.push(ch);
76 }
77 }
78
79 result
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum AuthRequirement {
85 None,
87 Required,
89}
90
91#[derive(Clone)]
93pub struct GatewayRoutePolicy {
94 route_matchers: Arc<HashMap<Method, RouteMatcher>>,
95 public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
96 require_auth_by_default: bool,
97}
98
99impl GatewayRoutePolicy {
100 #[must_use]
101 pub fn new(
102 route_matchers: Arc<HashMap<Method, RouteMatcher>>,
103 public_matchers: Arc<HashMap<Method, PublicRouteMatcher>>,
104 require_auth_by_default: bool,
105 ) -> Self {
106 Self {
107 route_matchers,
108 public_matchers,
109 require_auth_by_default,
110 }
111 }
112
113 #[must_use]
115 pub fn resolve(&self, method: &Method, path: &str) -> AuthRequirement {
116 let is_authenticated = self
118 .route_matchers
119 .get(method)
120 .is_some_and(|matcher| matcher.find(path));
121
122 let is_public = self
124 .public_matchers
125 .get(method)
126 .is_some_and(|matcher| matcher.find(path));
127
128 let needs_authn = is_authenticated || (self.require_auth_by_default && !is_public);
130
131 if needs_authn {
132 AuthRequirement::Required
133 } else {
134 AuthRequirement::None
135 }
136 }
137}
138
139#[derive(Clone)]
141pub struct AuthState {
142 pub authn_client: Arc<dyn AuthNResolverClient>,
143 pub route_policy: GatewayRoutePolicy,
144}
145
146#[allow(clippy::implicit_hasher)]
152pub fn build_route_policy(
153 cfg: &crate::config::ApiGatewayConfig,
154 authenticated_routes: std::collections::HashSet<(Method, String)>,
155 public_routes: std::collections::HashSet<(Method, String)>,
156) -> Result<GatewayRoutePolicy, anyhow::Error> {
157 let mut route_matchers_map: HashMap<Method, RouteMatcher> = HashMap::new();
159
160 for (method, path) in authenticated_routes {
161 let matcher = route_matchers_map
162 .entry(method)
163 .or_insert_with(RouteMatcher::new);
164 let matchit_path = convert_axum_path_to_matchit(&path);
166 matcher
167 .insert(&matchit_path)
168 .map_err(|e| anyhow::anyhow!("Failed to insert route pattern '{path}': {e}"))?;
169 }
170
171 let mut public_matchers_map: HashMap<Method, PublicRouteMatcher> = HashMap::new();
173
174 for (method, path) in public_routes {
175 let matcher = public_matchers_map
176 .entry(method)
177 .or_insert_with(PublicRouteMatcher::new);
178 let matchit_path = convert_axum_path_to_matchit(&path);
180 matcher
181 .insert(&matchit_path)
182 .map_err(|e| anyhow::anyhow!("Failed to insert public route pattern '{path}': {e}"))?;
183 }
184
185 Ok(GatewayRoutePolicy::new(
186 Arc::new(route_matchers_map),
187 Arc::new(public_matchers_map),
188 cfg.require_auth_by_default,
189 ))
190}
191
192pub async fn authn_middleware(
200 axum::extract::State(state): axum::extract::State<AuthState>,
201 mut req: axum::extract::Request,
202 next: axum::middleware::Next,
203) -> axum::response::Response {
204 if is_preflight_request(req.method(), req.headers()) {
207 req.extensions_mut().insert(SecurityContext::anonymous());
208 return next.run(req).await;
209 }
210
211 let path = req
212 .extensions()
213 .get::<axum::extract::MatchedPath>()
214 .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
215
216 let path = common::resolve_path(&req, path.as_str());
217
218 let requirement = state.route_policy.resolve(req.method(), path.as_str());
219
220 match requirement {
221 AuthRequirement::None => {
222 req.extensions_mut().insert(SecurityContext::anonymous());
223 next.run(req).await
224 }
225 AuthRequirement::Required => {
226 let Some(token) = extract_bearer_token(req.headers()) else {
227 return Problem::new(
228 axum::http::StatusCode::UNAUTHORIZED,
229 "Unauthorized",
230 "Missing or invalid Authorization header",
231 )
232 .into_response();
233 };
234
235 match state.authn_client.authenticate(token).await {
236 Ok(result) => {
237 req.extensions_mut().insert(result.security_context);
238 next.run(req).await
239 }
240 Err(err) => authn_error_to_response(&err),
241 }
242 }
243 }
244}
245
246fn authn_error_to_response(err: &AuthNResolverError) -> axum::response::Response {
248 log_authn_error(err);
249 let (status, title, detail) = match err {
250 AuthNResolverError::Unauthorized(_) => (
251 axum::http::StatusCode::UNAUTHORIZED,
252 "Unauthorized",
253 "Authentication failed",
254 ),
255 AuthNResolverError::NoPluginAvailable | AuthNResolverError::ServiceUnavailable(_) => (
256 axum::http::StatusCode::SERVICE_UNAVAILABLE,
257 "Service Unavailable",
258 "Authentication service unavailable",
259 ),
260 AuthNResolverError::TokenAcquisitionFailed(_) | AuthNResolverError::Internal(_) => (
261 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
262 "Internal Server Error",
263 "Internal authentication error",
264 ),
265 };
266 Problem::new(status, title, detail).into_response()
267}
268
269#[allow(clippy::cognitive_complexity)]
273fn log_authn_error(err: &AuthNResolverError) {
274 match err {
275 AuthNResolverError::Unauthorized(msg) => tracing::debug!("AuthN rejected: {msg}"),
276 AuthNResolverError::NoPluginAvailable => tracing::error!("No AuthN plugin available"),
277 AuthNResolverError::ServiceUnavailable(msg) => {
278 tracing::error!("AuthN service unavailable: {msg}");
279 }
280 AuthNResolverError::TokenAcquisitionFailed(msg) => {
281 tracing::error!("AuthN token acquisition failed: {msg}");
282 }
283 AuthNResolverError::Internal(msg) => tracing::error!("AuthN internal error: {msg}"),
284 }
285}
286
287fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
289 headers
290 .get(axum::http::header::AUTHORIZATION)
291 .and_then(|v| v.to_str().ok())
292 .and_then(|s| s.strip_prefix("Bearer ").map(str::trim))
293}
294
295fn is_preflight_request(method: &Method, headers: &axum::http::HeaderMap) -> bool {
301 method == Method::OPTIONS
302 && headers.contains_key(axum::http::header::ORIGIN)
303 && headers.contains_key(axum::http::header::ACCESS_CONTROL_REQUEST_METHOD)
304}
305
306#[cfg(test)]
307#[cfg_attr(coverage_nightly, coverage(off))]
308mod tests {
309 use super::*;
310 use axum::http::Method;
311
312 fn build_test_policy(
314 route_matchers: HashMap<Method, RouteMatcher>,
315 public_matchers: HashMap<Method, PublicRouteMatcher>,
316 require_auth_by_default: bool,
317 ) -> GatewayRoutePolicy {
318 GatewayRoutePolicy::new(
319 Arc::new(route_matchers),
320 Arc::new(public_matchers),
321 require_auth_by_default,
322 )
323 }
324
325 #[test]
326 fn test_convert_axum_path_to_matchit() {
327 assert_eq!(convert_axum_path_to_matchit("/users/:id"), "/users/{id}");
328 assert_eq!(
329 convert_axum_path_to_matchit("/posts/:post_id/comments/:comment_id"),
330 "/posts/{post_id}/comments/{comment_id}"
331 );
332 assert_eq!(convert_axum_path_to_matchit("/health"), "/health"); assert_eq!(
334 convert_axum_path_to_matchit("/api/v1/:resource/:id/status"),
335 "/api/v1/{resource}/{id}/status"
336 );
337 }
338
339 #[test]
340 fn test_matchit_router_with_params() {
341 let mut router = matchit::Router::new();
343 router.insert("/users/{id}", "user_route").unwrap();
344
345 let result = router.at("/users/42");
346 assert!(
347 result.is_ok(),
348 "matchit should match /users/{{id}} against /users/42"
349 );
350 assert_eq!(*result.unwrap().value, "user_route");
351 }
352
353 #[test]
354 fn explicit_public_route_with_path_params_returns_none() {
355 let mut public_matchers = HashMap::new();
356 let mut matcher = PublicRouteMatcher::new();
357 matcher.insert("/users/{id}").unwrap();
359
360 public_matchers.insert(Method::GET, matcher);
361
362 let policy = build_test_policy(HashMap::new(), public_matchers, true);
363
364 let result = policy.resolve(&Method::GET, "/users/42");
366 assert_eq!(result, AuthRequirement::None);
367 }
368
369 #[test]
370 fn explicit_public_route_exact_match_returns_none() {
371 let mut public_matchers = HashMap::new();
372 let mut matcher = PublicRouteMatcher::new();
373 matcher.insert("/health").unwrap();
374 public_matchers.insert(Method::GET, matcher);
375
376 let policy = build_test_policy(HashMap::new(), public_matchers, true);
377
378 let result = policy.resolve(&Method::GET, "/health");
379 assert_eq!(result, AuthRequirement::None);
380 }
381
382 #[test]
383 fn explicit_authenticated_route_returns_required() {
384 let mut route_matchers = HashMap::new();
385 let mut matcher = RouteMatcher::new();
386 matcher.insert("/admin/metrics").unwrap();
387 route_matchers.insert(Method::GET, matcher);
388
389 let policy = build_test_policy(route_matchers, HashMap::new(), false);
390
391 let result = policy.resolve(&Method::GET, "/admin/metrics");
392 assert_eq!(result, AuthRequirement::Required);
393 }
394
395 #[test]
396 fn route_without_requirement_with_require_auth_by_default_returns_required() {
397 let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
398
399 let result = policy.resolve(&Method::GET, "/profile");
400 assert_eq!(result, AuthRequirement::Required);
401 }
402
403 #[test]
404 fn route_without_requirement_without_require_auth_by_default_returns_none() {
405 let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
406
407 let result = policy.resolve(&Method::GET, "/profile");
408 assert_eq!(result, AuthRequirement::None);
409 }
410
411 #[test]
412 fn unknown_route_with_require_auth_by_default_true_returns_required() {
413 let policy = build_test_policy(HashMap::new(), HashMap::new(), true);
414
415 let result = policy.resolve(&Method::POST, "/unknown");
416 assert_eq!(result, AuthRequirement::Required);
417 }
418
419 #[test]
420 fn unknown_route_with_require_auth_by_default_false_returns_none() {
421 let policy = build_test_policy(HashMap::new(), HashMap::new(), false);
422
423 let result = policy.resolve(&Method::POST, "/unknown");
424 assert_eq!(result, AuthRequirement::None);
425 }
426
427 #[test]
428 fn public_route_overrides_require_auth_by_default() {
429 let mut public_matchers = HashMap::new();
430 let mut matcher = PublicRouteMatcher::new();
431 matcher.insert("/public").unwrap();
432 public_matchers.insert(Method::GET, matcher);
433
434 let policy = build_test_policy(HashMap::new(), public_matchers, true);
435
436 let result = policy.resolve(&Method::GET, "/public");
437 assert_eq!(result, AuthRequirement::None);
438 }
439
440 #[test]
441 fn authenticated_route_has_priority_over_default() {
442 let mut route_matchers = HashMap::new();
443 let mut matcher = RouteMatcher::new();
444 matcher.insert("/users/{id}").unwrap();
446 route_matchers.insert(Method::GET, matcher);
447
448 let policy = build_test_policy(route_matchers, HashMap::new(), false);
449
450 let result = policy.resolve(&Method::GET, "/users/123");
451 assert_eq!(result, AuthRequirement::Required);
452 }
453
454 #[test]
455 fn different_methods_resolve_independently() {
456 let mut route_matchers = HashMap::new();
457
458 let mut get_matcher = RouteMatcher::new();
460 get_matcher.insert("/user-management/v1/users").unwrap();
461 route_matchers.insert(Method::GET, get_matcher);
462
463 let policy = build_test_policy(route_matchers, HashMap::new(), false);
465
466 let get_result = policy.resolve(&Method::GET, "/user-management/v1/users");
468 assert_eq!(get_result, AuthRequirement::Required);
469
470 let post_result = policy.resolve(&Method::POST, "/user-management/v1/users");
472 assert_eq!(post_result, AuthRequirement::None);
473 }
474}