use super::helpers::is_htmx_request;
use axum::{
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Redirect, Response},
};
#[derive(Clone, Debug)]
pub struct AuthMiddleware {
login_path: String,
}
impl Default for AuthMiddleware {
fn default() -> Self {
Self {
login_path: "/login".to_string(),
}
}
}
impl AuthMiddleware {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_login_path(login_path: impl Into<String>) -> Self {
Self {
login_path: login_path.into(),
}
}
pub async fn handle(
request: Request,
next: Next,
) -> Result<Response, AuthMiddlewareError> {
Self::default().handle_with_config(request, next).await
}
pub async fn handle_with_config(
self,
request: Request,
next: Next,
) -> Result<Response, AuthMiddlewareError> {
let (parts, body) = request.into_parts();
let session = parts.extensions.get::<crate::auth::Session>().cloned();
let is_authenticated = session
.as_ref()
.and_then(super::super::auth::Session::user_id)
.is_some();
if !is_authenticated {
return Err(AuthMiddlewareError::for_request(
is_htmx_request(&parts.headers),
self.login_path,
));
}
let request = Request::from_parts(parts, body);
Ok(next.run(request).await)
}
}
#[derive(Debug)]
pub enum AuthMiddlewareError {
Unauthorized(String),
RedirectToLogin(String),
}
impl AuthMiddlewareError {
#[must_use]
pub fn for_request(is_htmx: bool, login_path: impl Into<String>) -> Self {
let login_path = login_path.into();
if is_htmx {
Self::Unauthorized(login_path)
} else {
Self::RedirectToLogin(login_path)
}
}
}
impl IntoResponse for AuthMiddlewareError {
fn into_response(self) -> Response {
match self {
Self::Unauthorized(login_path) => {
(
StatusCode::UNAUTHORIZED,
[("HX-Redirect", login_path.as_str())],
"Unauthorized",
)
.into_response()
}
Self::RedirectToLogin(login_path) => {
Redirect::to(&login_path).into_response()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::{Session, SessionData, SessionId};
use axum::{
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
Router,
};
use tower::ServiceExt;
async fn protected_handler() -> &'static str {
"Protected content"
}
#[tokio::test]
async fn test_unauthenticated_regular_request_redirects() {
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn(AuthMiddleware::handle));
let request = Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::SEE_OTHER);
assert_eq!(
response.headers().get("location").unwrap(),
"/login"
);
}
#[tokio::test]
async fn test_unauthenticated_htmx_request_returns_401() {
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn(AuthMiddleware::handle));
let request = Request::builder()
.uri("/protected")
.header("HX-Request", "true")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(
response.headers().get("HX-Redirect").unwrap(),
"/login"
);
}
#[tokio::test]
async fn test_authenticated_request_proceeds() {
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn(AuthMiddleware::handle));
let mut request = Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let session_id = SessionId::generate();
let mut session_data = SessionData::new();
session_data.user_id = Some(1);
let session = Session::new(session_id, session_data);
request.extensions_mut().insert(session);
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_custom_login_path_regular_request() {
let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn(move |req, next| {
custom_middleware.clone().handle_with_config(req, next)
}));
let request = Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::SEE_OTHER);
assert_eq!(
response.headers().get("location").unwrap(),
"/auth/signin"
);
}
#[tokio::test]
async fn test_custom_login_path_htmx_request() {
let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn(move |req, next| {
custom_middleware.clone().handle_with_config(req, next)
}));
let request = Request::builder()
.uri("/protected")
.header("HX-Request", "true")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(
response.headers().get("HX-Redirect").unwrap(),
"/auth/signin"
);
}
#[tokio::test]
async fn test_custom_login_path_with_authenticated_request() {
let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn(move |req, next| {
custom_middleware.clone().handle_with_config(req, next)
}));
let mut request = Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let session_id = SessionId::generate();
let mut session_data = SessionData::new();
session_data.user_id = Some(1);
let session = Session::new(session_id, session_data);
request.extensions_mut().insert(session);
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_default_login_path_is_slash_login() {
let middleware = AuthMiddleware::new();
assert_eq!(middleware.login_path, "/login");
let default_middleware = AuthMiddleware::default();
assert_eq!(default_middleware.login_path, "/login");
}
#[tokio::test]
async fn test_with_login_path_accepts_string() {
let middleware = AuthMiddleware::with_login_path("/custom".to_string());
assert_eq!(middleware.login_path, "/custom");
}
#[tokio::test]
async fn test_with_login_path_accepts_str() {
let middleware = AuthMiddleware::with_login_path("/custom");
assert_eq!(middleware.login_path, "/custom");
}
#[test]
fn test_for_request_returns_unauthorized_when_htmx() {
let error = AuthMiddlewareError::for_request(true, "/login");
assert!(matches!(error, AuthMiddlewareError::Unauthorized(path) if path == "/login"));
}
#[test]
fn test_for_request_returns_redirect_when_not_htmx() {
let error = AuthMiddlewareError::for_request(false, "/login");
assert!(matches!(error, AuthMiddlewareError::RedirectToLogin(path) if path == "/login"));
}
#[test]
fn test_for_request_accepts_string() {
let error = AuthMiddlewareError::for_request(true, "/custom/login".to_string());
assert!(matches!(error, AuthMiddlewareError::Unauthorized(path) if path == "/custom/login"));
}
}