allowthem_server/
extractors.rs1use std::sync::Arc;
2
3use axum::extract::{FromRef, FromRequestParts};
4use axum::http::header::COOKIE;
5use axum::http::request::Parts;
6
7use allowthem_core::{AuthClient, User, parse_session_cookie};
8
9use crate::error::AuthExtractError;
10
11pub struct AuthUser(pub User);
18
19impl<S> FromRequestParts<S> for AuthUser
20where
21 Arc<dyn AuthClient>: FromRef<S>,
22 S: Send + Sync,
23{
24 type Rejection = AuthExtractError;
25
26 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
27 let client = <Arc<dyn AuthClient>>::from_ref(state);
28
29 let cookie_header = parts
30 .headers
31 .get(COOKIE)
32 .and_then(|v| v.to_str().ok())
33 .ok_or(AuthExtractError::Unauthenticated)?;
34
35 let token = parse_session_cookie(cookie_header, client.session_cookie_name())
36 .ok_or(AuthExtractError::Unauthenticated)?;
37
38 let user = client
39 .validate_session(&token)
40 .await
41 .map_err(AuthExtractError::Internal)?
42 .ok_or(AuthExtractError::Unauthenticated)?;
43
44 Ok(AuthUser(user))
45 }
46}
47
48pub struct OptionalAuthUser(pub Option<User>);
56
57impl<S> FromRequestParts<S> for OptionalAuthUser
58where
59 Arc<dyn AuthClient>: FromRef<S>,
60 S: Send + Sync,
61{
62 type Rejection = std::convert::Infallible;
63
64 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
65 match AuthUser::from_request_parts(parts, state).await {
66 Ok(AuthUser(user)) => Ok(OptionalAuthUser(Some(user))),
67 Err(AuthExtractError::Internal(err)) => {
68 tracing::error!("auth extraction error: {err}");
69 Ok(OptionalAuthUser(None))
70 }
71 Err(_) => Ok(OptionalAuthUser(None)),
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use std::sync::Arc;
79
80 use super::*;
81 use allowthem_core::{
82 AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, generate_token,
83 hash_token,
84 };
85 use axum::extract::FromRef;
86 use axum::http::{Request, StatusCode};
87 use axum::routing::get;
88 use axum::{Json, Router};
89 use chrono::{Duration, Utc};
90 use tower::ServiceExt;
91
92 #[derive(Clone)]
93 struct TestState {
94 auth: Arc<dyn AuthClient>,
95 }
96
97 impl FromRef<TestState> for Arc<dyn AuthClient> {
98 fn from_ref(s: &TestState) -> Self {
99 Arc::clone(&s.auth)
100 }
101 }
102
103 async fn test_setup() -> (AllowThem, String) {
106 let ath = AllowThemBuilder::new("sqlite::memory:")
107 .cookie_secure(false)
108 .build()
109 .await
110 .unwrap();
111
112 let email = Email::new("test@example.com".into()).unwrap();
113 let user = ath
114 .db()
115 .create_user(email, "password123", None)
116 .await
117 .unwrap();
118
119 let token = generate_token();
120 let token_hash = hash_token(&token);
121 let expires = Utc::now() + Duration::hours(24);
122 ath.db()
123 .create_session(user.id, token_hash, None, None, expires)
124 .await
125 .unwrap();
126
127 let cookie = ath.session_cookie(&token);
128 let cookie_value = cookie.split(';').next().unwrap().to_string();
131 (ath, cookie_value)
132 }
133
134 fn test_app(ath: AllowThem) -> Router {
135 let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
136 let state = TestState { auth };
137 Router::new()
138 .route("/protected", get(protected_handler))
139 .route("/optional", get(optional_handler))
140 .with_state(state)
141 }
142
143 async fn protected_handler(AuthUser(user): AuthUser) -> Json<serde_json::Value> {
144 Json(serde_json::json!({"email": user.email}))
145 }
146
147 async fn optional_handler(OptionalAuthUser(user): OptionalAuthUser) -> Json<serde_json::Value> {
148 Json(serde_json::json!({"user": user.map(|u| u.email)}))
149 }
150
151 async fn read_body(resp: axum::http::Response<axum::body::Body>) -> serde_json::Value {
152 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
153 .await
154 .unwrap();
155 serde_json::from_slice(&bytes).unwrap()
156 }
157
158 #[tokio::test]
159 async fn no_cookie_returns_401() {
160 let (ath, _) = test_setup().await;
161 let app = test_app(ath);
162
163 let req = Request::builder()
164 .uri("/protected")
165 .body(axum::body::Body::empty())
166 .unwrap();
167 let resp = app.oneshot(req).await.unwrap();
168
169 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
170 let body = read_body(resp).await;
171 assert_eq!(body["error"], "unauthenticated");
172 }
173
174 #[tokio::test]
175 async fn garbage_cookie_returns_401() {
176 let (ath, _) = test_setup().await;
177 let app = test_app(ath);
178
179 let req = Request::builder()
180 .uri("/protected")
181 .header(COOKIE, "allowthem_session=garbage")
182 .body(axum::body::Body::empty())
183 .unwrap();
184 let resp = app.oneshot(req).await.unwrap();
185
186 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
187 }
188
189 #[tokio::test]
190 async fn valid_session_returns_user() {
191 let (ath, cookie_value) = test_setup().await;
192 let app = test_app(ath);
193
194 let req = Request::builder()
195 .uri("/protected")
196 .header(COOKIE, &cookie_value)
197 .body(axum::body::Body::empty())
198 .unwrap();
199 let resp = app.oneshot(req).await.unwrap();
200
201 assert_eq!(resp.status(), StatusCode::OK);
202 let body = read_body(resp).await;
203 assert_eq!(body["email"], "test@example.com");
204 }
205
206 #[tokio::test]
207 async fn expired_session_returns_401() {
208 let ath = AllowThemBuilder::new("sqlite::memory:")
209 .cookie_secure(false)
210 .build()
211 .await
212 .unwrap();
213
214 let email = Email::new("expired@example.com".into()).unwrap();
215 let user = ath
216 .db()
217 .create_user(email, "password123", None)
218 .await
219 .unwrap();
220
221 let token = generate_token();
222 let token_hash = hash_token(&token);
223 let expires = Utc::now() - Duration::hours(1);
225 ath.db()
226 .create_session(user.id, token_hash, None, None, expires)
227 .await
228 .unwrap();
229
230 let cookie = ath.session_cookie(&token);
231 let cookie_value = cookie.split(';').next().unwrap().to_string();
232 let app = test_app(ath);
233
234 let req = Request::builder()
235 .uri("/protected")
236 .header(COOKIE, &cookie_value)
237 .body(axum::body::Body::empty())
238 .unwrap();
239 let resp = app.oneshot(req).await.unwrap();
240
241 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
242 }
243
244 #[tokio::test]
245 async fn inactive_user_returns_401() {
246 let (ath, cookie_value) = test_setup().await;
247
248 let email = Email::new("test@example.com".into()).unwrap();
250 let user = ath.db().get_user_by_email(&email).await.unwrap();
251 ath.db().update_user_active(user.id, false).await.unwrap();
252
253 let app = test_app(ath);
254
255 let req = Request::builder()
256 .uri("/protected")
257 .header(COOKIE, &cookie_value)
258 .body(axum::body::Body::empty())
259 .unwrap();
260 let resp = app.oneshot(req).await.unwrap();
261
262 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
263 let body = read_body(resp).await;
264 assert_eq!(body["error"], "unauthenticated");
265 }
266
267 #[tokio::test]
268 async fn optional_no_cookie_returns_none() {
269 let (ath, _) = test_setup().await;
270 let app = test_app(ath);
271
272 let req = Request::builder()
273 .uri("/optional")
274 .body(axum::body::Body::empty())
275 .unwrap();
276 let resp = app.oneshot(req).await.unwrap();
277
278 assert_eq!(resp.status(), StatusCode::OK);
279 let body = read_body(resp).await;
280 assert!(body["user"].is_null());
281 }
282
283 #[tokio::test]
284 async fn optional_valid_session_returns_user() {
285 let (ath, cookie_value) = test_setup().await;
286 let app = test_app(ath);
287
288 let req = Request::builder()
289 .uri("/optional")
290 .header(COOKIE, &cookie_value)
291 .body(axum::body::Body::empty())
292 .unwrap();
293 let resp = app.oneshot(req).await.unwrap();
294
295 assert_eq!(resp.status(), StatusCode::OK);
296 let body = read_body(resp).await;
297 assert_eq!(body["user"], "test@example.com");
298 }
299}