chimes_auth/
middleware.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3use std::task::{Context, Poll};
4use actix_web::{Error, error, HttpResponse, web, HttpMessage};
5use actix_web::body::{MessageBody, EitherBody};
6use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
7use actix_web::http::header::{HeaderValue};
8use futures_core::future::{LocalBoxFuture};
9use serde::de::DeserializeOwned;
10
11#[cfg(target_feature="session")]
12use actix_session::{Session, SessionExt, {storage::SessionStore}};
13
14
15use crate::{ChimesAuthUser, ApiResult};
16use crate::ChimesAuthService;
17
18// The custom ChimesAuthorization for auth
19pub struct ChimesAuthorization<T, P> 
20where
21    T: Clone + Sized + ChimesAuthUser<T> + DeserializeOwned,
22    P: ChimesAuthService<T>
23{
24    #[allow(unused)]
25    auth_info: Option<T>,
26    auth_service: Rc<P>,
27    allow_urls: Rc<Vec<String>>,
28    header_key: Option<String>,
29    nojwt_header_key: Option<String>,
30    #[cfg(target_feature="session")]
31    session_key: Option<String>,
32}
33
34impl <T, P> ChimesAuthorization<T, P> 
35where
36    T: Clone + Sized + ChimesAuthUser<T> + DeserializeOwned,
37    P: ChimesAuthService<T>
38{
39    pub fn new(auth_service: P) -> Self {
40        Self{
41            auth_info: None,
42            auth_service: Rc::new(auth_service),
43            allow_urls: Rc::new(vec![]),
44            header_key: None,
45            nojwt_header_key: None,
46            #[cfg(target_feature="session")]
47            session_key: None,
48        }
49    }
50
51    pub fn allow(mut self, url: &String) -> Self {
52        Rc::get_mut(&mut self.allow_urls)
53                .unwrap()
54                .push(url.to_string());
55        
56        self
57    }
58
59    pub fn header_key(mut self, new_key: &String) -> Self {
60        self.header_key = Some(new_key.to_string());
61        self
62    }
63
64    pub fn nojwt_header_key(mut self, new_key: &String) -> Self {
65        self.nojwt_header_key = Some(new_key.to_string());
66        self
67    }    
68
69    #[cfg(target_feature="session")]
70    pub fn session_key(mut self, new_key: &String) -> Self {
71        self.session_key = Some(new_key.to_string());
72        self
73    }
74
75}
76
77impl<S, B, T, P> Transform<S, ServiceRequest> for ChimesAuthorization<T, P>
78    where
79        S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
80        S::Future: 'static,
81        B: MessageBody + 'static,
82        T: Clone + Sized + ChimesAuthUser<T> + DeserializeOwned + 'static,
83        P: Sized + ChimesAuthService<T>  + 'static,
84{
85    type Response = ServiceResponse<EitherBody<B>>;
86    type Error = Error;
87    type InitError = ();
88    type Transform = ChimesAuthenticationMiddleware<S, T, P>;
89    type Future = actix_utils::future::Ready<Result<Self::Transform, Self::InitError>>;
90
91    fn new_transform(&self, service: S) -> Self::Future {
92        actix_utils::future::ok(ChimesAuthenticationMiddleware {
93            auth_info: None,
94            service: Rc::new(RefCell::new(service)),
95            auth_service: self.auth_service.clone(),
96            allow_urls: self.allow_urls.clone(),
97            header_key: self.header_key.clone(),
98            nojwt_header_key: self.nojwt_header_key.clone(),
99            #[cfg(target_feature="session")]
100            session_key: self.session_key.clone()
101        })
102    }
103}
104
105pub struct ChimesAuthenticationMiddleware<S, T, P> {
106    #[allow(unused)]
107    auth_info: Option<T>,
108    service: Rc<RefCell<S>>,
109    auth_service: Rc<P>,
110    allow_urls: Rc<Vec<String>>,
111    header_key: Option<String>,
112    nojwt_header_key: Option<String>,
113    #[cfg(target_feature="session")]
114    session_key: Option<String>,
115}
116
117impl<S, T, P, B> Service<ServiceRequest> for ChimesAuthenticationMiddleware<S, T, P>
118    where
119        S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
120        S::Future: 'static,
121        B: MessageBody + 'static,
122        T: Clone + Sized + ChimesAuthUser<T> + DeserializeOwned + 'static,
123        P: Sized + ChimesAuthService<T>  + 'static,
124{
125    // type Response = ServiceResponse<B>;
126    type Response = ServiceResponse<EitherBody<B>>;
127    type Error = Error;
128    // type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
129    type Future = LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>;
130
131    fn poll_ready(self: &ChimesAuthenticationMiddleware<S, T, P>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132        self.service.poll_ready(cx)
133    }
134
135    fn call(&self, req: ServiceRequest) -> Self::Future {
136        let service = self.service.clone();
137        let auth = self.auth_service.clone();
138        let url_pattern = req.match_pattern().unwrap_or_default();
139        let passed_url = self.allow_urls.contains(&url_pattern);
140
141        #[cfg(target_feature="session")]
142        let session = req.get_session();
143        #[cfg(target_feature="session")]
144        let default_session_key = "chimes-logged-user".to_string();
145        #[cfg(target_feature="session")]
146        let auth_user = session.get::<T>(self.session_key.clone().unwrap_or(default_session_key).as_str()).unwrap();
147        
148
149        let header_key = self.header_key.clone().unwrap_or("Authentication".to_string());
150        let nojwt_header_key = self.nojwt_header_key.clone();
151
152        Box::pin(async move {
153            let value = HeaderValue::from_str("").unwrap();
154            let token = req.headers().get(header_key.as_str()).unwrap_or(&value);
155            let nojwt_token = if nojwt_header_key.is_some() {
156                req.headers().get(nojwt_header_key.unwrap().as_str())
157            } else {
158                None
159            };
160
161            let req_method = req.method().to_string();
162            
163            if passed_url {
164                Ok(service.call(req).await?.map_into_left_body())
165            } else {
166                #[cfg(not(target_feature= "session"))]
167                let ust = if nojwt_token.is_some() {
168                    match nojwt_token.unwrap_or(&value).to_str() {
169                        Ok(st) => {
170                            let us = auth.nojwt_authenticate(&st.to_string()).await;
171                            us
172                        }
173                        Err(_) => {
174                            None
175                        }
176                    }
177                } else {
178                    match token.to_str() {
179                        Ok(st) => {
180                            let us = auth.authenticate(&st.to_string()).await;
181                            us
182                        }
183                        Err(_) => {
184                            None
185                        }
186                    }
187                };
188
189                #[cfg(target_feature= "session")]
190                let ust = auth_user;
191
192                let permitted = auth.permit(&ust, &req_method, &url_pattern).await;
193
194                if permitted.is_some() {
195                    if ust.is_some() {
196                        req.extensions_mut().insert(ust.unwrap().clone());
197                    }
198                    let res = service.call(req).await?;
199                    Ok(res.map_into_left_body())
200                } else {
201                    if ust.is_none() {
202                        
203                        let err = actix_web::error::ErrorUnauthorized("Not-Authorized");
204
205                        let errresp = req.error_response(err);
206                        let wbj: web::Json<ApiResult<String>> = web::Json(ApiResult::error(401, &"Not-Authorized".to_string()));
207                        let hrp = HttpResponse::Unauthorized().json(wbj).map_into_boxed_body();
208                        
209                        let m = ServiceResponse::new(
210                            errresp.request().clone(),
211                            hrp,
212                        );
213                        Ok(m.map_into_right_body())
214                    } else {
215                        let err = actix_web::error::ErrorForbidden("Forbidden");
216
217                        let errresp = req.error_response(err);
218                        let wbj: web::Json<ApiResult<String>> = web::Json(ApiResult::error(403, &"Forbidden".to_string()));
219                        let hrp = HttpResponse::Forbidden().json(wbj).map_into_boxed_body();
220                        
221                        let m = ServiceResponse::new(
222                            errresp.request().clone(),
223                            hrp,
224                        );
225                        Ok(m.map_into_right_body())
226                    }
227                }
228            }
229        })
230
231    }
232}