1use axum::{
2 body::Body,
3 extract::FromRequestParts,
4 http::{Request, StatusCode, header, request::Parts},
5 middleware::Next,
6 response::Response,
7};
8use subtle::ConstantTimeEq;
9use uuid::Uuid;
10
11use allowthem_core::{AllowThem, derive_csrf_token, verify_csrf_token};
12
13const PRE_AUTH_CSRF_COOKIE: &str = "csrf_pre";
14
15#[derive(Clone)]
21pub struct CsrfToken(pub String);
22
23impl CsrfToken {
24 pub fn as_str(&self) -> &str {
25 &self.0
26 }
27}
28
29impl<S: Send + Sync> FromRequestParts<S> for CsrfToken {
30 type Rejection = StatusCode;
31
32 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
33 parts
34 .extensions
35 .get::<CsrfToken>()
36 .cloned()
37 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)
38 }
39}
40
41pub async fn csrf_middleware(
55 mut request: Request<Body>,
56 next: Next,
57) -> Result<Response, StatusCode> {
58 let ath = request
59 .extensions()
60 .get::<AllowThem>()
61 .cloned()
62 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
63
64 let csrf_key = ath
65 .csrf_key()
66 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
67
68 let method = request.method().clone();
69 let is_safe = matches!(
70 method,
71 axum::http::Method::GET | axum::http::Method::HEAD | axum::http::Method::OPTIONS
72 );
73
74 let session_token = ath.parse_session_cookie(
75 request
76 .headers()
77 .get(header::COOKIE)
78 .and_then(|v| v.to_str().ok())
79 .unwrap_or(""),
80 );
81
82 if is_safe {
83 let csrf_token = match &session_token {
84 Some(tok) => derive_csrf_token(tok, csrf_key),
85 None => extract_pre_auth_csrf_cookie(request.headers())
86 .unwrap_or_else(|| Uuid::new_v4().to_string()),
87 };
88
89 let is_new_pre_auth =
90 session_token.is_none() && extract_pre_auth_csrf_cookie(request.headers()).is_none();
91
92 request
93 .extensions_mut()
94 .insert(CsrfToken(csrf_token.clone()));
95
96 let mut response = next.run(request).await;
97
98 if is_new_pre_auth {
99 let secure = ath.session_config().secure;
100 set_pre_auth_csrf_cookie(&mut response, &csrf_token, secure);
101 }
102
103 Ok(response)
104 } else {
105 let submitted = extract_submitted_token(&mut request).await?;
106
107 match &session_token {
108 Some(tok) => {
109 if !verify_csrf_token(tok, csrf_key, &submitted) {
110 return Err(StatusCode::FORBIDDEN);
111 }
112 request.extensions_mut().insert(CsrfToken(submitted));
113 }
114 None => {
115 let cookie_val =
116 extract_pre_auth_csrf_cookie(request.headers()).ok_or(StatusCode::FORBIDDEN)?;
117 if cookie_val.len() != submitted.len() {
118 return Err(StatusCode::FORBIDDEN);
119 }
120 let matches: bool = cookie_val.as_bytes().ct_eq(submitted.as_bytes()).into();
121 if !matches {
122 return Err(StatusCode::FORBIDDEN);
123 }
124 request.extensions_mut().insert(CsrfToken(submitted));
125 }
126 }
127
128 Ok(next.run(request).await)
129 }
130}
131
132fn extract_pre_auth_csrf_cookie(headers: &header::HeaderMap) -> Option<String> {
133 let cookie_header = headers.get(header::COOKIE)?.to_str().ok()?;
134 for pair in cookie_header.split("; ") {
135 if let Some((name, value)) = pair.split_once('=')
136 && name.trim() == PRE_AUTH_CSRF_COOKIE
137 {
138 return Some(value.trim().to_string());
139 }
140 }
141 None
142}
143
144fn set_pre_auth_csrf_cookie(response: &mut Response, token: &str, secure: bool) {
145 let mut cookie = format!(
146 "{}={}; SameSite=Lax; Path=/; Max-Age=1800",
147 PRE_AUTH_CSRF_COOKIE, token
148 );
149 if secure {
150 cookie.push_str("; Secure");
151 }
152 if let Ok(value) = cookie.parse() {
153 response.headers_mut().append(header::SET_COOKIE, value);
154 }
155}
156
157async fn extract_submitted_token(request: &mut Request<Body>) -> Result<String, StatusCode> {
162 if let Some(header_val) = request.headers().get("x-csrf-token")
163 && let Ok(token) = header_val.to_str()
164 {
165 return Ok(token.to_string());
166 }
167
168 let is_form = request
169 .headers()
170 .get(header::CONTENT_TYPE)
171 .and_then(|v| v.to_str().ok())
172 .map(|ct| ct.starts_with("application/x-www-form-urlencoded"))
173 .unwrap_or(false);
174
175 if !is_form {
176 return Err(StatusCode::FORBIDDEN);
177 }
178
179 let body = std::mem::replace(request.body_mut(), Body::empty());
180 let bytes = axum::body::to_bytes(body, 64 * 1024)
181 .await
182 .map_err(|_| StatusCode::BAD_REQUEST)?;
183
184 *request.body_mut() = Body::from(bytes.clone());
185
186 let body_str = std::str::from_utf8(&bytes).map_err(|_| StatusCode::BAD_REQUEST)?;
187 for pair in body_str.split('&') {
188 if let Some((key, value)) = pair.split_once('=')
189 && key == "csrf_token"
190 {
191 return Ok(value.to_string());
192 }
193 }
194
195 Err(StatusCode::FORBIDDEN)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
202 use axum::{Router, middleware, routing::get};
203 use chrono::{Duration, Utc};
204 use tower::ServiceExt;
205
206 const TEST_CSRF_KEY: [u8; 32] = *b"test-csrf-key-32bytes-padding!!!";
207
208 async fn ok_handler() -> StatusCode {
209 StatusCode::OK
210 }
211
212 async fn build_ath() -> AllowThem {
213 AllowThemBuilder::new("sqlite::memory:")
214 .cookie_secure(false)
215 .csrf_key(TEST_CSRF_KEY)
216 .build()
217 .await
218 .unwrap()
219 }
220
221 fn test_app(ath: AllowThem) -> Router {
222 Router::new()
223 .route("/", get(ok_handler).post(ok_handler))
224 .layer(middleware::from_fn(csrf_middleware))
225 .layer(middleware::from_fn_with_state(
226 ath.clone(),
227 crate::cors::inject_ath_into_extensions,
228 ))
229 }
230
231 fn get_set_cookie(response: &Response) -> Option<String> {
232 response
233 .headers()
234 .get(header::SET_COOKIE)
235 .and_then(|v| v.to_str().ok())
236 .map(|s| s.to_string())
237 }
238
239 fn extract_token_from_set_cookie(set_cookie: &str) -> String {
240 set_cookie
241 .split(';')
242 .next()
243 .and_then(|pair| pair.split_once('='))
244 .map(|(_, v)| v.trim().to_string())
245 .expect("csrf token not found in Set-Cookie")
246 }
247
248 #[tokio::test]
251 async fn pre_auth_get_sets_csrf_pre_cookie() {
252 let app = test_app(build_ath().await);
253 let response = app
254 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
255 .await
256 .unwrap();
257 assert_eq!(response.status(), StatusCode::OK);
258 let set_cookie = get_set_cookie(&response).expect("Set-Cookie header missing");
259 assert!(set_cookie.starts_with("csrf_pre="));
260 assert!(set_cookie.contains("SameSite=Lax"));
261 assert!(set_cookie.contains("Max-Age=1800"));
262 assert!(!set_cookie.contains("Secure"));
263 }
264
265 #[tokio::test]
266 async fn pre_auth_get_does_not_reset_existing_csrf_pre_cookie() {
267 let app = test_app(build_ath().await);
268 let response = app
269 .oneshot(
270 Request::builder()
271 .uri("/")
272 .header(header::COOKIE, "csrf_pre=existing_value")
273 .body(Body::empty())
274 .unwrap(),
275 )
276 .await
277 .unwrap();
278 assert_eq!(response.status(), StatusCode::OK);
279 assert!(get_set_cookie(&response).is_none());
280 }
281
282 #[tokio::test]
283 async fn pre_auth_post_accepts_matching_cookie_and_header() {
284 let app = test_app(build_ath().await);
285 let get_resp = app
286 .clone()
287 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
288 .await
289 .unwrap();
290 let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
291 let token = extract_token_from_set_cookie(&set_cookie);
292 let post_resp = app
293 .oneshot(
294 Request::builder()
295 .method("POST")
296 .uri("/")
297 .header(header::COOKIE, format!("csrf_pre={token}"))
298 .header("x-csrf-token", &token)
299 .body(Body::empty())
300 .unwrap(),
301 )
302 .await
303 .unwrap();
304 assert_eq!(post_resp.status(), StatusCode::OK);
305 }
306
307 #[tokio::test]
308 async fn pre_auth_post_rejects_mismatched_token() {
309 let app = test_app(build_ath().await);
310 let response = app
311 .oneshot(
312 Request::builder()
313 .method("POST")
314 .uri("/")
315 .header(header::COOKIE, "csrf_pre=correct")
316 .header("x-csrf-token", "wrong")
317 .body(Body::empty())
318 .unwrap(),
319 )
320 .await
321 .unwrap();
322 assert_eq!(response.status(), StatusCode::FORBIDDEN);
323 }
324
325 #[tokio::test]
326 async fn pre_auth_post_rejects_missing_cookie() {
327 let app = test_app(build_ath().await);
328 let response = app
329 .oneshot(
330 Request::builder()
331 .method("POST")
332 .uri("/")
333 .header("x-csrf-token", "sometoken")
334 .body(Body::empty())
335 .unwrap(),
336 )
337 .await
338 .unwrap();
339 assert_eq!(response.status(), StatusCode::FORBIDDEN);
340 }
341
342 #[tokio::test]
343 async fn pre_auth_post_accepts_form_token() {
344 let app = test_app(build_ath().await);
345 let get_resp = app
346 .clone()
347 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
348 .await
349 .unwrap();
350 let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
351 let token = extract_token_from_set_cookie(&set_cookie);
352 let body = format!("username=alice&csrf_token={token}");
353 let post_resp = app
354 .oneshot(
355 Request::builder()
356 .method("POST")
357 .uri("/")
358 .header(header::COOKIE, format!("csrf_pre={token}"))
359 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
360 .body(Body::from(body))
361 .unwrap(),
362 )
363 .await
364 .unwrap();
365 assert_eq!(post_resp.status(), StatusCode::OK);
366 }
367
368 async fn make_session_cookie(ath: &AllowThem) -> (String, String) {
371 let email = Email::new("user@example.com".into()).unwrap();
372 let user = ath
373 .db()
374 .create_user(email, "password", None, None)
375 .await
376 .unwrap();
377 let token = generate_token();
378 let hash = hash_token(&token);
379 let expires = Utc::now() + Duration::hours(24);
380 ath.db()
381 .create_session(user.id, hash, None, None, expires)
382 .await
383 .unwrap();
384 let cookie_header = ath.session_cookie(&token);
385 let cookie_value = cookie_header.split(';').next().unwrap().to_string();
386 let csrf = derive_csrf_token(&token, &TEST_CSRF_KEY);
387 (cookie_value, csrf)
388 }
389
390 #[tokio::test]
391 async fn session_bound_get_does_not_set_csrf_pre_cookie() {
392 let ath = build_ath().await;
393 let (session_cookie, _) = make_session_cookie(&ath).await;
394 let app = test_app(ath);
395 let response = app
396 .oneshot(
397 Request::builder()
398 .uri("/")
399 .header(header::COOKIE, &session_cookie)
400 .body(Body::empty())
401 .unwrap(),
402 )
403 .await
404 .unwrap();
405 assert_eq!(response.status(), StatusCode::OK);
406 assert!(get_set_cookie(&response).is_none());
407 }
408
409 #[tokio::test]
410 async fn session_bound_post_accepts_derived_token_in_header() {
411 let ath = build_ath().await;
412 let (session_cookie, csrf) = make_session_cookie(&ath).await;
413 let app = test_app(ath);
414 let response = app
415 .oneshot(
416 Request::builder()
417 .method("POST")
418 .uri("/")
419 .header(header::COOKIE, &session_cookie)
420 .header("x-csrf-token", &csrf)
421 .body(Body::empty())
422 .unwrap(),
423 )
424 .await
425 .unwrap();
426 assert_eq!(response.status(), StatusCode::OK);
427 }
428
429 #[tokio::test]
430 async fn session_bound_post_rejects_wrong_token() {
431 let ath = build_ath().await;
432 let (session_cookie, _) = make_session_cookie(&ath).await;
433 let app = test_app(ath);
434 let response = app
435 .oneshot(
436 Request::builder()
437 .method("POST")
438 .uri("/")
439 .header(header::COOKIE, &session_cookie)
440 .header(
441 "x-csrf-token",
442 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
443 )
444 .body(Body::empty())
445 .unwrap(),
446 )
447 .await
448 .unwrap();
449 assert_eq!(response.status(), StatusCode::FORBIDDEN);
450 }
451
452 #[tokio::test]
453 async fn session_bound_post_accepts_form_token() {
454 let ath = build_ath().await;
455 let (session_cookie, csrf) = make_session_cookie(&ath).await;
456 let app = test_app(ath);
457 let body = format!("field=value&csrf_token={csrf}");
458 let response = app
459 .oneshot(
460 Request::builder()
461 .method("POST")
462 .uri("/")
463 .header(header::COOKIE, &session_cookie)
464 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
465 .body(Body::from(body))
466 .unwrap(),
467 )
468 .await
469 .unwrap();
470 assert_eq!(response.status(), StatusCode::OK);
471 }
472
473 #[tokio::test]
474 async fn returns_500_when_csrf_key_not_configured() {
475 let ath = AllowThemBuilder::new("sqlite::memory:")
476 .cookie_secure(false)
477 .build()
478 .await
479 .unwrap();
480 let app = Router::new()
481 .route("/", get(ok_handler).post(ok_handler))
482 .layer(middleware::from_fn(csrf_middleware))
483 .layer(middleware::from_fn_with_state(
484 ath.clone(),
485 crate::cors::inject_ath_into_extensions,
486 ))
487 .with_state(ath);
488
489 let response = app
490 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
491 .await
492 .unwrap();
493 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
494 }
495
496 #[tokio::test]
497 async fn head_does_not_require_csrf() {
498 let app = test_app(build_ath().await);
499 let response = app
500 .oneshot(
501 Request::builder()
502 .method("HEAD")
503 .uri("/")
504 .body(Body::empty())
505 .unwrap(),
506 )
507 .await
508 .unwrap();
509 assert_eq!(response.status(), StatusCode::OK);
510 }
511}