actix_csrf/
extractor.rs

1//! Contains various extractors related to CSRF tokens.
2
3use std::future::{ready, Future, Ready};
4use std::ops::{Deref, DerefMut};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use crate::{
9    host_prefix, secure_prefix, CsrfError, DEFAULT_CSRF_COOKIE_NAME, DEFAULT_CSRF_TOKEN_NAME,
10};
11
12use actix_web::dev::Payload;
13use actix_web::http::header::HeaderName;
14use actix_web::{FromRequest, HttpMessage, HttpRequest};
15use serde::de::{Error, Visitor};
16use serde::{Deserialize, Serialize};
17
18/// Extractor to get the CSRF header from the request.
19#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
20pub struct CsrfHeader(CsrfToken);
21
22impl CsrfHeader {
23    /// Checks if the header matches the CSRF header.
24    pub fn validate(&self, header_value: impl AsRef<str>) -> bool {
25        self.0.as_ref() == header_value.as_ref()
26    }
27}
28
29impl CsrfGuarded for CsrfHeader {
30    fn csrf_token(&self) -> &CsrfToken {
31        &self.0
32    }
33}
34
35impl FromRequest for CsrfHeader {
36    type Error = CsrfError;
37    type Future = Ready<Result<Self, Self::Error>>;
38
39    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
40        let header_name = req
41            .app_data::<CsrfHeaderConfig>()
42            .map_or(DEFAULT_CSRF_TOKEN_NAME, |v| v.header_name.as_ref());
43
44        let resp = req
45            .headers()
46            .get(header_name)
47            .map_or(Err(CsrfError::MissingCookie), |header| {
48                match header.to_str() {
49                    Ok(header) => Ok(Self(CsrfToken(header.to_owned()))),
50                    Err(_) => Err(CsrfError::MissingToken),
51                }
52            });
53
54        ready(resp)
55    }
56}
57
58impl AsRef<str> for CsrfHeader {
59    fn as_ref(&self) -> &str {
60        self.0.as_ref()
61    }
62}
63
64/// Configuration struct for [`CsrfHeader`].
65#[derive(Clone, Eq, PartialEq, Hash, Debug)]
66pub struct CsrfHeaderConfig {
67    header_name: HeaderName,
68}
69
70impl Default for CsrfHeaderConfig {
71    fn default() -> Self {
72        Self {
73            header_name: HeaderName::from_static(DEFAULT_CSRF_TOKEN_NAME),
74        }
75    }
76}
77
78impl CsrfHeaderConfig {
79    /// Sets the header name to read the CSRF token from.
80    pub const fn new(header_name: HeaderName) -> Self {
81        Self { header_name }
82    }
83}
84
85/// Extractor to get the CSRF cookie from the request.
86#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
87pub struct CsrfCookie(String);
88
89impl CsrfCookie {
90    /// Checks if the input matches the cookie.
91    pub fn validate(&self, token: impl AsRef<str>) -> bool {
92        self.0 == token.as_ref()
93    }
94
95    fn from_request_sync(req: &HttpRequest) -> Result<Self, CsrfError> {
96        let cookie_name = req
97            .app_data::<CsrfCookieConfig>()
98            .map_or(DEFAULT_CSRF_COOKIE_NAME, |v| v.cookie_name.as_ref());
99
100        req.cookie(cookie_name)
101            .ok_or(CsrfError::MissingCookie)
102            .map(|cookie| Self(cookie.value().to_string()))
103    }
104}
105
106impl FromRequest for CsrfCookie {
107    type Error = CsrfError;
108    type Future = Ready<Result<Self, Self::Error>>;
109
110    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
111        ready(Self::from_request_sync(req))
112    }
113}
114
115impl AsRef<str> for CsrfCookie {
116    fn as_ref(&self) -> &str {
117        self.0.as_ref()
118    }
119}
120
121/// Configuration struct for [`CsrfCookie`].
122#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
123pub struct CsrfCookieConfig {
124    cookie_name: String,
125}
126
127impl Default for CsrfCookieConfig {
128    fn default() -> Self {
129        Self {
130            cookie_name: DEFAULT_CSRF_COOKIE_NAME.to_string(),
131        }
132    }
133}
134
135impl CsrfCookieConfig {
136    /// Sets the cookie name. Consider using [`with_host_prefix`][1] or
137    /// [`with_secure_prefix`][2] if possible for increased security.
138    ///
139    /// [1]: Self::with_host_prefix
140    /// [2]: Self::with_secure_prefix
141    #[must_use]
142    pub const fn new(cookie_name: String) -> Self {
143        Self { cookie_name }
144    }
145
146    /// Sets the cookie name, prefixing it with `__Host-` if it wasn't already
147    /// prefixed. Note that this requires the cookie to be served with the
148    /// `secure` flag, must be set over HTTPS, must not have a domain specified,
149    /// and the path must be `/`.
150    #[must_use]
151    pub fn with_host_prefix(cookie_name: String) -> Self {
152        Self::with_prefix(host_prefix!(), cookie_name)
153    }
154
155    /// Sets the cookie name, prefixing it with `__Secure-` if it wasn't already
156    /// prefixed. Note that this requires the cookie to be served with the
157    /// `secure` flag.
158    #[must_use]
159    pub fn with_secure_prefix(cookie_name: String) -> Self {
160        Self::with_prefix(secure_prefix!(), cookie_name)
161    }
162
163    fn with_prefix(prefix: &'static str, cookie_name: String) -> Self {
164        if cookie_name.starts_with(prefix) {
165            Self { cookie_name }
166        } else {
167            Self {
168                cookie_name: format!("{}{}", prefix, cookie_name),
169            }
170        }
171    }
172}
173
174/// Extractor to get the CSRF token that will be set as a cookie.
175#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
176pub struct CsrfToken(pub(crate) String);
177
178impl Serialize for CsrfToken {
179    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
180    where
181        S: serde::Serializer,
182    {
183        serializer.serialize_newtype_struct("Csrf Token", &self.0)
184    }
185}
186
187impl<'de> Deserialize<'de> for CsrfToken {
188    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
189    where
190        D: serde::Deserializer<'de>,
191    {
192        struct CsrfTokenVisitor;
193        impl<'de> Visitor<'de> for CsrfTokenVisitor {
194            type Value = CsrfToken;
195
196            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
197                formatter.write_str("a valid csrf token")
198            }
199
200            fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
201            where
202                E: Error,
203            {
204                Ok(CsrfToken(v.to_owned()))
205            }
206        }
207
208        deserializer.deserialize_string(CsrfTokenVisitor)
209    }
210}
211
212impl CsrfToken {
213    /// Used for testing purposes only. Using this without permission may cause
214    /// horribleness.
215    #[must_use]
216    #[doc(hidden)]
217    pub const fn test_create(value: String) -> Self {
218        Self(value)
219    }
220
221    /// Retrieves a reference of the csrf token.
222    #[must_use]
223    pub fn get(&self) -> &str {
224        self.0.as_ref()
225    }
226
227    /// Consumes the struct, returning the underlying string.
228    #[must_use]
229    #[allow(clippy::missing_const_for_fn)] // false positive
230    pub fn into_inner(self) -> String {
231        self.0
232    }
233
234    fn from_request_sync(req: &HttpRequest) -> Result<Self, CsrfError> {
235        req.extensions()
236            .get::<Self>()
237            .cloned()
238            .ok_or(CsrfError::MissingToken)
239    }
240}
241
242impl AsRef<str> for CsrfToken {
243    fn as_ref(&self) -> &str {
244        self.0.as_ref()
245    }
246}
247
248impl FromRequest for CsrfToken {
249    type Error = CsrfError;
250    type Future = Ready<Result<Self, Self::Error>>;
251
252    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
253        ready(Self::from_request_sync(req))
254    }
255}
256
257/// This extractor wraps another extractor that returns some inner type that
258/// holds a CSRF token, and performs validation on the token. If the token is
259/// missing or invalid, then the extractor will return an error.
260///
261/// ```
262/// use actix_csrf::extractor::{Csrf, CsrfGuarded, CsrfToken};
263/// use actix_web::{post, Responder};
264/// use actix_web::web::Form;
265/// use serde::Deserialize;
266///
267/// #[derive(Deserialize)]
268/// struct Login {
269///    csrf: CsrfToken,
270///    email: String,
271///    password: String,
272/// }
273///
274/// impl CsrfGuarded for Login {
275///     fn csrf_token(&self) -> &CsrfToken {
276///         &self.csrf
277///     }
278/// }
279///
280/// #[post("/login")]
281/// async fn login(form: Csrf<Form<Login>>) -> impl Responder {
282///    // If we got here, then the CSRF token passed validation!
283///    format!("hello, {}, your password is {}", &form.email, &form.password)
284/// }
285/// ```
286#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
287pub struct Csrf<Inner>(Inner);
288
289impl<Inner> Csrf<Inner> {
290    /// Deconstruct to an inner value
291    #[must_use]
292    pub fn into_inner(self) -> Inner {
293        self.0
294    }
295}
296
297impl<Inner> Deref for Csrf<Inner> {
298    type Target = Inner;
299
300    fn deref(&self) -> &Self::Target {
301        &self.0
302    }
303}
304
305impl<Inner> DerefMut for Csrf<Inner> {
306    fn deref_mut(&mut self) -> &mut Self::Target {
307        &mut self.0
308    }
309}
310
311impl<Inner> FromRequest for Csrf<Inner>
312where
313    Inner: FromRequest + CsrfGuarded,
314{
315    type Error = CsrfExtractorError<Inner::Error>;
316    type Future = CsrfExtractorFuture<Inner::Future>;
317
318    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
319        CsrfExtractorFuture {
320            csrf_token: CsrfCookie::from_request_sync(req),
321            inner: Box::pin(Inner::from_request(req, payload)),
322        }
323    }
324}
325
326macro_rules! derive_csrf_guarded {
327    ($type:path) => {
328        impl<T> CsrfGuarded for $type
329        where
330            T: CsrfGuarded,
331        {
332            fn csrf_token(&self) -> &CsrfToken {
333                self.0.csrf_token()
334            }
335        }
336    };
337}
338
339derive_csrf_guarded!(actix_web::web::Form<T>);
340derive_csrf_guarded!(actix_web::web::Json<T>);
341
342/// Polls the underlying future, returning the underlying result if and only if
343/// the CSRF token is valid. This is an implementation detail of [`Csrf`], and
344/// cannot be constructed normally.
345pub struct CsrfExtractorFuture<Fut> {
346    csrf_token: Result<CsrfCookie, CsrfError>,
347    inner: Pin<Box<Fut>>,
348}
349
350impl<Fut, FutOut, FutErr> Future for CsrfExtractorFuture<Fut>
351where
352    Fut: Future<Output = Result<FutOut, FutErr>>,
353    FutOut: CsrfGuarded,
354{
355    type Output = Result<Csrf<FutOut>, CsrfExtractorError<FutErr>>;
356
357    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358        match self.inner.as_mut().poll(cx) {
359            Poll::Ready(Ok(out)) => {
360                if let Ok(ref token) = self.csrf_token {
361                    if out.csrf_token().as_ref() == token.as_ref() {
362                        return Poll::Ready(Ok(Csrf(out)));
363                    }
364                }
365
366                Poll::Ready(Err(CsrfExtractorError::InvalidToken))
367            }
368            Poll::Ready(Err(e)) => Poll::Ready(Err(CsrfExtractorError::Inner(e))),
369            Poll::Pending => Poll::Pending,
370        }
371    }
372}
373
374/// This trait represents types who have a field that represents a CSRF token.
375///
376/// This trait is required on an underlying type for the [`Csrf`] extractor to
377/// correctly function.
378pub trait CsrfGuarded {
379    /// Retrieves the CSRF token from the struct.
380    fn csrf_token(&self) -> &CsrfToken;
381}
382
383/// Represents an error that occurs when polling [`CsrfExtractorFuture`].
384#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
385pub enum CsrfExtractorError<Inner> {
386    /// A CSRF token was not found, or was invalid.
387    InvalidToken,
388    /// An underlying error occurred.
389    Inner(Inner),
390}
391
392impl<Inner> From<CsrfExtractorError<Inner>> for actix_web::error::Error
393where
394    Inner: Into<Self>,
395{
396    fn from(e: CsrfExtractorError<Inner>) -> Self {
397        match e {
398            CsrfExtractorError::InvalidToken => CsrfError::TokenMismatch.into(),
399            CsrfExtractorError::Inner(e) => e.into(),
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use std::error::Error;
407
408    use crate::DEFAULT_CSRF_COOKIE_NAME;
409
410    use super::*;
411
412    use actix_web::http::header;
413    use actix_web::test::TestRequest;
414
415    #[tokio::test]
416    async fn extract_from_header() -> Result<(), Box<dyn Error>> {
417        let req = TestRequest::default()
418            .insert_header((DEFAULT_CSRF_TOKEN_NAME, "sometoken"))
419            .to_http_request();
420        let token = CsrfHeader::extract(&req).await?;
421        assert!(token.validate("sometoken"));
422
423        Ok(())
424    }
425
426    #[tokio::test]
427    async fn not_found_header() {
428        let req = TestRequest::default()
429            .insert_header(("fake", "sometoken"))
430            .to_http_request();
431        let token = CsrfHeader::extract(&req).await;
432        assert!(token.is_err());
433    }
434
435    #[tokio::test]
436    async fn extract_from_cookie() -> Result<(), Box<dyn Error>> {
437        let req = TestRequest::default()
438            .insert_header((
439                header::COOKIE,
440                format!("{DEFAULT_CSRF_COOKIE_NAME}=sometoken"),
441            ))
442            .to_http_request();
443
444        let token = CsrfCookie::extract(&req).await?;
445        assert!(token.validate("sometoken"));
446        Ok(())
447    }
448
449    #[tokio::test]
450    async fn not_found_cookie() {
451        let req = TestRequest::default()
452            .insert_header(("fake", "sometoken"))
453            .to_http_request();
454        let token = CsrfCookie::extract(&req).await;
455        assert!(token.is_err());
456    }
457}