shield_tower/
session.rs

1use std::sync::{Arc, Mutex};
2
3use async_trait::async_trait;
4use shield::{SessionData, SessionError, SessionStorage};
5use tower_sessions::Session;
6
7#[derive(Clone, Debug)]
8pub struct TowerSessionStorage {
9    session: Session,
10    session_key: &'static str,
11    session_data: Arc<Mutex<SessionData>>,
12}
13
14impl TowerSessionStorage {
15    pub async fn load(session: Session, session_key: &'static str) -> Result<Self, SessionError> {
16        let data = Self::load_data(&session, session_key).await?;
17
18        Ok(Self {
19            session,
20            session_key,
21            session_data: Arc::new(Mutex::new(data)),
22        })
23    }
24
25    async fn load_data(
26        session: &Session,
27        session_key: &'static str,
28    ) -> Result<SessionData, SessionError> {
29        session
30            .get::<SessionData>(session_key)
31            .await
32            .map_err(|err| SessionError::Engine(err.to_string()))
33            .map(|session_data| session_data.unwrap_or_default())
34    }
35}
36
37#[async_trait]
38impl SessionStorage for TowerSessionStorage {
39    fn data(&self) -> Arc<Mutex<SessionData>> {
40        self.session_data.clone()
41    }
42
43    async fn update(&self) -> Result<(), SessionError> {
44        let data = self
45            .session_data
46            .lock()
47            .map_err(|err| SessionError::Lock(err.to_string()))?
48            .clone();
49
50        self.session
51            .insert(self.session_key, data)
52            .await
53            .map_err(|err| SessionError::Engine(err.to_string()))
54    }
55
56    async fn renew(&self) -> Result<(), SessionError> {
57        self.session
58            .cycle_id()
59            .await
60            .map_err(|err| SessionError::Engine(err.to_string()))
61    }
62
63    async fn purge(&self) -> Result<(), SessionError> {
64        self.session
65            .flush()
66            .await
67            .map_err(|err| SessionError::Engine(err.to_string()))?;
68
69        let data = Self::load_data(&self.session, self.session_key).await?;
70
71        {
72            let mut session_data = self
73                .session_data
74                .lock()
75                .map_err(|err| SessionError::Lock(err.to_string()))?;
76            *session_data = data;
77        }
78
79        Ok(())
80    }
81}