actix_cloud/
router.rs

1use std::{
2    fmt::Debug,
3    future::{ready, Ready},
4    rc::Rc,
5};
6
7#[cfg(feature = "csrf")]
8use actix_web::HttpMessage;
9use actix_web::{
10    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
11    web::ServiceConfig,
12    Route,
13};
14use anyhow::Result;
15use async_trait::async_trait;
16use futures::future::LocalBoxFuture;
17
18#[cfg(feature = "csrf")]
19pub fn build_router<F, Fut>(
20    router: Vec<Router>,
21    csrf: crate::csrf::Middleware<F>,
22) -> impl FnOnce(&mut ServiceConfig)
23where
24    F: Fn(actix_web::HttpRequest, String) -> Fut + 'static,
25    Fut: futures::Future<Output = Result<bool, actix_web::Error>>,
26{
27    move |cfg| {
28        for i in router {
29            if !i.path.is_empty() {
30                cfg.route(
31                    &i.path,
32                    i.route.wrap(csrf.clone()).wrap(RouterGuard {
33                        checker: i.checker,
34                        csrf: i.csrf,
35                    }),
36                );
37            }
38        }
39    }
40}
41
42#[cfg(not(feature = "csrf"))]
43pub fn build_router(router: Vec<Router>) -> impl FnOnce(&mut ServiceConfig) {
44    |cfg| {
45        for i in router {
46            if !i.path.is_empty() {
47                cfg.route(&i.path, i.route.wrap(RouterGuard { checker: i.checker }));
48            }
49        }
50    }
51}
52
53#[async_trait(?Send)]
54pub trait Checker {
55    async fn check(&self, req: &mut ServiceRequest) -> Result<bool>;
56}
57
58#[cfg(feature = "csrf")]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[derive(Clone, Copy, enum_as_inner::EnumAsInner)]
61pub enum CSRFType {
62    Header,
63    Param,
64    ForceHeader,
65    ForceParam,
66    Disabled,
67}
68
69pub struct Router {
70    pub path: String,
71    pub route: Route,
72    pub checker: Option<Rc<dyn Checker>>,
73    #[cfg(feature = "csrf")]
74    pub csrf: CSRFType,
75}
76
77pub(crate) struct RouterGuard {
78    checker: Option<Rc<dyn Checker>>,
79    #[cfg(feature = "csrf")]
80    csrf: CSRFType,
81}
82
83impl<S, B> Transform<S, ServiceRequest> for RouterGuard
84where
85    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
86    S::Future: 'static,
87    B: 'static + Debug,
88{
89    type Response = ServiceResponse<B>;
90    type Error = actix_web::Error;
91    type InitError = ();
92    type Transform = RouterGuardMiddleware<S>;
93    type Future = Ready<Result<Self::Transform, Self::InitError>>;
94
95    fn new_transform(&self, service: S) -> Self::Future {
96        ready(Ok(RouterGuardMiddleware {
97            service: Rc::new(service),
98            checker: self.checker.clone(),
99            #[cfg(feature = "csrf")]
100            csrf: self.csrf,
101        }))
102    }
103}
104
105pub(crate) struct RouterGuardMiddleware<S> {
106    service: Rc<S>,
107    checker: Option<Rc<dyn Checker>>,
108    #[cfg(feature = "csrf")]
109    csrf: CSRFType,
110}
111
112impl<S, B> Service<ServiceRequest> for RouterGuardMiddleware<S>
113where
114    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
115    S::Future: 'static,
116    B: 'static + Debug,
117{
118    type Response = ServiceResponse<B>;
119    type Error = actix_web::Error;
120    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
121
122    forward_ready!(service);
123
124    fn call(&self, mut req: ServiceRequest) -> Self::Future {
125        let srv = self.service.clone();
126        let checker = self.checker.clone();
127        #[cfg(feature = "csrf")]
128        req.extensions_mut().insert(self.csrf);
129        Box::pin(async move {
130            if let Some(checker) = checker {
131                match checker.check(&mut req).await {
132                    Ok(ok) => {
133                        if ok {
134                            srv.call(req).await
135                        } else {
136                            Err(actix_web::error::ErrorForbidden("Checker failed"))
137                        }
138                    }
139                    Err(e) => Err(actix_web::error::ErrorInternalServerError(e)),
140                }
141            } else {
142                srv.call(req).await
143            }
144        })
145    }
146}