Skip to main content

allowthem_server/
extractors.rs

1use std::sync::Arc;
2
3use axum::extract::{FromRef, FromRequestParts};
4use axum::http::header::COOKIE;
5use axum::http::request::Parts;
6use axum::response::{IntoResponse, Response};
7
8use allowthem_core::{AuthClient, RoleName, User, parse_session_cookie};
9
10use crate::error::{AuthExtractError, BrowserAdminForbidden, BrowserAuthRedirect};
11
12/// Axum extractor that provides the authenticated user.
13///
14/// Reads the session cookie, validates the session (with sliding-window
15/// renewal), and fetches the user. Rejects with 401 if not authenticated.
16///
17/// Usage: `AuthUser(user): AuthUser` in handler arguments.
18pub struct AuthUser(pub User);
19
20impl<S> FromRequestParts<S> for AuthUser
21where
22    Arc<dyn AuthClient>: FromRef<S>,
23    S: Send + Sync,
24{
25    type Rejection = AuthExtractError;
26
27    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
28        let client = <Arc<dyn AuthClient>>::from_ref(state);
29
30        let cookie_header = parts
31            .headers
32            .get(COOKIE)
33            .and_then(|v| v.to_str().ok())
34            .ok_or(AuthExtractError::Unauthenticated)?;
35
36        let token = parse_session_cookie(cookie_header, client.session_cookie_name())
37            .ok_or(AuthExtractError::Unauthenticated)?;
38
39        let user = client
40            .validate_session(&token)
41            .await
42            .map_err(AuthExtractError::Internal)?
43            .ok_or(AuthExtractError::Unauthenticated)?;
44
45        Ok(AuthUser(user))
46    }
47}
48
49/// Axum extractor that optionally provides the authenticated user.
50///
51/// Same flow as [`AuthUser`] but wraps `Option<User>` and never rejects.
52/// Returns `None` when not authenticated. Returns `Some(user)` when valid.
53/// Internal errors (database failures) are logged and treated as `None`.
54///
55/// Usage: `OptionalAuthUser(user): OptionalAuthUser` in handler arguments.
56pub struct OptionalAuthUser(pub Option<User>);
57
58impl<S> FromRequestParts<S> for OptionalAuthUser
59where
60    Arc<dyn AuthClient>: FromRef<S>,
61    S: Send + Sync,
62{
63    type Rejection = std::convert::Infallible;
64
65    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
66        match AuthUser::from_request_parts(parts, state).await {
67            Ok(AuthUser(user)) => Ok(OptionalAuthUser(Some(user))),
68            Err(AuthExtractError::Internal(err)) => {
69                tracing::error!("auth extraction error: {err}");
70                Ok(OptionalAuthUser(None))
71            }
72            Err(_) => Ok(OptionalAuthUser(None)),
73        }
74    }
75}
76
77/// Axum extractor for browser-facing routes that require authentication.
78///
79/// Same session validation as [`AuthUser`], but rejects with a 303 redirect
80/// to `/login?next={path}` instead of a JSON 401. Use this for routes that
81/// render HTML — unauthenticated users are sent to the login page and
82/// returned to the original path after logging in.
83pub struct BrowserAuthUser(pub User);
84
85impl<S> FromRequestParts<S> for BrowserAuthUser
86where
87    Arc<dyn AuthClient>: FromRef<S>,
88    S: Send + Sync,
89{
90    type Rejection = BrowserAuthRedirect;
91
92    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
93        let redirect = BrowserAuthRedirect::new(parts.uri.path());
94        let client = <Arc<dyn AuthClient>>::from_ref(state);
95
96        let cookie_header = parts
97            .headers
98            .get(COOKIE)
99            .and_then(|v| v.to_str().ok())
100            .ok_or(redirect)?;
101
102        let redirect = BrowserAuthRedirect::new(parts.uri.path());
103        let token =
104            parse_session_cookie(cookie_header, client.session_cookie_name()).ok_or(redirect)?;
105
106        let redirect = BrowserAuthRedirect::new(parts.uri.path());
107        let user = client
108            .validate_session(&token)
109            .await
110            .map_err(|err| {
111                tracing::error!("auth extraction error: {err}");
112                BrowserAuthRedirect::new(parts.uri.path())
113            })?
114            .ok_or(redirect)?;
115
116        Ok(BrowserAuthUser(user))
117    }
118}
119
120/// Axum extractor for admin browser routes.
121///
122/// Validates the session cookie and checks the `admin` role. Rejects with
123/// a redirect to `/login` if unauthenticated, or a 403 HTML response if
124/// authenticated but not an admin.
125pub struct BrowserAdminUser(pub User);
126
127impl<S> FromRequestParts<S> for BrowserAdminUser
128where
129    Arc<dyn AuthClient>: FromRef<S>,
130    S: Send + Sync,
131{
132    type Rejection = Response;
133
134    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
135        let client = <Arc<dyn AuthClient>>::from_ref(state);
136
137        // 1. Session validation — same flow as BrowserAuthUser
138        let cookie_header = parts
139            .headers
140            .get(COOKIE)
141            .and_then(|v| v.to_str().ok())
142            .ok_or_else(|| BrowserAuthRedirect::new(parts.uri.path()).into_response())?;
143
144        let token = parse_session_cookie(cookie_header, client.session_cookie_name())
145            .ok_or_else(|| BrowserAuthRedirect::new(parts.uri.path()).into_response())?;
146
147        let user = client
148            .validate_session(&token)
149            .await
150            .map_err(|err| {
151                tracing::error!("auth extraction error: {err}");
152                BrowserAuthRedirect::new(parts.uri.path()).into_response()
153            })?
154            .ok_or_else(|| BrowserAuthRedirect::new(parts.uri.path()).into_response())?;
155
156        // 2. Admin role check
157        let admin_role = RoleName::new("admin");
158        let is_admin = client
159            .check_role(&user.id, &admin_role)
160            .await
161            .map_err(|err| {
162                tracing::error!("role check error: {err}");
163                BrowserAdminForbidden.into_response()
164            })?;
165
166        if !is_admin {
167            return Err(BrowserAdminForbidden.into_response());
168        }
169
170        Ok(BrowserAdminUser(user))
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use std::sync::Arc;
177
178    use super::*;
179    use allowthem_core::{
180        AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, RoleName,
181        generate_token, hash_token,
182    };
183    use axum::extract::FromRef;
184    use axum::http::{Request, StatusCode};
185    use axum::routing::get;
186    use axum::{Json, Router};
187    use chrono::{Duration, Utc};
188    use tower::ServiceExt;
189
190    #[derive(Clone)]
191    struct TestState {
192        auth: Arc<dyn AuthClient>,
193    }
194
195    impl FromRef<TestState> for Arc<dyn AuthClient> {
196        fn from_ref(s: &TestState) -> Self {
197            Arc::clone(&s.auth)
198        }
199    }
200
201    /// Build an AllowThem, create a test user with an active session,
202    /// and return (AllowThem, cookie_header_value).
203    async fn test_setup() -> (AllowThem, String) {
204        let ath = AllowThemBuilder::new("sqlite::memory:")
205            .cookie_secure(false)
206            .build()
207            .await
208            .unwrap();
209
210        let email = Email::new("test@example.com".into()).unwrap();
211        let user = ath
212            .db()
213            .create_user(email, "password123", None, None)
214            .await
215            .unwrap();
216
217        let token = generate_token();
218        let token_hash = hash_token(&token);
219        let expires = Utc::now() + Duration::hours(24);
220        ath.db()
221            .create_session(user.id, token_hash, None, None, expires)
222            .await
223            .unwrap();
224
225        let cookie = ath.session_cookie(&token);
226        // session_cookie returns a Set-Cookie value; extract just the name=value
227        // for the Cookie request header (everything before the first ';').
228        let cookie_value = cookie.split(';').next().unwrap().to_string();
229        (ath, cookie_value)
230    }
231
232    fn test_app(ath: AllowThem) -> Router {
233        let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
234        let state = TestState { auth };
235        Router::new()
236            .route("/protected", get(protected_handler))
237            .route("/optional", get(optional_handler))
238            .route("/browser", get(browser_handler))
239            .route("/admin", get(admin_handler))
240            .with_state(state)
241    }
242
243    async fn protected_handler(AuthUser(user): AuthUser) -> Json<serde_json::Value> {
244        Json(serde_json::json!({"email": user.email}))
245    }
246
247    async fn optional_handler(OptionalAuthUser(user): OptionalAuthUser) -> Json<serde_json::Value> {
248        Json(serde_json::json!({"user": user.map(|u| u.email)}))
249    }
250
251    async fn browser_handler(BrowserAuthUser(user): BrowserAuthUser) -> Json<serde_json::Value> {
252        Json(serde_json::json!({"email": user.email}))
253    }
254
255    async fn admin_handler(BrowserAdminUser(user): BrowserAdminUser) -> Json<serde_json::Value> {
256        Json(serde_json::json!({"email": user.email}))
257    }
258
259    async fn read_body(resp: axum::http::Response<axum::body::Body>) -> serde_json::Value {
260        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
261            .await
262            .unwrap();
263        serde_json::from_slice(&bytes).unwrap()
264    }
265
266    #[tokio::test]
267    async fn no_cookie_returns_401() {
268        let (ath, _) = test_setup().await;
269        let app = test_app(ath);
270
271        let req = Request::builder()
272            .uri("/protected")
273            .body(axum::body::Body::empty())
274            .unwrap();
275        let resp = app.oneshot(req).await.unwrap();
276
277        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
278        let body = read_body(resp).await;
279        assert_eq!(body["error"], "unauthenticated");
280    }
281
282    #[tokio::test]
283    async fn garbage_cookie_returns_401() {
284        let (ath, _) = test_setup().await;
285        let app = test_app(ath);
286
287        let req = Request::builder()
288            .uri("/protected")
289            .header(COOKIE, "allowthem_session=garbage")
290            .body(axum::body::Body::empty())
291            .unwrap();
292        let resp = app.oneshot(req).await.unwrap();
293
294        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
295    }
296
297    #[tokio::test]
298    async fn valid_session_returns_user() {
299        let (ath, cookie_value) = test_setup().await;
300        let app = test_app(ath);
301
302        let req = Request::builder()
303            .uri("/protected")
304            .header(COOKIE, &cookie_value)
305            .body(axum::body::Body::empty())
306            .unwrap();
307        let resp = app.oneshot(req).await.unwrap();
308
309        assert_eq!(resp.status(), StatusCode::OK);
310        let body = read_body(resp).await;
311        assert_eq!(body["email"], "test@example.com");
312    }
313
314    #[tokio::test]
315    async fn expired_session_returns_401() {
316        let ath = AllowThemBuilder::new("sqlite::memory:")
317            .cookie_secure(false)
318            .build()
319            .await
320            .unwrap();
321
322        let email = Email::new("expired@example.com".into()).unwrap();
323        let user = ath
324            .db()
325            .create_user(email, "password123", None, None)
326            .await
327            .unwrap();
328
329        let token = generate_token();
330        let token_hash = hash_token(&token);
331        // Session already expired
332        let expires = Utc::now() - Duration::hours(1);
333        ath.db()
334            .create_session(user.id, token_hash, None, None, expires)
335            .await
336            .unwrap();
337
338        let cookie = ath.session_cookie(&token);
339        let cookie_value = cookie.split(';').next().unwrap().to_string();
340        let app = test_app(ath);
341
342        let req = Request::builder()
343            .uri("/protected")
344            .header(COOKIE, &cookie_value)
345            .body(axum::body::Body::empty())
346            .unwrap();
347        let resp = app.oneshot(req).await.unwrap();
348
349        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
350    }
351
352    #[tokio::test]
353    async fn inactive_user_returns_401() {
354        let (ath, cookie_value) = test_setup().await;
355
356        // Deactivate the user
357        let email = Email::new("test@example.com".into()).unwrap();
358        let user = ath.db().get_user_by_email(&email).await.unwrap();
359        ath.db().update_user_active(user.id, false).await.unwrap();
360
361        let app = test_app(ath);
362
363        let req = Request::builder()
364            .uri("/protected")
365            .header(COOKIE, &cookie_value)
366            .body(axum::body::Body::empty())
367            .unwrap();
368        let resp = app.oneshot(req).await.unwrap();
369
370        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
371        let body = read_body(resp).await;
372        assert_eq!(body["error"], "unauthenticated");
373    }
374
375    #[tokio::test]
376    async fn optional_no_cookie_returns_none() {
377        let (ath, _) = test_setup().await;
378        let app = test_app(ath);
379
380        let req = Request::builder()
381            .uri("/optional")
382            .body(axum::body::Body::empty())
383            .unwrap();
384        let resp = app.oneshot(req).await.unwrap();
385
386        assert_eq!(resp.status(), StatusCode::OK);
387        let body = read_body(resp).await;
388        assert!(body["user"].is_null());
389    }
390
391    #[tokio::test]
392    async fn optional_valid_session_returns_user() {
393        let (ath, cookie_value) = test_setup().await;
394        let app = test_app(ath);
395
396        let req = Request::builder()
397            .uri("/optional")
398            .header(COOKIE, &cookie_value)
399            .body(axum::body::Body::empty())
400            .unwrap();
401        let resp = app.oneshot(req).await.unwrap();
402
403        assert_eq!(resp.status(), StatusCode::OK);
404        let body = read_body(resp).await;
405        assert_eq!(body["user"], "test@example.com");
406    }
407
408    // --- BrowserAuthUser tests ---
409
410    #[tokio::test]
411    async fn browser_auth_no_cookie_redirects() {
412        let (ath, _) = test_setup().await;
413        let app = test_app(ath);
414
415        let req = Request::builder()
416            .uri("/browser")
417            .body(axum::body::Body::empty())
418            .unwrap();
419        let resp = app.oneshot(req).await.unwrap();
420
421        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
422        assert_eq!(
423            resp.headers().get("location").unwrap(),
424            "/login?next=/browser"
425        );
426    }
427
428    #[tokio::test]
429    async fn browser_auth_valid_session_returns_user() {
430        let (ath, cookie_value) = test_setup().await;
431        let app = test_app(ath);
432
433        let req = Request::builder()
434            .uri("/browser")
435            .header(COOKIE, &cookie_value)
436            .body(axum::body::Body::empty())
437            .unwrap();
438        let resp = app.oneshot(req).await.unwrap();
439
440        assert_eq!(resp.status(), StatusCode::OK);
441        let body = read_body(resp).await;
442        assert_eq!(body["email"], "test@example.com");
443    }
444
445    #[tokio::test]
446    async fn browser_auth_expired_session_redirects() {
447        let ath = AllowThemBuilder::new("sqlite::memory:")
448            .cookie_secure(false)
449            .build()
450            .await
451            .unwrap();
452
453        let email = Email::new("expired@example.com".into()).unwrap();
454        let user = ath
455            .db()
456            .create_user(email, "password123", None, None)
457            .await
458            .unwrap();
459
460        let token = generate_token();
461        let token_hash = hash_token(&token);
462        let expires = Utc::now() - Duration::hours(1);
463        ath.db()
464            .create_session(user.id, token_hash, None, None, expires)
465            .await
466            .unwrap();
467
468        let cookie = ath.session_cookie(&token);
469        let cookie_value = cookie.split(';').next().unwrap().to_string();
470        let app = test_app(ath);
471
472        let req = Request::builder()
473            .uri("/browser")
474            .header(COOKIE, &cookie_value)
475            .body(axum::body::Body::empty())
476            .unwrap();
477        let resp = app.oneshot(req).await.unwrap();
478
479        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
480        assert_eq!(
481            resp.headers().get("location").unwrap(),
482            "/login?next=/browser"
483        );
484    }
485
486    // --- BrowserAdminUser tests ---
487
488    #[tokio::test]
489    async fn browser_admin_user_unauthenticated_redirects() {
490        let (ath, _) = test_setup().await;
491        let app = test_app(ath);
492
493        let req = Request::builder()
494            .uri("/admin")
495            .body(axum::body::Body::empty())
496            .unwrap();
497        let resp = app.oneshot(req).await.unwrap();
498
499        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
500        assert_eq!(
501            resp.headers().get("location").unwrap(),
502            "/login?next=/admin"
503        );
504    }
505
506    #[tokio::test]
507    async fn browser_admin_user_non_admin_gets_403() {
508        let (ath, cookie_value) = test_setup().await;
509        let app = test_app(ath);
510
511        let req = Request::builder()
512            .uri("/admin")
513            .header(COOKIE, &cookie_value)
514            .body(axum::body::Body::empty())
515            .unwrap();
516        let resp = app.oneshot(req).await.unwrap();
517
518        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
519    }
520
521    #[tokio::test]
522    async fn browser_admin_user_admin_succeeds() {
523        let (ath, cookie_value) = test_setup().await;
524
525        // Create admin role and assign to the test user
526        let role_name = RoleName::new("admin");
527        let role = ath.db().create_role(&role_name, None).await.unwrap();
528        let email = Email::new("test@example.com".into()).unwrap();
529        let user = ath.db().get_user_by_email(&email).await.unwrap();
530        ath.db().assign_role(&user.id, &role.id).await.unwrap();
531
532        let app = test_app(ath);
533
534        let req = Request::builder()
535            .uri("/admin")
536            .header(COOKIE, &cookie_value)
537            .body(axum::body::Body::empty())
538            .unwrap();
539        let resp = app.oneshot(req).await.unwrap();
540
541        assert_eq!(resp.status(), StatusCode::OK);
542        let body = read_body(resp).await;
543        assert_eq!(body["email"], "test@example.com");
544    }
545}