1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use axum_extra::extract::cookie::Key;
8use cookie::{Cookie, CookieJar, SameSite};
9use http::{HeaderValue, Method, Request, Response};
10use serde::Deserialize;
11use tower::{Layer, Service};
12
13#[non_exhaustive]
19#[derive(Debug, Clone, Deserialize)]
20#[serde(default)]
21pub struct CsrfConfig {
22 pub cookie_name: String,
24 pub header_name: String,
26 pub field_name: String,
30 pub ttl_secs: u64,
32 pub exempt_methods: Vec<String>,
34}
35
36impl Default for CsrfConfig {
37 fn default() -> Self {
38 Self {
39 cookie_name: "_csrf".to_string(),
40 header_name: "X-CSRF-Token".to_string(),
41 field_name: "_csrf_token".to_string(),
42 ttl_secs: 21600,
43 exempt_methods: vec!["GET", "HEAD", "OPTIONS"]
44 .into_iter()
45 .map(String::from)
46 .collect(),
47 }
48 }
49}
50
51#[derive(Clone, Debug)]
54pub struct CsrfToken(pub String);
55
56#[derive(Clone)]
59pub struct CsrfLayer {
60 config: CsrfConfig,
61 key: Key,
62}
63
64impl<S> Layer<S> for CsrfLayer {
65 type Service = CsrfService<S>;
66
67 fn layer(&self, inner: S) -> Self::Service {
68 CsrfService {
69 inner,
70 config: self.config.clone(),
71 key: self.key.clone(),
72 }
73 }
74}
75
76#[derive(Clone)]
89pub struct CsrfService<S> {
90 inner: S,
91 config: CsrfConfig,
92 key: Key,
93}
94
95impl<S> CsrfService<S> {
96 fn sign_token(&self, token: &str) -> String {
98 let mut jar = CookieJar::new();
99 jar.signed_mut(&self.key).add(Cookie::new(
100 self.config.cookie_name.clone(),
101 token.to_string(),
102 ));
103 jar.get(&self.config.cookie_name)
104 .expect("cookie was just added")
105 .value()
106 .to_string()
107 }
108
109 fn verify_token(&self, signed_value: &str) -> Option<String> {
111 let mut jar = CookieJar::new();
112 jar.add_original(Cookie::new(
113 self.config.cookie_name.clone(),
114 signed_value.to_string(),
115 ));
116 jar.signed(&self.key)
117 .get(&self.config.cookie_name)
118 .map(|c: Cookie<'_>| c.value().to_string())
119 }
120
121 fn build_set_cookie(&self, signed_value: &str) -> String {
123 Cookie::build((self.config.cookie_name.clone(), signed_value.to_string()))
124 .http_only(true)
125 .same_site(SameSite::Lax)
126 .path("/")
127 .max_age(cookie::time::Duration::seconds(self.config.ttl_secs as i64))
128 .build()
129 .to_string()
130 }
131
132 fn is_exempt(&self, method: &Method) -> bool {
134 self.config
135 .exempt_methods
136 .iter()
137 .any(|m| m.eq_ignore_ascii_case(method.as_str()))
138 }
139
140 fn extract_submitted_token<B>(&self, request: &Request<B>) -> Option<String> {
142 request
143 .headers()
144 .get(&self.config.header_name)
145 .and_then(|v| v.to_str().ok())
146 .map(|s| s.to_string())
147 }
148
149 fn extract_cookie_value<B>(&self, request: &Request<B>) -> Option<String> {
151 let cookie_header = request.headers().get(http::header::COOKIE)?;
152 let cookie_str = cookie_header.to_str().ok()?;
153
154 for pair in cookie_str.split(';') {
155 let pair = pair.trim();
156 if let Some((name, value)) = pair.split_once('=')
157 && name.trim() == self.config.cookie_name
158 {
159 return Some(value.trim().to_string());
160 }
161 }
162
163 None
164 }
165}
166
167impl<S, ReqBody> Service<Request<ReqBody>> for CsrfService<S>
168where
169 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
170 S::Future: Send + 'static,
171 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
172 ReqBody: Send + 'static,
173{
174 type Response = Response<Body>;
175 type Error = S::Error;
176 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
177
178 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
179 self.inner.poll_ready(cx)
180 }
181
182 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
183 let mut inner = self.inner.clone();
185 std::mem::swap(&mut self.inner, &mut inner);
186
187 let is_exempt = self.is_exempt(request.method());
188
189 if is_exempt {
190 let existing = self
195 .extract_cookie_value(&request)
196 .and_then(|signed| self.verify_token(&signed));
197
198 let (token, set_cookie_value) = match existing {
199 Some(t) => (t, None),
200 None => {
201 let t = crate::id::ulid();
202 let signed = self.sign_token(&t);
203 let sc = self.build_set_cookie(&signed);
204 (t, Some(sc))
205 }
206 };
207
208 request.extensions_mut().insert(CsrfToken(token.clone()));
209
210 Box::pin(async move {
211 let mut response = inner.call(request).await?;
212
213 if let Some(sc) = set_cookie_value
214 && let Ok(header_value) = HeaderValue::from_str(&sc)
215 {
216 response
217 .headers_mut()
218 .append(http::header::SET_COOKIE, header_value);
219 }
220
221 response.extensions_mut().insert(CsrfToken(token));
222
223 Ok(response)
224 })
225 } else {
226 let cookie_value = self.extract_cookie_value(&request);
228 let submitted_token = self.extract_submitted_token(&request);
229
230 let verified = cookie_value
231 .and_then(|signed| self.verify_token(&signed))
232 .zip(submitted_token)
233 .is_some_and(|(cookie_token, header_token)| {
234 use subtle::ConstantTimeEq;
235 cookie_token
236 .as_bytes()
237 .ct_eq(header_token.as_bytes())
238 .into()
239 });
240
241 if verified {
242 Box::pin(async move { inner.call(request).await })
243 } else {
244 let header_name = self.config.header_name.clone();
245 Box::pin(async move {
246 let error = crate::error::Error::forbidden(format!(
247 "CSRF validation failed: missing or invalid {header_name}"
248 ));
249 Ok(error.into_response())
250 })
251 }
252 }
253 }
254}
255
256pub fn csrf(config: &CsrfConfig, key: &Key) -> CsrfLayer {
270 CsrfLayer {
271 config: config.clone(),
272 key: key.clone(),
273 }
274}