1use axum::{
2 body::Body,
3 extract::FromRequestParts,
4 http::{Request, StatusCode, header, request::Parts},
5 middleware::Next,
6 response::Response,
7};
8use uuid::Uuid;
9
10const CSRF_COOKIE_NAME: &str = "csrf_token";
11
12#[derive(Clone)]
18pub struct CsrfToken(pub String);
19
20impl CsrfToken {
21 pub fn as_str(&self) -> &str {
22 &self.0
23 }
24}
25
26impl<S: Send + Sync> FromRequestParts<S> for CsrfToken {
27 type Rejection = StatusCode;
28
29 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
30 parts
31 .extensions
32 .get::<CsrfToken>()
33 .cloned()
34 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)
35 }
36}
37
38pub async fn csrf_middleware(
48 mut request: Request<Body>,
49 next: Next,
50) -> Result<Response, StatusCode> {
51 let method = request.method().clone();
52 let is_safe = matches!(
53 method,
54 axum::http::Method::GET | axum::http::Method::HEAD | axum::http::Method::OPTIONS
55 );
56
57 let cookie_token = extract_csrf_cookie(request.headers());
58
59 if is_safe {
60 let is_new = cookie_token.is_none();
61 let token = cookie_token.unwrap_or_else(|| Uuid::new_v4().to_string());
62
63 request.extensions_mut().insert(CsrfToken(token.clone()));
64
65 let mut response = next.run(request).await;
66
67 if is_new {
68 let cookie = format!("{}={}; SameSite=Lax; Path=/", CSRF_COOKIE_NAME, token);
69 if let Ok(value) = cookie.parse() {
70 response.headers_mut().append(header::SET_COOKIE, value);
71 }
72 }
73
74 Ok(response)
75 } else {
76 let submitted = extract_submitted_token(&mut request).await?;
77
78 let cookie_val = cookie_token.ok_or(StatusCode::FORBIDDEN)?;
79
80 if submitted != cookie_val {
81 return Err(StatusCode::FORBIDDEN);
82 }
83
84 request.extensions_mut().insert(CsrfToken(cookie_val));
85
86 Ok(next.run(request).await)
87 }
88}
89
90fn extract_csrf_cookie(headers: &header::HeaderMap) -> Option<String> {
92 let cookie_header = headers.get(header::COOKIE)?.to_str().ok()?;
93 for pair in cookie_header.split("; ") {
94 if let Some((name, value)) = pair.split_once('=')
95 && name.trim() == CSRF_COOKIE_NAME
96 {
97 return Some(value.trim().to_string());
98 }
99 }
100 None
101}
102
103async fn extract_submitted_token(request: &mut Request<Body>) -> Result<String, StatusCode> {
108 if let Some(header_val) = request.headers().get("x-csrf-token")
110 && let Ok(token) = header_val.to_str()
111 {
112 return Ok(token.to_string());
113 }
114
115 let is_form = request
117 .headers()
118 .get(header::CONTENT_TYPE)
119 .and_then(|v| v.to_str().ok())
120 .map(|ct| ct.starts_with("application/x-www-form-urlencoded"))
121 .unwrap_or(false);
122
123 if !is_form {
124 return Err(StatusCode::FORBIDDEN);
125 }
126
127 let body = std::mem::replace(request.body_mut(), Body::empty());
129 let bytes = axum::body::to_bytes(body, 64 * 1024)
130 .await
131 .map_err(|_| StatusCode::BAD_REQUEST)?;
132
133 *request.body_mut() = Body::from(bytes.clone());
135
136 let body_str = std::str::from_utf8(&bytes).map_err(|_| StatusCode::BAD_REQUEST)?;
138 for pair in body_str.split('&') {
139 if let Some((key, value)) = pair.split_once('=')
140 && key == "csrf_token"
141 {
142 return Ok(value.to_string());
143 }
144 }
145
146 Err(StatusCode::FORBIDDEN)
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use axum::{Router, middleware, routing::get};
153 use tower::ServiceExt;
154
155 async fn ok_handler() -> StatusCode {
156 StatusCode::OK
157 }
158
159 fn test_app() -> Router {
160 Router::new()
161 .route("/", get(ok_handler).post(ok_handler))
162 .layer(middleware::from_fn(csrf_middleware))
163 }
164
165 fn get_set_cookie(response: &Response) -> Option<String> {
166 response
167 .headers()
168 .get(header::SET_COOKIE)
169 .and_then(|v| v.to_str().ok())
170 .map(|s| s.to_string())
171 }
172
173 fn extract_token_from_set_cookie(set_cookie: &str) -> String {
174 set_cookie
176 .split(';')
177 .next()
178 .and_then(|pair| pair.split_once('='))
179 .map(|(_, v)| v.trim().to_string())
180 .expect("csrf token not found in Set-Cookie")
181 }
182
183 #[tokio::test]
184 async fn get_sets_csrf_cookie() {
185 let app = test_app();
186 let response = app
187 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
188 .await
189 .unwrap();
190
191 assert_eq!(response.status(), StatusCode::OK);
192 let set_cookie = get_set_cookie(&response).expect("Set-Cookie header missing");
193 assert!(set_cookie.starts_with("csrf_token="));
194 assert!(set_cookie.contains("SameSite=Lax"));
195 }
196
197 #[tokio::test]
198 async fn head_does_not_require_csrf() {
199 let app = Router::new()
200 .route("/", axum::routing::any(ok_handler))
201 .layer(middleware::from_fn(csrf_middleware));
202
203 let response = app
204 .oneshot(
205 Request::builder()
206 .method("HEAD")
207 .uri("/")
208 .body(Body::empty())
209 .unwrap(),
210 )
211 .await
212 .unwrap();
213
214 assert_eq!(response.status(), StatusCode::OK);
215 }
216
217 #[tokio::test]
218 async fn post_with_valid_header_token_passes() {
219 let app = test_app();
220
221 let get_resp = app
223 .clone()
224 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
225 .await
226 .unwrap();
227 let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
228 let token = extract_token_from_set_cookie(&set_cookie);
229
230 let post_resp = app
232 .oneshot(
233 Request::builder()
234 .method("POST")
235 .uri("/")
236 .header(header::COOKIE, format!("csrf_token={token}"))
237 .header("x-csrf-token", &token)
238 .body(Body::empty())
239 .unwrap(),
240 )
241 .await
242 .unwrap();
243
244 assert_eq!(post_resp.status(), StatusCode::OK);
245 }
246
247 #[tokio::test]
248 async fn post_with_valid_form_token_passes() {
249 let app = test_app();
250
251 let get_resp = app
252 .clone()
253 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
254 .await
255 .unwrap();
256 let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
257 let token = extract_token_from_set_cookie(&set_cookie);
258
259 let body = format!("username=alice&csrf_token={token}");
260 let post_resp = app
261 .oneshot(
262 Request::builder()
263 .method("POST")
264 .uri("/")
265 .header(header::COOKIE, format!("csrf_token={token}"))
266 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
267 .body(Body::from(body))
268 .unwrap(),
269 )
270 .await
271 .unwrap();
272
273 assert_eq!(post_resp.status(), StatusCode::OK);
274 }
275
276 #[tokio::test]
277 async fn post_with_missing_token_returns_403() {
278 let app = test_app();
279
280 let response = app
282 .oneshot(
283 Request::builder()
284 .method("POST")
285 .uri("/")
286 .header(header::COOKIE, "csrf_token=someval")
287 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
288 .body(Body::from("username=alice"))
289 .unwrap(),
290 )
291 .await
292 .unwrap();
293
294 assert_eq!(response.status(), StatusCode::FORBIDDEN);
295 }
296
297 #[tokio::test]
298 async fn post_with_wrong_token_returns_403() {
299 let app = test_app();
300
301 let response = app
302 .oneshot(
303 Request::builder()
304 .method("POST")
305 .uri("/")
306 .header(header::COOKIE, "csrf_token=correct")
307 .header("x-csrf-token", "wrong")
308 .body(Body::empty())
309 .unwrap(),
310 )
311 .await
312 .unwrap();
313
314 assert_eq!(response.status(), StatusCode::FORBIDDEN);
315 }
316
317 #[tokio::test]
318 async fn post_with_missing_cookie_returns_403() {
319 let app = test_app();
320
321 let response = app
323 .oneshot(
324 Request::builder()
325 .method("POST")
326 .uri("/")
327 .header("x-csrf-token", "sometoken")
328 .body(Body::empty())
329 .unwrap(),
330 )
331 .await
332 .unwrap();
333
334 assert_eq!(response.status(), StatusCode::FORBIDDEN);
335 }
336
337 #[tokio::test]
338 async fn existing_cookie_not_overwritten_on_get() {
339 let app = test_app();
340
341 let response = app
342 .oneshot(
343 Request::builder()
344 .uri("/")
345 .header(header::COOKIE, "csrf_token=existing_token")
346 .body(Body::empty())
347 .unwrap(),
348 )
349 .await
350 .unwrap();
351
352 assert_eq!(response.status(), StatusCode::OK);
353 assert!(get_set_cookie(&response).is_none());
355 }
356}