1use std::sync::Arc;
2
3use axum::extract::{FromRef, FromRequestParts};
4use axum::http::header::COOKIE;
5use axum::http::request::Parts;
6use axum::response::{IntoResponse, Response};
7
8use allowthem_core::{AuthClient, RoleName, User, parse_session_cookie};
9
10use crate::error::{AuthExtractError, BrowserAdminForbidden, BrowserAuthRedirect};
11
12pub struct AuthUser(pub User);
19
20impl<S> FromRequestParts<S> for AuthUser
21where
22 Arc<dyn AuthClient>: FromRef<S>,
23 S: Send + Sync,
24{
25 type Rejection = AuthExtractError;
26
27 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
28 let client = <Arc<dyn AuthClient>>::from_ref(state);
29
30 let cookie_header = parts
31 .headers
32 .get(COOKIE)
33 .and_then(|v| v.to_str().ok())
34 .ok_or(AuthExtractError::Unauthenticated)?;
35
36 let token = parse_session_cookie(cookie_header, client.session_cookie_name())
37 .ok_or(AuthExtractError::Unauthenticated)?;
38
39 let user = client
40 .validate_session(&token)
41 .await
42 .map_err(AuthExtractError::Internal)?
43 .ok_or(AuthExtractError::Unauthenticated)?;
44
45 Ok(AuthUser(user))
46 }
47}
48
49pub struct OptionalAuthUser(pub Option<User>);
57
58impl<S> FromRequestParts<S> for OptionalAuthUser
59where
60 Arc<dyn AuthClient>: FromRef<S>,
61 S: Send + Sync,
62{
63 type Rejection = std::convert::Infallible;
64
65 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
66 match AuthUser::from_request_parts(parts, state).await {
67 Ok(AuthUser(user)) => Ok(OptionalAuthUser(Some(user))),
68 Err(AuthExtractError::Internal(err)) => {
69 tracing::error!("auth extraction error: {err}");
70 Ok(OptionalAuthUser(None))
71 }
72 Err(_) => Ok(OptionalAuthUser(None)),
73 }
74 }
75}
76
77pub struct BrowserAuthUser(pub User);
84
85impl<S> FromRequestParts<S> for BrowserAuthUser
86where
87 Arc<dyn AuthClient>: FromRef<S>,
88 S: Send + Sync,
89{
90 type Rejection = BrowserAuthRedirect;
91
92 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
93 let redirect = BrowserAuthRedirect::new(parts.uri.path());
94 let client = <Arc<dyn AuthClient>>::from_ref(state);
95
96 let cookie_header = parts
97 .headers
98 .get(COOKIE)
99 .and_then(|v| v.to_str().ok())
100 .ok_or(redirect)?;
101
102 let redirect = BrowserAuthRedirect::new(parts.uri.path());
103 let token =
104 parse_session_cookie(cookie_header, client.session_cookie_name()).ok_or(redirect)?;
105
106 let redirect = BrowserAuthRedirect::new(parts.uri.path());
107 let user = client
108 .validate_session(&token)
109 .await
110 .map_err(|err| {
111 tracing::error!("auth extraction error: {err}");
112 BrowserAuthRedirect::new(parts.uri.path())
113 })?
114 .ok_or(redirect)?;
115
116 Ok(BrowserAuthUser(user))
117 }
118}
119
120pub struct BrowserAdminUser(pub User);
126
127impl<S> FromRequestParts<S> for BrowserAdminUser
128where
129 Arc<dyn AuthClient>: FromRef<S>,
130 S: Send + Sync,
131{
132 type Rejection = Response;
133
134 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
135 let client = <Arc<dyn AuthClient>>::from_ref(state);
136
137 let cookie_header = parts
139 .headers
140 .get(COOKIE)
141 .and_then(|v| v.to_str().ok())
142 .ok_or_else(|| BrowserAuthRedirect::new(parts.uri.path()).into_response())?;
143
144 let token = parse_session_cookie(cookie_header, client.session_cookie_name())
145 .ok_or_else(|| BrowserAuthRedirect::new(parts.uri.path()).into_response())?;
146
147 let user = client
148 .validate_session(&token)
149 .await
150 .map_err(|err| {
151 tracing::error!("auth extraction error: {err}");
152 BrowserAuthRedirect::new(parts.uri.path()).into_response()
153 })?
154 .ok_or_else(|| BrowserAuthRedirect::new(parts.uri.path()).into_response())?;
155
156 let admin_role = RoleName::new("admin");
158 let is_admin = client
159 .check_role(&user.id, &admin_role)
160 .await
161 .map_err(|err| {
162 tracing::error!("role check error: {err}");
163 BrowserAdminForbidden.into_response()
164 })?;
165
166 if !is_admin {
167 return Err(BrowserAdminForbidden.into_response());
168 }
169
170 Ok(BrowserAdminUser(user))
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::sync::Arc;
177
178 use super::*;
179 use allowthem_core::{
180 AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, RoleName,
181 generate_token, hash_token,
182 };
183 use axum::extract::FromRef;
184 use axum::http::{Request, StatusCode};
185 use axum::routing::get;
186 use axum::{Json, Router};
187 use chrono::{Duration, Utc};
188 use tower::ServiceExt;
189
190 #[derive(Clone)]
191 struct TestState {
192 auth: Arc<dyn AuthClient>,
193 }
194
195 impl FromRef<TestState> for Arc<dyn AuthClient> {
196 fn from_ref(s: &TestState) -> Self {
197 Arc::clone(&s.auth)
198 }
199 }
200
201 async fn test_setup() -> (AllowThem, String) {
204 let ath = AllowThemBuilder::new("sqlite::memory:")
205 .cookie_secure(false)
206 .build()
207 .await
208 .unwrap();
209
210 let email = Email::new("test@example.com".into()).unwrap();
211 let user = ath
212 .db()
213 .create_user(email, "password123", None, None)
214 .await
215 .unwrap();
216
217 let token = generate_token();
218 let token_hash = hash_token(&token);
219 let expires = Utc::now() + Duration::hours(24);
220 ath.db()
221 .create_session(user.id, token_hash, None, None, expires)
222 .await
223 .unwrap();
224
225 let cookie = ath.session_cookie(&token);
226 let cookie_value = cookie.split(';').next().unwrap().to_string();
229 (ath, cookie_value)
230 }
231
232 fn test_app(ath: AllowThem) -> Router {
233 let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
234 let state = TestState { auth };
235 Router::new()
236 .route("/protected", get(protected_handler))
237 .route("/optional", get(optional_handler))
238 .route("/browser", get(browser_handler))
239 .route("/admin", get(admin_handler))
240 .with_state(state)
241 }
242
243 async fn protected_handler(AuthUser(user): AuthUser) -> Json<serde_json::Value> {
244 Json(serde_json::json!({"email": user.email}))
245 }
246
247 async fn optional_handler(OptionalAuthUser(user): OptionalAuthUser) -> Json<serde_json::Value> {
248 Json(serde_json::json!({"user": user.map(|u| u.email)}))
249 }
250
251 async fn browser_handler(BrowserAuthUser(user): BrowserAuthUser) -> Json<serde_json::Value> {
252 Json(serde_json::json!({"email": user.email}))
253 }
254
255 async fn admin_handler(BrowserAdminUser(user): BrowserAdminUser) -> Json<serde_json::Value> {
256 Json(serde_json::json!({"email": user.email}))
257 }
258
259 async fn read_body(resp: axum::http::Response<axum::body::Body>) -> serde_json::Value {
260 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
261 .await
262 .unwrap();
263 serde_json::from_slice(&bytes).unwrap()
264 }
265
266 #[tokio::test]
267 async fn no_cookie_returns_401() {
268 let (ath, _) = test_setup().await;
269 let app = test_app(ath);
270
271 let req = Request::builder()
272 .uri("/protected")
273 .body(axum::body::Body::empty())
274 .unwrap();
275 let resp = app.oneshot(req).await.unwrap();
276
277 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
278 let body = read_body(resp).await;
279 assert_eq!(body["error"], "unauthenticated");
280 }
281
282 #[tokio::test]
283 async fn garbage_cookie_returns_401() {
284 let (ath, _) = test_setup().await;
285 let app = test_app(ath);
286
287 let req = Request::builder()
288 .uri("/protected")
289 .header(COOKIE, "allowthem_session=garbage")
290 .body(axum::body::Body::empty())
291 .unwrap();
292 let resp = app.oneshot(req).await.unwrap();
293
294 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
295 }
296
297 #[tokio::test]
298 async fn valid_session_returns_user() {
299 let (ath, cookie_value) = test_setup().await;
300 let app = test_app(ath);
301
302 let req = Request::builder()
303 .uri("/protected")
304 .header(COOKIE, &cookie_value)
305 .body(axum::body::Body::empty())
306 .unwrap();
307 let resp = app.oneshot(req).await.unwrap();
308
309 assert_eq!(resp.status(), StatusCode::OK);
310 let body = read_body(resp).await;
311 assert_eq!(body["email"], "test@example.com");
312 }
313
314 #[tokio::test]
315 async fn expired_session_returns_401() {
316 let ath = AllowThemBuilder::new("sqlite::memory:")
317 .cookie_secure(false)
318 .build()
319 .await
320 .unwrap();
321
322 let email = Email::new("expired@example.com".into()).unwrap();
323 let user = ath
324 .db()
325 .create_user(email, "password123", None, None)
326 .await
327 .unwrap();
328
329 let token = generate_token();
330 let token_hash = hash_token(&token);
331 let expires = Utc::now() - Duration::hours(1);
333 ath.db()
334 .create_session(user.id, token_hash, None, None, expires)
335 .await
336 .unwrap();
337
338 let cookie = ath.session_cookie(&token);
339 let cookie_value = cookie.split(';').next().unwrap().to_string();
340 let app = test_app(ath);
341
342 let req = Request::builder()
343 .uri("/protected")
344 .header(COOKIE, &cookie_value)
345 .body(axum::body::Body::empty())
346 .unwrap();
347 let resp = app.oneshot(req).await.unwrap();
348
349 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
350 }
351
352 #[tokio::test]
353 async fn inactive_user_returns_401() {
354 let (ath, cookie_value) = test_setup().await;
355
356 let email = Email::new("test@example.com".into()).unwrap();
358 let user = ath.db().get_user_by_email(&email).await.unwrap();
359 ath.db().update_user_active(user.id, false).await.unwrap();
360
361 let app = test_app(ath);
362
363 let req = Request::builder()
364 .uri("/protected")
365 .header(COOKIE, &cookie_value)
366 .body(axum::body::Body::empty())
367 .unwrap();
368 let resp = app.oneshot(req).await.unwrap();
369
370 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
371 let body = read_body(resp).await;
372 assert_eq!(body["error"], "unauthenticated");
373 }
374
375 #[tokio::test]
376 async fn optional_no_cookie_returns_none() {
377 let (ath, _) = test_setup().await;
378 let app = test_app(ath);
379
380 let req = Request::builder()
381 .uri("/optional")
382 .body(axum::body::Body::empty())
383 .unwrap();
384 let resp = app.oneshot(req).await.unwrap();
385
386 assert_eq!(resp.status(), StatusCode::OK);
387 let body = read_body(resp).await;
388 assert!(body["user"].is_null());
389 }
390
391 #[tokio::test]
392 async fn optional_valid_session_returns_user() {
393 let (ath, cookie_value) = test_setup().await;
394 let app = test_app(ath);
395
396 let req = Request::builder()
397 .uri("/optional")
398 .header(COOKIE, &cookie_value)
399 .body(axum::body::Body::empty())
400 .unwrap();
401 let resp = app.oneshot(req).await.unwrap();
402
403 assert_eq!(resp.status(), StatusCode::OK);
404 let body = read_body(resp).await;
405 assert_eq!(body["user"], "test@example.com");
406 }
407
408 #[tokio::test]
411 async fn browser_auth_no_cookie_redirects() {
412 let (ath, _) = test_setup().await;
413 let app = test_app(ath);
414
415 let req = Request::builder()
416 .uri("/browser")
417 .body(axum::body::Body::empty())
418 .unwrap();
419 let resp = app.oneshot(req).await.unwrap();
420
421 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
422 assert_eq!(
423 resp.headers().get("location").unwrap(),
424 "/login?next=/browser"
425 );
426 }
427
428 #[tokio::test]
429 async fn browser_auth_valid_session_returns_user() {
430 let (ath, cookie_value) = test_setup().await;
431 let app = test_app(ath);
432
433 let req = Request::builder()
434 .uri("/browser")
435 .header(COOKIE, &cookie_value)
436 .body(axum::body::Body::empty())
437 .unwrap();
438 let resp = app.oneshot(req).await.unwrap();
439
440 assert_eq!(resp.status(), StatusCode::OK);
441 let body = read_body(resp).await;
442 assert_eq!(body["email"], "test@example.com");
443 }
444
445 #[tokio::test]
446 async fn browser_auth_expired_session_redirects() {
447 let ath = AllowThemBuilder::new("sqlite::memory:")
448 .cookie_secure(false)
449 .build()
450 .await
451 .unwrap();
452
453 let email = Email::new("expired@example.com".into()).unwrap();
454 let user = ath
455 .db()
456 .create_user(email, "password123", None, None)
457 .await
458 .unwrap();
459
460 let token = generate_token();
461 let token_hash = hash_token(&token);
462 let expires = Utc::now() - Duration::hours(1);
463 ath.db()
464 .create_session(user.id, token_hash, None, None, expires)
465 .await
466 .unwrap();
467
468 let cookie = ath.session_cookie(&token);
469 let cookie_value = cookie.split(';').next().unwrap().to_string();
470 let app = test_app(ath);
471
472 let req = Request::builder()
473 .uri("/browser")
474 .header(COOKIE, &cookie_value)
475 .body(axum::body::Body::empty())
476 .unwrap();
477 let resp = app.oneshot(req).await.unwrap();
478
479 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
480 assert_eq!(
481 resp.headers().get("location").unwrap(),
482 "/login?next=/browser"
483 );
484 }
485
486 #[tokio::test]
489 async fn browser_admin_user_unauthenticated_redirects() {
490 let (ath, _) = test_setup().await;
491 let app = test_app(ath);
492
493 let req = Request::builder()
494 .uri("/admin")
495 .body(axum::body::Body::empty())
496 .unwrap();
497 let resp = app.oneshot(req).await.unwrap();
498
499 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
500 assert_eq!(
501 resp.headers().get("location").unwrap(),
502 "/login?next=/admin"
503 );
504 }
505
506 #[tokio::test]
507 async fn browser_admin_user_non_admin_gets_403() {
508 let (ath, cookie_value) = test_setup().await;
509 let app = test_app(ath);
510
511 let req = Request::builder()
512 .uri("/admin")
513 .header(COOKIE, &cookie_value)
514 .body(axum::body::Body::empty())
515 .unwrap();
516 let resp = app.oneshot(req).await.unwrap();
517
518 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
519 }
520
521 #[tokio::test]
522 async fn browser_admin_user_admin_succeeds() {
523 let (ath, cookie_value) = test_setup().await;
524
525 let role_name = RoleName::new("admin");
527 let role = ath.db().create_role(&role_name, None).await.unwrap();
528 let email = Email::new("test@example.com".into()).unwrap();
529 let user = ath.db().get_user_by_email(&email).await.unwrap();
530 ath.db().assign_role(&user.id, &role.id).await.unwrap();
531
532 let app = test_app(ath);
533
534 let req = Request::builder()
535 .uri("/admin")
536 .header(COOKIE, &cookie_value)
537 .body(axum::body::Body::empty())
538 .unwrap();
539 let resp = app.oneshot(req).await.unwrap();
540
541 assert_eq!(resp.status(), StatusCode::OK);
542 let body = read_body(resp).await;
543 assert_eq!(body["email"], "test@example.com");
544 }
545}