1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use tower_sesh_core::{
6 store::{Error, SessionStoreImpl, Ttl},
7 Record, SessionKey,
8};
9
10pub use tower_sesh_core::SessionStore;
11
12type Result<T, E = Error> = std::result::Result<T, E>;
13
14#[derive(Clone)]
15pub struct MemoryStore<T>(Arc<Mutex<HashMap<SessionKey, Record<T>>>>);
16
17impl<T> Default for MemoryStore<T> {
18 fn default() -> Self {
19 let store = HashMap::new();
20 MemoryStore(Arc::new(Mutex::new(store)))
21 }
22}
23
24impl<T> MemoryStore<T> {
25 pub fn new() -> Self {
26 Self::default()
27 }
28}
29
30impl<T> SessionStore<T> for MemoryStore<T> where T: 'static + Send + Sync + Clone {}
31
32#[async_trait]
33impl<T> SessionStoreImpl<T> for MemoryStore<T>
34where
35 T: 'static + Send + Sync + Clone,
36{
37 async fn create(&self, data: &T, ttl: Ttl) -> Result<SessionKey> {
38 let session_key = SessionKey::generate();
39 self.update(&session_key, data, ttl).await?;
40 Ok(session_key)
41 }
42
43 async fn load(&self, session_key: &SessionKey) -> Result<Option<Record<T>>> {
44 let store_guard = self.0.lock();
45 Ok(store_guard.get(session_key).cloned())
46 }
47
48 async fn update(&self, session_key: &SessionKey, data: &T, ttl: Ttl) -> Result<()> {
49 let record = Record::new(data.clone(), ttl);
50 self.0.lock().insert(session_key.clone(), record);
51 Ok(())
52 }
53
54 async fn update_ttl(&self, session_key: &SessionKey, ttl: Ttl) -> Result<()> {
55 if let Some(record) = self.0.lock().get_mut(session_key) {
56 record.ttl = ttl;
57 }
58 Ok(())
59 }
60
61 async fn delete(&self, session_key: &SessionKey) -> Result<()> {
62 self.0.lock().remove(session_key);
63 Ok(())
64 }
65}
66
67pub struct CachingStore<T, Cache: SessionStore<T>, Store: SessionStore<T>> {
68 cache: Cache,
69 store: Store,
70 _marker: PhantomData<fn() -> T>,
71}
72
73impl<T, Cache: SessionStore<T>, Store: SessionStore<T>> CachingStore<T, Cache, Store> {
74 pub fn from_cache_and_store(cache: Cache, store: Store) -> Self {
75 Self {
76 cache,
77 store,
78 _marker: PhantomData,
79 }
80 }
81}
82
83impl<T, Cache: SessionStore<T>, Store: SessionStore<T>> SessionStore<T>
84 for CachingStore<T, Cache, Store>
85where
86 T: 'static + Send + Sync,
87{
88}
89
90#[async_trait]
91impl<T, Cache: SessionStore<T>, Store: SessionStore<T>> SessionStoreImpl<T>
92 for CachingStore<T, Cache, Store>
93where
94 T: 'static + Send + Sync,
95{
96 async fn create(&self, data: &T, ttl: Ttl) -> Result<SessionKey> {
97 let session_key = self.store.create(data, ttl).await?;
98 self.cache.update(&session_key, data, ttl).await?;
99
100 Ok(session_key)
101 }
102
103 async fn load(&self, session_key: &SessionKey) -> Result<Option<Record<T>>> {
104 match self.cache.load(session_key).await {
105 Ok(Some(record)) => Ok(Some(record)),
106 Ok(None) | Err(_) => {
107 let record = self.store.load(session_key).await?;
108
109 if let Some(record) = record.as_ref() {
110 let _ = self
111 .cache
112 .update(session_key, &record.data, record.ttl)
113 .await;
114 }
115
116 Ok(record)
117 }
118 }
119 }
120
121 async fn update(&self, session_key: &SessionKey, data: &T, ttl: Ttl) -> Result<()> {
122 let store_fut = self.store.update(session_key, data, ttl);
123 let cache_fut = self.cache.update(session_key, data, ttl);
124
125 futures::try_join!(store_fut, cache_fut)?;
126
127 Ok(())
128 }
129
130 async fn update_ttl(&self, session_key: &SessionKey, ttl: Ttl) -> Result<()> {
131 let store_fut = self.store.update_ttl(session_key, ttl);
132 let cache_fut = self.cache.update_ttl(session_key, ttl);
133
134 futures::try_join!(store_fut, cache_fut)?;
135
136 Ok(())
137 }
138
139 async fn delete(&self, session_key: &SessionKey) -> Result<()> {
140 let store_fut = self.store.delete(session_key);
141 let cache_fut = self.cache.delete(session_key);
142
143 futures::try_join!(store_fut, cache_fut)?;
144
145 Ok(())
146 }
147}