1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use axum_extra::extract::cookie::Key;
8use cookie::{Cookie, CookieJar, SameSite};
9use http::{HeaderValue, Method, Request, Response};
10use serde::Deserialize;
11use tower::{Layer, Service};
12
13#[non_exhaustive]
19#[derive(Debug, Clone, Deserialize)]
20#[serde(default)]
21pub struct CsrfConfig {
22 pub cookie_name: String,
24 pub header_name: String,
26 pub field_name: String,
30 pub ttl_secs: u64,
32 pub exempt_methods: Vec<String>,
34}
35
36impl Default for CsrfConfig {
37 fn default() -> Self {
38 Self {
39 cookie_name: "_csrf".to_string(),
40 header_name: "X-CSRF-Token".to_string(),
41 field_name: "_csrf_token".to_string(),
42 ttl_secs: 21600,
43 exempt_methods: vec!["GET", "HEAD", "OPTIONS"]
44 .into_iter()
45 .map(String::from)
46 .collect(),
47 }
48 }
49}
50
51#[derive(Clone, Debug)]
54pub struct CsrfToken(pub String);
55
56#[derive(Clone)]
59pub struct CsrfLayer {
60 config: CsrfConfig,
61 key: Key,
62}
63
64impl<S> Layer<S> for CsrfLayer {
65 type Service = CsrfService<S>;
66
67 fn layer(&self, inner: S) -> Self::Service {
68 CsrfService {
69 inner,
70 config: self.config.clone(),
71 key: self.key.clone(),
72 }
73 }
74}
75
76#[derive(Clone)]
86pub struct CsrfService<S> {
87 inner: S,
88 config: CsrfConfig,
89 key: Key,
90}
91
92impl<S> CsrfService<S> {
93 fn sign_token(&self, token: &str) -> String {
95 let mut jar = CookieJar::new();
96 jar.signed_mut(&self.key).add(Cookie::new(
97 self.config.cookie_name.clone(),
98 token.to_string(),
99 ));
100 jar.get(&self.config.cookie_name)
101 .expect("cookie was just added")
102 .value()
103 .to_string()
104 }
105
106 fn verify_token(&self, signed_value: &str) -> Option<String> {
108 let mut jar = CookieJar::new();
109 jar.add_original(Cookie::new(
110 self.config.cookie_name.clone(),
111 signed_value.to_string(),
112 ));
113 jar.signed(&self.key)
114 .get(&self.config.cookie_name)
115 .map(|c: Cookie<'_>| c.value().to_string())
116 }
117
118 fn build_set_cookie(&self, signed_value: &str) -> String {
120 Cookie::build((self.config.cookie_name.clone(), signed_value.to_string()))
121 .http_only(true)
122 .same_site(SameSite::Lax)
123 .path("/")
124 .max_age(cookie::time::Duration::seconds(self.config.ttl_secs as i64))
125 .build()
126 .to_string()
127 }
128
129 fn is_exempt(&self, method: &Method) -> bool {
131 self.config
132 .exempt_methods
133 .iter()
134 .any(|m| m.eq_ignore_ascii_case(method.as_str()))
135 }
136
137 fn extract_submitted_token<B>(&self, request: &Request<B>) -> Option<String> {
139 request
140 .headers()
141 .get(&self.config.header_name)
142 .and_then(|v| v.to_str().ok())
143 .map(|s| s.to_string())
144 }
145
146 fn extract_cookie_value<B>(&self, request: &Request<B>) -> Option<String> {
148 let cookie_header = request.headers().get(http::header::COOKIE)?;
149 let cookie_str = cookie_header.to_str().ok()?;
150
151 for pair in cookie_str.split(';') {
152 let pair = pair.trim();
153 if let Some((name, value)) = pair.split_once('=')
154 && name.trim() == self.config.cookie_name
155 {
156 return Some(value.trim().to_string());
157 }
158 }
159
160 None
161 }
162}
163
164impl<S, ReqBody> Service<Request<ReqBody>> for CsrfService<S>
165where
166 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
167 S::Future: Send + 'static,
168 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
169 ReqBody: Send + 'static,
170{
171 type Response = Response<Body>;
172 type Error = S::Error;
173 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
174
175 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176 self.inner.poll_ready(cx)
177 }
178
179 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
180 let mut inner = self.inner.clone();
182 std::mem::swap(&mut self.inner, &mut inner);
183
184 let is_exempt = self.is_exempt(request.method());
185
186 if is_exempt {
187 let token = crate::id::ulid();
189 let signed_value = self.sign_token(&token);
190 let set_cookie_value = self.build_set_cookie(&signed_value);
191
192 request.extensions_mut().insert(CsrfToken(token.clone()));
193
194 Box::pin(async move {
195 let mut response = inner.call(request).await?;
196
197 if let Ok(header_value) = HeaderValue::from_str(&set_cookie_value) {
198 response
199 .headers_mut()
200 .append(http::header::SET_COOKIE, header_value);
201 }
202
203 response.extensions_mut().insert(CsrfToken(token));
204
205 Ok(response)
206 })
207 } else {
208 let cookie_value = self.extract_cookie_value(&request);
210 let submitted_token = self.extract_submitted_token(&request);
211
212 let verified = cookie_value
213 .and_then(|signed| self.verify_token(&signed))
214 .zip(submitted_token)
215 .is_some_and(|(cookie_token, header_token)| {
216 use subtle::ConstantTimeEq;
217 cookie_token
218 .as_bytes()
219 .ct_eq(header_token.as_bytes())
220 .into()
221 });
222
223 if verified {
224 Box::pin(async move { inner.call(request).await })
225 } else {
226 let header_name = self.config.header_name.clone();
227 Box::pin(async move {
228 let error = crate::error::Error::forbidden(format!(
229 "CSRF validation failed: missing or invalid {header_name}"
230 ));
231 Ok(error.into_response())
232 })
233 }
234 }
235 }
236}
237
238pub fn csrf(config: &CsrfConfig, key: &Key) -> CsrfLayer {
252 CsrfLayer {
253 config: config.clone(),
254 key: key.clone(),
255 }
256}