actix_cloud/
csrf.rs

1use std::{future::Future, rc::Rc};
2
3use actix_web::{
4    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
5    HttpMessage, HttpRequest,
6};
7use futures::future::{ready, LocalBoxFuture, Ready};
8use qstring::QString;
9
10use crate::router::CSRFType;
11
12pub struct Middleware<F> {
13    cookie: Rc<String>,
14    header: Rc<String>,
15    checker: Rc<F>,
16}
17
18impl<F> Clone for Middleware<F> {
19    fn clone(&self) -> Self {
20        Self {
21            cookie: self.cookie.clone(),
22            header: self.header.clone(),
23            checker: self.checker.clone(),
24        }
25    }
26}
27
28impl<F, Fut> Middleware<F>
29where
30    F: Fn(HttpRequest, String) -> Fut,
31    Fut: Future<Output = Result<bool, actix_web::Error>>,
32{
33    pub fn new(cookie: String, header: String, checker: F) -> Self {
34        Self {
35            cookie: Rc::new(cookie),
36            header: Rc::new(header),
37            checker: Rc::new(checker),
38        }
39    }
40}
41
42impl<S, B, F, Fut> Transform<S, ServiceRequest> for Middleware<F>
43where
44    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
45    S::Future: 'static,
46    B: 'static,
47    F: Fn(HttpRequest, String) -> Fut + 'static,
48    Fut: Future<Output = Result<bool, actix_web::Error>>,
49{
50    type Response = ServiceResponse<B>;
51    type Error = actix_web::Error;
52    type InitError = ();
53    type Transform = MiddlewareService<S, F>;
54    type Future = Ready<Result<Self::Transform, Self::InitError>>;
55
56    fn new_transform(&self, service: S) -> Self::Future {
57        ready(Ok(MiddlewareService {
58            service: Rc::new(service),
59            cookie: self.cookie.clone(),
60            header: self.header.clone(),
61            checker: self.checker.clone(),
62        }))
63    }
64}
65
66pub struct MiddlewareService<S, F> {
67    service: Rc<S>,
68    cookie: Rc<String>,
69    header: Rc<String>,
70    checker: Rc<F>,
71}
72
73impl<S, B, F, Fut> MiddlewareService<S, F>
74where
75    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
76    S::Future: 'static,
77    B: 'static,
78    F: Fn(HttpRequest, String) -> Fut + 'static,
79    Fut: Future<Output = Result<bool, actix_web::Error>>,
80{
81    fn get_safe_header(req: &ServiceRequest, name: &str) -> Option<String> {
82        let mut ret: Vec<&str> = req
83            .headers()
84            .get_all(name)
85            .map(|x| x.to_str().unwrap())
86            .collect();
87        if ret.len() != 1 {
88            return None;
89        }
90        ret.pop().map(ToOwned::to_owned)
91    }
92
93    async fn check_csrf(
94        req: &ServiceRequest,
95        cookie: &str,
96        header: &str,
97        checker: Rc<F>,
98        allow_param: bool,
99    ) -> Result<bool, actix_web::Error> {
100        let Some(cookie) = req.cookie(cookie) else {
101            return Ok(false);
102        };
103        let mut csrf = Self::get_safe_header(req, header);
104        if csrf.is_none() && allow_param {
105            let qs = QString::from(req.query_string());
106            csrf = qs.get(header).map(ToOwned::to_owned);
107        }
108        let Some(csrf) = csrf else {
109            return Ok(false);
110        };
111        if csrf != cookie.value() {
112            return Ok(false);
113        }
114        checker(req.request().clone(), csrf).await
115    }
116}
117
118impl<S, B, F, Fut> Service<ServiceRequest> for MiddlewareService<S, F>
119where
120    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
121    S::Future: 'static,
122    B: 'static,
123    F: Fn(HttpRequest, String) -> Fut + 'static,
124    Fut: Future<Output = Result<bool, actix_web::Error>>,
125{
126    type Response = ServiceResponse<B>;
127    type Error = actix_web::Error;
128    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
129
130    forward_ready!(service);
131
132    fn call(&self, req: ServiceRequest) -> Self::Future {
133        let srv = self.service.clone();
134        let header = self.header.clone();
135        let cookie = self.cookie.clone();
136        let checker = self.checker.clone();
137        Box::pin(async move {
138            let csrf = req.extensions().get::<CSRFType>().unwrap().to_owned();
139            if csrf.is_force_header() || csrf.is_force_param() || !req.method().is_safe() {
140                let ret = match csrf {
141                    CSRFType::Header => {
142                        Self::check_csrf(&req, &cookie, &header, checker, false).await
143                    }
144                    CSRFType::Param => {
145                        Self::check_csrf(&req, &cookie, &header, checker, true).await
146                    }
147                    CSRFType::ForceHeader => {
148                        Self::check_csrf(&req, &cookie, &header, checker, false).await
149                    }
150                    CSRFType::ForceParam => {
151                        Self::check_csrf(&req, &cookie, &header, checker, true).await
152                    }
153                    CSRFType::Disabled => Ok(true),
154                }?;
155                if !ret {
156                    return Err(actix_web::error::ErrorBadRequest("CSRF check failed"));
157                }
158            }
159            srv.call(req).await
160        })
161    }
162}