shield_tower/
service.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use http::{Request, Response};
8use shield::{Session, Shield, User};
9use tower_service::Service;
10use tracing::debug;
11
12use crate::session::TowerSessionStorage;
13
14#[derive(Clone)]
15pub struct ShieldService<S, U: User> {
16    inner: S,
17    shield: Shield<U>,
18    session_key: &'static str,
19}
20
21impl<S, U: User> ShieldService<S, U> {
22    pub fn new(inner: S, shield: Shield<U>, session_key: &'static str) -> Self {
23        Self {
24            inner,
25            shield,
26            session_key,
27        }
28    }
29
30    fn internal_server_error<ResBody: Default>() -> Response<ResBody> {
31        let mut response = Response::default();
32        *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
33        response
34    }
35}
36
37impl<S, U: User + Clone + 'static, ReqBody, ResBody> Service<Request<ReqBody>>
38    for ShieldService<S, U>
39where
40    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
41    S::Future: Send + 'static,
42    ReqBody: Send + 'static,
43    ResBody: Default + Send,
44{
45    type Response = S::Response;
46    type Error = S::Error;
47    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
48
49    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50        self.inner.poll_ready(cx)
51    }
52
53    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
54        // TODO: Improve error handling to not only return a 500 response.
55
56        //  https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
57        let clone = self.inner.clone();
58        let mut inner = std::mem::replace(&mut self.inner, clone);
59
60        let shield = self.shield.clone();
61        let session_key = self.session_key;
62
63        Box::pin(async move {
64            let session = match req.extensions().get::<tower_sessions::Session>() {
65                Some(session) => session,
66                None => {
67                    return Ok(Self::internal_server_error());
68                }
69            };
70
71            let session_storage =
72                match TowerSessionStorage::load(session.clone(), session_key).await {
73                    Ok(session_storage) => session_storage,
74                    Err(_err) => return Ok(Self::internal_server_error()),
75                };
76            let shield_session = Session::new(session_storage);
77
78            let authenticated = match shield_session.data().lock() {
79                Ok(session) => session.authentication.clone(),
80                Err(_err) => return Ok(Self::internal_server_error()),
81            };
82
83            let user = if let Some(authenticated) = authenticated {
84                // TODO: Verify provider and subprovider still exist.
85
86                match shield.storage().user_by_id(&authenticated.user_id).await {
87                    Ok(user) => {
88                        if user.is_none() {
89                            if let Err(_err) = shield_session.purge().await {
90                                return Ok(Self::internal_server_error());
91                            }
92                        }
93
94                        user
95                    }
96                    Err(_err) => return Ok(Self::internal_server_error()),
97                }
98            } else {
99                None
100            };
101
102            debug!("{:?}", user.as_ref().map(|user| user.id()));
103
104            req.extensions_mut().insert(shield);
105            req.extensions_mut().insert(shield_session);
106            req.extensions_mut().insert(user);
107
108            inner.call(req).await
109        })
110    }
111}