acton_htmx/middleware/
auth.rs1use super::helpers::is_htmx_request;
36use axum::{
37 extract::Request,
38 http::StatusCode,
39 middleware::Next,
40 response::{IntoResponse, Redirect, Response},
41};
42
43#[derive(Clone, Debug)]
53pub struct AuthMiddleware {
54 login_path: String,
55}
56
57impl Default for AuthMiddleware {
58 fn default() -> Self {
59 Self {
60 login_path: "/login".to_string(),
61 }
62 }
63}
64
65impl AuthMiddleware {
66 #[must_use]
70 pub fn new() -> Self {
71 Self::default()
72 }
73
74 #[must_use]
84 pub fn with_login_path(login_path: impl Into<String>) -> Self {
85 Self {
86 login_path: login_path.into(),
87 }
88 }
89
90 pub async fn handle(
105 request: Request,
106 next: Next,
107 ) -> Result<Response, AuthMiddlewareError> {
108 Self::default().handle_with_config(request, next).await
109 }
110
111 pub async fn handle_with_config(
124 self,
125 request: Request,
126 next: Next,
127 ) -> Result<Response, AuthMiddlewareError> {
128 let (parts, body) = request.into_parts();
130
131 let session = parts.extensions.get::<crate::auth::Session>().cloned();
133
134 let is_authenticated = session
135 .as_ref()
136 .and_then(super::super::auth::Session::user_id)
137 .is_some();
138
139 if !is_authenticated {
140 return Err(AuthMiddlewareError::for_request(
142 is_htmx_request(&parts.headers),
143 self.login_path,
144 ));
145 }
146
147 let request = Request::from_parts(parts, body);
149 Ok(next.run(request).await)
150 }
151}
152
153#[derive(Debug)]
155pub enum AuthMiddlewareError {
156 Unauthorized(String),
160 RedirectToLogin(String),
164}
165
166impl AuthMiddlewareError {
167 #[must_use]
181 pub fn for_request(is_htmx: bool, login_path: impl Into<String>) -> Self {
182 let login_path = login_path.into();
183 if is_htmx {
184 Self::Unauthorized(login_path)
185 } else {
186 Self::RedirectToLogin(login_path)
187 }
188 }
189}
190
191impl IntoResponse for AuthMiddlewareError {
192 fn into_response(self) -> Response {
193 match self {
194 Self::Unauthorized(login_path) => {
195 (
197 StatusCode::UNAUTHORIZED,
198 [("HX-Redirect", login_path.as_str())],
199 "Unauthorized",
200 )
201 .into_response()
202 }
203 Self::RedirectToLogin(login_path) => {
204 Redirect::to(&login_path).into_response()
206 }
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::auth::{Session, SessionData, SessionId};
215 use axum::{
216 body::Body,
217 http::{Request, StatusCode},
218 middleware,
219 routing::get,
220 Router,
221 };
222 use tower::ServiceExt;
223
224 async fn protected_handler() -> &'static str {
225 "Protected content"
226 }
227
228 #[tokio::test]
229 async fn test_unauthenticated_regular_request_redirects() {
230 let app = Router::new()
231 .route("/protected", get(protected_handler))
232 .layer(middleware::from_fn(AuthMiddleware::handle));
233
234 let request = Request::builder()
235 .uri("/protected")
236 .body(Body::empty())
237 .unwrap();
238
239 let response = app.oneshot(request).await.unwrap();
240
241 assert_eq!(response.status(), StatusCode::SEE_OTHER);
243 assert_eq!(
244 response.headers().get("location").unwrap(),
245 "/login"
246 );
247 }
248
249 #[tokio::test]
250 async fn test_unauthenticated_htmx_request_returns_401() {
251 let app = Router::new()
252 .route("/protected", get(protected_handler))
253 .layer(middleware::from_fn(AuthMiddleware::handle));
254
255 let request = Request::builder()
256 .uri("/protected")
257 .header("HX-Request", "true")
258 .body(Body::empty())
259 .unwrap();
260
261 let response = app.oneshot(request).await.unwrap();
262
263 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
265 assert_eq!(
266 response.headers().get("HX-Redirect").unwrap(),
267 "/login"
268 );
269 }
270
271 #[tokio::test]
272 async fn test_authenticated_request_proceeds() {
273 let app = Router::new()
274 .route("/protected", get(protected_handler))
275 .layer(middleware::from_fn(AuthMiddleware::handle));
276
277 let mut request = Request::builder()
278 .uri("/protected")
279 .body(Body::empty())
280 .unwrap();
281
282 let session_id = SessionId::generate();
284 let mut session_data = SessionData::new();
285 session_data.user_id = Some(1);
286 let session = Session::new(session_id, session_data);
287
288 request.extensions_mut().insert(session);
289
290 let response = app.oneshot(request).await.unwrap();
291
292 assert_eq!(response.status(), StatusCode::OK);
294 }
295
296 #[tokio::test]
297 async fn test_custom_login_path_regular_request() {
298 let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
299 let app = Router::new()
300 .route("/protected", get(protected_handler))
301 .layer(middleware::from_fn(move |req, next| {
302 custom_middleware.clone().handle_with_config(req, next)
303 }));
304
305 let request = Request::builder()
306 .uri("/protected")
307 .body(Body::empty())
308 .unwrap();
309
310 let response = app.oneshot(request).await.unwrap();
311
312 assert_eq!(response.status(), StatusCode::SEE_OTHER);
314 assert_eq!(
315 response.headers().get("location").unwrap(),
316 "/auth/signin"
317 );
318 }
319
320 #[tokio::test]
321 async fn test_custom_login_path_htmx_request() {
322 let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
323 let app = Router::new()
324 .route("/protected", get(protected_handler))
325 .layer(middleware::from_fn(move |req, next| {
326 custom_middleware.clone().handle_with_config(req, next)
327 }));
328
329 let request = Request::builder()
330 .uri("/protected")
331 .header("HX-Request", "true")
332 .body(Body::empty())
333 .unwrap();
334
335 let response = app.oneshot(request).await.unwrap();
336
337 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
339 assert_eq!(
340 response.headers().get("HX-Redirect").unwrap(),
341 "/auth/signin"
342 );
343 }
344
345 #[tokio::test]
346 async fn test_custom_login_path_with_authenticated_request() {
347 let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
348 let app = Router::new()
349 .route("/protected", get(protected_handler))
350 .layer(middleware::from_fn(move |req, next| {
351 custom_middleware.clone().handle_with_config(req, next)
352 }));
353
354 let mut request = Request::builder()
355 .uri("/protected")
356 .body(Body::empty())
357 .unwrap();
358
359 let session_id = SessionId::generate();
361 let mut session_data = SessionData::new();
362 session_data.user_id = Some(1);
363 let session = Session::new(session_id, session_data);
364
365 request.extensions_mut().insert(session);
366
367 let response = app.oneshot(request).await.unwrap();
368
369 assert_eq!(response.status(), StatusCode::OK);
371 }
372
373 #[tokio::test]
374 async fn test_default_login_path_is_slash_login() {
375 let middleware = AuthMiddleware::new();
376 assert_eq!(middleware.login_path, "/login");
377
378 let default_middleware = AuthMiddleware::default();
379 assert_eq!(default_middleware.login_path, "/login");
380 }
381
382 #[tokio::test]
383 async fn test_with_login_path_accepts_string() {
384 let middleware = AuthMiddleware::with_login_path("/custom".to_string());
385 assert_eq!(middleware.login_path, "/custom");
386 }
387
388 #[tokio::test]
389 async fn test_with_login_path_accepts_str() {
390 let middleware = AuthMiddleware::with_login_path("/custom");
391 assert_eq!(middleware.login_path, "/custom");
392 }
393
394 #[test]
395 fn test_for_request_returns_unauthorized_when_htmx() {
396 let error = AuthMiddlewareError::for_request(true, "/login");
397 assert!(matches!(error, AuthMiddlewareError::Unauthorized(path) if path == "/login"));
398 }
399
400 #[test]
401 fn test_for_request_returns_redirect_when_not_htmx() {
402 let error = AuthMiddlewareError::for_request(false, "/login");
403 assert!(matches!(error, AuthMiddlewareError::RedirectToLogin(path) if path == "/login"));
404 }
405
406 #[test]
407 fn test_for_request_accepts_string() {
408 let error = AuthMiddlewareError::for_request(true, "/custom/login".to_string());
409 assert!(matches!(error, AuthMiddlewareError::Unauthorized(path) if path == "/custom/login"));
410 }
411}