1use std::future::{ready, Future, Ready};
4use std::ops::{Deref, DerefMut};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use crate::{
9 host_prefix, secure_prefix, CsrfError, DEFAULT_CSRF_COOKIE_NAME, DEFAULT_CSRF_TOKEN_NAME,
10};
11
12use actix_web::dev::Payload;
13use actix_web::http::header::HeaderName;
14use actix_web::{FromRequest, HttpMessage, HttpRequest};
15use serde::de::{Error, Visitor};
16use serde::{Deserialize, Serialize};
17
18#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
20pub struct CsrfHeader(CsrfToken);
21
22impl CsrfHeader {
23 pub fn validate(&self, header_value: impl AsRef<str>) -> bool {
25 self.0.as_ref() == header_value.as_ref()
26 }
27}
28
29impl CsrfGuarded for CsrfHeader {
30 fn csrf_token(&self) -> &CsrfToken {
31 &self.0
32 }
33}
34
35impl FromRequest for CsrfHeader {
36 type Error = CsrfError;
37 type Future = Ready<Result<Self, Self::Error>>;
38
39 fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
40 let header_name = req
41 .app_data::<CsrfHeaderConfig>()
42 .map_or(DEFAULT_CSRF_TOKEN_NAME, |v| v.header_name.as_ref());
43
44 let resp = req
45 .headers()
46 .get(header_name)
47 .map_or(Err(CsrfError::MissingCookie), |header| {
48 match header.to_str() {
49 Ok(header) => Ok(Self(CsrfToken(header.to_owned()))),
50 Err(_) => Err(CsrfError::MissingToken),
51 }
52 });
53
54 ready(resp)
55 }
56}
57
58impl AsRef<str> for CsrfHeader {
59 fn as_ref(&self) -> &str {
60 self.0.as_ref()
61 }
62}
63
64#[derive(Clone, Eq, PartialEq, Hash, Debug)]
66pub struct CsrfHeaderConfig {
67 header_name: HeaderName,
68}
69
70impl Default for CsrfHeaderConfig {
71 fn default() -> Self {
72 Self {
73 header_name: HeaderName::from_static(DEFAULT_CSRF_TOKEN_NAME),
74 }
75 }
76}
77
78impl CsrfHeaderConfig {
79 pub const fn new(header_name: HeaderName) -> Self {
81 Self { header_name }
82 }
83}
84
85#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
87pub struct CsrfCookie(String);
88
89impl CsrfCookie {
90 pub fn validate(&self, token: impl AsRef<str>) -> bool {
92 self.0 == token.as_ref()
93 }
94
95 fn from_request_sync(req: &HttpRequest) -> Result<Self, CsrfError> {
96 let cookie_name = req
97 .app_data::<CsrfCookieConfig>()
98 .map_or(DEFAULT_CSRF_COOKIE_NAME, |v| v.cookie_name.as_ref());
99
100 req.cookie(cookie_name)
101 .ok_or(CsrfError::MissingCookie)
102 .map(|cookie| Self(cookie.value().to_string()))
103 }
104}
105
106impl FromRequest for CsrfCookie {
107 type Error = CsrfError;
108 type Future = Ready<Result<Self, Self::Error>>;
109
110 fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
111 ready(Self::from_request_sync(req))
112 }
113}
114
115impl AsRef<str> for CsrfCookie {
116 fn as_ref(&self) -> &str {
117 self.0.as_ref()
118 }
119}
120
121#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
123pub struct CsrfCookieConfig {
124 cookie_name: String,
125}
126
127impl Default for CsrfCookieConfig {
128 fn default() -> Self {
129 Self {
130 cookie_name: DEFAULT_CSRF_COOKIE_NAME.to_string(),
131 }
132 }
133}
134
135impl CsrfCookieConfig {
136 #[must_use]
142 pub const fn new(cookie_name: String) -> Self {
143 Self { cookie_name }
144 }
145
146 #[must_use]
151 pub fn with_host_prefix(cookie_name: String) -> Self {
152 Self::with_prefix(host_prefix!(), cookie_name)
153 }
154
155 #[must_use]
159 pub fn with_secure_prefix(cookie_name: String) -> Self {
160 Self::with_prefix(secure_prefix!(), cookie_name)
161 }
162
163 fn with_prefix(prefix: &'static str, cookie_name: String) -> Self {
164 if cookie_name.starts_with(prefix) {
165 Self { cookie_name }
166 } else {
167 Self {
168 cookie_name: format!("{}{}", prefix, cookie_name),
169 }
170 }
171 }
172}
173
174#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
176pub struct CsrfToken(pub(crate) String);
177
178impl Serialize for CsrfToken {
179 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
180 where
181 S: serde::Serializer,
182 {
183 serializer.serialize_newtype_struct("Csrf Token", &self.0)
184 }
185}
186
187impl<'de> Deserialize<'de> for CsrfToken {
188 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
189 where
190 D: serde::Deserializer<'de>,
191 {
192 struct CsrfTokenVisitor;
193 impl<'de> Visitor<'de> for CsrfTokenVisitor {
194 type Value = CsrfToken;
195
196 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
197 formatter.write_str("a valid csrf token")
198 }
199
200 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
201 where
202 E: Error,
203 {
204 Ok(CsrfToken(v.to_owned()))
205 }
206 }
207
208 deserializer.deserialize_string(CsrfTokenVisitor)
209 }
210}
211
212impl CsrfToken {
213 #[must_use]
216 #[doc(hidden)]
217 pub const fn test_create(value: String) -> Self {
218 Self(value)
219 }
220
221 #[must_use]
223 pub fn get(&self) -> &str {
224 self.0.as_ref()
225 }
226
227 #[must_use]
229 #[allow(clippy::missing_const_for_fn)] pub fn into_inner(self) -> String {
231 self.0
232 }
233
234 fn from_request_sync(req: &HttpRequest) -> Result<Self, CsrfError> {
235 req.extensions()
236 .get::<Self>()
237 .cloned()
238 .ok_or(CsrfError::MissingToken)
239 }
240}
241
242impl AsRef<str> for CsrfToken {
243 fn as_ref(&self) -> &str {
244 self.0.as_ref()
245 }
246}
247
248impl FromRequest for CsrfToken {
249 type Error = CsrfError;
250 type Future = Ready<Result<Self, Self::Error>>;
251
252 fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
253 ready(Self::from_request_sync(req))
254 }
255}
256
257#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
287pub struct Csrf<Inner>(Inner);
288
289impl<Inner> Csrf<Inner> {
290 #[must_use]
292 pub fn into_inner(self) -> Inner {
293 self.0
294 }
295}
296
297impl<Inner> Deref for Csrf<Inner> {
298 type Target = Inner;
299
300 fn deref(&self) -> &Self::Target {
301 &self.0
302 }
303}
304
305impl<Inner> DerefMut for Csrf<Inner> {
306 fn deref_mut(&mut self) -> &mut Self::Target {
307 &mut self.0
308 }
309}
310
311impl<Inner> FromRequest for Csrf<Inner>
312where
313 Inner: FromRequest + CsrfGuarded,
314{
315 type Error = CsrfExtractorError<Inner::Error>;
316 type Future = CsrfExtractorFuture<Inner::Future>;
317
318 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
319 CsrfExtractorFuture {
320 csrf_token: CsrfCookie::from_request_sync(req),
321 inner: Box::pin(Inner::from_request(req, payload)),
322 }
323 }
324}
325
326macro_rules! derive_csrf_guarded {
327 ($type:path) => {
328 impl<T> CsrfGuarded for $type
329 where
330 T: CsrfGuarded,
331 {
332 fn csrf_token(&self) -> &CsrfToken {
333 self.0.csrf_token()
334 }
335 }
336 };
337}
338
339derive_csrf_guarded!(actix_web::web::Form<T>);
340derive_csrf_guarded!(actix_web::web::Json<T>);
341
342pub struct CsrfExtractorFuture<Fut> {
346 csrf_token: Result<CsrfCookie, CsrfError>,
347 inner: Pin<Box<Fut>>,
348}
349
350impl<Fut, FutOut, FutErr> Future for CsrfExtractorFuture<Fut>
351where
352 Fut: Future<Output = Result<FutOut, FutErr>>,
353 FutOut: CsrfGuarded,
354{
355 type Output = Result<Csrf<FutOut>, CsrfExtractorError<FutErr>>;
356
357 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358 match self.inner.as_mut().poll(cx) {
359 Poll::Ready(Ok(out)) => {
360 if let Ok(ref token) = self.csrf_token {
361 if out.csrf_token().as_ref() == token.as_ref() {
362 return Poll::Ready(Ok(Csrf(out)));
363 }
364 }
365
366 Poll::Ready(Err(CsrfExtractorError::InvalidToken))
367 }
368 Poll::Ready(Err(e)) => Poll::Ready(Err(CsrfExtractorError::Inner(e))),
369 Poll::Pending => Poll::Pending,
370 }
371 }
372}
373
374pub trait CsrfGuarded {
379 fn csrf_token(&self) -> &CsrfToken;
381}
382
383#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
385pub enum CsrfExtractorError<Inner> {
386 InvalidToken,
388 Inner(Inner),
390}
391
392impl<Inner> From<CsrfExtractorError<Inner>> for actix_web::error::Error
393where
394 Inner: Into<Self>,
395{
396 fn from(e: CsrfExtractorError<Inner>) -> Self {
397 match e {
398 CsrfExtractorError::InvalidToken => CsrfError::TokenMismatch.into(),
399 CsrfExtractorError::Inner(e) => e.into(),
400 }
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use std::error::Error;
407
408 use crate::DEFAULT_CSRF_COOKIE_NAME;
409
410 use super::*;
411
412 use actix_web::http::header;
413 use actix_web::test::TestRequest;
414
415 #[tokio::test]
416 async fn extract_from_header() -> Result<(), Box<dyn Error>> {
417 let req = TestRequest::default()
418 .insert_header((DEFAULT_CSRF_TOKEN_NAME, "sometoken"))
419 .to_http_request();
420 let token = CsrfHeader::extract(&req).await?;
421 assert!(token.validate("sometoken"));
422
423 Ok(())
424 }
425
426 #[tokio::test]
427 async fn not_found_header() {
428 let req = TestRequest::default()
429 .insert_header(("fake", "sometoken"))
430 .to_http_request();
431 let token = CsrfHeader::extract(&req).await;
432 assert!(token.is_err());
433 }
434
435 #[tokio::test]
436 async fn extract_from_cookie() -> Result<(), Box<dyn Error>> {
437 let req = TestRequest::default()
438 .insert_header((
439 header::COOKIE,
440 format!("{DEFAULT_CSRF_COOKIE_NAME}=sometoken"),
441 ))
442 .to_http_request();
443
444 let token = CsrfCookie::extract(&req).await?;
445 assert!(token.validate("sometoken"));
446 Ok(())
447 }
448
449 #[tokio::test]
450 async fn not_found_cookie() {
451 let req = TestRequest::default()
452 .insert_header(("fake", "sometoken"))
453 .to_http_request();
454 let token = CsrfCookie::extract(&req).await;
455 assert!(token.is_err());
456 }
457}