1use std::sync::Arc;
2
3use actix_web::cookie::time::Duration;
4use anyhow::Error;
5use redis::{aio::ConnectionManager, AsyncCommands, Client, Cmd, FromRedisValue, Value};
6
7use super::SessionKey;
8use crate::storage::{
9 interface::{LoadError, SaveError, SessionState, UpdateError},
10 utils::generate_session_key,
11 SessionStore,
12};
13
14#[derive(Clone)]
78pub struct RedisSessionStore {
79 configuration: CacheConfiguration,
80 client: RedisSessionConn,
81}
82
83#[derive(Clone)]
84enum RedisSessionConn {
85 Single(ConnectionManager),
87
88 #[cfg(feature = "redis-pool")]
90 Pool(deadpool_redis::Pool),
91}
92
93#[derive(Clone)]
94struct CacheConfiguration {
95 cache_keygen: Arc<dyn Fn(&str) -> String + Send + Sync>,
96}
97
98impl Default for CacheConfiguration {
99 fn default() -> Self {
100 Self {
101 cache_keygen: Arc::new(str::to_owned),
102 }
103 }
104}
105
106impl RedisSessionStore {
107 pub fn builder(connection_string: impl Into<String>) -> RedisSessionStoreBuilder {
112 RedisSessionStoreBuilder {
113 configuration: CacheConfiguration::default(),
114 conn_builder: RedisSessionConnBuilder::Single(connection_string.into()),
115 }
116 }
117
118 #[cfg(feature = "redis-pool")]
123 pub fn builder_pooled(pool: impl Into<deadpool_redis::Pool>) -> RedisSessionStoreBuilder {
124 RedisSessionStoreBuilder {
125 configuration: CacheConfiguration::default(),
126 conn_builder: RedisSessionConnBuilder::Pool(pool.into()),
127 }
128 }
129
130 pub async fn new(connection_string: impl Into<String>) -> Result<RedisSessionStore, Error> {
135 Self::builder(connection_string).build().await
136 }
137
138 #[cfg(feature = "redis-pool")]
143 pub async fn new_pooled(
144 pool: impl Into<deadpool_redis::Pool>,
145 ) -> anyhow::Result<RedisSessionStore> {
146 Self::builder_pooled(pool).build().await
147 }
148}
149
150#[must_use]
153pub struct RedisSessionStoreBuilder {
154 configuration: CacheConfiguration,
155 conn_builder: RedisSessionConnBuilder,
156}
157
158enum RedisSessionConnBuilder {
159 Single(String),
161
162 #[cfg(feature = "redis-pool")]
164 Pool(deadpool_redis::Pool),
165}
166
167impl RedisSessionConnBuilder {
168 async fn into_client(self) -> anyhow::Result<RedisSessionConn> {
169 Ok(match self {
170 RedisSessionConnBuilder::Single(conn_string) => {
171 RedisSessionConn::Single(ConnectionManager::new(Client::open(conn_string)?).await?)
172 }
173
174 #[cfg(feature = "redis-pool")]
175 RedisSessionConnBuilder::Pool(pool) => RedisSessionConn::Pool(pool),
176 })
177 }
178}
179
180impl RedisSessionStoreBuilder {
181 pub fn cache_keygen<F>(mut self, keygen: F) -> Self
183 where
184 F: Fn(&str) -> String + 'static + Send + Sync,
185 {
186 self.configuration.cache_keygen = Arc::new(keygen);
187 self
188 }
189
190 pub async fn build(self) -> anyhow::Result<RedisSessionStore> {
192 let client = self.conn_builder.into_client().await?;
193
194 Ok(RedisSessionStore {
195 configuration: self.configuration,
196 client,
197 })
198 }
199}
200
201impl SessionStore for RedisSessionStore {
202 async fn load(&self, session_key: &SessionKey) -> Result<Option<SessionState>, LoadError> {
203 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
204
205 let value: Option<String> = self
206 .execute_command(redis::cmd("GET").arg(&[&cache_key]))
207 .await
208 .map_err(Into::into)
209 .map_err(LoadError::Other)?;
210
211 match value {
212 None => Ok(None),
213 Some(value) => Ok(serde_json::from_str(&value)
214 .map_err(Into::into)
215 .map_err(LoadError::Deserialization)?),
216 }
217 }
218
219 async fn save(
220 &self,
221 session_state: SessionState,
222 ttl: &Duration,
223 ) -> Result<SessionKey, SaveError> {
224 let body = serde_json::to_string(&session_state)
225 .map_err(Into::into)
226 .map_err(SaveError::Serialization)?;
227 let session_key = generate_session_key();
228 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
229
230 self.execute_command::<()>(
231 redis::cmd("SET")
232 .arg(&[
233 &cache_key, &body, "NX", "EX", ])
238 .arg(
239 ttl.whole_seconds(), ),
241 )
242 .await
243 .map_err(Into::into)
244 .map_err(SaveError::Other)?;
245
246 Ok(session_key)
247 }
248
249 async fn update(
250 &self,
251 session_key: SessionKey,
252 session_state: SessionState,
253 ttl: &Duration,
254 ) -> Result<SessionKey, UpdateError> {
255 let body = serde_json::to_string(&session_state)
256 .map_err(Into::into)
257 .map_err(UpdateError::Serialization)?;
258
259 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
260
261 let v: Value = self
262 .execute_command(redis::cmd("SET").arg(&[
263 &cache_key,
264 &body,
265 "XX", "EX", &format!("{}", ttl.whole_seconds()),
268 ]))
269 .await
270 .map_err(Into::into)
271 .map_err(UpdateError::Other)?;
272
273 match v {
274 Value::Nil => {
275 self.save(session_state, ttl)
280 .await
281 .map_err(|err| match err {
282 SaveError::Serialization(err) => UpdateError::Serialization(err),
283 SaveError::Other(err) => UpdateError::Other(err),
284 })
285 }
286 Value::Int(_) | Value::Okay | Value::SimpleString(_) => Ok(session_key),
287 val => Err(UpdateError::Other(anyhow::anyhow!(
288 "Failed to update session state. {:?}",
289 val
290 ))),
291 }
292 }
293
294 async fn update_ttl(&self, session_key: &SessionKey, ttl: &Duration) -> anyhow::Result<()> {
295 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
296
297 match self.client {
298 RedisSessionConn::Single(ref conn) => {
299 conn.clone()
300 .expire::<_, ()>(&cache_key, ttl.whole_seconds())
301 .await?;
302 }
303
304 #[cfg(feature = "redis-pool")]
305 RedisSessionConn::Pool(ref pool) => {
306 pool.get()
307 .await?
308 .expire::<_, ()>(&cache_key, ttl.whole_seconds())
309 .await?;
310 }
311 }
312
313 Ok(())
314 }
315
316 async fn delete(&self, session_key: &SessionKey) -> Result<(), Error> {
317 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
318
319 self.execute_command::<()>(redis::cmd("DEL").arg(&[&cache_key]))
320 .await
321 .map_err(Into::into)
322 .map_err(UpdateError::Other)?;
323
324 Ok(())
325 }
326}
327
328impl RedisSessionStore {
329 #[allow(clippy::needless_pass_by_ref_mut)]
343 async fn execute_command<T: FromRedisValue>(&self, cmd: &mut Cmd) -> anyhow::Result<T> {
344 let mut can_retry = true;
345
346 match self.client {
347 RedisSessionConn::Single(ref conn) => {
348 let mut conn = conn.clone();
349
350 loop {
351 match cmd.query_async(&mut conn).await {
352 Ok(value) => return Ok(value),
353 Err(err) => {
354 if can_retry && err.is_connection_dropped() {
355 tracing::debug!(
356 "Connection dropped while trying to talk to Redis. Retrying."
357 );
358
359 can_retry = false;
361
362 continue;
363 } else {
364 return Err(err.into());
365 }
366 }
367 }
368 }
369 }
370
371 #[cfg(feature = "redis-pool")]
372 RedisSessionConn::Pool(ref pool) => {
373 let mut conn = pool.get().await?;
374
375 loop {
376 match cmd.query_async(&mut conn).await {
377 Ok(value) => return Ok(value),
378 Err(err) => {
379 if can_retry && err.is_connection_dropped() {
380 tracing::debug!(
381 "Connection dropped while trying to talk to Redis. Retrying."
382 );
383
384 can_retry = false;
386
387 continue;
388 } else {
389 return Err(err.into());
390 }
391 }
392 }
393 }
394 }
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use std::collections::HashMap;
402
403 use actix_web::cookie::time;
404 #[cfg(not(feature = "redis-session"))]
405 use deadpool_redis::{Config, Runtime};
406
407 use super::*;
408 use crate::test_helpers::acceptance_test_suite;
409
410 async fn redis_store() -> RedisSessionStore {
411 #[cfg(feature = "redis-session")]
412 {
413 RedisSessionStore::new("redis://127.0.0.1:6379")
414 .await
415 .unwrap()
416 }
417
418 #[cfg(not(feature = "redis-session"))]
419 {
420 let redis_pool = Config::from_url("redis://127.0.0.1:6379")
421 .create_pool(Some(Runtime::Tokio1))
422 .unwrap();
423 RedisSessionStore::new(redis_pool.clone())
424 }
425 }
426
427 #[actix_web::test]
428 async fn test_session_workflow() {
429 let redis_store = redis_store().await;
430 acceptance_test_suite(move || redis_store.clone(), true).await;
431 }
432
433 #[actix_web::test]
434 async fn loading_a_missing_session_returns_none() {
435 let store = redis_store().await;
436 let session_key = generate_session_key();
437 assert!(store.load(&session_key).await.unwrap().is_none());
438 }
439
440 #[actix_web::test]
441 async fn loading_an_invalid_session_state_returns_deserialization_error() {
442 let store = redis_store().await;
443 let session_key = generate_session_key();
444
445 match store.client {
446 RedisSessionConn::Single(ref conn) => conn
447 .clone()
448 .set::<_, _, ()>(session_key.as_ref(), "random-thing-which-is-not-json")
449 .await
450 .unwrap(),
451
452 #[cfg(feature = "redis-pool")]
453 RedisSessionConn::Pool(ref pool) => {
454 pool.get()
455 .await
456 .unwrap()
457 .set::<_, _, ()>(session_key.as_ref(), "random-thing-which-is-not-json")
458 .await
459 .unwrap();
460 }
461 }
462
463 assert!(matches!(
464 store.load(&session_key).await.unwrap_err(),
465 LoadError::Deserialization(_),
466 ));
467 }
468
469 #[actix_web::test]
470 async fn updating_of_an_expired_state_is_handled_gracefully() {
471 let store = redis_store().await;
472 let session_key = generate_session_key();
473 let initial_session_key = session_key.as_ref().to_owned();
474 let updated_session_key = store
475 .update(session_key, HashMap::new(), &time::Duration::seconds(1))
476 .await
477 .unwrap();
478 assert_ne!(initial_session_key, updated_session_key.as_ref());
479 }
480}