1use std::sync::{Arc, Mutex};
2
3use async_trait::async_trait;
4use shield::{SessionData, SessionError, SessionStorage};
5use tower_sessions::Session;
6
7#[derive(Clone, Debug)]
8pub struct TowerSessionStorage {
9 session: Session,
10 session_key: &'static str,
11 session_data: Arc<Mutex<SessionData>>,
12}
13
14impl TowerSessionStorage {
15 pub async fn load(session: Session, session_key: &'static str) -> Result<Self, SessionError> {
16 let data = Self::load_data(&session, session_key).await?;
17
18 Ok(Self {
19 session,
20 session_key,
21 session_data: Arc::new(Mutex::new(data)),
22 })
23 }
24
25 async fn load_data(
26 session: &Session,
27 session_key: &'static str,
28 ) -> Result<SessionData, SessionError> {
29 session
30 .get::<SessionData>(session_key)
31 .await
32 .map_err(|err| SessionError::Engine(err.to_string()))
33 .map(|session_data| session_data.unwrap_or_default())
34 }
35}
36
37#[async_trait]
38impl SessionStorage for TowerSessionStorage {
39 fn data(&self) -> Arc<Mutex<SessionData>> {
40 self.session_data.clone()
41 }
42
43 async fn update(&self) -> Result<(), SessionError> {
44 let data = self
45 .session_data
46 .lock()
47 .map_err(|err| SessionError::Lock(err.to_string()))?
48 .clone();
49
50 self.session
51 .insert(self.session_key, data)
52 .await
53 .map_err(|err| SessionError::Engine(err.to_string()))
54 }
55
56 async fn renew(&self) -> Result<(), SessionError> {
57 self.session
58 .cycle_id()
59 .await
60 .map_err(|err| SessionError::Engine(err.to_string()))
61 }
62
63 async fn purge(&self) -> Result<(), SessionError> {
64 self.session
65 .flush()
66 .await
67 .map_err(|err| SessionError::Engine(err.to_string()))?;
68
69 let data = Self::load_data(&self.session, self.session_key).await?;
70
71 {
72 let mut session_data = self
73 .session_data
74 .lock()
75 .map_err(|err| SessionError::Lock(err.to_string()))?;
76 *session_data = data;
77 }
78
79 Ok(())
80 }
81}