shield_tower/
service.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use http::{Request, Response};
8use shield::{Session, Shield, User};
9use tower_service::Service;
10
11use crate::session::TowerSessionStorage;
12
13#[derive(Clone)]
14pub struct ShieldService<S, U: User> {
15    inner: S,
16    shield: Shield<U>,
17    session_key: &'static str,
18}
19
20impl<S, U: User> ShieldService<S, U> {
21    pub fn new(inner: S, shield: Shield<U>, session_key: &'static str) -> Self {
22        Self {
23            inner,
24            shield,
25            session_key,
26        }
27    }
28
29    fn internal_server_error<ResBody: Default>() -> Response<ResBody> {
30        let mut response = Response::default();
31        *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
32        response
33    }
34}
35
36impl<S, U: User + Clone + 'static, ReqBody, ResBody> Service<Request<ReqBody>>
37    for ShieldService<S, U>
38where
39    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
40    S::Future: Send + 'static,
41    ReqBody: Send + 'static,
42    ResBody: Default + Send,
43{
44    type Response = S::Response;
45    type Error = S::Error;
46    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
47
48    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
49        self.inner.poll_ready(cx)
50    }
51
52    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
53        // TODO: Improve error handling to not only return a 500 response.
54
55        //  https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
56        let clone = self.inner.clone();
57        let mut inner = std::mem::replace(&mut self.inner, clone);
58
59        let shield = self.shield.clone();
60        let session_key = self.session_key;
61
62        Box::pin(async move {
63            let session = match req.extensions().get::<tower_sessions::Session>() {
64                Some(session) => session,
65                None => {
66                    return Ok(Self::internal_server_error());
67                }
68            };
69
70            let session_storage =
71                match TowerSessionStorage::load(session.clone(), session_key).await {
72                    Ok(session_storage) => session_storage,
73                    Err(_err) => return Ok(Self::internal_server_error()),
74                };
75            let shield_session = Session::new(session_storage);
76
77            // let user = match shield.user(&shield_session).await {
78            //     Ok(user) => user,
79            //     Err(_err) => return Ok(Self::internal_server_error()),
80            // };
81
82            req.extensions_mut().insert(shield);
83            req.extensions_mut().insert(shield_session);
84            // req.extensions_mut().insert(user);
85
86            inner.call(req).await
87        })
88    }
89}