1use std::{
2 fmt::Debug,
3 future::{ready, Ready},
4 rc::Rc,
5};
6
7#[cfg(feature = "csrf")]
8use actix_web::HttpMessage;
9use actix_web::{
10 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
11 web::ServiceConfig,
12 Route,
13};
14use anyhow::Result;
15use async_trait::async_trait;
16use futures::future::LocalBoxFuture;
17
18#[cfg(feature = "csrf")]
19pub fn build_router<F, Fut>(
20 router: Vec<Router>,
21 csrf: crate::csrf::Middleware<F>,
22) -> impl FnOnce(&mut ServiceConfig)
23where
24 F: Fn(actix_web::HttpRequest, String) -> Fut + 'static,
25 Fut: futures::Future<Output = Result<bool, actix_web::Error>>,
26{
27 move |cfg| {
28 for i in router {
29 if !i.path.is_empty() {
30 cfg.route(
31 &i.path,
32 i.route.wrap(csrf.clone()).wrap(RouterGuard {
33 checker: i.checker,
34 csrf: i.csrf,
35 }),
36 );
37 }
38 }
39 }
40}
41
42#[cfg(not(feature = "csrf"))]
43pub fn build_router(router: Vec<Router>) -> impl FnOnce(&mut ServiceConfig) {
44 |cfg| {
45 for i in router {
46 if !i.path.is_empty() {
47 cfg.route(&i.path, i.route.wrap(RouterGuard { checker: i.checker }));
48 }
49 }
50 }
51}
52
53#[async_trait(?Send)]
54pub trait Checker {
55 async fn check(&self, req: &mut ServiceRequest) -> Result<bool>;
56}
57
58#[cfg(feature = "csrf")]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[derive(Clone, Copy, enum_as_inner::EnumAsInner)]
61pub enum CSRFType {
62 Header,
63 Param,
64 ForceHeader,
65 ForceParam,
66 Disabled,
67}
68
69pub struct Router {
70 pub path: String,
71 pub route: Route,
72 pub checker: Option<Rc<dyn Checker>>,
73 #[cfg(feature = "csrf")]
74 pub csrf: CSRFType,
75}
76
77pub(crate) struct RouterGuard {
78 checker: Option<Rc<dyn Checker>>,
79 #[cfg(feature = "csrf")]
80 csrf: CSRFType,
81}
82
83impl<S, B> Transform<S, ServiceRequest> for RouterGuard
84where
85 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
86 S::Future: 'static,
87 B: 'static + Debug,
88{
89 type Response = ServiceResponse<B>;
90 type Error = actix_web::Error;
91 type InitError = ();
92 type Transform = RouterGuardMiddleware<S>;
93 type Future = Ready<Result<Self::Transform, Self::InitError>>;
94
95 fn new_transform(&self, service: S) -> Self::Future {
96 ready(Ok(RouterGuardMiddleware {
97 service: Rc::new(service),
98 checker: self.checker.clone(),
99 #[cfg(feature = "csrf")]
100 csrf: self.csrf,
101 }))
102 }
103}
104
105pub(crate) struct RouterGuardMiddleware<S> {
106 service: Rc<S>,
107 checker: Option<Rc<dyn Checker>>,
108 #[cfg(feature = "csrf")]
109 csrf: CSRFType,
110}
111
112impl<S, B> Service<ServiceRequest> for RouterGuardMiddleware<S>
113where
114 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
115 S::Future: 'static,
116 B: 'static + Debug,
117{
118 type Response = ServiceResponse<B>;
119 type Error = actix_web::Error;
120 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
121
122 forward_ready!(service);
123
124 fn call(&self, mut req: ServiceRequest) -> Self::Future {
125 let srv = self.service.clone();
126 let checker = self.checker.clone();
127 #[cfg(feature = "csrf")]
128 req.extensions_mut().insert(self.csrf);
129 Box::pin(async move {
130 if let Some(checker) = checker {
131 match checker.check(&mut req).await {
132 Ok(ok) => {
133 if ok {
134 srv.call(req).await
135 } else {
136 Err(actix_web::error::ErrorForbidden("Checker failed"))
137 }
138 }
139 Err(e) => Err(actix_web::error::ErrorInternalServerError(e)),
140 }
141 } else {
142 srv.call(req).await
143 }
144 })
145 }
146}