Skip to main content

axum_security/cookie/
service.rs

1use std::{
2    convert::Infallible,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use axum::{extract::Request, http::StatusCode, response::IntoResponse};
8use tower::{Layer, Service};
9
10use crate::cookie::CookieContext;
11
12pub struct CookieService<S, SERV> {
13    inner: CookieContext<S>,
14    rest: SERV,
15}
16
17impl<S, SERV> Service<Request> for CookieService<S, SERV>
18where
19    SERV: Service<Request, Error = Infallible> + Clone + Send + 'static,
20    <SERV as Service<Request>>::Response: IntoResponse,
21    <SERV as Service<Request>>::Future: Send,
22    S: Clone + Send + Sync + 'static,
23{
24    type Response = axum::response::Response;
25
26    type Error = Infallible;
27
28    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
29
30    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
31        self.rest.poll_ready(cx)
32    }
33
34    fn call(&mut self, mut req: Request) -> Self::Future {
35        let mut this = self.clone();
36        Box::pin(async move {
37            match this.inner.load_from_headers(req.headers()).await {
38                Ok(Some(session)) => {
39                    req.extensions_mut().insert(session);
40                }
41                Ok(None) => {}
42                Err(_) => return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
43            }
44
45            this.rest.call(req).await.map(|e| e.into_response())
46        })
47    }
48}
49
50impl<SERV, T> Layer<SERV> for CookieContext<T>
51where
52    T: 'static,
53{
54    type Service = CookieService<T, SERV>;
55
56    fn layer(&self, inner: SERV) -> Self::Service {
57        CookieService {
58            inner: self.clone(),
59            rest: inner,
60        }
61    }
62}
63
64impl<T, SERV> Clone for CookieService<T, SERV>
65where
66    SERV: Clone,
67{
68    fn clone(&self) -> Self {
69        Self {
70            inner: self.inner.clone(),
71            rest: self.rest.clone(),
72        }
73    }
74}