1use axum::{
2 body::Body,
3 extract::{FromRequestParts, State},
4 http::{Request, StatusCode, header, request::Parts},
5 middleware::Next,
6 response::Response,
7};
8use subtle::ConstantTimeEq;
9use uuid::Uuid;
10
11use allowthem_core::{AllowThem, derive_csrf_token, verify_csrf_token};
12
13const PRE_AUTH_CSRF_COOKIE: &str = "csrf_pre";
14
15#[derive(Clone)]
21pub struct CsrfToken(pub String);
22
23impl CsrfToken {
24 pub fn as_str(&self) -> &str {
25 &self.0
26 }
27}
28
29impl<S: Send + Sync> FromRequestParts<S> for CsrfToken {
30 type Rejection = StatusCode;
31
32 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
33 parts
34 .extensions
35 .get::<CsrfToken>()
36 .cloned()
37 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)
38 }
39}
40
41pub async fn csrf_middleware(
55 State(ath): State<AllowThem>,
56 mut request: Request<Body>,
57 next: Next,
58) -> Result<Response, StatusCode> {
59 let csrf_key = ath
60 .csrf_key()
61 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
62
63 let method = request.method().clone();
64 let is_safe = matches!(
65 method,
66 axum::http::Method::GET | axum::http::Method::HEAD | axum::http::Method::OPTIONS
67 );
68
69 let session_token = ath.parse_session_cookie(
70 request
71 .headers()
72 .get(header::COOKIE)
73 .and_then(|v| v.to_str().ok())
74 .unwrap_or(""),
75 );
76
77 if is_safe {
78 let csrf_token = match &session_token {
79 Some(tok) => derive_csrf_token(tok, csrf_key),
80 None => extract_pre_auth_csrf_cookie(request.headers())
81 .unwrap_or_else(|| Uuid::new_v4().to_string()),
82 };
83
84 let is_new_pre_auth =
85 session_token.is_none() && extract_pre_auth_csrf_cookie(request.headers()).is_none();
86
87 request
88 .extensions_mut()
89 .insert(CsrfToken(csrf_token.clone()));
90
91 let mut response = next.run(request).await;
92
93 if is_new_pre_auth {
94 let secure = ath.session_config().secure;
95 set_pre_auth_csrf_cookie(&mut response, &csrf_token, secure);
96 }
97
98 Ok(response)
99 } else {
100 let submitted = extract_submitted_token(&mut request).await?;
101
102 match &session_token {
103 Some(tok) => {
104 if !verify_csrf_token(tok, csrf_key, &submitted) {
105 return Err(StatusCode::FORBIDDEN);
106 }
107 request.extensions_mut().insert(CsrfToken(submitted));
108 }
109 None => {
110 let cookie_val =
111 extract_pre_auth_csrf_cookie(request.headers()).ok_or(StatusCode::FORBIDDEN)?;
112 if cookie_val.len() != submitted.len() {
113 return Err(StatusCode::FORBIDDEN);
114 }
115 let matches: bool = cookie_val.as_bytes().ct_eq(submitted.as_bytes()).into();
116 if !matches {
117 return Err(StatusCode::FORBIDDEN);
118 }
119 request.extensions_mut().insert(CsrfToken(submitted));
120 }
121 }
122
123 Ok(next.run(request).await)
124 }
125}
126
127fn extract_pre_auth_csrf_cookie(headers: &header::HeaderMap) -> Option<String> {
128 let cookie_header = headers.get(header::COOKIE)?.to_str().ok()?;
129 for pair in cookie_header.split("; ") {
130 if let Some((name, value)) = pair.split_once('=')
131 && name.trim() == PRE_AUTH_CSRF_COOKIE
132 {
133 return Some(value.trim().to_string());
134 }
135 }
136 None
137}
138
139fn set_pre_auth_csrf_cookie(response: &mut Response, token: &str, secure: bool) {
140 let mut cookie = format!(
141 "{}={}; SameSite=Lax; Path=/; Max-Age=1800",
142 PRE_AUTH_CSRF_COOKIE, token
143 );
144 if secure {
145 cookie.push_str("; Secure");
146 }
147 if let Ok(value) = cookie.parse() {
148 response.headers_mut().append(header::SET_COOKIE, value);
149 }
150}
151
152async fn extract_submitted_token(request: &mut Request<Body>) -> Result<String, StatusCode> {
157 if let Some(header_val) = request.headers().get("x-csrf-token")
158 && let Ok(token) = header_val.to_str()
159 {
160 return Ok(token.to_string());
161 }
162
163 let is_form = request
164 .headers()
165 .get(header::CONTENT_TYPE)
166 .and_then(|v| v.to_str().ok())
167 .map(|ct| ct.starts_with("application/x-www-form-urlencoded"))
168 .unwrap_or(false);
169
170 if !is_form {
171 return Err(StatusCode::FORBIDDEN);
172 }
173
174 let body = std::mem::replace(request.body_mut(), Body::empty());
175 let bytes = axum::body::to_bytes(body, 64 * 1024)
176 .await
177 .map_err(|_| StatusCode::BAD_REQUEST)?;
178
179 *request.body_mut() = Body::from(bytes.clone());
180
181 let body_str = std::str::from_utf8(&bytes).map_err(|_| StatusCode::BAD_REQUEST)?;
182 for pair in body_str.split('&') {
183 if let Some((key, value)) = pair.split_once('=')
184 && key == "csrf_token"
185 {
186 return Ok(value.to_string());
187 }
188 }
189
190 Err(StatusCode::FORBIDDEN)
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
197 use axum::{Router, middleware, routing::get};
198 use chrono::{Duration, Utc};
199 use tower::ServiceExt;
200
201 const TEST_CSRF_KEY: [u8; 32] = *b"test-csrf-key-32bytes-padding!!!";
202
203 async fn ok_handler() -> StatusCode {
204 StatusCode::OK
205 }
206
207 async fn build_ath() -> AllowThem {
208 AllowThemBuilder::new("sqlite::memory:")
209 .cookie_secure(false)
210 .csrf_key(TEST_CSRF_KEY)
211 .build()
212 .await
213 .unwrap()
214 }
215
216 fn test_app(ath: AllowThem) -> Router {
217 Router::new()
218 .route("/", get(ok_handler).post(ok_handler))
219 .layer(middleware::from_fn_with_state(ath.clone(), csrf_middleware))
220 .with_state(ath)
221 }
222
223 fn get_set_cookie(response: &Response) -> Option<String> {
224 response
225 .headers()
226 .get(header::SET_COOKIE)
227 .and_then(|v| v.to_str().ok())
228 .map(|s| s.to_string())
229 }
230
231 fn extract_token_from_set_cookie(set_cookie: &str) -> String {
232 set_cookie
233 .split(';')
234 .next()
235 .and_then(|pair| pair.split_once('='))
236 .map(|(_, v)| v.trim().to_string())
237 .expect("csrf token not found in Set-Cookie")
238 }
239
240 #[tokio::test]
243 async fn pre_auth_get_sets_csrf_pre_cookie() {
244 let app = test_app(build_ath().await);
245 let response = app
246 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
247 .await
248 .unwrap();
249 assert_eq!(response.status(), StatusCode::OK);
250 let set_cookie = get_set_cookie(&response).expect("Set-Cookie header missing");
251 assert!(set_cookie.starts_with("csrf_pre="));
252 assert!(set_cookie.contains("SameSite=Lax"));
253 assert!(set_cookie.contains("Max-Age=1800"));
254 assert!(!set_cookie.contains("Secure"));
255 }
256
257 #[tokio::test]
258 async fn pre_auth_get_does_not_reset_existing_csrf_pre_cookie() {
259 let app = test_app(build_ath().await);
260 let response = app
261 .oneshot(
262 Request::builder()
263 .uri("/")
264 .header(header::COOKIE, "csrf_pre=existing_value")
265 .body(Body::empty())
266 .unwrap(),
267 )
268 .await
269 .unwrap();
270 assert_eq!(response.status(), StatusCode::OK);
271 assert!(get_set_cookie(&response).is_none());
272 }
273
274 #[tokio::test]
275 async fn pre_auth_post_accepts_matching_cookie_and_header() {
276 let app = test_app(build_ath().await);
277 let get_resp = app
278 .clone()
279 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
280 .await
281 .unwrap();
282 let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
283 let token = extract_token_from_set_cookie(&set_cookie);
284 let post_resp = app
285 .oneshot(
286 Request::builder()
287 .method("POST")
288 .uri("/")
289 .header(header::COOKIE, format!("csrf_pre={token}"))
290 .header("x-csrf-token", &token)
291 .body(Body::empty())
292 .unwrap(),
293 )
294 .await
295 .unwrap();
296 assert_eq!(post_resp.status(), StatusCode::OK);
297 }
298
299 #[tokio::test]
300 async fn pre_auth_post_rejects_mismatched_token() {
301 let app = test_app(build_ath().await);
302 let response = app
303 .oneshot(
304 Request::builder()
305 .method("POST")
306 .uri("/")
307 .header(header::COOKIE, "csrf_pre=correct")
308 .header("x-csrf-token", "wrong")
309 .body(Body::empty())
310 .unwrap(),
311 )
312 .await
313 .unwrap();
314 assert_eq!(response.status(), StatusCode::FORBIDDEN);
315 }
316
317 #[tokio::test]
318 async fn pre_auth_post_rejects_missing_cookie() {
319 let app = test_app(build_ath().await);
320 let response = app
321 .oneshot(
322 Request::builder()
323 .method("POST")
324 .uri("/")
325 .header("x-csrf-token", "sometoken")
326 .body(Body::empty())
327 .unwrap(),
328 )
329 .await
330 .unwrap();
331 assert_eq!(response.status(), StatusCode::FORBIDDEN);
332 }
333
334 #[tokio::test]
335 async fn pre_auth_post_accepts_form_token() {
336 let app = test_app(build_ath().await);
337 let get_resp = app
338 .clone()
339 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
340 .await
341 .unwrap();
342 let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
343 let token = extract_token_from_set_cookie(&set_cookie);
344 let body = format!("username=alice&csrf_token={token}");
345 let post_resp = app
346 .oneshot(
347 Request::builder()
348 .method("POST")
349 .uri("/")
350 .header(header::COOKIE, format!("csrf_pre={token}"))
351 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
352 .body(Body::from(body))
353 .unwrap(),
354 )
355 .await
356 .unwrap();
357 assert_eq!(post_resp.status(), StatusCode::OK);
358 }
359
360 async fn make_session_cookie(ath: &AllowThem) -> (String, String) {
363 let email = Email::new("user@example.com".into()).unwrap();
364 let user = ath
365 .db()
366 .create_user(email, "password", None, None)
367 .await
368 .unwrap();
369 let token = generate_token();
370 let hash = hash_token(&token);
371 let expires = Utc::now() + Duration::hours(24);
372 ath.db()
373 .create_session(user.id, hash, None, None, expires)
374 .await
375 .unwrap();
376 let cookie_header = ath.session_cookie(&token);
377 let cookie_value = cookie_header.split(';').next().unwrap().to_string();
378 let csrf = derive_csrf_token(&token, &TEST_CSRF_KEY);
379 (cookie_value, csrf)
380 }
381
382 #[tokio::test]
383 async fn session_bound_get_does_not_set_csrf_pre_cookie() {
384 let ath = build_ath().await;
385 let (session_cookie, _) = make_session_cookie(&ath).await;
386 let app = test_app(ath);
387 let response = app
388 .oneshot(
389 Request::builder()
390 .uri("/")
391 .header(header::COOKIE, &session_cookie)
392 .body(Body::empty())
393 .unwrap(),
394 )
395 .await
396 .unwrap();
397 assert_eq!(response.status(), StatusCode::OK);
398 assert!(get_set_cookie(&response).is_none());
399 }
400
401 #[tokio::test]
402 async fn session_bound_post_accepts_derived_token_in_header() {
403 let ath = build_ath().await;
404 let (session_cookie, csrf) = make_session_cookie(&ath).await;
405 let app = test_app(ath);
406 let response = app
407 .oneshot(
408 Request::builder()
409 .method("POST")
410 .uri("/")
411 .header(header::COOKIE, &session_cookie)
412 .header("x-csrf-token", &csrf)
413 .body(Body::empty())
414 .unwrap(),
415 )
416 .await
417 .unwrap();
418 assert_eq!(response.status(), StatusCode::OK);
419 }
420
421 #[tokio::test]
422 async fn session_bound_post_rejects_wrong_token() {
423 let ath = build_ath().await;
424 let (session_cookie, _) = make_session_cookie(&ath).await;
425 let app = test_app(ath);
426 let response = app
427 .oneshot(
428 Request::builder()
429 .method("POST")
430 .uri("/")
431 .header(header::COOKIE, &session_cookie)
432 .header(
433 "x-csrf-token",
434 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
435 )
436 .body(Body::empty())
437 .unwrap(),
438 )
439 .await
440 .unwrap();
441 assert_eq!(response.status(), StatusCode::FORBIDDEN);
442 }
443
444 #[tokio::test]
445 async fn session_bound_post_accepts_form_token() {
446 let ath = build_ath().await;
447 let (session_cookie, csrf) = make_session_cookie(&ath).await;
448 let app = test_app(ath);
449 let body = format!("field=value&csrf_token={csrf}");
450 let response = app
451 .oneshot(
452 Request::builder()
453 .method("POST")
454 .uri("/")
455 .header(header::COOKIE, &session_cookie)
456 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
457 .body(Body::from(body))
458 .unwrap(),
459 )
460 .await
461 .unwrap();
462 assert_eq!(response.status(), StatusCode::OK);
463 }
464
465 #[tokio::test]
466 async fn returns_500_when_csrf_key_not_configured() {
467 let ath = AllowThemBuilder::new("sqlite::memory:")
468 .cookie_secure(false)
469 .build()
470 .await
471 .unwrap();
472 let app = Router::new()
473 .route("/", get(ok_handler).post(ok_handler))
474 .layer(middleware::from_fn_with_state(ath.clone(), csrf_middleware))
475 .with_state(ath);
476 let response = app
477 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
478 .await
479 .unwrap();
480 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
481 }
482
483 #[tokio::test]
484 async fn head_does_not_require_csrf() {
485 let app = test_app(build_ath().await);
486 let response = app
487 .oneshot(
488 Request::builder()
489 .method("HEAD")
490 .uri("/")
491 .body(Body::empty())
492 .unwrap(),
493 )
494 .await
495 .unwrap();
496 assert_eq!(response.status(), StatusCode::OK);
497 }
498}