1use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use axum::body::Body;
15use axum::response::IntoResponse;
16use http::Request;
17use tower::{Layer, Service};
18
19use crate::Error;
20use crate::auth::apikey::ApiKeyMeta;
21use crate::auth::role::Role;
22use crate::auth::session::Session;
23
24fn redirect_response(path: &http::HeaderValue, headers: &http::HeaderMap) -> http::Response<Body> {
36 let is_htmx = headers.get("hx-request").and_then(|v| v.to_str().ok()) == Some("true");
37
38 let mut response = http::Response::new(Body::empty());
39 if is_htmx {
40 *response.status_mut() = http::StatusCode::OK;
41 response.headers_mut().insert("hx-redirect", path.clone());
42 } else {
43 *response.status_mut() = http::StatusCode::SEE_OTHER;
44 response
45 .headers_mut()
46 .insert(http::header::LOCATION, path.clone());
47 }
48 response
49}
50
51pub fn require_role(roles: impl IntoIterator<Item = impl Into<String>>) -> RequireRoleLayer {
85 RequireRoleLayer {
86 roles: Arc::new(roles.into_iter().map(Into::into).collect()),
87 }
88}
89
90pub struct RequireRoleLayer {
92 roles: Arc<Vec<String>>,
93}
94
95impl Clone for RequireRoleLayer {
96 fn clone(&self) -> Self {
97 Self {
98 roles: self.roles.clone(),
99 }
100 }
101}
102
103impl<S> Layer<S> for RequireRoleLayer {
104 type Service = RequireRoleService<S>;
105
106 fn layer(&self, inner: S) -> Self::Service {
107 RequireRoleService {
108 inner,
109 roles: self.roles.clone(),
110 }
111 }
112}
113
114pub struct RequireRoleService<S> {
116 inner: S,
117 roles: Arc<Vec<String>>,
118}
119
120impl<S: Clone> Clone for RequireRoleService<S> {
121 fn clone(&self) -> Self {
122 Self {
123 inner: self.inner.clone(),
124 roles: self.roles.clone(),
125 }
126 }
127}
128
129impl<S> Service<Request<Body>> for RequireRoleService<S>
130where
131 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
132 S::Future: Send + 'static,
133 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
134{
135 type Response = http::Response<Body>;
136 type Error = S::Error;
137 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
138
139 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 self.inner.poll_ready(cx)
141 }
142
143 fn call(&mut self, request: Request<Body>) -> Self::Future {
144 let roles = self.roles.clone();
145 let mut inner = self.inner.clone();
146 std::mem::swap(&mut self.inner, &mut inner);
147
148 Box::pin(async move {
149 let role = match request.extensions().get::<Role>() {
150 Some(r) => r,
151 None => {
152 return Ok(Error::unauthorized("authentication required").into_response());
153 }
154 };
155
156 if !roles.iter().any(|allowed| allowed == role.as_str()) {
157 return Ok(Error::forbidden("insufficient role").into_response());
158 }
159
160 inner.call(request).await
161 })
162 }
163}
164
165pub fn require_authenticated(redirect_to: impl Into<String>) -> RequireAuthenticatedLayer {
208 let raw = redirect_to.into();
209 let value = http::HeaderValue::from_str(&raw)
210 .expect("require_authenticated: redirect_to must be a valid HTTP header value");
211 RequireAuthenticatedLayer {
212 redirect_to: Arc::new(value),
213 }
214}
215
216pub struct RequireAuthenticatedLayer {
218 redirect_to: Arc<http::HeaderValue>,
219}
220
221impl Clone for RequireAuthenticatedLayer {
222 fn clone(&self) -> Self {
223 Self {
224 redirect_to: self.redirect_to.clone(),
225 }
226 }
227}
228
229impl<S> Layer<S> for RequireAuthenticatedLayer {
230 type Service = RequireAuthenticatedService<S>;
231
232 fn layer(&self, inner: S) -> Self::Service {
233 RequireAuthenticatedService {
234 inner,
235 redirect_to: self.redirect_to.clone(),
236 }
237 }
238}
239
240pub struct RequireAuthenticatedService<S> {
242 inner: S,
243 redirect_to: Arc<http::HeaderValue>,
244}
245
246impl<S: Clone> Clone for RequireAuthenticatedService<S> {
247 fn clone(&self) -> Self {
248 Self {
249 inner: self.inner.clone(),
250 redirect_to: self.redirect_to.clone(),
251 }
252 }
253}
254
255impl<S> Service<Request<Body>> for RequireAuthenticatedService<S>
256where
257 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
258 S::Future: Send + 'static,
259 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
260{
261 type Response = http::Response<Body>;
262 type Error = S::Error;
263 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
264
265 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266 self.inner.poll_ready(cx)
267 }
268
269 fn call(&mut self, request: Request<Body>) -> Self::Future {
270 let redirect_to = self.redirect_to.clone();
271 let mut inner = self.inner.clone();
272 std::mem::swap(&mut self.inner, &mut inner);
273
274 Box::pin(async move {
275 if request.extensions().get::<Session>().is_none() {
276 return Ok(redirect_response(&redirect_to, request.headers()));
277 }
278 inner.call(request).await
279 })
280 }
281}
282
283pub fn require_unauthenticated(redirect_to: impl Into<String>) -> RequireUnauthenticatedLayer {
326 let raw = redirect_to.into();
327 let value = http::HeaderValue::from_str(&raw)
328 .expect("require_unauthenticated: redirect_to must be a valid HTTP header value");
329 RequireUnauthenticatedLayer {
330 redirect_to: Arc::new(value),
331 }
332}
333
334pub struct RequireUnauthenticatedLayer {
336 redirect_to: Arc<http::HeaderValue>,
337}
338
339impl Clone for RequireUnauthenticatedLayer {
340 fn clone(&self) -> Self {
341 Self {
342 redirect_to: self.redirect_to.clone(),
343 }
344 }
345}
346
347impl<S> Layer<S> for RequireUnauthenticatedLayer {
348 type Service = RequireUnauthenticatedService<S>;
349
350 fn layer(&self, inner: S) -> Self::Service {
351 RequireUnauthenticatedService {
352 inner,
353 redirect_to: self.redirect_to.clone(),
354 }
355 }
356}
357
358pub struct RequireUnauthenticatedService<S> {
360 inner: S,
361 redirect_to: Arc<http::HeaderValue>,
362}
363
364impl<S: Clone> Clone for RequireUnauthenticatedService<S> {
365 fn clone(&self) -> Self {
366 Self {
367 inner: self.inner.clone(),
368 redirect_to: self.redirect_to.clone(),
369 }
370 }
371}
372
373impl<S> Service<Request<Body>> for RequireUnauthenticatedService<S>
374where
375 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
376 S::Future: Send + 'static,
377 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
378{
379 type Response = http::Response<Body>;
380 type Error = S::Error;
381 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
382
383 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
384 self.inner.poll_ready(cx)
385 }
386
387 fn call(&mut self, request: Request<Body>) -> Self::Future {
388 let redirect_to = self.redirect_to.clone();
389 let mut inner = self.inner.clone();
390 std::mem::swap(&mut self.inner, &mut inner);
391
392 Box::pin(async move {
393 if request.extensions().get::<Session>().is_some() {
394 return Ok(redirect_response(&redirect_to, request.headers()));
395 }
396 inner.call(request).await
397 })
398 }
399}
400
401pub fn require_scope(scope: &str) -> ScopeLayer {
437 ScopeLayer {
438 scope: scope.to_owned(),
439 }
440}
441
442#[derive(Clone)]
447pub struct ScopeLayer {
448 scope: String,
449}
450
451impl<S> Layer<S> for ScopeLayer {
452 type Service = ScopeMiddleware<S>;
453
454 fn layer(&self, inner: S) -> Self::Service {
455 ScopeMiddleware {
456 inner,
457 scope: self.scope.clone(),
458 }
459 }
460}
461
462pub struct ScopeMiddleware<S> {
464 inner: S,
465 scope: String,
466}
467
468impl<S: Clone> Clone for ScopeMiddleware<S> {
469 fn clone(&self) -> Self {
470 Self {
471 inner: self.inner.clone(),
472 scope: self.scope.clone(),
473 }
474 }
475}
476
477impl<S> Service<Request<Body>> for ScopeMiddleware<S>
478where
479 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
480 S::Future: Send + 'static,
481 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
482{
483 type Response = http::Response<Body>;
484 type Error = S::Error;
485 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
486
487 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
488 self.inner.poll_ready(cx)
489 }
490
491 fn call(&mut self, request: Request<Body>) -> Self::Future {
492 let scope = self.scope.clone();
493 let mut inner = self.inner.clone();
494 std::mem::swap(&mut self.inner, &mut inner);
495
496 Box::pin(async move {
497 let Some(meta) = request.extensions().get::<ApiKeyMeta>() else {
498 tracing::error!(
499 "require_scope guard reached without an API key in extensions; \
500 ApiKeyLayer must run before this guard"
501 );
502 return Ok(Error::internal("server misconfigured").into_response());
503 };
504
505 if !meta.scopes.iter().any(|s| s == &scope) {
506 return Ok(
507 Error::forbidden(format!("missing required scope: {scope}")).into_response()
508 );
509 }
510
511 inner.call(request).await
512 })
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use chrono::Utc;
520 use http::{Response, StatusCode};
521 use std::convert::Infallible;
522 use tower::ServiceExt;
523
524 async fn ok_handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
525 Ok(Response::new(Body::from("ok")))
526 }
527
528 #[test]
531 fn redirect_response_non_htmx_returns_303_with_location() {
532 let headers = http::HeaderMap::new();
533 let path = http::HeaderValue::from_static("/auth");
534 let resp = redirect_response(&path, &headers);
535 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
536 assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
537 assert!(resp.headers().get("hx-redirect").is_none());
538 }
539
540 #[test]
541 fn redirect_response_htmx_returns_200_with_hx_redirect() {
542 let mut headers = http::HeaderMap::new();
543 headers.insert("hx-request", http::HeaderValue::from_static("true"));
544 let path = http::HeaderValue::from_static("/app");
545 let resp = redirect_response(&path, &headers);
546 assert_eq!(resp.status(), StatusCode::OK);
547 assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/app");
548 assert!(resp.headers().get(http::header::LOCATION).is_none());
549 }
550
551 #[test]
552 fn redirect_response_hx_request_false_uses_303() {
553 let mut headers = http::HeaderMap::new();
554 headers.insert("hx-request", http::HeaderValue::from_static("false"));
555 let path = http::HeaderValue::from_static("/x");
556 let resp = redirect_response(&path, &headers);
557 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
558 }
559
560 #[test]
561 #[should_panic(expected = "valid HTTP header value")]
562 fn require_authenticated_panics_on_invalid_redirect() {
563 let _ = require_authenticated("bad\npath");
564 }
565
566 #[test]
567 #[should_panic(expected = "valid HTTP header value")]
568 fn require_unauthenticated_panics_on_invalid_redirect() {
569 let _ = require_unauthenticated("bad\npath");
570 }
571
572 #[tokio::test]
575 async fn require_role_passes_when_role_in_list() {
576 let layer = require_role(["admin", "owner"]);
577 let svc = layer.layer(tower::service_fn(ok_handler));
578
579 let mut req = Request::builder().body(Body::empty()).unwrap();
580 req.extensions_mut().insert(Role("admin".into()));
581 let resp = svc.oneshot(req).await.unwrap();
582 assert_eq!(resp.status(), StatusCode::OK);
583 }
584
585 #[tokio::test]
586 async fn require_role_403_when_role_not_in_list() {
587 let layer = require_role(["admin", "owner"]);
588 let svc = layer.layer(tower::service_fn(ok_handler));
589
590 let mut req = Request::builder().body(Body::empty()).unwrap();
591 req.extensions_mut().insert(Role("viewer".into()));
592 let resp = svc.oneshot(req).await.unwrap();
593 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
594 }
595
596 #[tokio::test]
597 async fn require_role_401_when_role_missing() {
598 let layer = require_role(["admin"]);
599 let svc = layer.layer(tower::service_fn(ok_handler));
600
601 let req = Request::builder().body(Body::empty()).unwrap();
602 let resp = svc.oneshot(req).await.unwrap();
603 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
604 }
605
606 #[tokio::test]
607 async fn require_role_403_when_empty_roles_list() {
608 let layer = require_role(std::iter::empty::<String>());
609 let svc = layer.layer(tower::service_fn(ok_handler));
610
611 let mut req = Request::builder().body(Body::empty()).unwrap();
612 req.extensions_mut().insert(Role("admin".into()));
613 let resp = svc.oneshot(req).await.unwrap();
614 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
615 }
616
617 #[tokio::test]
618 async fn require_role_empty_string_matches() {
619 let layer = require_role([""]);
620 let svc = layer.layer(tower::service_fn(ok_handler));
621
622 let mut req = Request::builder().body(Body::empty()).unwrap();
623 req.extensions_mut().insert(Role("".into()));
624 let resp = svc.oneshot(req).await.unwrap();
625 assert_eq!(resp.status(), StatusCode::OK);
626 }
627
628 #[tokio::test]
629 async fn require_role_does_not_call_inner_on_reject() {
630 use std::sync::atomic::{AtomicBool, Ordering};
631
632 let called = Arc::new(AtomicBool::new(false));
633 let called_clone = called.clone();
634
635 let layer = require_role(["admin"]);
636 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
637 let called = called_clone.clone();
638 async move {
639 called.store(true, Ordering::SeqCst);
640 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
641 }
642 }));
643
644 let mut req = Request::builder().body(Body::empty()).unwrap();
645 req.extensions_mut().insert(Role("viewer".into()));
646 let resp = svc.oneshot(req).await.unwrap();
647 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
648 assert!(!called.load(Ordering::SeqCst));
649 }
650
651 fn test_session() -> Session {
654 let now = Utc::now();
655 Session {
656 id: "sess-1".into(),
657 user_id: "user-1".into(),
658 ip_address: "127.0.0.1".into(),
659 user_agent: "test".into(),
660 device_name: "test".into(),
661 device_type: "other".into(),
662 fingerprint: "fp".into(),
663 data: serde_json::json!({}),
664 created_at: now,
665 last_active_at: now,
666 expires_at: now + chrono::Duration::hours(1),
667 }
668 }
669
670 #[tokio::test]
671 async fn require_authenticated_passes_when_session_present() {
672 let layer = require_authenticated("/auth");
673 let svc = layer.layer(tower::service_fn(ok_handler));
674
675 let mut req = Request::builder().body(Body::empty()).unwrap();
676 req.extensions_mut().insert(test_session());
677 let resp = svc.oneshot(req).await.unwrap();
678 assert_eq!(resp.status(), StatusCode::OK);
679 }
680
681 #[tokio::test]
682 async fn require_authenticated_redirects_non_htmx_when_session_missing() {
683 let layer = require_authenticated("/auth");
684 let svc = layer.layer(tower::service_fn(ok_handler));
685
686 let req = Request::builder().body(Body::empty()).unwrap();
687 let resp = svc.oneshot(req).await.unwrap();
688 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
689 assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
690 }
691
692 #[tokio::test]
693 async fn require_authenticated_redirects_htmx_when_session_missing() {
694 let layer = require_authenticated("/auth");
695 let svc = layer.layer(tower::service_fn(ok_handler));
696
697 let req = Request::builder()
698 .header("hx-request", "true")
699 .body(Body::empty())
700 .unwrap();
701 let resp = svc.oneshot(req).await.unwrap();
702 assert_eq!(resp.status(), StatusCode::OK);
703 assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/auth");
704 }
705
706 #[tokio::test]
707 async fn require_authenticated_role_without_session_still_redirects() {
708 let layer = require_authenticated("/auth");
709 let svc = layer.layer(tower::service_fn(ok_handler));
710
711 let mut req = Request::builder().body(Body::empty()).unwrap();
712 req.extensions_mut().insert(Role("admin".into()));
713 let resp = svc.oneshot(req).await.unwrap();
714 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
715 assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
716 }
717
718 #[tokio::test]
719 async fn require_authenticated_does_not_call_inner_on_reject() {
720 use std::sync::atomic::{AtomicBool, Ordering};
721
722 let called = Arc::new(AtomicBool::new(false));
723 let called_clone = called.clone();
724
725 let layer = require_authenticated("/auth");
726 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
727 let called = called_clone.clone();
728 async move {
729 called.store(true, Ordering::SeqCst);
730 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
731 }
732 }));
733
734 let req = Request::builder().body(Body::empty()).unwrap();
735 let resp = svc.oneshot(req).await.unwrap();
736 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
737 assert!(!called.load(Ordering::SeqCst));
738 }
739
740 #[tokio::test]
743 async fn require_unauthenticated_passes_when_session_absent() {
744 let layer = require_unauthenticated("/app");
745 let svc = layer.layer(tower::service_fn(ok_handler));
746
747 let req = Request::builder().body(Body::empty()).unwrap();
748 let resp = svc.oneshot(req).await.unwrap();
749 assert_eq!(resp.status(), StatusCode::OK);
750 }
751
752 #[tokio::test]
753 async fn require_unauthenticated_redirects_non_htmx_when_session_present() {
754 let layer = require_unauthenticated("/app");
755 let svc = layer.layer(tower::service_fn(ok_handler));
756
757 let mut req = Request::builder().body(Body::empty()).unwrap();
758 req.extensions_mut().insert(test_session());
759 let resp = svc.oneshot(req).await.unwrap();
760 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
761 assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/app");
762 }
763
764 #[tokio::test]
765 async fn require_unauthenticated_redirects_htmx_when_session_present() {
766 let layer = require_unauthenticated("/app");
767 let svc = layer.layer(tower::service_fn(ok_handler));
768
769 let mut req = Request::builder()
770 .header("hx-request", "true")
771 .body(Body::empty())
772 .unwrap();
773 req.extensions_mut().insert(test_session());
774 let resp = svc.oneshot(req).await.unwrap();
775 assert_eq!(resp.status(), StatusCode::OK);
776 assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/app");
777 }
778
779 #[tokio::test]
780 async fn require_unauthenticated_does_not_call_inner_on_reject() {
781 use std::sync::atomic::{AtomicBool, Ordering};
782
783 let called = Arc::new(AtomicBool::new(false));
784 let called_clone = called.clone();
785
786 let layer = require_unauthenticated("/app");
787 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
788 let called = called_clone.clone();
789 async move {
790 called.store(true, Ordering::SeqCst);
791 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
792 }
793 }));
794
795 let mut req = Request::builder().body(Body::empty()).unwrap();
796 req.extensions_mut().insert(test_session());
797 let resp = svc.oneshot(req).await.unwrap();
798 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
799 assert!(!called.load(Ordering::SeqCst));
800 }
801
802 fn meta_with_scopes(scopes: &[&str]) -> ApiKeyMeta {
805 ApiKeyMeta {
806 id: "01HX".into(),
807 tenant_id: "t".into(),
808 name: "test key".into(),
809 scopes: scopes.iter().map(|s| (*s).into()).collect(),
810 expires_at: None,
811 last_used_at: None,
812 created_at: "2026-01-01T00:00:00Z".into(),
813 }
814 }
815
816 #[tokio::test]
817 async fn require_scope_passes_when_scope_present() {
818 let layer = require_scope("read:orders");
819 let svc = layer.layer(tower::service_fn(ok_handler));
820
821 let mut req = Request::builder().body(Body::empty()).unwrap();
822 req.extensions_mut()
823 .insert(meta_with_scopes(&["read:orders", "write:orders"]));
824 let resp = svc.oneshot(req).await.unwrap();
825 assert_eq!(resp.status(), StatusCode::OK);
826 }
827
828 #[tokio::test]
829 async fn require_scope_403_when_scope_absent() {
830 let layer = require_scope("admin:all");
831 let svc = layer.layer(tower::service_fn(ok_handler));
832
833 let mut req = Request::builder().body(Body::empty()).unwrap();
834 req.extensions_mut()
835 .insert(meta_with_scopes(&["read:orders"]));
836 let resp = svc.oneshot(req).await.unwrap();
837 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
838 }
839
840 #[tokio::test]
841 async fn require_scope_500_when_apikey_meta_missing() {
842 let layer = require_scope("read:orders");
843 let svc = layer.layer(tower::service_fn(ok_handler));
844
845 let req = Request::builder().body(Body::empty()).unwrap();
846 let resp = svc.oneshot(req).await.unwrap();
847 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
848 }
849}