1use std::{
2 fmt::Debug,
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use axum::http::{self, Request, Response};
9use tower_cookies::CookieManager;
10use tower_layer::Layer;
11use tower_service::Service;
12use tower_sessions::{
13 service::{CookieController, PlaintextCookie},
14 Session, SessionManager, SessionManagerLayer, SessionStore,
15};
16use tracing::Instrument;
17
18use crate::{AuthSession, AuthUser, AuthnBackend};
19
20#[derive(Debug, Clone)]
22pub struct AuthManager<S, Backend: AuthnBackend> {
23 inner: S,
24 backend: Backend,
25 data_key: &'static str,
26}
27
28impl<S, Backend: AuthnBackend> AuthManager<S, Backend> {
29 pub fn new(inner: S, backend: Backend, data_key: &'static str) -> Self {
31 Self {
32 inner,
33 backend,
34 data_key,
35 }
36 }
37}
38
39impl<ReqBody, ResBody, S, Backend> Service<Request<ReqBody>> for AuthManager<S, Backend>
40where
41 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
42 S::Future: Send + 'static,
43 ReqBody: Send + 'static,
44 ResBody: Default + Send,
45 Backend: AuthnBackend + 'static,
46{
47 type Response = S::Response;
48 type Error = S::Error;
49 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
50
51 #[inline]
52 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
53 self.inner.poll_ready(cx)
54 }
55
56 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
57 let span = tracing::info_span!("call", user.id = tracing::field::Empty);
58
59 let backend = self.backend.clone();
60 let data_key = self.data_key;
61
62 let clone = self.inner.clone();
67 let mut inner = std::mem::replace(&mut self.inner, clone);
68
69 Box::pin(
70 async move {
71 let Some(session) = req.extensions().get::<Session>().cloned() else {
72 tracing::error!("session not found in request extensions");
73 let mut res = Response::default();
74 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
75 return Ok(res);
76 };
77
78 let auth_session = match AuthSession::from_session(session, backend, data_key).await
79 {
80 Ok(auth_session) => auth_session,
81 Err(err) => {
82 tracing::error!(
83 err = %err,
84 "could not create auth session from session"
85 );
86 let mut res = Response::default();
87 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
88 return Ok(res);
89 }
90 };
91
92 if let Some(ref user) = auth_session.user {
93 tracing::Span::current().record("user.id", user.id().to_string());
94 }
95
96 req.extensions_mut().insert(auth_session);
97
98 inner.call(req).await
99 }
100 .instrument(span),
101 )
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct AuthManagerLayer<
108 Backend: AuthnBackend,
109 Sessions: SessionStore,
110 C: CookieController = PlaintextCookie,
111> {
112 backend: Backend,
113 session_manager_layer: SessionManagerLayer<Sessions, C>,
114 data_key: &'static str,
115}
116
117impl<Backend: AuthnBackend, Sessions: SessionStore, C: CookieController>
118 AuthManagerLayer<Backend, Sessions, C>
119{
120 pub(crate) fn new(
122 backend: Backend,
123 data_key: &'static str,
124 session_manager_layer: SessionManagerLayer<Sessions, C>,
125 ) -> Self {
126 Self {
127 backend,
128 session_manager_layer,
129 data_key,
130 }
131 }
132}
133
134impl<S, Backend: AuthnBackend, Sessions: SessionStore, C: CookieController> Layer<S>
135 for AuthManagerLayer<Backend, Sessions, C>
136{
137 type Service = CookieManager<SessionManager<AuthManager<S, Backend>, Sessions, C>>;
138
139 fn layer(&self, inner: S) -> Self::Service {
140 let login_manager = AuthManager {
141 inner,
142 backend: self.backend.clone(),
143 data_key: self.data_key,
144 };
145
146 self.session_manager_layer.layer(login_manager)
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct AuthManagerLayerBuilder<
153 Backend: AuthnBackend,
154 Sessions: SessionStore,
155 C: CookieController = PlaintextCookie,
156> {
157 backend: Backend,
158 session_manager_layer: SessionManagerLayer<Sessions, C>,
159 data_key: Option<&'static str>,
160}
161
162impl<Backend: AuthnBackend, Sessions: SessionStore, C: CookieController>
163 AuthManagerLayerBuilder<Backend, Sessions, C>
164{
165 pub fn new(backend: Backend, session_manager_layer: SessionManagerLayer<Sessions, C>) -> Self {
168 Self {
169 backend,
170 session_manager_layer,
171 data_key: None,
172 }
173 }
174
175 pub fn with_data_key(
178 mut self,
179 data_key: &'static str,
180 ) -> AuthManagerLayerBuilder<Backend, Sessions, C> {
181 self.data_key = Some(data_key);
182 self
183 }
184
185 pub fn build(self) -> AuthManagerLayer<Backend, Sessions, C> {
187 AuthManagerLayer::new(
188 self.backend,
189 self.data_key.unwrap_or("axum-login.data"),
190 self.session_manager_layer,
191 )
192 }
193}