axum_login/
service.rs

1use std::{
2    fmt::Debug,
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use axum::http::{self, Request, Response};
9use tower_cookies::CookieManager;
10use tower_layer::Layer;
11use tower_service::Service;
12use tower_sessions::{
13    service::{CookieController, PlaintextCookie},
14    Session, SessionManager, SessionManagerLayer, SessionStore,
15};
16use tracing::Instrument;
17
18use crate::{AuthSession, AuthUser, AuthnBackend};
19
20/// A middleware that provides [`AuthSession`] as a request extension.
21#[derive(Debug, Clone)]
22pub struct AuthManager<S, Backend: AuthnBackend> {
23    inner: S,
24    backend: Backend,
25    data_key: &'static str,
26}
27
28impl<S, Backend: AuthnBackend> AuthManager<S, Backend> {
29    /// Create a new [`AuthManager`] with the provided access controller.
30    pub fn new(inner: S, backend: Backend, data_key: &'static str) -> Self {
31        Self {
32            inner,
33            backend,
34            data_key,
35        }
36    }
37}
38
39impl<ReqBody, ResBody, S, Backend> Service<Request<ReqBody>> for AuthManager<S, Backend>
40where
41    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
42    S::Future: Send + 'static,
43    ReqBody: Send + 'static,
44    ResBody: Default + Send,
45    Backend: AuthnBackend + 'static,
46{
47    type Response = S::Response;
48    type Error = S::Error;
49    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
50
51    #[inline]
52    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
53        self.inner.poll_ready(cx)
54    }
55
56    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
57        let span = tracing::info_span!("call", user.id = tracing::field::Empty);
58
59        let backend = self.backend.clone();
60        let data_key = self.data_key;
61
62        // Because the inner service can panic until ready, we need to ensure we only
63        // use the ready service.
64        //
65        // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
66        let clone = self.inner.clone();
67        let mut inner = std::mem::replace(&mut self.inner, clone);
68
69        Box::pin(
70            async move {
71                let Some(session) = req.extensions().get::<Session>().cloned() else {
72                    tracing::error!("session not found in request extensions");
73                    let mut res = Response::default();
74                    *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
75                    return Ok(res);
76                };
77
78                let auth_session = match AuthSession::from_session(session, backend, data_key).await
79                {
80                    Ok(auth_session) => auth_session,
81                    Err(err) => {
82                        tracing::error!(
83                            err = %err,
84                            "could not create auth session from session"
85                        );
86                        let mut res = Response::default();
87                        *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
88                        return Ok(res);
89                    }
90                };
91
92                if let Some(ref user) = auth_session.user {
93                    tracing::Span::current().record("user.id", user.id().to_string());
94                }
95
96                req.extensions_mut().insert(auth_session);
97
98                inner.call(req).await
99            }
100            .instrument(span),
101        )
102    }
103}
104
105/// A layer for providing [`AuthSession`] as a request extension.
106#[derive(Debug, Clone)]
107pub struct AuthManagerLayer<
108    Backend: AuthnBackend,
109    Sessions: SessionStore,
110    C: CookieController = PlaintextCookie,
111> {
112    backend: Backend,
113    session_manager_layer: SessionManagerLayer<Sessions, C>,
114    data_key: &'static str,
115}
116
117impl<Backend: AuthnBackend, Sessions: SessionStore, C: CookieController>
118    AuthManagerLayer<Backend, Sessions, C>
119{
120    /// Create a new [`AuthManagerLayer`] with the provided access controller.
121    pub(crate) fn new(
122        backend: Backend,
123        data_key: &'static str,
124        session_manager_layer: SessionManagerLayer<Sessions, C>,
125    ) -> Self {
126        Self {
127            backend,
128            session_manager_layer,
129            data_key,
130        }
131    }
132}
133
134impl<S, Backend: AuthnBackend, Sessions: SessionStore, C: CookieController> Layer<S>
135    for AuthManagerLayer<Backend, Sessions, C>
136{
137    type Service = CookieManager<SessionManager<AuthManager<S, Backend>, Sessions, C>>;
138
139    fn layer(&self, inner: S) -> Self::Service {
140        let login_manager = AuthManager {
141            inner,
142            backend: self.backend.clone(),
143            data_key: self.data_key,
144        };
145
146        self.session_manager_layer.layer(login_manager)
147    }
148}
149
150/// Builder for the [`AuthManagerLayer`].
151#[derive(Debug, Clone)]
152pub struct AuthManagerLayerBuilder<
153    Backend: AuthnBackend,
154    Sessions: SessionStore,
155    C: CookieController = PlaintextCookie,
156> {
157    backend: Backend,
158    session_manager_layer: SessionManagerLayer<Sessions, C>,
159    data_key: Option<&'static str>,
160}
161
162impl<Backend: AuthnBackend, Sessions: SessionStore, C: CookieController>
163    AuthManagerLayerBuilder<Backend, Sessions, C>
164{
165    /// Create a new [`AuthManagerLayerBuilder`] with the provided access
166    /// controller.
167    pub fn new(backend: Backend, session_manager_layer: SessionManagerLayer<Sessions, C>) -> Self {
168        Self {
169            backend,
170            session_manager_layer,
171            data_key: None,
172        }
173    }
174
175    /// Configure the `data_key` optional property of the builder. If not
176    /// configured it will default to "axum-login.data".
177    pub fn with_data_key(
178        mut self,
179        data_key: &'static str,
180    ) -> AuthManagerLayerBuilder<Backend, Sessions, C> {
181        self.data_key = Some(data_key);
182        self
183    }
184
185    /// Build the [`AuthManagerLayer`].
186    pub fn build(self) -> AuthManagerLayer<Backend, Sessions, C> {
187        AuthManagerLayer::new(
188            self.backend,
189            self.data_key.unwrap_or("axum-login.data"),
190            self.session_manager_layer,
191        )
192    }
193}