Skip to main content

modo/auth/
guard.rs

1//! Route-level gating layers — `require_authenticated`, `require_role`,
2//! `require_scope`.
3//!
4//! Provides guard layers that reject requests based on authentication state,
5//! role membership, or API key scope. All guards run after route matching
6//! (`.route_layer()`) and expect upstream middleware (role extractor,
7//! [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer)) to have populated
8//! extensions before the guard executes.
9
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use axum::body::Body;
15use axum::response::IntoResponse;
16use http::Request;
17use tower::{Layer, Service};
18
19use crate::Error;
20use crate::auth::apikey::ApiKeyMeta;
21use crate::auth::role::Role;
22
23// --- require_role ---
24
25/// Creates a guard layer that rejects requests unless the resolved
26/// [`Role`] matches ANY of the allowed roles. Exact string match only;
27/// there is no hierarchy.
28///
29/// # Status codes
30///
31/// - **401 Unauthorized** — no [`Role`] in request extensions (upstream
32///   middleware never populated one).
33/// - **403 Forbidden** — a role is present but not in the allowed list.
34///   An empty `roles` iterator always returns 403.
35///
36/// # Wiring
37///
38/// Apply with `.route_layer()` so the guard runs after route matching.
39/// A role-resolving middleware (e.g. from [`crate::auth::role`]) must run
40/// earlier via `.layer()` so that [`Role`] is in extensions when this
41/// guard runs.
42///
43/// # Example
44///
45/// ```rust,no_run
46/// # fn example() {
47/// use axum::Router;
48/// use axum::routing::get;
49/// use modo::auth::guard::require_role;
50///
51/// let app: Router = Router::new()
52///     .route("/admin", get(|| async { "admin area" }))
53///     .route_layer(require_role(["admin", "owner"]));
54/// # }
55/// ```
56pub fn require_role(roles: impl IntoIterator<Item = impl Into<String>>) -> RequireRoleLayer {
57    RequireRoleLayer {
58        roles: Arc::new(roles.into_iter().map(Into::into).collect()),
59    }
60}
61
62/// Tower layer produced by [`require_role()`].
63pub struct RequireRoleLayer {
64    roles: Arc<Vec<String>>,
65}
66
67impl Clone for RequireRoleLayer {
68    fn clone(&self) -> Self {
69        Self {
70            roles: self.roles.clone(),
71        }
72    }
73}
74
75impl<S> Layer<S> for RequireRoleLayer {
76    type Service = RequireRoleService<S>;
77
78    fn layer(&self, inner: S) -> Self::Service {
79        RequireRoleService {
80            inner,
81            roles: self.roles.clone(),
82        }
83    }
84}
85
86/// Tower service produced by [`RequireRoleLayer`].
87pub struct RequireRoleService<S> {
88    inner: S,
89    roles: Arc<Vec<String>>,
90}
91
92impl<S: Clone> Clone for RequireRoleService<S> {
93    fn clone(&self) -> Self {
94        Self {
95            inner: self.inner.clone(),
96            roles: self.roles.clone(),
97        }
98    }
99}
100
101impl<S> Service<Request<Body>> for RequireRoleService<S>
102where
103    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
104    S::Future: Send + 'static,
105    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
106{
107    type Response = http::Response<Body>;
108    type Error = S::Error;
109    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
110
111    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112        self.inner.poll_ready(cx)
113    }
114
115    fn call(&mut self, request: Request<Body>) -> Self::Future {
116        let roles = self.roles.clone();
117        let mut inner = self.inner.clone();
118        std::mem::swap(&mut self.inner, &mut inner);
119
120        Box::pin(async move {
121            let role = match request.extensions().get::<Role>() {
122                Some(r) => r,
123                None => {
124                    return Ok(Error::unauthorized("authentication required").into_response());
125                }
126            };
127
128            if !roles.iter().any(|allowed| allowed == role.as_str()) {
129                return Ok(Error::forbidden("insufficient role").into_response());
130            }
131
132            inner.call(request).await
133        })
134    }
135}
136
137// --- require_authenticated ---
138
139/// Creates a guard layer that rejects requests unless a [`Role`] is
140/// present in extensions. The role's value is not inspected — any
141/// resolved role is accepted.
142///
143/// # Status codes
144///
145/// - **401 Unauthorized** — no [`Role`] in request extensions.
146///
147/// # Wiring
148///
149/// Apply with `.route_layer()` so the guard runs after route matching.
150/// A role-resolving middleware (e.g. from [`crate::auth::role`]) must run
151/// earlier via `.layer()` so that [`Role`] is in extensions when this
152/// guard runs.
153///
154/// # Example
155///
156/// ```rust,no_run
157/// # fn example() {
158/// use axum::Router;
159/// use axum::routing::get;
160/// use modo::auth::guard::require_authenticated;
161///
162/// let app: Router = Router::new()
163///     .route("/me", get(|| async { "profile" }))
164///     .route_layer(require_authenticated());
165/// # }
166/// ```
167pub fn require_authenticated() -> RequireAuthenticatedLayer {
168    RequireAuthenticatedLayer
169}
170
171/// Tower layer produced by [`require_authenticated()`].
172pub struct RequireAuthenticatedLayer;
173
174impl Clone for RequireAuthenticatedLayer {
175    fn clone(&self) -> Self {
176        Self
177    }
178}
179
180impl<S> Layer<S> for RequireAuthenticatedLayer {
181    type Service = RequireAuthenticatedService<S>;
182
183    fn layer(&self, inner: S) -> Self::Service {
184        RequireAuthenticatedService { inner }
185    }
186}
187
188/// Tower service produced by [`RequireAuthenticatedLayer`].
189pub struct RequireAuthenticatedService<S> {
190    inner: S,
191}
192
193impl<S: Clone> Clone for RequireAuthenticatedService<S> {
194    fn clone(&self) -> Self {
195        Self {
196            inner: self.inner.clone(),
197        }
198    }
199}
200
201impl<S> Service<Request<Body>> for RequireAuthenticatedService<S>
202where
203    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
204    S::Future: Send + 'static,
205    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
206{
207    type Response = http::Response<Body>;
208    type Error = S::Error;
209    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
210
211    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        self.inner.poll_ready(cx)
213    }
214
215    fn call(&mut self, request: Request<Body>) -> Self::Future {
216        let mut inner = self.inner.clone();
217        std::mem::swap(&mut self.inner, &mut inner);
218
219        Box::pin(async move {
220            if request.extensions().get::<Role>().is_none() {
221                return Ok(Error::unauthorized("authentication required").into_response());
222            }
223
224            inner.call(request).await
225        })
226    }
227}
228
229// --- require_scope ---
230
231/// Creates a guard layer that rejects requests unless the verified API
232/// key's scope list contains the required scope. Uses exact string
233/// matching; there is no wildcard or hierarchy.
234///
235/// # Status codes
236///
237/// - **500 Internal Server Error** — no [`ApiKeyMeta`] in request
238///   extensions. The guard is fail-closed and logs an error; this state
239///   indicates the wiring is wrong (missing
240///   [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer) upstream).
241/// - **403 Forbidden** — the API key is present but does not carry the
242///   required scope.
243///
244/// # Wiring
245///
246/// Apply with `.route_layer()` so the guard runs after route matching.
247/// [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer) must run earlier
248/// (via `.layer()`) so that [`ApiKeyMeta`] is in extensions when this
249/// guard runs.
250///
251/// # Example
252///
253/// ```rust,no_run
254/// # fn example() {
255/// use axum::Router;
256/// use axum::routing::get;
257/// use modo::auth::guard::require_scope;
258///
259/// let app: Router = Router::new()
260///     .route("/orders", get(|| async { "orders" }))
261///     .route_layer(require_scope("read:orders"));
262/// # }
263/// ```
264pub fn require_scope(scope: &str) -> ScopeLayer {
265    ScopeLayer {
266        scope: scope.to_owned(),
267    }
268}
269
270/// Tower [`Layer`] that checks for a required scope on the verified API key.
271///
272/// Created by [`require_scope`]. Apply as a `.route_layer()` after
273/// [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer).
274#[derive(Clone)]
275pub struct ScopeLayer {
276    scope: String,
277}
278
279impl<S> Layer<S> for ScopeLayer {
280    type Service = ScopeMiddleware<S>;
281
282    fn layer(&self, inner: S) -> Self::Service {
283        ScopeMiddleware {
284            inner,
285            scope: self.scope.clone(),
286        }
287    }
288}
289
290/// Tower [`Service`] that checks for a required scope.
291pub struct ScopeMiddleware<S> {
292    inner: S,
293    scope: String,
294}
295
296impl<S: Clone> Clone for ScopeMiddleware<S> {
297    fn clone(&self) -> Self {
298        Self {
299            inner: self.inner.clone(),
300            scope: self.scope.clone(),
301        }
302    }
303}
304
305impl<S> Service<Request<Body>> for ScopeMiddleware<S>
306where
307    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
308    S::Future: Send + 'static,
309    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
310{
311    type Response = http::Response<Body>;
312    type Error = S::Error;
313    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
314
315    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
316        self.inner.poll_ready(cx)
317    }
318
319    fn call(&mut self, request: Request<Body>) -> Self::Future {
320        let scope = self.scope.clone();
321        let mut inner = self.inner.clone();
322        std::mem::swap(&mut self.inner, &mut inner);
323
324        Box::pin(async move {
325            let Some(meta) = request.extensions().get::<ApiKeyMeta>() else {
326                tracing::error!(
327                    "require_scope guard reached without an API key in extensions; \
328                     ApiKeyLayer must run before this guard"
329                );
330                return Ok(Error::internal("server misconfigured").into_response());
331            };
332
333            if !meta.scopes.iter().any(|s| s == &scope) {
334                return Ok(
335                    Error::forbidden(format!("missing required scope: {scope}")).into_response()
336                );
337            }
338
339            inner.call(request).await
340        })
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use http::{Response, StatusCode};
348    use std::convert::Infallible;
349    use tower::ServiceExt;
350
351    async fn ok_handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
352        Ok(Response::new(Body::from("ok")))
353    }
354
355    // --- require_role tests ---
356
357    #[tokio::test]
358    async fn require_role_passes_when_role_in_list() {
359        let layer = require_role(["admin", "owner"]);
360        let svc = layer.layer(tower::service_fn(ok_handler));
361
362        let mut req = Request::builder().body(Body::empty()).unwrap();
363        req.extensions_mut().insert(Role("admin".into()));
364        let resp = svc.oneshot(req).await.unwrap();
365        assert_eq!(resp.status(), StatusCode::OK);
366    }
367
368    #[tokio::test]
369    async fn require_role_403_when_role_not_in_list() {
370        let layer = require_role(["admin", "owner"]);
371        let svc = layer.layer(tower::service_fn(ok_handler));
372
373        let mut req = Request::builder().body(Body::empty()).unwrap();
374        req.extensions_mut().insert(Role("viewer".into()));
375        let resp = svc.oneshot(req).await.unwrap();
376        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
377    }
378
379    #[tokio::test]
380    async fn require_role_401_when_role_missing() {
381        let layer = require_role(["admin"]);
382        let svc = layer.layer(tower::service_fn(ok_handler));
383
384        let req = Request::builder().body(Body::empty()).unwrap();
385        let resp = svc.oneshot(req).await.unwrap();
386        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
387    }
388
389    #[tokio::test]
390    async fn require_role_403_when_empty_roles_list() {
391        let layer = require_role(std::iter::empty::<String>());
392        let svc = layer.layer(tower::service_fn(ok_handler));
393
394        let mut req = Request::builder().body(Body::empty()).unwrap();
395        req.extensions_mut().insert(Role("admin".into()));
396        let resp = svc.oneshot(req).await.unwrap();
397        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
398    }
399
400    #[tokio::test]
401    async fn require_role_empty_string_matches() {
402        let layer = require_role([""]);
403        let svc = layer.layer(tower::service_fn(ok_handler));
404
405        let mut req = Request::builder().body(Body::empty()).unwrap();
406        req.extensions_mut().insert(Role("".into()));
407        let resp = svc.oneshot(req).await.unwrap();
408        assert_eq!(resp.status(), StatusCode::OK);
409    }
410
411    #[tokio::test]
412    async fn require_role_does_not_call_inner_on_reject() {
413        use std::sync::atomic::{AtomicBool, Ordering};
414
415        let called = Arc::new(AtomicBool::new(false));
416        let called_clone = called.clone();
417
418        let layer = require_role(["admin"]);
419        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
420            let called = called_clone.clone();
421            async move {
422                called.store(true, Ordering::SeqCst);
423                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
424            }
425        }));
426
427        let mut req = Request::builder().body(Body::empty()).unwrap();
428        req.extensions_mut().insert(Role("viewer".into()));
429        let resp = svc.oneshot(req).await.unwrap();
430        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
431        assert!(!called.load(Ordering::SeqCst));
432    }
433
434    // --- require_authenticated tests ---
435
436    #[tokio::test]
437    async fn require_authenticated_passes_when_role_present() {
438        let layer = require_authenticated();
439        let svc = layer.layer(tower::service_fn(ok_handler));
440
441        let mut req = Request::builder().body(Body::empty()).unwrap();
442        req.extensions_mut().insert(Role("viewer".into()));
443        let resp = svc.oneshot(req).await.unwrap();
444        assert_eq!(resp.status(), StatusCode::OK);
445    }
446
447    #[tokio::test]
448    async fn require_authenticated_401_when_role_missing() {
449        let layer = require_authenticated();
450        let svc = layer.layer(tower::service_fn(ok_handler));
451
452        let req = Request::builder().body(Body::empty()).unwrap();
453        let resp = svc.oneshot(req).await.unwrap();
454        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
455    }
456
457    #[tokio::test]
458    async fn require_authenticated_does_not_call_inner_on_reject() {
459        use std::sync::atomic::{AtomicBool, Ordering};
460
461        let called = Arc::new(AtomicBool::new(false));
462        let called_clone = called.clone();
463
464        let layer = require_authenticated();
465        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
466            let called = called_clone.clone();
467            async move {
468                called.store(true, Ordering::SeqCst);
469                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
470            }
471        }));
472
473        let req = Request::builder().body(Body::empty()).unwrap();
474        let resp = svc.oneshot(req).await.unwrap();
475        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
476        assert!(!called.load(Ordering::SeqCst));
477    }
478
479    // --- require_scope tests ---
480
481    fn meta_with_scopes(scopes: &[&str]) -> ApiKeyMeta {
482        ApiKeyMeta {
483            id: "01HX".into(),
484            tenant_id: "t".into(),
485            name: "test key".into(),
486            scopes: scopes.iter().map(|s| (*s).into()).collect(),
487            expires_at: None,
488            last_used_at: None,
489            created_at: "2026-01-01T00:00:00Z".into(),
490        }
491    }
492
493    #[tokio::test]
494    async fn require_scope_passes_when_scope_present() {
495        let layer = require_scope("read:orders");
496        let svc = layer.layer(tower::service_fn(ok_handler));
497
498        let mut req = Request::builder().body(Body::empty()).unwrap();
499        req.extensions_mut()
500            .insert(meta_with_scopes(&["read:orders", "write:orders"]));
501        let resp = svc.oneshot(req).await.unwrap();
502        assert_eq!(resp.status(), StatusCode::OK);
503    }
504
505    #[tokio::test]
506    async fn require_scope_403_when_scope_absent() {
507        let layer = require_scope("admin:all");
508        let svc = layer.layer(tower::service_fn(ok_handler));
509
510        let mut req = Request::builder().body(Body::empty()).unwrap();
511        req.extensions_mut()
512            .insert(meta_with_scopes(&["read:orders"]));
513        let resp = svc.oneshot(req).await.unwrap();
514        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
515    }
516
517    #[tokio::test]
518    async fn require_scope_500_when_apikey_meta_missing() {
519        let layer = require_scope("read:orders");
520        let svc = layer.layer(tower::service_fn(ok_handler));
521
522        let req = Request::builder().body(Body::empty()).unwrap();
523        let resp = svc.oneshot(req).await.unwrap();
524        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
525    }
526}