viz_core/middleware/session/
config.rs

1use std::{
2    fmt,
3    sync::{atomic::Ordering, Arc},
4    time::Duration,
5};
6
7use crate::{
8    middleware::helper::{CookieOptions, Cookieable},
9    types::{Cookie, Session},
10    Handler, IntoResponse, Request, RequestExt, Response, Result, Transform,
11};
12
13use super::{Storage, Store, PURGED, RENEWED, UNCHANGED};
14
15/// A configuration for [`SessionMiddleware`].
16pub struct Config<S, G, V>(Arc<(Store<S, G, V>, CookieOptions)>);
17
18impl<S, G, V> Config<S, G, V> {
19    /// Creates a new configuration with the [`Store`] and [`CookieOptions`].
20    #[must_use]
21    pub fn new(store: Store<S, G, V>, cookie: CookieOptions) -> Self {
22        Self(Arc::new((store, cookie)))
23    }
24
25    /// Gets the store.
26    #[must_use]
27    pub fn store(&self) -> &Store<S, G, V> {
28        &self.0 .0
29    }
30
31    /// Gets the TTL.
32    #[must_use]
33    pub fn ttl(&self) -> Option<Duration> {
34        self.options().max_age
35    }
36}
37
38impl<S, G, V> Clone for Config<S, G, V> {
39    fn clone(&self) -> Self {
40        Self(self.0.clone())
41    }
42}
43
44impl<S, G, V> Cookieable for Config<S, G, V> {
45    fn options(&self) -> &CookieOptions {
46        &self.0 .1
47    }
48}
49
50impl<S, G, V> fmt::Debug for Config<S, G, V> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("SessionConfig").finish()
53    }
54}
55
56impl<H, S, G, V> Transform<H> for Config<S, G, V> {
57    type Output = SessionMiddleware<H, S, G, V>;
58
59    fn transform(&self, h: H) -> Self::Output {
60        SessionMiddleware {
61            h,
62            config: self.clone(),
63        }
64    }
65}
66
67/// Session middleware.
68#[derive(Debug)]
69pub struct SessionMiddleware<H, S, G, V> {
70    h: H,
71    config: Config<S, G, V>,
72}
73
74impl<H, S, G, V> Clone for SessionMiddleware<H, S, G, V>
75where
76    H: Clone,
77{
78    fn clone(&self) -> Self {
79        Self {
80            h: self.h.clone(),
81            config: self.config.clone(),
82        }
83    }
84}
85
86#[crate::async_trait]
87impl<H, O, S, G, V> Handler<Request> for SessionMiddleware<H, S, G, V>
88where
89    H: Handler<Request, Output = Result<O>>,
90    O: IntoResponse,
91    S: Storage + 'static,
92    G: Fn() -> String + Send + Sync + 'static,
93    V: Fn(&str) -> bool + Send + Sync + 'static,
94{
95    type Output = Result<Response>;
96
97    async fn call(&self, mut req: Request) -> Self::Output {
98        let Self { h, config } = self;
99
100        let cookies = req.cookies()?;
101        let cookie = config.get_cookie(&cookies);
102
103        let mut session_id = cookie.as_ref().map(Cookie::value).map(ToString::to_string);
104        let data = match &session_id {
105            Some(sid) if (config.store().verify)(sid) => config.store().get(sid).await?,
106            _ => None,
107        };
108        if data.is_none() && session_id.is_some() {
109            session_id.take();
110        }
111        let session = Session::new(data.unwrap_or_default());
112        req.extensions_mut().insert(session.clone());
113
114        let resp = h.call(req).await.map(IntoResponse::into_response);
115
116        let status = session.status().load(Ordering::Acquire);
117
118        if status == UNCHANGED {
119            return resp;
120        }
121
122        if status == PURGED {
123            if let Some(sid) = &session_id {
124                config.store().remove(sid).await?;
125                config.remove_cookie(&cookies);
126            }
127
128            return resp;
129        }
130
131        if status == RENEWED {
132            if let Some(sid) = &session_id.take() {
133                config.store().remove(sid).await?;
134            }
135        }
136
137        let sid = session_id.unwrap_or_else(|| {
138            let sid = (config.store().generate)();
139            config.set_cookie(&cookies, &sid);
140            sid
141        });
142
143        config
144            .store()
145            .set(&sid, session.data()?, &config.ttl().unwrap_or_else(max_age))
146            .await?;
147
148        resp
149    }
150}
151
152const fn max_age() -> Duration {
153    Duration::from_secs(CookieOptions::MAX_AGE)
154}