axum_login/
session.rs

1use std::fmt::Debug;
2
3use serde::{Deserialize, Serialize};
4use subtle::ConstantTimeEq;
5use tower_sessions::{session, Session};
6
7use crate::{
8    backend::{AuthUser, UserId},
9    AuthnBackend,
10};
11
12/// An error type which maps session and backend errors.
13#[derive(thiserror::Error)]
14pub enum Error<Backend: AuthnBackend> {
15    /// A mapping to `tower_sessions::session::Error'.
16    #[error(transparent)]
17    Session(session::Error),
18
19    /// A mapping to `Backend::Error`.
20    #[error(transparent)]
21    Backend(Backend::Error),
22}
23
24impl<Backend: AuthnBackend> Debug for Error<Backend> {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Error::Session(err) => write!(f, "{err:?}")?,
28            Error::Backend(err) => write!(f, "{err:?}")?,
29        };
30
31        Ok(())
32    }
33}
34
35impl<Backend: AuthnBackend> From<session::Error> for Error<Backend> {
36    fn from(value: session::Error) -> Self {
37        Self::Session(value)
38    }
39}
40
41#[derive(Debug, Clone, Deserialize, Serialize)]
42struct Data<UserId> {
43    user_id: Option<UserId>,
44    auth_hash: Option<Vec<u8>>,
45}
46
47impl<UserId: Clone> Default for Data<UserId> {
48    fn default() -> Self {
49        Self {
50            user_id: None,
51            auth_hash: None,
52        }
53    }
54}
55
56/// A specialized session for identification, authentication, and authorization
57/// of users associated with a backend.
58///
59/// The session is generic over some backend which implements [`AuthnBackend`].
60/// The backend may also implement [`AuthzBackend`](crate::AuthzBackend),
61/// in which case it will also supply authorization methods.
62///
63/// Methods for authenticating the session and logging a user in are provided.
64///
65/// Generally this session will be used in the context of some authentication
66/// workflow, for example via a frontend login form. There a user would provide
67/// their credentials, such as username and password, and via the backend
68/// the session would authenticate those credentials.
69///
70/// Once the supplied credentials have been authenticated, a user will be
71/// returned. In the case the credentials are invalid, no user will be returned.
72/// When we do have a user, it's then possible to set the state of the session
73/// so that the user is logged in.
74#[derive(Debug, Clone)]
75pub struct AuthSession<Backend: AuthnBackend> {
76    /// The user associated by the backend. `None` when not logged in.
77    pub user: Option<Backend::User>,
78
79    /// The authentication and authorization backend.
80    pub backend: Backend,
81
82    /// The underlying session.
83    pub session: Session,
84
85    data: Data<UserId<Backend>>,
86    data_key: &'static str,
87}
88
89impl<Backend: AuthnBackend> AuthSession<Backend> {
90    /// Verifies the provided credentials via the backend returning the
91    /// authenticated user if valid and otherwise `None`.
92    #[tracing::instrument(level = "debug", skip_all, fields(user.id), ret, err)]
93    pub async fn authenticate(
94        &self,
95        creds: Backend::Credentials,
96    ) -> Result<Option<Backend::User>, Error<Backend>> {
97        let result = self
98            .backend
99            .authenticate(creds)
100            .await
101            .map_err(Error::Backend);
102
103        if let Ok(Some(ref user)) = result {
104            tracing::Span::current().record("user.id", user.id().to_string());
105        }
106
107        result
108    }
109
110    /// Updates the session such that the user is logged in.
111    #[tracing::instrument(level = "debug", skip_all, fields(user.id = user.id().to_string()), ret, err)]
112    pub async fn login(&mut self, user: &Backend::User) -> Result<(), Error<Backend>> {
113        self.user = Some(user.clone());
114
115        if self.data.auth_hash.is_none() {
116            self.session.cycle_id().await?; // Session-fixation
117                                            // mitigation.
118        }
119
120        self.data.user_id = Some(user.id());
121        self.data.auth_hash = Some(user.session_auth_hash().to_owned());
122
123        self.update_session().await?;
124
125        Ok(())
126    }
127
128    /// Updates the session such that the user is logged out.
129    #[tracing::instrument(level = "debug", skip_all, fields(user.id), ret, err)]
130    pub async fn logout(&mut self) -> Result<Option<Backend::User>, Error<Backend>> {
131        let user = self.user.take();
132
133        if let Some(ref user) = user {
134            tracing::Span::current().record("user.id", user.id().to_string());
135        }
136
137        self.session.flush().await?;
138
139        Ok(user)
140    }
141
142    async fn update_session(&mut self) -> Result<(), session::Error> {
143        self.session.insert(self.data_key, self.data.clone()).await
144    }
145
146    pub(crate) async fn from_session(
147        session: Session,
148        backend: Backend,
149        data_key: &'static str,
150    ) -> Result<Self, Error<Backend>> {
151        let mut data: Data<_> = session.get(data_key).await?.unwrap_or_default();
152
153        let mut user = if let Some(ref user_id) = data.user_id {
154            backend.get_user(user_id).await.map_err(Error::Backend)?
155        } else {
156            None
157        };
158
159        if let Some(ref authed_user) = user {
160            let session_auth_hash = authed_user.session_auth_hash();
161            let session_verified = data
162                .auth_hash
163                .as_ref()
164                .is_some_and(|auth_hash| auth_hash.ct_eq(session_auth_hash).into());
165            if !session_verified {
166                user = None;
167                data = Data::default();
168                session.flush().await?;
169            }
170        }
171
172        Ok(Self {
173            user,
174            data,
175            backend,
176            session,
177            data_key,
178        })
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use std::sync::Arc;
185
186    use mockall::{predicate::*, *};
187    use tower_sessions::MemoryStore;
188
189    use super::*;
190
191    mock! {
192        #[derive(Debug)]
193        Backend {}
194
195        impl Clone for Backend {
196            fn clone(&self) -> Self;
197        }
198
199        impl AuthnBackend for Backend {
200            type User = MockUser;
201            type Credentials = MockCredentials;
202            type Error = MockError;
203
204            async fn authenticate(&self, creds: MockCredentials) -> Result<Option<MockUser>, MockError>;
205            async fn get_user(&self, user_id: &i64) -> Result<Option<MockUser>, MockError>;
206
207        }
208    }
209
210    #[derive(Debug, Clone)]
211    struct MockUser {
212        id: i64,
213        auth_hash: Vec<u8>,
214    }
215
216    impl AuthUser for MockUser {
217        type Id = i64;
218
219        fn id(&self) -> Self::Id {
220            self.id
221        }
222
223        fn session_auth_hash(&self) -> &[u8] {
224            &self.auth_hash
225        }
226    }
227
228    #[derive(Debug, Clone, PartialEq)]
229    struct MockCredentials;
230
231    #[derive(Debug)]
232    struct MockError;
233
234    impl std::fmt::Display for MockError {
235        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236            write!(f, "Mock error")
237        }
238    }
239
240    impl std::error::Error for MockError {}
241
242    #[tokio::test]
243    async fn test_authenticate() {
244        let mut mock_backend = MockBackend::default();
245        let mock_user = MockUser {
246            id: 42,
247            auth_hash: Default::default(),
248        };
249        let creds = MockCredentials;
250
251        mock_backend
252            .expect_authenticate()
253            .with(eq(creds.clone()))
254            .times(1)
255            .returning(move |_| Ok(Some(mock_user.clone())));
256
257        let store = Arc::new(MemoryStore::default());
258
259        let session = Session::new(None, store, None);
260        let auth_session = AuthSession {
261            user: None,
262            backend: mock_backend,
263            data: Data::default(),
264            session,
265            data_key: "auth_data",
266        };
267
268        let result = auth_session.authenticate(creds).await;
269        assert!(result.is_ok());
270        assert!(result.unwrap().is_some());
271    }
272
273    #[tokio::test]
274    async fn test_authenticate_bad_credentials() {
275        let mut mock_backend = MockBackend::default();
276        let bad_creds = MockCredentials;
277
278        mock_backend
279            .expect_authenticate()
280            .with(eq(bad_creds.clone()))
281            .times(1)
282            .returning(|_| Ok(None));
283
284        let store = Arc::new(MemoryStore::default());
285
286        let session = Session::new(None, store, None);
287        let auth_session = AuthSession {
288            user: None,
289            backend: mock_backend,
290            data: Data::default(),
291            session,
292            data_key: "auth_data",
293        };
294
295        let result = auth_session.authenticate(bad_creds).await;
296        assert!(result.is_ok());
297        assert!(result.unwrap().is_none());
298    }
299
300    #[tokio::test]
301    async fn test_login() {
302        let mock_backend = MockBackend::default();
303        let mock_user = MockUser {
304            id: 42,
305            auth_hash: Default::default(),
306        };
307
308        let store = Arc::new(MemoryStore::default());
309        let session = Session::new(None, store, None);
310        let original_session_id = session.id();
311        let mut auth_session = AuthSession {
312            user: None,
313            backend: mock_backend,
314            data: Data::default(),
315            session: session.clone(),
316            data_key: "auth_data",
317        };
318
319        auth_session.login(&mock_user).await.unwrap();
320        assert!(auth_session.user.is_some());
321        assert_eq!(auth_session.user.unwrap().id(), 42);
322
323        // Simulate request persisting session.
324        session.save().await.unwrap();
325
326        // We were provided no session initially.
327        assert!(original_session_id.is_none());
328
329        // We have a session ID after saving.
330        assert!(session.id().is_some());
331    }
332
333    #[tokio::test]
334    async fn test_logout() {
335        let mock_backend = MockBackend::default();
336        let mock_user = MockUser {
337            id: 42,
338            auth_hash: Default::default(),
339        };
340
341        let store = Arc::new(MemoryStore::default());
342        let session = Session::new(None, store, None);
343        let mut auth_session = AuthSession {
344            user: Some(mock_user.clone()),
345            backend: mock_backend,
346            data: Data::default(),
347            session,
348            data_key: "auth_data",
349        };
350
351        let logged_out_user = auth_session.logout().await.unwrap();
352        assert!(logged_out_user.is_some());
353        assert_eq!(logged_out_user.unwrap().id(), 42);
354        assert!(auth_session.user.is_none());
355    }
356
357    #[tokio::test]
358    async fn test_from_session() {
359        let mut mock_backend = MockBackend::default();
360        let mock_user = MockUser {
361            id: 42,
362            auth_hash: vec![1, 2, 3, 4],
363        };
364
365        mock_backend
366            .expect_get_user()
367            .with(eq(mock_user.id))
368            .times(1)
369            .returning(move |_| Ok(Some(mock_user.clone())));
370
371        let store = Arc::new(MemoryStore::default());
372        let session = Session::new(None, store.clone(), None);
373        let data_key = "auth_data";
374
375        // Simulate a user being logged in
376        let data = Data {
377            user_id: Some(42),
378            auth_hash: Some(vec![1, 2, 3, 4]),
379        };
380        session.insert(data_key, &data).await.unwrap();
381
382        let auth_session = AuthSession::from_session(session, mock_backend, data_key)
383            .await
384            .unwrap();
385
386        assert!(auth_session.user.is_some());
387        assert_eq!(auth_session.user.unwrap().id(), 42);
388    }
389
390    #[tokio::test]
391    async fn test_from_session_bad_auth_hash() {
392        let mut mock_backend = MockBackend::default();
393        let mock_user = MockUser {
394            id: 42,
395            auth_hash: vec![1, 2, 3, 4],
396        };
397
398        mock_backend
399            .expect_get_user()
400            .with(eq(mock_user.id))
401            .times(1)
402            .returning(move |_| Ok(Some(mock_user.clone())));
403
404        let store = Arc::new(MemoryStore::default());
405        let session = Session::new(None, store.clone(), None);
406        let data_key = "auth_data";
407
408        // Try to use a malformed auth hash.
409        let data = Data {
410            user_id: Some(42),
411            auth_hash: Some(vec![4, 3, 2, 1]),
412        };
413        session.insert(data_key, &data).await.unwrap();
414
415        let auth_session = AuthSession::from_session(session, mock_backend, data_key)
416            .await
417            .unwrap();
418
419        assert!(auth_session.user.is_none());
420    }
421}