use crate::auth::{Session, User, UserError};
use crate::middleware::is_htmx_request;
use crate::state::ActonHtmxState;
use axum::{
extract::{FromRef, FromRequestParts},
http::{request::Parts, StatusCode},
response::{IntoResponse, Redirect, Response},
};
pub struct Authenticated<T>(pub T);
impl<S> FromRequestParts<S> for Authenticated<User>
where
S: Send + Sync,
ActonHtmxState: FromRef<S>,
{
type Rejection = AuthenticationError;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let is_htmx = is_htmx_request(&parts.headers);
let session = parts
.extensions
.get::<Session>()
.cloned()
.ok_or_else(|| AuthenticationError::missing_session(is_htmx))?;
let user_id = session
.user_id()
.ok_or_else(|| AuthenticationError::not_authenticated(is_htmx))?;
let app_state = ActonHtmxState::from_ref(state);
let user = User::find_by_id(user_id, app_state.database_pool())
.await
.map_err(|e| match e {
UserError::NotFound => AuthenticationError::not_authenticated(is_htmx),
_ => AuthenticationError::DatabaseError(e),
})?;
Ok(Self(user))
}
}
pub struct OptionalAuth<T>(pub Option<T>);
impl<S> FromRequestParts<S> for OptionalAuth<User>
where
S: Send + Sync,
ActonHtmxState: FromRef<S>,
{
type Rejection = AuthenticationError;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Some(session) = parts.extensions.get::<Session>().cloned() else {
return Ok(Self(None)); };
let Some(user_id) = session.user_id() else {
return Ok(Self(None)); };
let app_state = ActonHtmxState::from_ref(state);
let user = User::find_by_id(user_id, app_state.database_pool())
.await
.ok();
Ok(Self(user))
}
}
#[derive(Debug)]
pub enum AuthenticationError {
MissingSessionHtmx,
MissingSession,
NotAuthenticatedHtmx,
NotAuthenticated,
DatabaseNotConfigured,
DatabaseError(UserError),
}
impl AuthenticationError {
#[must_use]
pub const fn missing_session(is_htmx: bool) -> Self {
if is_htmx {
Self::MissingSessionHtmx
} else {
Self::MissingSession
}
}
#[must_use]
pub const fn not_authenticated(is_htmx: bool) -> Self {
if is_htmx {
Self::NotAuthenticatedHtmx
} else {
Self::NotAuthenticated
}
}
}
impl IntoResponse for AuthenticationError {
fn into_response(self) -> Response {
match self {
Self::MissingSessionHtmx | Self::NotAuthenticatedHtmx => {
(
StatusCode::UNAUTHORIZED,
[("HX-Redirect", "/login")],
"Unauthorized",
)
.into_response()
}
Self::MissingSession | Self::NotAuthenticated => {
Redirect::to("/login").into_response()
}
Self::DatabaseNotConfigured => {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Database not configured",
)
.into_response()
}
Self::DatabaseError(_) => {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to load user",
)
.into_response()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
#[test]
fn test_authentication_error_missing_session_regular_returns_redirect() {
let error = AuthenticationError::MissingSession;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::SEE_OTHER);
assert_eq!(
response.headers().get("location").unwrap(),
"/login"
);
}
#[test]
fn test_authentication_error_missing_session_htmx_returns_401_with_hx_redirect() {
let error = AuthenticationError::MissingSessionHtmx;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(
response.headers().get("HX-Redirect").unwrap(),
"/login"
);
}
#[test]
fn test_authentication_error_not_authenticated_regular_returns_redirect() {
let error = AuthenticationError::NotAuthenticated;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::SEE_OTHER);
assert_eq!(
response.headers().get("location").unwrap(),
"/login"
);
}
#[test]
fn test_authentication_error_not_authenticated_htmx_returns_401_with_hx_redirect() {
let error = AuthenticationError::NotAuthenticatedHtmx;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(
response.headers().get("HX-Redirect").unwrap(),
"/login"
);
}
#[test]
fn test_authentication_error_database_not_configured_returns_500() {
let error = AuthenticationError::DatabaseNotConfigured;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_authentication_error_database_error_returns_500() {
let error = AuthenticationError::DatabaseError(UserError::NotFound);
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_missing_session_helper_returns_htmx_variant_when_is_htmx_true() {
let error = AuthenticationError::missing_session(true);
assert!(matches!(error, AuthenticationError::MissingSessionHtmx));
}
#[test]
fn test_missing_session_helper_returns_regular_variant_when_is_htmx_false() {
let error = AuthenticationError::missing_session(false);
assert!(matches!(error, AuthenticationError::MissingSession));
}
#[test]
fn test_not_authenticated_helper_returns_htmx_variant_when_is_htmx_true() {
let error = AuthenticationError::not_authenticated(true);
assert!(matches!(error, AuthenticationError::NotAuthenticatedHtmx));
}
#[test]
fn test_not_authenticated_helper_returns_regular_variant_when_is_htmx_false() {
let error = AuthenticationError::not_authenticated(false);
assert!(matches!(error, AuthenticationError::NotAuthenticated));
}
}