acton_htmx/auth/
extractors.rs1use crate::auth::{Session, User, UserError};
37use crate::middleware::is_htmx_request;
38use crate::state::ActonHtmxState;
39use axum::{
40 extract::{FromRef, FromRequestParts},
41 http::{request::Parts, StatusCode},
42 response::{IntoResponse, Redirect, Response},
43};
44
45pub struct Authenticated<T>(pub T);
64
65impl<S> FromRequestParts<S> for Authenticated<User>
66where
67 S: Send + Sync,
68 ActonHtmxState: FromRef<S>,
69{
70 type Rejection = AuthenticationError;
71
72 async fn from_request_parts(
73 parts: &mut Parts,
74 state: &S,
75 ) -> Result<Self, Self::Rejection> {
76 let is_htmx = is_htmx_request(&parts.headers);
78
79 let session = parts
81 .extensions
82 .get::<Session>()
83 .cloned()
84 .ok_or_else(|| AuthenticationError::missing_session(is_htmx))?;
85
86 let user_id = session
88 .user_id()
89 .ok_or_else(|| AuthenticationError::not_authenticated(is_htmx))?;
90
91 let app_state = ActonHtmxState::from_ref(state);
93
94 let user = User::find_by_id(user_id, app_state.database_pool())
96 .await
97 .map_err(|e| match e {
98 UserError::NotFound => AuthenticationError::not_authenticated(is_htmx),
99 _ => AuthenticationError::DatabaseError(e),
100 })?;
101
102 Ok(Self(user))
103 }
104}
105
106pub struct OptionalAuth<T>(pub Option<T>);
126
127impl<S> FromRequestParts<S> for OptionalAuth<User>
128where
129 S: Send + Sync,
130 ActonHtmxState: FromRef<S>,
131{
132 type Rejection = AuthenticationError;
133
134 async fn from_request_parts(
135 parts: &mut Parts,
136 state: &S,
137 ) -> Result<Self, Self::Rejection> {
138 let Some(session) = parts.extensions.get::<Session>().cloned() else {
140 return Ok(Self(None)); };
142
143 let Some(user_id) = session.user_id() else {
145 return Ok(Self(None)); };
147
148 let app_state = ActonHtmxState::from_ref(state);
150
151 let user = User::find_by_id(user_id, app_state.database_pool())
153 .await
154 .ok(); Ok(Self(user))
157 }
158}
159
160#[derive(Debug)]
162pub enum AuthenticationError {
163 MissingSessionHtmx,
165
166 MissingSession,
168
169 NotAuthenticatedHtmx,
171
172 NotAuthenticated,
174
175 DatabaseNotConfigured,
177
178 DatabaseError(UserError),
180}
181
182impl AuthenticationError {
183 #[must_use]
196 pub const fn missing_session(is_htmx: bool) -> Self {
197 if is_htmx {
198 Self::MissingSessionHtmx
199 } else {
200 Self::MissingSession
201 }
202 }
203
204 #[must_use]
217 pub const fn not_authenticated(is_htmx: bool) -> Self {
218 if is_htmx {
219 Self::NotAuthenticatedHtmx
220 } else {
221 Self::NotAuthenticated
222 }
223 }
224}
225
226impl IntoResponse for AuthenticationError {
227 fn into_response(self) -> Response {
228 match self {
229 Self::MissingSessionHtmx | Self::NotAuthenticatedHtmx => {
230 (
232 StatusCode::UNAUTHORIZED,
233 [("HX-Redirect", "/login")],
234 "Unauthorized",
235 )
236 .into_response()
237 }
238 Self::MissingSession | Self::NotAuthenticated => {
239 Redirect::to("/login").into_response()
241 }
242 Self::DatabaseNotConfigured => {
243 (
244 StatusCode::INTERNAL_SERVER_ERROR,
245 "Database not configured",
246 )
247 .into_response()
248 }
249 Self::DatabaseError(_) => {
250 (
251 StatusCode::INTERNAL_SERVER_ERROR,
252 "Failed to load user",
253 )
254 .into_response()
255 }
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use axum::http::StatusCode;
264
265 #[test]
266 fn test_authentication_error_missing_session_regular_returns_redirect() {
267 let error = AuthenticationError::MissingSession;
268 let response = error.into_response();
269
270 assert_eq!(response.status(), StatusCode::SEE_OTHER);
271 assert_eq!(
272 response.headers().get("location").unwrap(),
273 "/login"
274 );
275 }
276
277 #[test]
278 fn test_authentication_error_missing_session_htmx_returns_401_with_hx_redirect() {
279 let error = AuthenticationError::MissingSessionHtmx;
280 let response = error.into_response();
281
282 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
283 assert_eq!(
284 response.headers().get("HX-Redirect").unwrap(),
285 "/login"
286 );
287 }
288
289 #[test]
290 fn test_authentication_error_not_authenticated_regular_returns_redirect() {
291 let error = AuthenticationError::NotAuthenticated;
292 let response = error.into_response();
293
294 assert_eq!(response.status(), StatusCode::SEE_OTHER);
295 assert_eq!(
296 response.headers().get("location").unwrap(),
297 "/login"
298 );
299 }
300
301 #[test]
302 fn test_authentication_error_not_authenticated_htmx_returns_401_with_hx_redirect() {
303 let error = AuthenticationError::NotAuthenticatedHtmx;
304 let response = error.into_response();
305
306 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
307 assert_eq!(
308 response.headers().get("HX-Redirect").unwrap(),
309 "/login"
310 );
311 }
312
313 #[test]
314 fn test_authentication_error_database_not_configured_returns_500() {
315 let error = AuthenticationError::DatabaseNotConfigured;
316 let response = error.into_response();
317
318 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
319 }
320
321 #[test]
322 fn test_authentication_error_database_error_returns_500() {
323 let error = AuthenticationError::DatabaseError(UserError::NotFound);
324 let response = error.into_response();
325
326 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
327 }
328
329 #[test]
330 fn test_missing_session_helper_returns_htmx_variant_when_is_htmx_true() {
331 let error = AuthenticationError::missing_session(true);
332 assert!(matches!(error, AuthenticationError::MissingSessionHtmx));
333 }
334
335 #[test]
336 fn test_missing_session_helper_returns_regular_variant_when_is_htmx_false() {
337 let error = AuthenticationError::missing_session(false);
338 assert!(matches!(error, AuthenticationError::MissingSession));
339 }
340
341 #[test]
342 fn test_not_authenticated_helper_returns_htmx_variant_when_is_htmx_true() {
343 let error = AuthenticationError::not_authenticated(true);
344 assert!(matches!(error, AuthenticationError::NotAuthenticatedHtmx));
345 }
346
347 #[test]
348 fn test_not_authenticated_helper_returns_regular_variant_when_is_htmx_false() {
349 let error = AuthenticationError::not_authenticated(false);
350 assert!(matches!(error, AuthenticationError::NotAuthenticated));
351 }
352}