async_fred_session/
lib.rs

1//! # async-fred-session
2//! Redis backed session store for async-session using fred.rs.
3//! ```rust
4//! # #[tokio::main(flavor = "current_thread")]
5//! # async fn main() {
6//! use async_fred_session::{RedisSessionStore, fred::{pool::RedisPool, types::RedisConfig}};
7//! use async_session::{Session, SessionStore};
8//!
9//! // pool creation
10//! let config = RedisConfig::from_url("redis://127.0.0.1:6379").unwrap();
11//! let rds_pool = RedisPool::new(config, None, None, 6).unwrap();
12//! rds_pool.connect();
13//! rds_pool.wait_for_connect().await.unwrap();
14//!
15//! // store and session
16//! let store = RedisSessionStore::from_pool(rds_pool, Some("async-fred-session/".into()));
17//! let mut session = Session::new();
18//! session.insert("key", "value").unwrap();
19//!
20//! let cookie_value = store.store_session(session).await.unwrap().unwrap();
21//! let session = store.load_session(cookie_value).await.unwrap().unwrap();
22//! assert_eq!(&session.get::<String>("key").unwrap(), "value");
23//! # }
24//! ```
25
26#![forbid(unsafe_code, future_incompatible)]
27
28pub use fred;
29
30use async_session::{async_trait, serde_json, Result, Session, SessionStore};
31use fred::{
32    pool::RedisPool,
33    prelude::*,
34    types::{RedisKey, ScanType, Scanner},
35};
36use futures::stream::StreamExt;
37
38#[derive(Clone)]
39pub struct RedisSessionStore {
40    pool: RedisPool,
41    prefix: Option<String>,
42}
43
44impl std::fmt::Debug for RedisSessionStore {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{:?}", self.prefix)
47    }
48}
49
50impl RedisSessionStore {
51    /// creates a redis store from an existing [`fred::pool::RedisPool`]
52    /// ```rust
53    /// # #[tokio::main(flavor = "current_thread")]
54    /// # async fn main() {
55    /// use async_fred_session::RedisSessionStore;
56    /// use async_session::{Session, SessionStore};
57    /// use fred::{pool::RedisPool, prelude::*};
58    ///
59    /// let conf = RedisConfig::from_url("redis://127.0.0.1:6379").unwrap();
60    /// let pool = RedisPool::new(conf, None, None, 6).unwrap();
61    /// pool.connect();
62    /// pool.wait_for_connect().await.unwrap();
63    /// let store = RedisSessionStore::from_pool(pool, Some("async-fred-session/".into()));
64    /// # }
65    /// ```
66    pub fn from_pool(pool: RedisPool, prefix: Option<String>) -> Self {
67        Self { pool, prefix }
68    }
69
70    /// returns the number of sessions in this store
71    pub async fn count(&self) -> Result<usize> {
72        match self.prefix {
73            None => Ok(self.pool.dbsize().await?),
74            Some(_) => Ok(self.ids().await?.map_or(0, |v| v.len())),
75        }
76    }
77
78    async fn ids(&self) -> Result<Option<Vec<RedisKey>>> {
79        let mut result = Vec::new();
80        let mut scanner = self
81            .pool
82            .scan(self.prefix_key("*"), None, Some(ScanType::String));
83
84        while let Some(res) = scanner.next().await {
85            if let Some(keys) = res?.take_results() {
86                result.extend_from_slice(&keys);
87            }
88        }
89
90        Ok((!result.is_empty()).then_some(result))
91    }
92
93    fn prefix_key(&self, key: &str) -> String {
94        match &self.prefix {
95            None => key.to_string(),
96            Some(prefix) => format!("{prefix}{key}"),
97        }
98    }
99
100    #[cfg(test)]
101    async fn ttl_for_session(&self, session: &Session) -> Result<usize> {
102        Ok(self.pool.ttl(self.prefix_key(session.id())).await?)
103    }
104}
105
106#[async_trait]
107impl SessionStore for RedisSessionStore {
108    async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
109        let id = Session::id_from_cookie_value(&cookie_value)?;
110        Ok(self
111            .pool
112            .get::<Option<String>, String>(self.prefix_key(&id))
113            .await?
114            .map(|v| serde_json::from_str(&v))
115            .transpose()?)
116    }
117
118    async fn store_session(&self, session: Session) -> Result<Option<String>> {
119        let id = self.prefix_key(session.id());
120        let string = serde_json::to_string(&session)?;
121        let expiration = session
122            .expires_in()
123            .map(|d| Expiration::EX(d.as_secs() as i64));
124
125        self.pool.set(id, string, expiration, None, false).await?;
126
127        Ok(session.into_cookie_value())
128    }
129
130    async fn destroy_session(&self, session: Session) -> Result {
131        Ok(self.pool.del(self.prefix_key(session.id())).await?)
132    }
133
134    async fn clear_store(&self) -> Result {
135        match self.prefix {
136            None => Ok(self.pool.flushall(false).await?),
137            Some(_) => match self.ids().await? {
138                None => Ok(()),
139                Some(ids) => Ok(self.pool.del(ids).await?),
140            },
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use std::time::Duration;
149    use tokio::time::sleep;
150
151    async fn create_session_store() -> RedisSessionStore {
152        let conf = RedisConfig::from_url("redis://127.0.0.1:6379").unwrap();
153        let pool = RedisPool::new(conf, None, None, 6).unwrap();
154
155        pool.connect();
156        pool.wait_for_connect().await.unwrap();
157
158        let store = RedisSessionStore::from_pool(pool, Some("async-session-test/".into()));
159        store.clear_store().await.unwrap();
160        store
161    }
162
163    #[tokio::test]
164    async fn creating_a_new_session_with_no_expiry() -> Result {
165        let store = create_session_store().await;
166        let mut session = Session::new();
167        session.insert("key", "Hello")?;
168
169        let cloned = session.clone();
170        let cookie_value = store.store_session(session).await?.unwrap();
171        let loaded_session = store.load_session(cookie_value).await?.unwrap();
172
173        assert_eq!(cloned.id(), loaded_session.id());
174        assert_eq!("Hello", &loaded_session.get::<String>("key").unwrap());
175        assert!(!loaded_session.is_expired());
176        assert!(loaded_session.validate().is_some());
177
178        Ok(())
179    }
180
181    #[tokio::test]
182    async fn updating_a_session() -> Result {
183        let store = create_session_store().await;
184        let mut session = Session::new();
185
186        session.insert("key", "value")?;
187        let cookie_value = store.store_session(session).await?.unwrap();
188        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
189
190        session.insert("key", "other value")?;
191        assert_eq!(None, store.store_session(session).await?);
192        let session = store.load_session(cookie_value.clone()).await?.unwrap();
193
194        assert_eq!(&session.get::<String>("key").unwrap(), "other value");
195        assert_eq!(1, store.count().await.unwrap());
196
197        Ok(())
198    }
199
200    #[tokio::test]
201    async fn updating_a_session_extending_expiry() -> Result {
202        let store = create_session_store().await;
203        let mut session = Session::new();
204        session.expire_in(Duration::from_secs(5));
205        let original_expires = session.expiry().unwrap().clone();
206        let cookie_value = store.store_session(session).await?.unwrap();
207
208        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
209        let ttl = store.ttl_for_session(&session).await?;
210        assert!(ttl > 3 && ttl < 5);
211
212        assert_eq!(session.expiry().unwrap(), &original_expires);
213        session.expire_in(Duration::from_secs(10));
214        let new_expires = session.expiry().unwrap().clone();
215        store.store_session(session).await?;
216
217        let session = store.load_session(cookie_value.clone()).await?.unwrap();
218        let ttl = store.ttl_for_session(&session).await?;
219        assert!(ttl > 8 && ttl < 10);
220        assert_eq!(session.expiry().unwrap(), &new_expires);
221
222        assert_eq!(1, store.count().await.unwrap());
223        sleep(Duration::from_secs(10)).await;
224        assert_eq!(0, store.count().await.unwrap());
225
226        Ok(())
227    }
228
229    #[tokio::test]
230    async fn creating_a_new_session_with_expiry() -> Result {
231        let store = create_session_store().await;
232        let mut session = Session::new();
233        session.expire_in(Duration::from_secs(3));
234        session.insert("key", "value")?;
235        let cloned = session.clone();
236
237        let cookie_value = store.store_session(session).await?.unwrap();
238
239        assert!(store.ttl_for_session(&cloned).await? > 1);
240
241        let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
242        assert_eq!(cloned.id(), loaded_session.id());
243        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
244
245        assert!(!loaded_session.is_expired());
246
247        sleep(Duration::from_secs(2)).await;
248        assert_eq!(None, store.load_session(cookie_value).await?);
249
250        Ok(())
251    }
252
253    #[tokio::test]
254    async fn destroying_a_single_session() -> Result {
255        let store = create_session_store().await;
256        for _ in 0..3 {
257            store.store_session(Session::new()).await?;
258        }
259
260        let cookie = store.store_session(Session::new()).await?.unwrap();
261        assert_eq!(4, store.count().await?);
262        let session = store.load_session(cookie.clone()).await?.unwrap();
263        store.destroy_session(session.clone()).await.unwrap();
264        assert_eq!(None, store.load_session(cookie).await?);
265        assert_eq!(3, store.count().await?);
266
267        // attempting to destroy the session again is not an error
268        assert!(store.destroy_session(session).await.is_ok());
269        Ok(())
270    }
271
272    #[tokio::test]
273    async fn clearing_the_whole_store() -> Result {
274        let store = create_session_store().await;
275        for _ in 0..3 {
276            store.store_session(Session::new()).await?;
277        }
278
279        assert_eq!(3, store.count().await?);
280        store.clear_store().await.unwrap();
281        assert_eq!(0, store.count().await?);
282
283        Ok(())
284    }
285}