1use std::fmt::Debug;
2
3use serde::{Deserialize, Serialize};
4use subtle::ConstantTimeEq;
5use tower_sessions::{session, Session};
6
7use crate::{
8 backend::{AuthUser, UserId},
9 AuthnBackend,
10};
11
12#[derive(thiserror::Error)]
14pub enum Error<Backend: AuthnBackend> {
15 #[error(transparent)]
17 Session(session::Error),
18
19 #[error(transparent)]
21 Backend(Backend::Error),
22}
23
24impl<Backend: AuthnBackend> Debug for Error<Backend> {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 Error::Session(err) => write!(f, "{err:?}")?,
28 Error::Backend(err) => write!(f, "{err:?}")?,
29 };
30
31 Ok(())
32 }
33}
34
35impl<Backend: AuthnBackend> From<session::Error> for Error<Backend> {
36 fn from(value: session::Error) -> Self {
37 Self::Session(value)
38 }
39}
40
41#[derive(Debug, Clone, Deserialize, Serialize)]
42struct Data<UserId> {
43 user_id: Option<UserId>,
44 auth_hash: Option<Vec<u8>>,
45}
46
47impl<UserId: Clone> Default for Data<UserId> {
48 fn default() -> Self {
49 Self {
50 user_id: None,
51 auth_hash: None,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
75pub struct AuthSession<Backend: AuthnBackend> {
76 pub user: Option<Backend::User>,
78
79 pub backend: Backend,
81
82 pub session: Session,
84
85 data: Data<UserId<Backend>>,
86 data_key: &'static str,
87}
88
89impl<Backend: AuthnBackend> AuthSession<Backend> {
90 #[tracing::instrument(level = "debug", skip_all, fields(user.id), ret, err)]
93 pub async fn authenticate(
94 &self,
95 creds: Backend::Credentials,
96 ) -> Result<Option<Backend::User>, Error<Backend>> {
97 let result = self
98 .backend
99 .authenticate(creds)
100 .await
101 .map_err(Error::Backend);
102
103 if let Ok(Some(ref user)) = result {
104 tracing::Span::current().record("user.id", user.id().to_string());
105 }
106
107 result
108 }
109
110 #[tracing::instrument(level = "debug", skip_all, fields(user.id = user.id().to_string()), ret, err)]
112 pub async fn login(&mut self, user: &Backend::User) -> Result<(), Error<Backend>> {
113 self.user = Some(user.clone());
114
115 if self.data.auth_hash.is_none() {
116 self.session.cycle_id().await?; }
119
120 self.data.user_id = Some(user.id());
121 self.data.auth_hash = Some(user.session_auth_hash().to_owned());
122
123 self.update_session().await?;
124
125 Ok(())
126 }
127
128 #[tracing::instrument(level = "debug", skip_all, fields(user.id), ret, err)]
130 pub async fn logout(&mut self) -> Result<Option<Backend::User>, Error<Backend>> {
131 let user = self.user.take();
132
133 if let Some(ref user) = user {
134 tracing::Span::current().record("user.id", user.id().to_string());
135 }
136
137 self.session.flush().await?;
138
139 Ok(user)
140 }
141
142 async fn update_session(&mut self) -> Result<(), session::Error> {
143 self.session.insert(self.data_key, self.data.clone()).await
144 }
145
146 pub(crate) async fn from_session(
147 session: Session,
148 backend: Backend,
149 data_key: &'static str,
150 ) -> Result<Self, Error<Backend>> {
151 let mut data: Data<_> = session.get(data_key).await?.unwrap_or_default();
152
153 let mut user = if let Some(ref user_id) = data.user_id {
154 backend.get_user(user_id).await.map_err(Error::Backend)?
155 } else {
156 None
157 };
158
159 if let Some(ref authed_user) = user {
160 let session_auth_hash = authed_user.session_auth_hash();
161 let session_verified = data
162 .auth_hash
163 .as_ref()
164 .is_some_and(|auth_hash| auth_hash.ct_eq(session_auth_hash).into());
165 if !session_verified {
166 user = None;
167 data = Data::default();
168 session.flush().await?;
169 }
170 }
171
172 Ok(Self {
173 user,
174 data,
175 backend,
176 session,
177 data_key,
178 })
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use std::sync::Arc;
185
186 use mockall::{predicate::*, *};
187 use tower_sessions::MemoryStore;
188
189 use super::*;
190
191 mock! {
192 #[derive(Debug)]
193 Backend {}
194
195 impl Clone for Backend {
196 fn clone(&self) -> Self;
197 }
198
199 impl AuthnBackend for Backend {
200 type User = MockUser;
201 type Credentials = MockCredentials;
202 type Error = MockError;
203
204 async fn authenticate(&self, creds: MockCredentials) -> Result<Option<MockUser>, MockError>;
205 async fn get_user(&self, user_id: &i64) -> Result<Option<MockUser>, MockError>;
206
207 }
208 }
209
210 #[derive(Debug, Clone)]
211 struct MockUser {
212 id: i64,
213 auth_hash: Vec<u8>,
214 }
215
216 impl AuthUser for MockUser {
217 type Id = i64;
218
219 fn id(&self) -> Self::Id {
220 self.id
221 }
222
223 fn session_auth_hash(&self) -> &[u8] {
224 &self.auth_hash
225 }
226 }
227
228 #[derive(Debug, Clone, PartialEq)]
229 struct MockCredentials;
230
231 #[derive(Debug)]
232 struct MockError;
233
234 impl std::fmt::Display for MockError {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "Mock error")
237 }
238 }
239
240 impl std::error::Error for MockError {}
241
242 #[tokio::test]
243 async fn test_authenticate() {
244 let mut mock_backend = MockBackend::default();
245 let mock_user = MockUser {
246 id: 42,
247 auth_hash: Default::default(),
248 };
249 let creds = MockCredentials;
250
251 mock_backend
252 .expect_authenticate()
253 .with(eq(creds.clone()))
254 .times(1)
255 .returning(move |_| Ok(Some(mock_user.clone())));
256
257 let store = Arc::new(MemoryStore::default());
258
259 let session = Session::new(None, store, None);
260 let auth_session = AuthSession {
261 user: None,
262 backend: mock_backend,
263 data: Data::default(),
264 session,
265 data_key: "auth_data",
266 };
267
268 let result = auth_session.authenticate(creds).await;
269 assert!(result.is_ok());
270 assert!(result.unwrap().is_some());
271 }
272
273 #[tokio::test]
274 async fn test_authenticate_bad_credentials() {
275 let mut mock_backend = MockBackend::default();
276 let bad_creds = MockCredentials;
277
278 mock_backend
279 .expect_authenticate()
280 .with(eq(bad_creds.clone()))
281 .times(1)
282 .returning(|_| Ok(None));
283
284 let store = Arc::new(MemoryStore::default());
285
286 let session = Session::new(None, store, None);
287 let auth_session = AuthSession {
288 user: None,
289 backend: mock_backend,
290 data: Data::default(),
291 session,
292 data_key: "auth_data",
293 };
294
295 let result = auth_session.authenticate(bad_creds).await;
296 assert!(result.is_ok());
297 assert!(result.unwrap().is_none());
298 }
299
300 #[tokio::test]
301 async fn test_login() {
302 let mock_backend = MockBackend::default();
303 let mock_user = MockUser {
304 id: 42,
305 auth_hash: Default::default(),
306 };
307
308 let store = Arc::new(MemoryStore::default());
309 let session = Session::new(None, store, None);
310 let original_session_id = session.id();
311 let mut auth_session = AuthSession {
312 user: None,
313 backend: mock_backend,
314 data: Data::default(),
315 session: session.clone(),
316 data_key: "auth_data",
317 };
318
319 auth_session.login(&mock_user).await.unwrap();
320 assert!(auth_session.user.is_some());
321 assert_eq!(auth_session.user.unwrap().id(), 42);
322
323 session.save().await.unwrap();
325
326 assert!(original_session_id.is_none());
328
329 assert!(session.id().is_some());
331 }
332
333 #[tokio::test]
334 async fn test_logout() {
335 let mock_backend = MockBackend::default();
336 let mock_user = MockUser {
337 id: 42,
338 auth_hash: Default::default(),
339 };
340
341 let store = Arc::new(MemoryStore::default());
342 let session = Session::new(None, store, None);
343 let mut auth_session = AuthSession {
344 user: Some(mock_user.clone()),
345 backend: mock_backend,
346 data: Data::default(),
347 session,
348 data_key: "auth_data",
349 };
350
351 let logged_out_user = auth_session.logout().await.unwrap();
352 assert!(logged_out_user.is_some());
353 assert_eq!(logged_out_user.unwrap().id(), 42);
354 assert!(auth_session.user.is_none());
355 }
356
357 #[tokio::test]
358 async fn test_from_session() {
359 let mut mock_backend = MockBackend::default();
360 let mock_user = MockUser {
361 id: 42,
362 auth_hash: vec![1, 2, 3, 4],
363 };
364
365 mock_backend
366 .expect_get_user()
367 .with(eq(mock_user.id))
368 .times(1)
369 .returning(move |_| Ok(Some(mock_user.clone())));
370
371 let store = Arc::new(MemoryStore::default());
372 let session = Session::new(None, store.clone(), None);
373 let data_key = "auth_data";
374
375 let data = Data {
377 user_id: Some(42),
378 auth_hash: Some(vec![1, 2, 3, 4]),
379 };
380 session.insert(data_key, &data).await.unwrap();
381
382 let auth_session = AuthSession::from_session(session, mock_backend, data_key)
383 .await
384 .unwrap();
385
386 assert!(auth_session.user.is_some());
387 assert_eq!(auth_session.user.unwrap().id(), 42);
388 }
389
390 #[tokio::test]
391 async fn test_from_session_bad_auth_hash() {
392 let mut mock_backend = MockBackend::default();
393 let mock_user = MockUser {
394 id: 42,
395 auth_hash: vec![1, 2, 3, 4],
396 };
397
398 mock_backend
399 .expect_get_user()
400 .with(eq(mock_user.id))
401 .times(1)
402 .returning(move |_| Ok(Some(mock_user.clone())));
403
404 let store = Arc::new(MemoryStore::default());
405 let session = Session::new(None, store.clone(), None);
406 let data_key = "auth_data";
407
408 let data = Data {
410 user_id: Some(42),
411 auth_hash: Some(vec![4, 3, 2, 1]),
412 };
413 session.insert(data_key, &data).await.unwrap();
414
415 let auth_session = AuthSession::from_session(session, mock_backend, data_key)
416 .await
417 .unwrap();
418
419 assert!(auth_session.user.is_none());
420 }
421}