Skip to main content

allowthem_server/
extractors.rs

1use std::sync::Arc;
2
3use axum::extract::{FromRef, FromRequestParts};
4use axum::http::header::COOKIE;
5use axum::http::request::Parts;
6
7use allowthem_core::{AuthClient, User, parse_session_cookie};
8
9use crate::error::AuthExtractError;
10
11/// Axum extractor that provides the authenticated user.
12///
13/// Reads the session cookie, validates the session (with sliding-window
14/// renewal), and fetches the user. Rejects with 401 if not authenticated.
15///
16/// Usage: `AuthUser(user): AuthUser` in handler arguments.
17pub struct AuthUser(pub User);
18
19impl<S> FromRequestParts<S> for AuthUser
20where
21    Arc<dyn AuthClient>: FromRef<S>,
22    S: Send + Sync,
23{
24    type Rejection = AuthExtractError;
25
26    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
27        let client = <Arc<dyn AuthClient>>::from_ref(state);
28
29        let cookie_header = parts
30            .headers
31            .get(COOKIE)
32            .and_then(|v| v.to_str().ok())
33            .ok_or(AuthExtractError::Unauthenticated)?;
34
35        let token = parse_session_cookie(cookie_header, client.session_cookie_name())
36            .ok_or(AuthExtractError::Unauthenticated)?;
37
38        let user = client
39            .validate_session(&token)
40            .await
41            .map_err(AuthExtractError::Internal)?
42            .ok_or(AuthExtractError::Unauthenticated)?;
43
44        Ok(AuthUser(user))
45    }
46}
47
48/// Axum extractor that optionally provides the authenticated user.
49///
50/// Same flow as [`AuthUser`] but wraps `Option<User>` and never rejects.
51/// Returns `None` when not authenticated. Returns `Some(user)` when valid.
52/// Internal errors (database failures) are logged and treated as `None`.
53///
54/// Usage: `OptionalAuthUser(user): OptionalAuthUser` in handler arguments.
55pub struct OptionalAuthUser(pub Option<User>);
56
57impl<S> FromRequestParts<S> for OptionalAuthUser
58where
59    Arc<dyn AuthClient>: FromRef<S>,
60    S: Send + Sync,
61{
62    type Rejection = std::convert::Infallible;
63
64    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
65        match AuthUser::from_request_parts(parts, state).await {
66            Ok(AuthUser(user)) => Ok(OptionalAuthUser(Some(user))),
67            Err(AuthExtractError::Internal(err)) => {
68                tracing::error!("auth extraction error: {err}");
69                Ok(OptionalAuthUser(None))
70            }
71            Err(_) => Ok(OptionalAuthUser(None)),
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use std::sync::Arc;
79
80    use super::*;
81    use allowthem_core::{
82        AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, generate_token,
83        hash_token,
84    };
85    use axum::extract::FromRef;
86    use axum::http::{Request, StatusCode};
87    use axum::routing::get;
88    use axum::{Json, Router};
89    use chrono::{Duration, Utc};
90    use tower::ServiceExt;
91
92    #[derive(Clone)]
93    struct TestState {
94        auth: Arc<dyn AuthClient>,
95    }
96
97    impl FromRef<TestState> for Arc<dyn AuthClient> {
98        fn from_ref(s: &TestState) -> Self {
99            Arc::clone(&s.auth)
100        }
101    }
102
103    /// Build an AllowThem, create a test user with an active session,
104    /// and return (AllowThem, cookie_header_value).
105    async fn test_setup() -> (AllowThem, String) {
106        let ath = AllowThemBuilder::new("sqlite::memory:")
107            .cookie_secure(false)
108            .build()
109            .await
110            .unwrap();
111
112        let email = Email::new("test@example.com".into()).unwrap();
113        let user = ath
114            .db()
115            .create_user(email, "password123", None)
116            .await
117            .unwrap();
118
119        let token = generate_token();
120        let token_hash = hash_token(&token);
121        let expires = Utc::now() + Duration::hours(24);
122        ath.db()
123            .create_session(user.id, token_hash, None, None, expires)
124            .await
125            .unwrap();
126
127        let cookie = ath.session_cookie(&token);
128        // session_cookie returns a Set-Cookie value; extract just the name=value
129        // for the Cookie request header (everything before the first ';').
130        let cookie_value = cookie.split(';').next().unwrap().to_string();
131        (ath, cookie_value)
132    }
133
134    fn test_app(ath: AllowThem) -> Router {
135        let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
136        let state = TestState { auth };
137        Router::new()
138            .route("/protected", get(protected_handler))
139            .route("/optional", get(optional_handler))
140            .with_state(state)
141    }
142
143    async fn protected_handler(AuthUser(user): AuthUser) -> Json<serde_json::Value> {
144        Json(serde_json::json!({"email": user.email}))
145    }
146
147    async fn optional_handler(OptionalAuthUser(user): OptionalAuthUser) -> Json<serde_json::Value> {
148        Json(serde_json::json!({"user": user.map(|u| u.email)}))
149    }
150
151    async fn read_body(resp: axum::http::Response<axum::body::Body>) -> serde_json::Value {
152        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
153            .await
154            .unwrap();
155        serde_json::from_slice(&bytes).unwrap()
156    }
157
158    #[tokio::test]
159    async fn no_cookie_returns_401() {
160        let (ath, _) = test_setup().await;
161        let app = test_app(ath);
162
163        let req = Request::builder()
164            .uri("/protected")
165            .body(axum::body::Body::empty())
166            .unwrap();
167        let resp = app.oneshot(req).await.unwrap();
168
169        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
170        let body = read_body(resp).await;
171        assert_eq!(body["error"], "unauthenticated");
172    }
173
174    #[tokio::test]
175    async fn garbage_cookie_returns_401() {
176        let (ath, _) = test_setup().await;
177        let app = test_app(ath);
178
179        let req = Request::builder()
180            .uri("/protected")
181            .header(COOKIE, "allowthem_session=garbage")
182            .body(axum::body::Body::empty())
183            .unwrap();
184        let resp = app.oneshot(req).await.unwrap();
185
186        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
187    }
188
189    #[tokio::test]
190    async fn valid_session_returns_user() {
191        let (ath, cookie_value) = test_setup().await;
192        let app = test_app(ath);
193
194        let req = Request::builder()
195            .uri("/protected")
196            .header(COOKIE, &cookie_value)
197            .body(axum::body::Body::empty())
198            .unwrap();
199        let resp = app.oneshot(req).await.unwrap();
200
201        assert_eq!(resp.status(), StatusCode::OK);
202        let body = read_body(resp).await;
203        assert_eq!(body["email"], "test@example.com");
204    }
205
206    #[tokio::test]
207    async fn expired_session_returns_401() {
208        let ath = AllowThemBuilder::new("sqlite::memory:")
209            .cookie_secure(false)
210            .build()
211            .await
212            .unwrap();
213
214        let email = Email::new("expired@example.com".into()).unwrap();
215        let user = ath
216            .db()
217            .create_user(email, "password123", None)
218            .await
219            .unwrap();
220
221        let token = generate_token();
222        let token_hash = hash_token(&token);
223        // Session already expired
224        let expires = Utc::now() - Duration::hours(1);
225        ath.db()
226            .create_session(user.id, token_hash, None, None, expires)
227            .await
228            .unwrap();
229
230        let cookie = ath.session_cookie(&token);
231        let cookie_value = cookie.split(';').next().unwrap().to_string();
232        let app = test_app(ath);
233
234        let req = Request::builder()
235            .uri("/protected")
236            .header(COOKIE, &cookie_value)
237            .body(axum::body::Body::empty())
238            .unwrap();
239        let resp = app.oneshot(req).await.unwrap();
240
241        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
242    }
243
244    #[tokio::test]
245    async fn inactive_user_returns_401() {
246        let (ath, cookie_value) = test_setup().await;
247
248        // Deactivate the user
249        let email = Email::new("test@example.com".into()).unwrap();
250        let user = ath.db().get_user_by_email(&email).await.unwrap();
251        ath.db().update_user_active(user.id, false).await.unwrap();
252
253        let app = test_app(ath);
254
255        let req = Request::builder()
256            .uri("/protected")
257            .header(COOKIE, &cookie_value)
258            .body(axum::body::Body::empty())
259            .unwrap();
260        let resp = app.oneshot(req).await.unwrap();
261
262        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
263        let body = read_body(resp).await;
264        assert_eq!(body["error"], "unauthenticated");
265    }
266
267    #[tokio::test]
268    async fn optional_no_cookie_returns_none() {
269        let (ath, _) = test_setup().await;
270        let app = test_app(ath);
271
272        let req = Request::builder()
273            .uri("/optional")
274            .body(axum::body::Body::empty())
275            .unwrap();
276        let resp = app.oneshot(req).await.unwrap();
277
278        assert_eq!(resp.status(), StatusCode::OK);
279        let body = read_body(resp).await;
280        assert!(body["user"].is_null());
281    }
282
283    #[tokio::test]
284    async fn optional_valid_session_returns_user() {
285        let (ath, cookie_value) = test_setup().await;
286        let app = test_app(ath);
287
288        let req = Request::builder()
289            .uri("/optional")
290            .header(COOKIE, &cookie_value)
291            .body(axum::body::Body::empty())
292            .unwrap();
293        let resp = app.oneshot(req).await.unwrap();
294
295        assert_eq!(resp.status(), StatusCode::OK);
296        let body = read_body(resp).await;
297        assert_eq!(body["user"], "test@example.com");
298    }
299}