tower_sesh/
session.rs

1use std::{
2    ops::{Deref, DerefMut},
3    sync::Arc,
4};
5
6use async_trait::async_trait;
7use parking_lot::{Mutex, MutexGuard};
8use tower_sesh_core::{store::Ttl, Record, SessionKey};
9
10/// Extractor to read and mutate session data.
11///
12/// # Session migration
13///
14/// TODO
15///
16/// # Logging rejections
17///
18/// To see the logs, enable the `tracing` feature for `tower-sesh` (enabled by
19/// default) and the `tower_sesh::rejection=trace` tracing target, for example
20/// with `RUST_LOG=info,tower_sesh::rejection=trace cargo run`.
21pub struct Session<T>(Arc<Mutex<Inner<T>>>);
22
23/// A RAII mutex guard holding a lock to a mutex contained in `Session<T>`. The
24/// data `T` can be accessed through this guard via its [`Deref`] and
25/// [`DerefMut`] implementations.
26///
27/// The lock is automatically released whenever the guard is dropped.
28//
29// # Invariants
30//
31// 1. When constructing `SessionGuard`, the `data` contained within
32//    `SessionInner` must contain a `Some` variant. This invariant must be met
33//    while the mutex lock is held.
34// 2. After the previous invariant is met, and until the `SessionGuard` is
35//    dropped, the lock must never be released and `data` must never be replaced
36//    with `None`.
37pub struct SessionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
38
39/// A RAII mutex guard holding a lock to a mutex contained in `Session<T>`. The
40/// data `Option<T>` can be accessed through this guard via its [`Deref`] and
41/// [`DerefMut`] implementations.
42///
43/// The lock is automatically released whenever the guard is dropped.
44pub struct OptionSessionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
45
46struct Inner<T> {
47    session_key: Option<SessionKey>,
48    data: Option<T>,
49    expires_at: Option<Ttl>,
50    status: Status,
51}
52
53/// # State transitions
54///
55/// Unchanged -> Changed | Renewed | Purged
56/// Renewed -> Changed | Purged
57/// Changed -> Purged
58/// Purged
59enum Status {
60    Unchanged,
61    Renewed,
62    Changed,
63    Purged,
64}
65use Status::*;
66
67impl<T> Inner<T> {
68    fn changed(&mut self) {
69        if !matches!(self.status, Purged) {
70            self.status = Changed;
71        }
72    }
73}
74
75impl<T> Session<T> {
76    fn new(session_key: SessionKey, record: Record<T>) -> Session<T> {
77        let inner = Inner {
78            session_key: Some(session_key),
79            data: Some(record.data),
80            expires_at: Some(record.ttl),
81            status: Unchanged,
82        };
83        Session(Arc::new(Mutex::new(inner)))
84    }
85
86    fn empty() -> Session<T> {
87        let inner = Inner {
88            session_key: None,
89            data: None,
90            expires_at: None,
91            status: Unchanged,
92        };
93        Session(Arc::new(Mutex::new(inner)))
94    }
95
96    fn ignored(session_key: SessionKey) -> Session<T> {
97        let inner = Inner {
98            session_key: Some(session_key),
99            data: None,
100            expires_at: None,
101            status: Unchanged,
102        };
103        Session(Arc::new(Mutex::new(inner)))
104    }
105
106    #[must_use]
107    pub fn get(&self) -> OptionSessionGuard<'_, T> {
108        let lock = self.0.lock();
109
110        OptionSessionGuard::new(lock)
111    }
112
113    pub fn insert(&self, value: T) -> SessionGuard<'_, T> {
114        let mut lock = self.0.lock();
115
116        lock.data = Some(value);
117        lock.changed();
118
119        // SAFETY: a `None` variant for `data` would have been replaced by a
120        // `Some` variant in the code above.
121        unsafe { SessionGuard::new(lock) }
122    }
123
124    pub fn get_or_insert(&self, value: T) -> SessionGuard<'_, T> {
125        let mut lock = self.0.lock();
126
127        if lock.data.is_none() {
128            lock.data = Some(value);
129            lock.changed();
130        }
131
132        // SAFETY: a `None` variant for `data` would have been replaced by a
133        // `Some` variant in the code above.
134        unsafe { SessionGuard::new(lock) }
135    }
136
137    pub fn get_or_insert_with<F>(&self, f: F) -> SessionGuard<'_, T>
138    where
139        F: FnOnce() -> T,
140    {
141        let mut lock = self.0.lock();
142
143        if lock.data.is_none() {
144            lock.data = Some(f());
145            lock.changed();
146        }
147
148        // SAFETY: a `None` variant for `data` would have been replaced by a
149        // `Some` variant in the code above.
150        unsafe { SessionGuard::new(lock) }
151    }
152
153    #[inline]
154    pub fn get_or_insert_default(&self) -> SessionGuard<'_, T>
155    where
156        T: Default,
157    {
158        self.get_or_insert_with(T::default)
159    }
160}
161
162impl<T> Clone for Session<T> {
163    fn clone(&self) -> Self {
164        Session(Arc::clone(&self.0))
165    }
166}
167
168define_rejection! {
169    #[status = INTERNAL_SERVER_ERROR]
170    #[body = "Failed to load session"]
171    /// Rejection for [`Session`] if an unrecoverable error occurred when
172    /// loading the session.
173    pub struct SessionRejection;
174}
175
176#[cfg(feature = "axum")]
177#[async_trait]
178impl<S, T> axum::extract::FromRequestParts<S> for Session<T>
179where
180    T: 'static + Send + Sync,
181{
182    type Rejection = SessionRejection;
183
184    async fn from_request_parts(
185        parts: &mut http::request::Parts,
186        _state: &S,
187    ) -> Result<Self, Self::Rejection> {
188        match lazy::get_or_init(&mut parts.extensions).await {
189            Ok(Some(session)) => Ok(session),
190            Ok(None) => Err(SessionRejection),
191            // Panic because this indicates a bug in the program rather than an
192            // expected failure.
193            Err(_) => panic!(
194                "Missing request extension. `SessionLayer` must be called \
195                before the `Session` extractor is run. Also, check that the \
196                generic type for `Session<T>` is correct."
197            ),
198        }
199    }
200}
201
202impl<'a, T> SessionGuard<'a, T> {
203    /// # Safety
204    ///
205    /// The caller of this method must ensure that `guard.data` is a
206    /// `Some` variant.
207    #[track_caller]
208    unsafe fn new(guard: MutexGuard<'a, Inner<T>>) -> Self {
209        debug_assert!(guard.data.is_some());
210        SessionGuard(guard)
211    }
212}
213
214impl<T> Deref for SessionGuard<'_, T> {
215    type Target = T;
216
217    fn deref(&self) -> &Self::Target {
218        // SAFETY: `SessionGuard` holds the lock, so `data` can never be set
219        // to `None`.
220        unsafe { self.0.data.as_ref().unwrap_unchecked() }
221    }
222}
223
224impl<T> DerefMut for SessionGuard<'_, T> {
225    fn deref_mut(&mut self) -> &mut Self::Target {
226        self.0.changed();
227
228        // SAFETY: `SessionGuard` holds the lock, so `data` can never be set
229        // to `None`.
230        unsafe { self.0.data.as_mut().unwrap_unchecked() }
231    }
232}
233
234impl<'a, T> OptionSessionGuard<'a, T> {
235    fn new(guard: MutexGuard<'a, Inner<T>>) -> Self {
236        OptionSessionGuard(guard)
237    }
238}
239
240impl<T> Deref for OptionSessionGuard<'_, T> {
241    type Target = Option<T>;
242
243    fn deref(&self) -> &Self::Target {
244        &self.0.data
245    }
246}
247
248impl<T> DerefMut for OptionSessionGuard<'_, T> {
249    fn deref_mut(&mut self) -> &mut Self::Target {
250        self.0.changed();
251
252        &mut self.0.data
253    }
254}
255
256pub(crate) mod lazy {
257    use std::{error::Error as StdError, fmt, sync::Arc};
258
259    use async_once_cell::OnceCell;
260    use cookie::Cookie;
261    use http::Extensions;
262    use tower_sesh_core::{store::ErrorKind, SessionKey, SessionStore};
263
264    use crate::{middleware::SessionConfig, util::ErrorExt};
265
266    use super::Session;
267
268    pub(crate) fn insert<T>(
269        cookie: Option<Cookie<'static>>,
270        store: &Arc<impl SessionStore<T>>,
271        extensions: &mut Extensions,
272        session_config: SessionConfig,
273    ) where
274        T: 'static + Send,
275    {
276        debug_assert!(
277            extensions.get::<LazySession<T>>().is_none(),
278            "`session::lazy::insert` was called more than once!"
279        );
280
281        let lazy_session = match cookie {
282            Some(cookie) => LazySession::new(cookie, Arc::clone(store), session_config),
283            None => LazySession::empty(),
284        };
285        extensions.insert::<LazySession<T>>(lazy_session);
286    }
287
288    pub(super) async fn get_or_init<T>(
289        extensions: &mut Extensions,
290    ) -> Result<Option<Session<T>>, Error>
291    where
292        T: 'static + Send,
293    {
294        match extensions.get::<LazySession<T>>() {
295            Some(lazy_session) => Ok(lazy_session.get_or_init().await.cloned()),
296            None => Err(Error),
297        }
298    }
299
300    pub(crate) fn take<T>(extensions: &mut Extensions) -> Result<Option<Session<T>>, Error>
301    where
302        T: 'static + Send,
303    {
304        match extensions.remove::<LazySession<T>>() {
305            Some(lazy_session) => Ok(lazy_session.get().cloned()),
306            None => Err(Error),
307        }
308    }
309
310    enum LazySession<T> {
311        Empty(Arc<OnceCell<Session<T>>>),
312        Init {
313            cookie: Cookie<'static>,
314            store: Arc<dyn SessionStore<T> + 'static>,
315            session: Arc<OnceCell<Option<Session<T>>>>,
316            config: SessionConfig,
317        },
318    }
319
320    impl<T> Clone for LazySession<T> {
321        fn clone(&self) -> Self {
322            match self {
323                LazySession::Empty(session) => LazySession::Empty(Arc::clone(session)),
324                LazySession::Init {
325                    cookie,
326                    store,
327                    session,
328                    config,
329                } => LazySession::Init {
330                    cookie: cookie.clone(),
331                    store: Arc::clone(store),
332                    session: Arc::clone(session),
333                    config: config.clone(),
334                },
335            }
336        }
337    }
338
339    impl<T> LazySession<T>
340    where
341        T: 'static,
342    {
343        fn new(
344            cookie: Cookie<'static>,
345            store: Arc<impl SessionStore<T>>,
346            config: SessionConfig,
347        ) -> LazySession<T> {
348            LazySession::Init {
349                cookie,
350                store,
351                session: Arc::new(OnceCell::new()),
352                config,
353            }
354        }
355
356        fn empty() -> LazySession<T> {
357            LazySession::Empty(Arc::new(OnceCell::new()))
358        }
359
360        async fn get_or_init(&self) -> Option<&Session<T>> {
361            match self {
362                LazySession::Empty(session) => {
363                    Some(session.get_or_init(async { Session::empty() }).await)
364                }
365                LazySession::Init {
366                    cookie,
367                    store,
368                    session,
369                    config,
370                } => session
371                    .get_or_init(init_session(cookie, store.as_ref(), config))
372                    .await
373                    .as_ref(),
374            }
375        }
376
377        fn get(&self) -> Option<&Session<T>> {
378            match self {
379                LazySession::Empty(session) => session.get(),
380                LazySession::Init { session, .. } => session.get().and_then(Option::as_ref),
381            }
382        }
383    }
384
385    async fn init_session<T>(
386        cookie: &Cookie<'static>,
387        store: &dyn SessionStore<T>,
388        config: &SessionConfig,
389    ) -> Option<Session<T>>
390    where
391        T: 'static,
392    {
393        let session_key = match SessionKey::decode(cookie.value()) {
394            Ok(session_key) => session_key,
395            Err(_) => return Some(Session::empty()),
396        };
397
398        match store.load(&session_key).await {
399            Ok(Some(record)) => Some(Session::new(session_key, record)),
400            Ok(None) => Some(Session::empty()),
401            Err(err) => {
402                match err.kind() {
403                    ErrorKind::Serde(_) if config.ignore_invalid_session => {
404                        Some(Session::ignored(session_key))
405                    }
406                    _ => {
407                        // TODO: Better error reporting
408                        error!(message = %err.display_chain());
409                        None
410                    }
411                }
412            }
413        }
414    }
415
416    pub(crate) struct Error;
417
418    impl StdError for Error {
419        fn source(&self) -> Option<&(dyn StdError + 'static)> {
420            None
421        }
422    }
423
424    impl fmt::Display for Error {
425        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426            f.write_str("missing request extension")
427        }
428    }
429
430    impl fmt::Debug for Error {
431        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432            write!(f, "Error({:?})", self.to_string())
433        }
434    }
435}