Skip to main content

modo/auth/
guard.rs

1//! Route-level gating layers — `require_authenticated`, `require_unauthenticated`,
2//! `require_role`, `require_scope`.
3//!
4//! Provides guard layers that reject requests based on authentication state,
5//! role membership, or API key scope. All guards run after route matching
6//! (`.route_layer()`) and expect upstream middleware (role extractor,
7//! [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer)) to have populated
8//! extensions before the guard executes.
9
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use axum::body::Body;
15use axum::response::IntoResponse;
16use http::Request;
17use tower::{Layer, Service};
18
19use crate::Error;
20use crate::auth::apikey::ApiKeyMeta;
21use crate::auth::role::Role;
22use crate::auth::session::Session;
23
24// --- shared redirect helper ---
25
26/// Build a redirect response for guard short-circuits.
27///
28/// For htmx requests (`hx-request: true`), returns `200 OK` with the
29/// `HX-Redirect: <path>` header so htmx performs the client-side navigation.
30/// For all other requests, returns `303 See Other` with `Location: <path>`.
31///
32/// The caller must pass a pre-validated `HeaderValue` (built at layer
33/// construction) so every emitted redirect carries a valid header — no silent
34/// drops at request time.
35fn redirect_response(path: &http::HeaderValue, headers: &http::HeaderMap) -> http::Response<Body> {
36    let is_htmx = headers.get("hx-request").and_then(|v| v.to_str().ok()) == Some("true");
37
38    let mut response = http::Response::new(Body::empty());
39    if is_htmx {
40        *response.status_mut() = http::StatusCode::OK;
41        response.headers_mut().insert("hx-redirect", path.clone());
42    } else {
43        *response.status_mut() = http::StatusCode::SEE_OTHER;
44        response
45            .headers_mut()
46            .insert(http::header::LOCATION, path.clone());
47    }
48    response
49}
50
51// --- require_role ---
52
53/// Creates a guard layer that rejects requests unless the resolved
54/// [`Role`] matches ANY of the allowed roles. Exact string match only;
55/// there is no hierarchy.
56///
57/// # Status codes
58///
59/// - **401 Unauthorized** — no [`Role`] in request extensions (upstream
60///   middleware never populated one).
61/// - **403 Forbidden** — a role is present but not in the allowed list.
62///   An empty `roles` iterator always returns 403.
63///
64/// # Wiring
65///
66/// Apply with `.route_layer()` so the guard runs after route matching.
67/// A role-resolving middleware (e.g. from [`crate::auth::role`]) must run
68/// earlier via `.layer()` so that [`Role`] is in extensions when this
69/// guard runs.
70///
71/// # Example
72///
73/// ```rust,no_run
74/// # fn example() {
75/// use axum::Router;
76/// use axum::routing::get;
77/// use modo::auth::guard::require_role;
78///
79/// let app: Router = Router::new()
80///     .route("/admin", get(|| async { "admin area" }))
81///     .route_layer(require_role(["admin", "owner"]));
82/// # }
83/// ```
84pub fn require_role(roles: impl IntoIterator<Item = impl Into<String>>) -> RequireRoleLayer {
85    RequireRoleLayer {
86        roles: Arc::new(roles.into_iter().map(Into::into).collect()),
87    }
88}
89
90/// Tower layer produced by [`require_role()`].
91pub struct RequireRoleLayer {
92    roles: Arc<Vec<String>>,
93}
94
95impl Clone for RequireRoleLayer {
96    fn clone(&self) -> Self {
97        Self {
98            roles: self.roles.clone(),
99        }
100    }
101}
102
103impl<S> Layer<S> for RequireRoleLayer {
104    type Service = RequireRoleService<S>;
105
106    fn layer(&self, inner: S) -> Self::Service {
107        RequireRoleService {
108            inner,
109            roles: self.roles.clone(),
110        }
111    }
112}
113
114/// Tower service produced by [`RequireRoleLayer`].
115pub struct RequireRoleService<S> {
116    inner: S,
117    roles: Arc<Vec<String>>,
118}
119
120impl<S: Clone> Clone for RequireRoleService<S> {
121    fn clone(&self) -> Self {
122        Self {
123            inner: self.inner.clone(),
124            roles: self.roles.clone(),
125        }
126    }
127}
128
129impl<S> Service<Request<Body>> for RequireRoleService<S>
130where
131    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
132    S::Future: Send + 'static,
133    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
134{
135    type Response = http::Response<Body>;
136    type Error = S::Error;
137    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
138
139    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140        self.inner.poll_ready(cx)
141    }
142
143    fn call(&mut self, request: Request<Body>) -> Self::Future {
144        let roles = self.roles.clone();
145        let mut inner = self.inner.clone();
146        std::mem::swap(&mut self.inner, &mut inner);
147
148        Box::pin(async move {
149            let role = match request.extensions().get::<Role>() {
150                Some(r) => r,
151                None => {
152                    return Ok(Error::unauthorized("authentication required").into_response());
153                }
154            };
155
156            if !roles.iter().any(|allowed| allowed == role.as_str()) {
157                return Ok(Error::forbidden("insufficient role").into_response());
158            }
159
160            inner.call(request).await
161        })
162    }
163}
164
165// --- require_authenticated ---
166
167/// Creates a guard layer that redirects requests without a [`Session`] in
168/// extensions to `redirect_to`. The session's contents are not inspected —
169/// any present session passes the check.
170///
171/// # Response behavior
172///
173/// When a session is absent:
174/// - **htmx** (`hx-request: true`) — `200 OK` with `HX-Redirect: <redirect_to>`
175/// - **non-htmx** — `303 See Other` with `Location: <redirect_to>`
176///
177/// When a session is present, the request is forwarded to the inner service.
178///
179/// # Wiring
180///
181/// Apply with `.route_layer()` so the guard runs after route matching.
182/// The session middleware ([`CookieSessionLayer`](crate::auth::session::CookieSessionLayer)
183/// or the JWT session middleware) must run earlier via `.layer()` so that
184/// [`Session`] is in extensions when this guard runs. No role middleware is
185/// required.
186///
187/// # Example
188///
189/// ```rust,no_run
190/// # fn example() {
191/// use axum::Router;
192/// use axum::routing::get;
193/// use modo::auth::guard::require_authenticated;
194///
195/// let app: Router = Router::new()
196///     .route("/app", get(|| async { "dashboard" }))
197///     .route_layer(require_authenticated("/auth"));
198/// # }
199/// ```
200///
201/// # Panics
202///
203/// Panics at construction if `redirect_to` is not a valid HTTP header value
204/// (e.g. contains a newline or non-visible bytes). Since `redirect_to` is
205/// typically a compile-time constant like `"/auth"`, this surfaces
206/// misconfiguration at startup rather than silently at request time.
207pub fn require_authenticated(redirect_to: impl Into<String>) -> RequireAuthenticatedLayer {
208    let raw = redirect_to.into();
209    let value = http::HeaderValue::from_str(&raw)
210        .expect("require_authenticated: redirect_to must be a valid HTTP header value");
211    RequireAuthenticatedLayer {
212        redirect_to: Arc::new(value),
213    }
214}
215
216/// Tower layer produced by [`require_authenticated()`].
217pub struct RequireAuthenticatedLayer {
218    redirect_to: Arc<http::HeaderValue>,
219}
220
221impl Clone for RequireAuthenticatedLayer {
222    fn clone(&self) -> Self {
223        Self {
224            redirect_to: self.redirect_to.clone(),
225        }
226    }
227}
228
229impl<S> Layer<S> for RequireAuthenticatedLayer {
230    type Service = RequireAuthenticatedService<S>;
231
232    fn layer(&self, inner: S) -> Self::Service {
233        RequireAuthenticatedService {
234            inner,
235            redirect_to: self.redirect_to.clone(),
236        }
237    }
238}
239
240/// Tower service produced by [`RequireAuthenticatedLayer`].
241pub struct RequireAuthenticatedService<S> {
242    inner: S,
243    redirect_to: Arc<http::HeaderValue>,
244}
245
246impl<S: Clone> Clone for RequireAuthenticatedService<S> {
247    fn clone(&self) -> Self {
248        Self {
249            inner: self.inner.clone(),
250            redirect_to: self.redirect_to.clone(),
251        }
252    }
253}
254
255impl<S> Service<Request<Body>> for RequireAuthenticatedService<S>
256where
257    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
258    S::Future: Send + 'static,
259    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
260{
261    type Response = http::Response<Body>;
262    type Error = S::Error;
263    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
264
265    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266        self.inner.poll_ready(cx)
267    }
268
269    fn call(&mut self, request: Request<Body>) -> Self::Future {
270        let redirect_to = self.redirect_to.clone();
271        let mut inner = self.inner.clone();
272        std::mem::swap(&mut self.inner, &mut inner);
273
274        Box::pin(async move {
275            if request.extensions().get::<Session>().is_none() {
276                return Ok(redirect_response(&redirect_to, request.headers()));
277            }
278            inner.call(request).await
279        })
280    }
281}
282
283// --- require_unauthenticated ---
284
285/// Creates a guard layer that redirects requests *with* a [`Session`] in
286/// extensions to `redirect_to`. Use it on guest-only routes (login, signup,
287/// magic-link entry) so an already-signed-in caller doesn't see the login
288/// form.
289///
290/// # Response behavior
291///
292/// When a session is present:
293/// - **htmx** (`hx-request: true`) — `200 OK` with `HX-Redirect: <redirect_to>`
294/// - **non-htmx** — `303 See Other` with `Location: <redirect_to>`
295///
296/// When a session is absent, the request is forwarded to the inner service.
297///
298/// # Wiring
299///
300/// Apply with `.route_layer()` so the guard runs after route matching.
301/// The session middleware ([`CookieSessionLayer`](crate::auth::session::CookieSessionLayer)
302/// or the JWT session middleware) must run earlier via `.layer()` so that
303/// [`Session`] is in extensions when this guard runs.
304///
305/// # Example
306///
307/// ```rust,no_run
308/// # fn example() {
309/// use axum::Router;
310/// use axum::routing::get;
311/// use modo::auth::guard::require_unauthenticated;
312///
313/// let app: Router = Router::new()
314///     .route("/auth", get(|| async { "login page" }))
315///     .route_layer(require_unauthenticated("/app"));
316/// # }
317/// ```
318///
319/// # Panics
320///
321/// Panics at construction if `redirect_to` is not a valid HTTP header value
322/// (e.g. contains a newline or non-visible bytes). Since `redirect_to` is
323/// typically a compile-time constant like `"/app"`, this surfaces
324/// misconfiguration at startup rather than silently at request time.
325pub fn require_unauthenticated(redirect_to: impl Into<String>) -> RequireUnauthenticatedLayer {
326    let raw = redirect_to.into();
327    let value = http::HeaderValue::from_str(&raw)
328        .expect("require_unauthenticated: redirect_to must be a valid HTTP header value");
329    RequireUnauthenticatedLayer {
330        redirect_to: Arc::new(value),
331    }
332}
333
334/// Tower layer produced by [`require_unauthenticated()`].
335pub struct RequireUnauthenticatedLayer {
336    redirect_to: Arc<http::HeaderValue>,
337}
338
339impl Clone for RequireUnauthenticatedLayer {
340    fn clone(&self) -> Self {
341        Self {
342            redirect_to: self.redirect_to.clone(),
343        }
344    }
345}
346
347impl<S> Layer<S> for RequireUnauthenticatedLayer {
348    type Service = RequireUnauthenticatedService<S>;
349
350    fn layer(&self, inner: S) -> Self::Service {
351        RequireUnauthenticatedService {
352            inner,
353            redirect_to: self.redirect_to.clone(),
354        }
355    }
356}
357
358/// Tower service produced by [`RequireUnauthenticatedLayer`].
359pub struct RequireUnauthenticatedService<S> {
360    inner: S,
361    redirect_to: Arc<http::HeaderValue>,
362}
363
364impl<S: Clone> Clone for RequireUnauthenticatedService<S> {
365    fn clone(&self) -> Self {
366        Self {
367            inner: self.inner.clone(),
368            redirect_to: self.redirect_to.clone(),
369        }
370    }
371}
372
373impl<S> Service<Request<Body>> for RequireUnauthenticatedService<S>
374where
375    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
376    S::Future: Send + 'static,
377    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
378{
379    type Response = http::Response<Body>;
380    type Error = S::Error;
381    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
382
383    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
384        self.inner.poll_ready(cx)
385    }
386
387    fn call(&mut self, request: Request<Body>) -> Self::Future {
388        let redirect_to = self.redirect_to.clone();
389        let mut inner = self.inner.clone();
390        std::mem::swap(&mut self.inner, &mut inner);
391
392        Box::pin(async move {
393            if request.extensions().get::<Session>().is_some() {
394                return Ok(redirect_response(&redirect_to, request.headers()));
395            }
396            inner.call(request).await
397        })
398    }
399}
400
401// --- require_scope ---
402
403/// Creates a guard layer that rejects requests unless the verified API
404/// key's scope list contains the required scope. Uses exact string
405/// matching; there is no wildcard or hierarchy.
406///
407/// # Status codes
408///
409/// - **500 Internal Server Error** — no [`ApiKeyMeta`] in request
410///   extensions. The guard is fail-closed and logs an error; this state
411///   indicates the wiring is wrong (missing
412///   [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer) upstream).
413/// - **403 Forbidden** — the API key is present but does not carry the
414///   required scope.
415///
416/// # Wiring
417///
418/// Apply with `.route_layer()` so the guard runs after route matching.
419/// [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer) must run earlier
420/// (via `.layer()`) so that [`ApiKeyMeta`] is in extensions when this
421/// guard runs.
422///
423/// # Example
424///
425/// ```rust,no_run
426/// # fn example() {
427/// use axum::Router;
428/// use axum::routing::get;
429/// use modo::auth::guard::require_scope;
430///
431/// let app: Router = Router::new()
432///     .route("/orders", get(|| async { "orders" }))
433///     .route_layer(require_scope("read:orders"));
434/// # }
435/// ```
436pub fn require_scope(scope: &str) -> ScopeLayer {
437    ScopeLayer {
438        scope: scope.to_owned(),
439    }
440}
441
442/// Tower [`Layer`] that checks for a required scope on the verified API key.
443///
444/// Created by [`require_scope`]. Apply as a `.route_layer()` after
445/// [`ApiKeyLayer`](crate::auth::apikey::ApiKeyLayer).
446#[derive(Clone)]
447pub struct ScopeLayer {
448    scope: String,
449}
450
451impl<S> Layer<S> for ScopeLayer {
452    type Service = ScopeMiddleware<S>;
453
454    fn layer(&self, inner: S) -> Self::Service {
455        ScopeMiddleware {
456            inner,
457            scope: self.scope.clone(),
458        }
459    }
460}
461
462/// Tower [`Service`] that checks for a required scope.
463pub struct ScopeMiddleware<S> {
464    inner: S,
465    scope: String,
466}
467
468impl<S: Clone> Clone for ScopeMiddleware<S> {
469    fn clone(&self) -> Self {
470        Self {
471            inner: self.inner.clone(),
472            scope: self.scope.clone(),
473        }
474    }
475}
476
477impl<S> Service<Request<Body>> for ScopeMiddleware<S>
478where
479    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
480    S::Future: Send + 'static,
481    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
482{
483    type Response = http::Response<Body>;
484    type Error = S::Error;
485    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
486
487    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
488        self.inner.poll_ready(cx)
489    }
490
491    fn call(&mut self, request: Request<Body>) -> Self::Future {
492        let scope = self.scope.clone();
493        let mut inner = self.inner.clone();
494        std::mem::swap(&mut self.inner, &mut inner);
495
496        Box::pin(async move {
497            let Some(meta) = request.extensions().get::<ApiKeyMeta>() else {
498                tracing::error!(
499                    "require_scope guard reached without an API key in extensions; \
500                     ApiKeyLayer must run before this guard"
501                );
502                return Ok(Error::internal("server misconfigured").into_response());
503            };
504
505            if !meta.scopes.iter().any(|s| s == &scope) {
506                return Ok(
507                    Error::forbidden(format!("missing required scope: {scope}")).into_response()
508                );
509            }
510
511            inner.call(request).await
512        })
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use chrono::Utc;
520    use http::{Response, StatusCode};
521    use std::convert::Infallible;
522    use tower::ServiceExt;
523
524    async fn ok_handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
525        Ok(Response::new(Body::from("ok")))
526    }
527
528    // --- redirect_response helper tests ---
529
530    #[test]
531    fn redirect_response_non_htmx_returns_303_with_location() {
532        let headers = http::HeaderMap::new();
533        let path = http::HeaderValue::from_static("/auth");
534        let resp = redirect_response(&path, &headers);
535        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
536        assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
537        assert!(resp.headers().get("hx-redirect").is_none());
538    }
539
540    #[test]
541    fn redirect_response_htmx_returns_200_with_hx_redirect() {
542        let mut headers = http::HeaderMap::new();
543        headers.insert("hx-request", http::HeaderValue::from_static("true"));
544        let path = http::HeaderValue::from_static("/app");
545        let resp = redirect_response(&path, &headers);
546        assert_eq!(resp.status(), StatusCode::OK);
547        assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/app");
548        assert!(resp.headers().get(http::header::LOCATION).is_none());
549    }
550
551    #[test]
552    fn redirect_response_hx_request_false_uses_303() {
553        let mut headers = http::HeaderMap::new();
554        headers.insert("hx-request", http::HeaderValue::from_static("false"));
555        let path = http::HeaderValue::from_static("/x");
556        let resp = redirect_response(&path, &headers);
557        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
558    }
559
560    #[test]
561    #[should_panic(expected = "valid HTTP header value")]
562    fn require_authenticated_panics_on_invalid_redirect() {
563        let _ = require_authenticated("bad\npath");
564    }
565
566    #[test]
567    #[should_panic(expected = "valid HTTP header value")]
568    fn require_unauthenticated_panics_on_invalid_redirect() {
569        let _ = require_unauthenticated("bad\npath");
570    }
571
572    // --- require_role tests ---
573
574    #[tokio::test]
575    async fn require_role_passes_when_role_in_list() {
576        let layer = require_role(["admin", "owner"]);
577        let svc = layer.layer(tower::service_fn(ok_handler));
578
579        let mut req = Request::builder().body(Body::empty()).unwrap();
580        req.extensions_mut().insert(Role("admin".into()));
581        let resp = svc.oneshot(req).await.unwrap();
582        assert_eq!(resp.status(), StatusCode::OK);
583    }
584
585    #[tokio::test]
586    async fn require_role_403_when_role_not_in_list() {
587        let layer = require_role(["admin", "owner"]);
588        let svc = layer.layer(tower::service_fn(ok_handler));
589
590        let mut req = Request::builder().body(Body::empty()).unwrap();
591        req.extensions_mut().insert(Role("viewer".into()));
592        let resp = svc.oneshot(req).await.unwrap();
593        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
594    }
595
596    #[tokio::test]
597    async fn require_role_401_when_role_missing() {
598        let layer = require_role(["admin"]);
599        let svc = layer.layer(tower::service_fn(ok_handler));
600
601        let req = Request::builder().body(Body::empty()).unwrap();
602        let resp = svc.oneshot(req).await.unwrap();
603        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
604    }
605
606    #[tokio::test]
607    async fn require_role_403_when_empty_roles_list() {
608        let layer = require_role(std::iter::empty::<String>());
609        let svc = layer.layer(tower::service_fn(ok_handler));
610
611        let mut req = Request::builder().body(Body::empty()).unwrap();
612        req.extensions_mut().insert(Role("admin".into()));
613        let resp = svc.oneshot(req).await.unwrap();
614        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
615    }
616
617    #[tokio::test]
618    async fn require_role_empty_string_matches() {
619        let layer = require_role([""]);
620        let svc = layer.layer(tower::service_fn(ok_handler));
621
622        let mut req = Request::builder().body(Body::empty()).unwrap();
623        req.extensions_mut().insert(Role("".into()));
624        let resp = svc.oneshot(req).await.unwrap();
625        assert_eq!(resp.status(), StatusCode::OK);
626    }
627
628    #[tokio::test]
629    async fn require_role_does_not_call_inner_on_reject() {
630        use std::sync::atomic::{AtomicBool, Ordering};
631
632        let called = Arc::new(AtomicBool::new(false));
633        let called_clone = called.clone();
634
635        let layer = require_role(["admin"]);
636        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
637            let called = called_clone.clone();
638            async move {
639                called.store(true, Ordering::SeqCst);
640                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
641            }
642        }));
643
644        let mut req = Request::builder().body(Body::empty()).unwrap();
645        req.extensions_mut().insert(Role("viewer".into()));
646        let resp = svc.oneshot(req).await.unwrap();
647        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
648        assert!(!called.load(Ordering::SeqCst));
649    }
650
651    // --- require_authenticated tests (session-based) ---
652
653    fn test_session() -> Session {
654        let now = Utc::now();
655        Session {
656            id: "sess-1".into(),
657            user_id: "user-1".into(),
658            ip_address: "127.0.0.1".into(),
659            user_agent: "test".into(),
660            device_name: "test".into(),
661            device_type: "other".into(),
662            fingerprint: "fp".into(),
663            data: serde_json::json!({}),
664            created_at: now,
665            last_active_at: now,
666            expires_at: now + chrono::Duration::hours(1),
667        }
668    }
669
670    #[tokio::test]
671    async fn require_authenticated_passes_when_session_present() {
672        let layer = require_authenticated("/auth");
673        let svc = layer.layer(tower::service_fn(ok_handler));
674
675        let mut req = Request::builder().body(Body::empty()).unwrap();
676        req.extensions_mut().insert(test_session());
677        let resp = svc.oneshot(req).await.unwrap();
678        assert_eq!(resp.status(), StatusCode::OK);
679    }
680
681    #[tokio::test]
682    async fn require_authenticated_redirects_non_htmx_when_session_missing() {
683        let layer = require_authenticated("/auth");
684        let svc = layer.layer(tower::service_fn(ok_handler));
685
686        let req = Request::builder().body(Body::empty()).unwrap();
687        let resp = svc.oneshot(req).await.unwrap();
688        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
689        assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
690    }
691
692    #[tokio::test]
693    async fn require_authenticated_redirects_htmx_when_session_missing() {
694        let layer = require_authenticated("/auth");
695        let svc = layer.layer(tower::service_fn(ok_handler));
696
697        let req = Request::builder()
698            .header("hx-request", "true")
699            .body(Body::empty())
700            .unwrap();
701        let resp = svc.oneshot(req).await.unwrap();
702        assert_eq!(resp.status(), StatusCode::OK);
703        assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/auth");
704    }
705
706    #[tokio::test]
707    async fn require_authenticated_role_without_session_still_redirects() {
708        let layer = require_authenticated("/auth");
709        let svc = layer.layer(tower::service_fn(ok_handler));
710
711        let mut req = Request::builder().body(Body::empty()).unwrap();
712        req.extensions_mut().insert(Role("admin".into()));
713        let resp = svc.oneshot(req).await.unwrap();
714        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
715        assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
716    }
717
718    #[tokio::test]
719    async fn require_authenticated_does_not_call_inner_on_reject() {
720        use std::sync::atomic::{AtomicBool, Ordering};
721
722        let called = Arc::new(AtomicBool::new(false));
723        let called_clone = called.clone();
724
725        let layer = require_authenticated("/auth");
726        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
727            let called = called_clone.clone();
728            async move {
729                called.store(true, Ordering::SeqCst);
730                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
731            }
732        }));
733
734        let req = Request::builder().body(Body::empty()).unwrap();
735        let resp = svc.oneshot(req).await.unwrap();
736        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
737        assert!(!called.load(Ordering::SeqCst));
738    }
739
740    // --- require_unauthenticated tests ---
741
742    #[tokio::test]
743    async fn require_unauthenticated_passes_when_session_absent() {
744        let layer = require_unauthenticated("/app");
745        let svc = layer.layer(tower::service_fn(ok_handler));
746
747        let req = Request::builder().body(Body::empty()).unwrap();
748        let resp = svc.oneshot(req).await.unwrap();
749        assert_eq!(resp.status(), StatusCode::OK);
750    }
751
752    #[tokio::test]
753    async fn require_unauthenticated_redirects_non_htmx_when_session_present() {
754        let layer = require_unauthenticated("/app");
755        let svc = layer.layer(tower::service_fn(ok_handler));
756
757        let mut req = Request::builder().body(Body::empty()).unwrap();
758        req.extensions_mut().insert(test_session());
759        let resp = svc.oneshot(req).await.unwrap();
760        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
761        assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/app");
762    }
763
764    #[tokio::test]
765    async fn require_unauthenticated_redirects_htmx_when_session_present() {
766        let layer = require_unauthenticated("/app");
767        let svc = layer.layer(tower::service_fn(ok_handler));
768
769        let mut req = Request::builder()
770            .header("hx-request", "true")
771            .body(Body::empty())
772            .unwrap();
773        req.extensions_mut().insert(test_session());
774        let resp = svc.oneshot(req).await.unwrap();
775        assert_eq!(resp.status(), StatusCode::OK);
776        assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/app");
777    }
778
779    #[tokio::test]
780    async fn require_unauthenticated_does_not_call_inner_on_reject() {
781        use std::sync::atomic::{AtomicBool, Ordering};
782
783        let called = Arc::new(AtomicBool::new(false));
784        let called_clone = called.clone();
785
786        let layer = require_unauthenticated("/app");
787        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
788            let called = called_clone.clone();
789            async move {
790                called.store(true, Ordering::SeqCst);
791                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
792            }
793        }));
794
795        let mut req = Request::builder().body(Body::empty()).unwrap();
796        req.extensions_mut().insert(test_session());
797        let resp = svc.oneshot(req).await.unwrap();
798        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
799        assert!(!called.load(Ordering::SeqCst));
800    }
801
802    // --- require_scope tests ---
803
804    fn meta_with_scopes(scopes: &[&str]) -> ApiKeyMeta {
805        ApiKeyMeta {
806            id: "01HX".into(),
807            tenant_id: "t".into(),
808            name: "test key".into(),
809            scopes: scopes.iter().map(|s| (*s).into()).collect(),
810            expires_at: None,
811            last_used_at: None,
812            created_at: "2026-01-01T00:00:00Z".into(),
813        }
814    }
815
816    #[tokio::test]
817    async fn require_scope_passes_when_scope_present() {
818        let layer = require_scope("read:orders");
819        let svc = layer.layer(tower::service_fn(ok_handler));
820
821        let mut req = Request::builder().body(Body::empty()).unwrap();
822        req.extensions_mut()
823            .insert(meta_with_scopes(&["read:orders", "write:orders"]));
824        let resp = svc.oneshot(req).await.unwrap();
825        assert_eq!(resp.status(), StatusCode::OK);
826    }
827
828    #[tokio::test]
829    async fn require_scope_403_when_scope_absent() {
830        let layer = require_scope("admin:all");
831        let svc = layer.layer(tower::service_fn(ok_handler));
832
833        let mut req = Request::builder().body(Body::empty()).unwrap();
834        req.extensions_mut()
835            .insert(meta_with_scopes(&["read:orders"]));
836        let resp = svc.oneshot(req).await.unwrap();
837        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
838    }
839
840    #[tokio::test]
841    async fn require_scope_500_when_apikey_meta_missing() {
842        let layer = require_scope("read:orders");
843        let svc = layer.layer(tower::service_fn(ok_handler));
844
845        let req = Request::builder().body(Body::empty()).unwrap();
846        let resp = svc.oneshot(req).await.unwrap();
847        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
848    }
849}