1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use http::{Request, Response};
8use shield::{Session, Shield, User};
9use tower_service::Service;
10
11use crate::session::TowerSessionStorage;
12
13#[derive(Clone)]
14pub struct ShieldService<S, U: User> {
15 inner: S,
16 shield: Shield<U>,
17 session_key: &'static str,
18}
19
20impl<S, U: User> ShieldService<S, U> {
21 pub fn new(inner: S, shield: Shield<U>, session_key: &'static str) -> Self {
22 Self {
23 inner,
24 shield,
25 session_key,
26 }
27 }
28
29 fn internal_server_error<ResBody: Default>() -> Response<ResBody> {
30 let mut response = Response::default();
31 *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
32 response
33 }
34}
35
36impl<S, U: User + Clone + 'static, ReqBody, ResBody> Service<Request<ReqBody>>
37 for ShieldService<S, U>
38where
39 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
40 S::Future: Send + 'static,
41 ReqBody: Send + 'static,
42 ResBody: Default + Send,
43{
44 type Response = S::Response;
45 type Error = S::Error;
46 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
47
48 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
49 self.inner.poll_ready(cx)
50 }
51
52 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
53 let clone = self.inner.clone();
57 let mut inner = std::mem::replace(&mut self.inner, clone);
58
59 let shield = self.shield.clone();
60 let session_key = self.session_key;
61
62 Box::pin(async move {
63 let session = match req.extensions().get::<tower_sessions::Session>() {
64 Some(session) => session,
65 None => {
66 return Ok(Self::internal_server_error());
67 }
68 };
69
70 let session_storage =
71 match TowerSessionStorage::load(session.clone(), session_key).await {
72 Ok(session_storage) => session_storage,
73 Err(_err) => return Ok(Self::internal_server_error()),
74 };
75 let shield_session = Session::new(session_storage);
76
77 req.extensions_mut().insert(shield);
83 req.extensions_mut().insert(shield_session);
84 inner.call(req).await
87 })
88 }
89}