1use std::{future::Future, rc::Rc};
2
3use actix_web::{
4 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
5 HttpMessage, HttpRequest,
6};
7use futures::future::{ready, LocalBoxFuture, Ready};
8use qstring::QString;
9
10use crate::router::CSRFType;
11
12pub struct Middleware<F> {
13 cookie: Rc<String>,
14 header: Rc<String>,
15 checker: Rc<F>,
16}
17
18impl<F> Clone for Middleware<F> {
19 fn clone(&self) -> Self {
20 Self {
21 cookie: self.cookie.clone(),
22 header: self.header.clone(),
23 checker: self.checker.clone(),
24 }
25 }
26}
27
28impl<F, Fut> Middleware<F>
29where
30 F: Fn(HttpRequest, String) -> Fut,
31 Fut: Future<Output = Result<bool, actix_web::Error>>,
32{
33 pub fn new(cookie: String, header: String, checker: F) -> Self {
34 Self {
35 cookie: Rc::new(cookie),
36 header: Rc::new(header),
37 checker: Rc::new(checker),
38 }
39 }
40}
41
42impl<S, B, F, Fut> Transform<S, ServiceRequest> for Middleware<F>
43where
44 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
45 S::Future: 'static,
46 B: 'static,
47 F: Fn(HttpRequest, String) -> Fut + 'static,
48 Fut: Future<Output = Result<bool, actix_web::Error>>,
49{
50 type Response = ServiceResponse<B>;
51 type Error = actix_web::Error;
52 type InitError = ();
53 type Transform = MiddlewareService<S, F>;
54 type Future = Ready<Result<Self::Transform, Self::InitError>>;
55
56 fn new_transform(&self, service: S) -> Self::Future {
57 ready(Ok(MiddlewareService {
58 service: Rc::new(service),
59 cookie: self.cookie.clone(),
60 header: self.header.clone(),
61 checker: self.checker.clone(),
62 }))
63 }
64}
65
66pub struct MiddlewareService<S, F> {
67 service: Rc<S>,
68 cookie: Rc<String>,
69 header: Rc<String>,
70 checker: Rc<F>,
71}
72
73impl<S, B, F, Fut> MiddlewareService<S, F>
74where
75 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
76 S::Future: 'static,
77 B: 'static,
78 F: Fn(HttpRequest, String) -> Fut + 'static,
79 Fut: Future<Output = Result<bool, actix_web::Error>>,
80{
81 fn get_safe_header(req: &ServiceRequest, name: &str) -> Option<String> {
82 let mut ret: Vec<&str> = req
83 .headers()
84 .get_all(name)
85 .map(|x| x.to_str().unwrap())
86 .collect();
87 if ret.len() != 1 {
88 return None;
89 }
90 ret.pop().map(ToOwned::to_owned)
91 }
92
93 async fn check_csrf(
94 req: &ServiceRequest,
95 cookie: &str,
96 header: &str,
97 checker: Rc<F>,
98 allow_param: bool,
99 ) -> Result<bool, actix_web::Error> {
100 let Some(cookie) = req.cookie(cookie) else {
101 return Ok(false);
102 };
103 let mut csrf = Self::get_safe_header(req, header);
104 if csrf.is_none() && allow_param {
105 let qs = QString::from(req.query_string());
106 csrf = qs.get(header).map(ToOwned::to_owned);
107 }
108 let Some(csrf) = csrf else {
109 return Ok(false);
110 };
111 if csrf != cookie.value() {
112 return Ok(false);
113 }
114 checker(req.request().clone(), csrf).await
115 }
116}
117
118impl<S, B, F, Fut> Service<ServiceRequest> for MiddlewareService<S, F>
119where
120 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
121 S::Future: 'static,
122 B: 'static,
123 F: Fn(HttpRequest, String) -> Fut + 'static,
124 Fut: Future<Output = Result<bool, actix_web::Error>>,
125{
126 type Response = ServiceResponse<B>;
127 type Error = actix_web::Error;
128 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
129
130 forward_ready!(service);
131
132 fn call(&self, req: ServiceRequest) -> Self::Future {
133 let srv = self.service.clone();
134 let header = self.header.clone();
135 let cookie = self.cookie.clone();
136 let checker = self.checker.clone();
137 Box::pin(async move {
138 let csrf = req.extensions().get::<CSRFType>().unwrap().to_owned();
139 if csrf.is_force_header() || csrf.is_force_param() || !req.method().is_safe() {
140 let ret = match csrf {
141 CSRFType::Header => {
142 Self::check_csrf(&req, &cookie, &header, checker, false).await
143 }
144 CSRFType::Param => {
145 Self::check_csrf(&req, &cookie, &header, checker, true).await
146 }
147 CSRFType::ForceHeader => {
148 Self::check_csrf(&req, &cookie, &header, checker, false).await
149 }
150 CSRFType::ForceParam => {
151 Self::check_csrf(&req, &cookie, &header, checker, true).await
152 }
153 CSRFType::Disabled => Ok(true),
154 }?;
155 if !ret {
156 return Err(actix_web::error::ErrorBadRequest("CSRF check failed"));
157 }
158 }
159 srv.call(req).await
160 })
161 }
162}