acton_htmx/middleware/
auth.rs

1//! Authentication middleware for protecting routes
2//!
3//! This module provides middleware for requiring authentication on routes
4//! and extractors for accessing the authenticated user.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use acton_htmx::middleware::AuthMiddleware;
10//! use acton_htmx::auth::Authenticated;
11//! use axum::{Router, routing::get, middleware};
12//!
13//! async fn protected_handler(
14//!     Authenticated(user): Authenticated<acton_htmx::auth::User>,
15//! ) -> String {
16//!     format!("Hello, {}!", user.email)
17//! }
18//!
19//! # async fn example() {
20//! // Default login path (/login)
21//! let app = Router::new()
22//!     .route("/protected", get(protected_handler))
23//!     .layer(middleware::from_fn(AuthMiddleware::handle));
24//!
25//! // Custom login path
26//! let custom_middleware = AuthMiddleware::with_login_path("/auth/login");
27//! let app = Router::new()
28//!     .route("/protected", get(protected_handler))
29//!     .layer(middleware::from_fn(move |req, next| {
30//!         custom_middleware.clone().handle_with_config(req, next)
31//!     }));
32//! # }
33//! ```
34
35use super::helpers::is_htmx_request;
36use axum::{
37    extract::Request,
38    http::StatusCode,
39    middleware::Next,
40    response::{IntoResponse, Redirect, Response},
41};
42
43/// Middleware that requires authentication for routes
44///
45/// If the user is not authenticated, they will be redirected to the login page.
46/// For HTMX requests, returns a 401 Unauthorized status with HX-Redirect header.
47///
48/// # Login Path Configuration
49///
50/// By default, unauthenticated users are redirected to `/login`. This can be
51/// customized using [`AuthMiddleware::with_login_path`].
52#[derive(Clone, Debug)]
53pub struct AuthMiddleware {
54    login_path: String,
55}
56
57impl Default for AuthMiddleware {
58    fn default() -> Self {
59        Self {
60            login_path: "/login".to_string(),
61        }
62    }
63}
64
65impl AuthMiddleware {
66    /// Create a new authentication middleware with default settings
67    ///
68    /// By default, redirects to `/login` for unauthenticated requests.
69    #[must_use]
70    pub fn new() -> Self {
71        Self::default()
72    }
73
74    /// Create authentication middleware with custom login path
75    ///
76    /// # Example
77    ///
78    /// ```rust
79    /// use acton_htmx::middleware::AuthMiddleware;
80    ///
81    /// let middleware = AuthMiddleware::with_login_path("/auth/login");
82    /// ```
83    #[must_use]
84    pub fn with_login_path(login_path: impl Into<String>) -> Self {
85        Self {
86            login_path: login_path.into(),
87        }
88    }
89
90    /// Middleware handler that checks for authentication with default login path
91    ///
92    /// This is a convenience method that uses the default login path `/login`.
93    /// For custom login paths, use [`AuthMiddleware::with_login_path`] and
94    /// [`AuthMiddleware::handle_with_config`].
95    ///
96    /// # Errors
97    ///
98    /// Returns [`AuthMiddlewareError`] if:
99    /// - No valid session exists in request extensions
100    /// - Session exists but does not contain a user_id
101    ///
102    /// For HTMX requests, returns 401 with HX-Redirect header to login page.
103    /// For standard browser requests, redirects to login page.
104    pub async fn handle(
105        request: Request,
106        next: Next,
107    ) -> Result<Response, AuthMiddlewareError> {
108        Self::default().handle_with_config(request, next).await
109    }
110
111    /// Middleware handler that checks for authentication with configured login path
112    ///
113    /// This method uses the login path configured in this middleware instance.
114    ///
115    /// # Errors
116    ///
117    /// Returns [`AuthMiddlewareError`] if:
118    /// - No valid session exists in request extensions
119    /// - Session exists but does not contain a user_id
120    ///
121    /// For HTMX requests, returns 401 with HX-Redirect header to configured login page.
122    /// For standard browser requests, redirects to configured login page.
123    pub async fn handle_with_config(
124        self,
125        request: Request,
126        next: Next,
127    ) -> Result<Response, AuthMiddlewareError> {
128        // Check if user is authenticated by looking for user_id in session
129        let (parts, body) = request.into_parts();
130
131        // Get session from request extensions
132        let session = parts.extensions.get::<crate::auth::Session>().cloned();
133
134        let is_authenticated = session
135            .as_ref()
136            .and_then(super::super::auth::Session::user_id)
137            .is_some();
138
139        if !is_authenticated {
140            // Use helper to create appropriate error for request type
141            return Err(AuthMiddlewareError::for_request(
142                is_htmx_request(&parts.headers),
143                self.login_path,
144            ));
145        }
146
147        // User is authenticated, continue with the request
148        let request = Request::from_parts(parts, body);
149        Ok(next.run(request).await)
150    }
151}
152
153/// Authentication middleware errors
154#[derive(Debug)]
155pub enum AuthMiddlewareError {
156    /// User is not authenticated (HTMX request)
157    ///
158    /// Contains the login path to redirect to
159    Unauthorized(String),
160    /// Redirect to login page (regular request)
161    ///
162    /// Contains the login path to redirect to
163    RedirectToLogin(String),
164}
165
166impl AuthMiddlewareError {
167    /// Create an authentication error appropriate for the request type.
168    ///
169    /// This helper reduces duplication by encapsulating the HTMX detection logic.
170    ///
171    /// # Arguments
172    ///
173    /// * `is_htmx` - Whether the request is from HTMX
174    /// * `login_path` - The path to redirect to for login
175    ///
176    /// # Returns
177    ///
178    /// * [`Unauthorized`](Self::Unauthorized) for HTMX requests (returns 401 with HX-Redirect)
179    /// * [`RedirectToLogin`](Self::RedirectToLogin) for regular requests (returns 303 redirect)
180    #[must_use]
181    pub fn for_request(is_htmx: bool, login_path: impl Into<String>) -> Self {
182        let login_path = login_path.into();
183        if is_htmx {
184            Self::Unauthorized(login_path)
185        } else {
186            Self::RedirectToLogin(login_path)
187        }
188    }
189}
190
191impl IntoResponse for AuthMiddlewareError {
192    fn into_response(self) -> Response {
193        match self {
194            Self::Unauthorized(login_path) => {
195                // Return 401 with HX-Redirect header for HTMX
196                (
197                    StatusCode::UNAUTHORIZED,
198                    [("HX-Redirect", login_path.as_str())],
199                    "Unauthorized",
200                )
201                    .into_response()
202            }
203            Self::RedirectToLogin(login_path) => {
204                // Regular HTTP redirect
205                Redirect::to(&login_path).into_response()
206            }
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::auth::{Session, SessionData, SessionId};
215    use axum::{
216        body::Body,
217        http::{Request, StatusCode},
218        middleware,
219        routing::get,
220        Router,
221    };
222    use tower::ServiceExt;
223
224    async fn protected_handler() -> &'static str {
225        "Protected content"
226    }
227
228    #[tokio::test]
229    async fn test_unauthenticated_regular_request_redirects() {
230        let app = Router::new()
231            .route("/protected", get(protected_handler))
232            .layer(middleware::from_fn(AuthMiddleware::handle));
233
234        let request = Request::builder()
235            .uri("/protected")
236            .body(Body::empty())
237            .unwrap();
238
239        let response = app.oneshot(request).await.unwrap();
240
241        // Should redirect to login
242        assert_eq!(response.status(), StatusCode::SEE_OTHER);
243        assert_eq!(
244            response.headers().get("location").unwrap(),
245            "/login"
246        );
247    }
248
249    #[tokio::test]
250    async fn test_unauthenticated_htmx_request_returns_401() {
251        let app = Router::new()
252            .route("/protected", get(protected_handler))
253            .layer(middleware::from_fn(AuthMiddleware::handle));
254
255        let request = Request::builder()
256            .uri("/protected")
257            .header("HX-Request", "true")
258            .body(Body::empty())
259            .unwrap();
260
261        let response = app.oneshot(request).await.unwrap();
262
263        // Should return 401 with HX-Redirect header
264        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
265        assert_eq!(
266            response.headers().get("HX-Redirect").unwrap(),
267            "/login"
268        );
269    }
270
271    #[tokio::test]
272    async fn test_authenticated_request_proceeds() {
273        let app = Router::new()
274            .route("/protected", get(protected_handler))
275            .layer(middleware::from_fn(AuthMiddleware::handle));
276
277        let mut request = Request::builder()
278            .uri("/protected")
279            .body(Body::empty())
280            .unwrap();
281
282        // Add authenticated session to request extensions
283        let session_id = SessionId::generate();
284        let mut session_data = SessionData::new();
285        session_data.user_id = Some(1);
286        let session = Session::new(session_id, session_data);
287
288        request.extensions_mut().insert(session);
289
290        let response = app.oneshot(request).await.unwrap();
291
292        // Should proceed to handler
293        assert_eq!(response.status(), StatusCode::OK);
294    }
295
296    #[tokio::test]
297    async fn test_custom_login_path_regular_request() {
298        let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
299        let app = Router::new()
300            .route("/protected", get(protected_handler))
301            .layer(middleware::from_fn(move |req, next| {
302                custom_middleware.clone().handle_with_config(req, next)
303            }));
304
305        let request = Request::builder()
306            .uri("/protected")
307            .body(Body::empty())
308            .unwrap();
309
310        let response = app.oneshot(request).await.unwrap();
311
312        // Should redirect to custom login path
313        assert_eq!(response.status(), StatusCode::SEE_OTHER);
314        assert_eq!(
315            response.headers().get("location").unwrap(),
316            "/auth/signin"
317        );
318    }
319
320    #[tokio::test]
321    async fn test_custom_login_path_htmx_request() {
322        let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
323        let app = Router::new()
324            .route("/protected", get(protected_handler))
325            .layer(middleware::from_fn(move |req, next| {
326                custom_middleware.clone().handle_with_config(req, next)
327            }));
328
329        let request = Request::builder()
330            .uri("/protected")
331            .header("HX-Request", "true")
332            .body(Body::empty())
333            .unwrap();
334
335        let response = app.oneshot(request).await.unwrap();
336
337        // Should return 401 with HX-Redirect to custom login path
338        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
339        assert_eq!(
340            response.headers().get("HX-Redirect").unwrap(),
341            "/auth/signin"
342        );
343    }
344
345    #[tokio::test]
346    async fn test_custom_login_path_with_authenticated_request() {
347        let custom_middleware = AuthMiddleware::with_login_path("/auth/signin");
348        let app = Router::new()
349            .route("/protected", get(protected_handler))
350            .layer(middleware::from_fn(move |req, next| {
351                custom_middleware.clone().handle_with_config(req, next)
352            }));
353
354        let mut request = Request::builder()
355            .uri("/protected")
356            .body(Body::empty())
357            .unwrap();
358
359        // Add authenticated session to request extensions
360        let session_id = SessionId::generate();
361        let mut session_data = SessionData::new();
362        session_data.user_id = Some(1);
363        let session = Session::new(session_id, session_data);
364
365        request.extensions_mut().insert(session);
366
367        let response = app.oneshot(request).await.unwrap();
368
369        // Should proceed to handler regardless of custom login path
370        assert_eq!(response.status(), StatusCode::OK);
371    }
372
373    #[tokio::test]
374    async fn test_default_login_path_is_slash_login() {
375        let middleware = AuthMiddleware::new();
376        assert_eq!(middleware.login_path, "/login");
377
378        let default_middleware = AuthMiddleware::default();
379        assert_eq!(default_middleware.login_path, "/login");
380    }
381
382    #[tokio::test]
383    async fn test_with_login_path_accepts_string() {
384        let middleware = AuthMiddleware::with_login_path("/custom".to_string());
385        assert_eq!(middleware.login_path, "/custom");
386    }
387
388    #[tokio::test]
389    async fn test_with_login_path_accepts_str() {
390        let middleware = AuthMiddleware::with_login_path("/custom");
391        assert_eq!(middleware.login_path, "/custom");
392    }
393
394    #[test]
395    fn test_for_request_returns_unauthorized_when_htmx() {
396        let error = AuthMiddlewareError::for_request(true, "/login");
397        assert!(matches!(error, AuthMiddlewareError::Unauthorized(path) if path == "/login"));
398    }
399
400    #[test]
401    fn test_for_request_returns_redirect_when_not_htmx() {
402        let error = AuthMiddlewareError::for_request(false, "/login");
403        assert!(matches!(error, AuthMiddlewareError::RedirectToLogin(path) if path == "/login"));
404    }
405
406    #[test]
407    fn test_for_request_accepts_string() {
408        let error = AuthMiddlewareError::for_request(true, "/custom/login".to_string());
409        assert!(matches!(error, AuthMiddlewareError::Unauthorized(path) if path == "/custom/login"));
410    }
411}