use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::response::IntoResponse;
use http::Request;
use tower::{Layer, Service};
use crate::Error;
use crate::auth::apikey::ApiKeyMeta;
use crate::auth::role::Role;
use crate::auth::session::Session;
fn redirect_response(path: &http::HeaderValue, headers: &http::HeaderMap) -> http::Response<Body> {
let is_htmx = headers.get("hx-request").and_then(|v| v.to_str().ok()) == Some("true");
let mut response = http::Response::new(Body::empty());
if is_htmx {
*response.status_mut() = http::StatusCode::OK;
response.headers_mut().insert("hx-redirect", path.clone());
} else {
*response.status_mut() = http::StatusCode::SEE_OTHER;
response
.headers_mut()
.insert(http::header::LOCATION, path.clone());
}
response
}
pub fn require_role(roles: impl IntoIterator<Item = impl Into<String>>) -> RequireRoleLayer {
RequireRoleLayer {
roles: Arc::new(roles.into_iter().map(Into::into).collect()),
}
}
pub struct RequireRoleLayer {
roles: Arc<Vec<String>>,
}
impl Clone for RequireRoleLayer {
fn clone(&self) -> Self {
Self {
roles: self.roles.clone(),
}
}
}
impl<S> Layer<S> for RequireRoleLayer {
type Service = RequireRoleService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequireRoleService {
inner,
roles: self.roles.clone(),
}
}
}
pub struct RequireRoleService<S> {
inner: S,
roles: Arc<Vec<String>>,
}
impl<S: Clone> Clone for RequireRoleService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
roles: self.roles.clone(),
}
}
}
impl<S> Service<Request<Body>> for RequireRoleService<S>
where
S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
{
type Response = http::Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let roles = self.roles.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let role = match request.extensions().get::<Role>() {
Some(r) => r,
None => {
return Ok(Error::unauthorized("authentication required").into_response());
}
};
if !roles.iter().any(|allowed| allowed == role.as_str()) {
return Ok(Error::forbidden("insufficient role").into_response());
}
inner.call(request).await
})
}
}
pub fn require_authenticated(redirect_to: impl Into<String>) -> RequireAuthenticatedLayer {
let raw = redirect_to.into();
let value = http::HeaderValue::from_str(&raw)
.expect("require_authenticated: redirect_to must be a valid HTTP header value");
RequireAuthenticatedLayer {
redirect_to: Arc::new(value),
}
}
pub struct RequireAuthenticatedLayer {
redirect_to: Arc<http::HeaderValue>,
}
impl Clone for RequireAuthenticatedLayer {
fn clone(&self) -> Self {
Self {
redirect_to: self.redirect_to.clone(),
}
}
}
impl<S> Layer<S> for RequireAuthenticatedLayer {
type Service = RequireAuthenticatedService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequireAuthenticatedService {
inner,
redirect_to: self.redirect_to.clone(),
}
}
}
pub struct RequireAuthenticatedService<S> {
inner: S,
redirect_to: Arc<http::HeaderValue>,
}
impl<S: Clone> Clone for RequireAuthenticatedService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
redirect_to: self.redirect_to.clone(),
}
}
}
impl<S> Service<Request<Body>> for RequireAuthenticatedService<S>
where
S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
{
type Response = http::Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let redirect_to = self.redirect_to.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
if request.extensions().get::<Session>().is_none() {
return Ok(redirect_response(&redirect_to, request.headers()));
}
inner.call(request).await
})
}
}
pub fn require_unauthenticated(redirect_to: impl Into<String>) -> RequireUnauthenticatedLayer {
let raw = redirect_to.into();
let value = http::HeaderValue::from_str(&raw)
.expect("require_unauthenticated: redirect_to must be a valid HTTP header value");
RequireUnauthenticatedLayer {
redirect_to: Arc::new(value),
}
}
pub struct RequireUnauthenticatedLayer {
redirect_to: Arc<http::HeaderValue>,
}
impl Clone for RequireUnauthenticatedLayer {
fn clone(&self) -> Self {
Self {
redirect_to: self.redirect_to.clone(),
}
}
}
impl<S> Layer<S> for RequireUnauthenticatedLayer {
type Service = RequireUnauthenticatedService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequireUnauthenticatedService {
inner,
redirect_to: self.redirect_to.clone(),
}
}
}
pub struct RequireUnauthenticatedService<S> {
inner: S,
redirect_to: Arc<http::HeaderValue>,
}
impl<S: Clone> Clone for RequireUnauthenticatedService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
redirect_to: self.redirect_to.clone(),
}
}
}
impl<S> Service<Request<Body>> for RequireUnauthenticatedService<S>
where
S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
{
type Response = http::Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let redirect_to = self.redirect_to.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
if request.extensions().get::<Session>().is_some() {
return Ok(redirect_response(&redirect_to, request.headers()));
}
inner.call(request).await
})
}
}
pub fn require_scope(scope: &str) -> ScopeLayer {
ScopeLayer {
scope: scope.to_owned(),
}
}
#[derive(Clone)]
pub struct ScopeLayer {
scope: String,
}
impl<S> Layer<S> for ScopeLayer {
type Service = ScopeMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
ScopeMiddleware {
inner,
scope: self.scope.clone(),
}
}
}
pub struct ScopeMiddleware<S> {
inner: S,
scope: String,
}
impl<S: Clone> Clone for ScopeMiddleware<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
scope: self.scope.clone(),
}
}
}
impl<S> Service<Request<Body>> for ScopeMiddleware<S>
where
S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
{
type Response = http::Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let scope = self.scope.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let Some(meta) = request.extensions().get::<ApiKeyMeta>() else {
tracing::error!(
"require_scope guard reached without an API key in extensions; \
ApiKeyLayer must run before this guard"
);
return Ok(Error::internal("server misconfigured").into_response());
};
if !meta.scopes.iter().any(|s| s == &scope) {
return Ok(
Error::forbidden(format!("missing required scope: {scope}")).into_response()
);
}
inner.call(request).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use http::{Response, StatusCode};
use std::convert::Infallible;
use tower::ServiceExt;
async fn ok_handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from("ok")))
}
#[test]
fn redirect_response_non_htmx_returns_303_with_location() {
let headers = http::HeaderMap::new();
let path = http::HeaderValue::from_static("/auth");
let resp = redirect_response(&path, &headers);
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
assert!(resp.headers().get("hx-redirect").is_none());
}
#[test]
fn redirect_response_htmx_returns_200_with_hx_redirect() {
let mut headers = http::HeaderMap::new();
headers.insert("hx-request", http::HeaderValue::from_static("true"));
let path = http::HeaderValue::from_static("/app");
let resp = redirect_response(&path, &headers);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/app");
assert!(resp.headers().get(http::header::LOCATION).is_none());
}
#[test]
fn redirect_response_hx_request_false_uses_303() {
let mut headers = http::HeaderMap::new();
headers.insert("hx-request", http::HeaderValue::from_static("false"));
let path = http::HeaderValue::from_static("/x");
let resp = redirect_response(&path, &headers);
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
}
#[test]
#[should_panic(expected = "valid HTTP header value")]
fn require_authenticated_panics_on_invalid_redirect() {
let _ = require_authenticated("bad\npath");
}
#[test]
#[should_panic(expected = "valid HTTP header value")]
fn require_unauthenticated_panics_on_invalid_redirect() {
let _ = require_unauthenticated("bad\npath");
}
#[tokio::test]
async fn require_role_passes_when_role_in_list() {
let layer = require_role(["admin", "owner"]);
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(Role("admin".into()));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_role_403_when_role_not_in_list() {
let layer = require_role(["admin", "owner"]);
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(Role("viewer".into()));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn require_role_401_when_role_missing() {
let layer = require_role(["admin"]);
let svc = layer.layer(tower::service_fn(ok_handler));
let req = Request::builder().body(Body::empty()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn require_role_403_when_empty_roles_list() {
let layer = require_role(std::iter::empty::<String>());
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(Role("admin".into()));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn require_role_empty_string_matches() {
let layer = require_role([""]);
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(Role("".into()));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_role_does_not_call_inner_on_reject() {
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let layer = require_role(["admin"]);
let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
let called = called_clone.clone();
async move {
called.store(true, Ordering::SeqCst);
Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
}
}));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(Role("viewer".into()));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
assert!(!called.load(Ordering::SeqCst));
}
fn test_session() -> Session {
let now = Utc::now();
Session {
id: "sess-1".into(),
user_id: "user-1".into(),
ip_address: "127.0.0.1".into(),
user_agent: "test".into(),
device_name: "test".into(),
device_type: "other".into(),
fingerprint: "fp".into(),
data: serde_json::json!({}),
created_at: now,
last_active_at: now,
expires_at: now + chrono::Duration::hours(1),
}
}
#[tokio::test]
async fn require_authenticated_passes_when_session_present() {
let layer = require_authenticated("/auth");
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(test_session());
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_authenticated_redirects_non_htmx_when_session_missing() {
let layer = require_authenticated("/auth");
let svc = layer.layer(tower::service_fn(ok_handler));
let req = Request::builder().body(Body::empty()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
}
#[tokio::test]
async fn require_authenticated_redirects_htmx_when_session_missing() {
let layer = require_authenticated("/auth");
let svc = layer.layer(tower::service_fn(ok_handler));
let req = Request::builder()
.header("hx-request", "true")
.body(Body::empty())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/auth");
}
#[tokio::test]
async fn require_authenticated_role_without_session_still_redirects() {
let layer = require_authenticated("/auth");
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(Role("admin".into()));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/auth");
}
#[tokio::test]
async fn require_authenticated_does_not_call_inner_on_reject() {
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let layer = require_authenticated("/auth");
let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
let called = called_clone.clone();
async move {
called.store(true, Ordering::SeqCst);
Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
}
}));
let req = Request::builder().body(Body::empty()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert!(!called.load(Ordering::SeqCst));
}
#[tokio::test]
async fn require_unauthenticated_passes_when_session_absent() {
let layer = require_unauthenticated("/app");
let svc = layer.layer(tower::service_fn(ok_handler));
let req = Request::builder().body(Body::empty()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_unauthenticated_redirects_non_htmx_when_session_present() {
let layer = require_unauthenticated("/app");
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(test_session());
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get(http::header::LOCATION).unwrap(), "/app");
}
#[tokio::test]
async fn require_unauthenticated_redirects_htmx_when_session_present() {
let layer = require_unauthenticated("/app");
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder()
.header("hx-request", "true")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(test_session());
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get("hx-redirect").unwrap(), "/app");
}
#[tokio::test]
async fn require_unauthenticated_does_not_call_inner_on_reject() {
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let layer = require_unauthenticated("/app");
let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
let called = called_clone.clone();
async move {
called.store(true, Ordering::SeqCst);
Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
}
}));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut().insert(test_session());
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert!(!called.load(Ordering::SeqCst));
}
fn meta_with_scopes(scopes: &[&str]) -> ApiKeyMeta {
ApiKeyMeta {
id: "01HX".into(),
tenant_id: "t".into(),
name: "test key".into(),
scopes: scopes.iter().map(|s| (*s).into()).collect(),
expires_at: None,
last_used_at: None,
created_at: "2026-01-01T00:00:00Z".into(),
}
}
#[tokio::test]
async fn require_scope_passes_when_scope_present() {
let layer = require_scope("read:orders");
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut()
.insert(meta_with_scopes(&["read:orders", "write:orders"]));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_scope_403_when_scope_absent() {
let layer = require_scope("admin:all");
let svc = layer.layer(tower::service_fn(ok_handler));
let mut req = Request::builder().body(Body::empty()).unwrap();
req.extensions_mut()
.insert(meta_with_scopes(&["read:orders"]));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn require_scope_500_when_apikey_meta_missing() {
let layer = require_scope("read:orders");
let svc = layer.layer(tower::service_fn(ok_handler));
let req = Request::builder().body(Body::empty()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}