1use std::sync::Arc;
2
3use actix_web::cookie::time::Duration;
4use anyhow::Error;
5use redis::{aio::ConnectionManager, AsyncCommands, Client, Cmd, FromRedisValue, Value};
6
7use super::SessionKey;
8use crate::storage::{
9 interface::{LoadError, SaveError, SessionState, UpdateError},
10 utils::generate_session_key,
11 SessionStore,
12};
13
14#[derive(Clone)]
78pub struct RedisSessionStore {
79 configuration: CacheConfiguration,
80 client: RedisSessionConn,
81}
82
83#[derive(Clone)]
84enum RedisSessionConn {
85 Single(ConnectionManager),
87
88 #[cfg(feature = "redis-pool")]
90 Pool(deadpool_redis::Pool),
91}
92
93#[derive(Clone)]
94struct CacheConfiguration {
95 cache_keygen: Arc<dyn Fn(&str) -> String + Send + Sync>,
96}
97
98impl Default for CacheConfiguration {
99 fn default() -> Self {
100 Self {
101 cache_keygen: Arc::new(str::to_owned),
102 }
103 }
104}
105
106impl RedisSessionStore {
107 pub fn builder(connection_string: impl Into<String>) -> RedisSessionStoreBuilder {
112 RedisSessionStoreBuilder {
113 configuration: CacheConfiguration::default(),
114 conn_builder: RedisSessionConnBuilder::Single(connection_string.into()),
115 }
116 }
117
118 #[cfg(feature = "redis-pool")]
123 pub fn builder_pooled(pool: impl Into<deadpool_redis::Pool>) -> RedisSessionStoreBuilder {
124 RedisSessionStoreBuilder {
125 configuration: CacheConfiguration::default(),
126 conn_builder: RedisSessionConnBuilder::Pool(pool.into()),
127 }
128 }
129
130 pub async fn new(connection_string: impl Into<String>) -> Result<RedisSessionStore, Error> {
135 Self::builder(connection_string).build().await
136 }
137
138 #[cfg(feature = "redis-pool")]
143 pub async fn new_pooled(
144 pool: impl Into<deadpool_redis::Pool>,
145 ) -> anyhow::Result<RedisSessionStore> {
146 Self::builder_pooled(pool).build().await
147 }
148}
149
150#[must_use]
153pub struct RedisSessionStoreBuilder {
154 configuration: CacheConfiguration,
155 conn_builder: RedisSessionConnBuilder,
156}
157
158enum RedisSessionConnBuilder {
159 Single(String),
161
162 #[cfg(feature = "redis-pool")]
164 Pool(deadpool_redis::Pool),
165}
166
167impl RedisSessionConnBuilder {
168 async fn into_client(self) -> anyhow::Result<RedisSessionConn> {
169 Ok(match self {
170 RedisSessionConnBuilder::Single(conn_string) => {
171 RedisSessionConn::Single(ConnectionManager::new(Client::open(conn_string)?).await?)
172 }
173
174 #[cfg(feature = "redis-pool")]
175 RedisSessionConnBuilder::Pool(pool) => RedisSessionConn::Pool(pool),
176 })
177 }
178}
179
180impl RedisSessionStoreBuilder {
181 pub fn cache_keygen<F>(mut self, keygen: F) -> Self
183 where
184 F: Fn(&str) -> String + 'static + Send + Sync,
185 {
186 self.configuration.cache_keygen = Arc::new(keygen);
187 self
188 }
189
190 pub async fn build(self) -> anyhow::Result<RedisSessionStore> {
192 let client = self.conn_builder.into_client().await?;
193
194 Ok(RedisSessionStore {
195 configuration: self.configuration,
196 client,
197 })
198 }
199}
200
201impl SessionStore for RedisSessionStore {
202 async fn load(&self, session_key: &SessionKey) -> Result<Option<SessionState>, LoadError> {
203 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
204
205 let value: Option<String> = self
206 .execute_command(redis::cmd("GET").arg(&[&cache_key]))
207 .await
208 .map_err(LoadError::Other)?;
209
210 match value {
211 None => Ok(None),
212 Some(value) => Ok(serde_json::from_str(&value)
213 .map_err(Into::into)
214 .map_err(LoadError::Deserialization)?),
215 }
216 }
217
218 async fn save(
219 &self,
220 session_state: SessionState,
221 ttl: &Duration,
222 ) -> Result<SessionKey, SaveError> {
223 let body = serde_json::to_string(&session_state)
224 .map_err(Into::into)
225 .map_err(SaveError::Serialization)?;
226 let session_key = generate_session_key();
227 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
228
229 self.execute_command::<()>(
230 redis::cmd("SET")
231 .arg(&[
232 &cache_key, &body, "NX", "EX", ])
237 .arg(
238 ttl.whole_seconds(), ),
240 )
241 .await
242 .map_err(SaveError::Other)?;
243
244 Ok(session_key)
245 }
246
247 async fn update(
248 &self,
249 session_key: SessionKey,
250 session_state: SessionState,
251 ttl: &Duration,
252 ) -> Result<SessionKey, UpdateError> {
253 let body = serde_json::to_string(&session_state)
254 .map_err(Into::into)
255 .map_err(UpdateError::Serialization)?;
256
257 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
258
259 let v: Value = self
260 .execute_command(redis::cmd("SET").arg(&[
261 &cache_key,
262 &body,
263 "XX", "EX", &format!("{}", ttl.whole_seconds()),
266 ]))
267 .await
268 .map_err(UpdateError::Other)?;
269
270 match v {
271 Value::Nil => {
272 self.save(session_state, ttl)
277 .await
278 .map_err(|err| match err {
279 SaveError::Serialization(err) => UpdateError::Serialization(err),
280 SaveError::Other(err) => UpdateError::Other(err),
281 })
282 }
283 Value::Int(_) | Value::Okay | Value::SimpleString(_) => Ok(session_key),
284 val => Err(UpdateError::Other(anyhow::anyhow!(
285 "Failed to update session state. {:?}",
286 val
287 ))),
288 }
289 }
290
291 async fn update_ttl(&self, session_key: &SessionKey, ttl: &Duration) -> anyhow::Result<()> {
292 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
293
294 match self.client {
295 RedisSessionConn::Single(ref conn) => {
296 conn.clone()
297 .expire::<_, ()>(&cache_key, ttl.whole_seconds())
298 .await?;
299 }
300
301 #[cfg(feature = "redis-pool")]
302 RedisSessionConn::Pool(ref pool) => {
303 pool.get()
304 .await?
305 .expire::<_, ()>(&cache_key, ttl.whole_seconds())
306 .await?;
307 }
308 }
309
310 Ok(())
311 }
312
313 async fn delete(&self, session_key: &SessionKey) -> Result<(), Error> {
314 let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
315
316 self.execute_command::<()>(redis::cmd("DEL").arg(&[&cache_key]))
317 .await
318 .map_err(UpdateError::Other)?;
319
320 Ok(())
321 }
322}
323
324impl RedisSessionStore {
325 #[allow(clippy::needless_pass_by_ref_mut)]
339 async fn execute_command<T: FromRedisValue>(&self, cmd: &mut Cmd) -> anyhow::Result<T> {
340 let mut can_retry = true;
341
342 match self.client {
343 RedisSessionConn::Single(ref conn) => {
344 let mut conn = conn.clone();
345
346 loop {
347 match cmd.query_async(&mut conn).await {
348 Ok(value) => return Ok(value),
349 Err(err) => {
350 if can_retry && err.is_connection_dropped() {
351 tracing::debug!(
352 "Connection dropped while trying to talk to Redis. Retrying."
353 );
354
355 can_retry = false;
357
358 continue;
359 } else {
360 return Err(err.into());
361 }
362 }
363 }
364 }
365 }
366
367 #[cfg(feature = "redis-pool")]
368 RedisSessionConn::Pool(ref pool) => {
369 let mut conn = pool.get().await?;
370
371 loop {
372 match cmd.query_async(&mut conn).await {
373 Ok(value) => return Ok(value),
374 Err(err) => {
375 if can_retry && err.is_connection_dropped() {
376 tracing::debug!(
377 "Connection dropped while trying to talk to Redis. Retrying."
378 );
379
380 can_retry = false;
382
383 continue;
384 } else {
385 return Err(err.into());
386 }
387 }
388 }
389 }
390 }
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use std::collections::HashMap;
398
399 use actix_web::cookie::time;
400 #[cfg(not(feature = "redis-session"))]
401 use deadpool_redis::{Config, Runtime};
402
403 use super::*;
404 use crate::test_helpers::acceptance_test_suite;
405
406 async fn redis_store() -> RedisSessionStore {
407 #[cfg(feature = "redis-session")]
408 {
409 RedisSessionStore::new("redis://127.0.0.1:6379")
410 .await
411 .unwrap()
412 }
413
414 #[cfg(not(feature = "redis-session"))]
415 {
416 let redis_pool = Config::from_url("redis://127.0.0.1:6379")
417 .create_pool(Some(Runtime::Tokio1))
418 .unwrap();
419 RedisSessionStore::new(redis_pool.clone())
420 }
421 }
422
423 #[actix_web::test]
424 async fn test_session_workflow() {
425 let redis_store = redis_store().await;
426 acceptance_test_suite(move || redis_store.clone(), true).await;
427 }
428
429 #[actix_web::test]
430 async fn loading_a_missing_session_returns_none() {
431 let store = redis_store().await;
432 let session_key = generate_session_key();
433 assert!(store.load(&session_key).await.unwrap().is_none());
434 }
435
436 #[actix_web::test]
437 async fn loading_an_invalid_session_state_returns_deserialization_error() {
438 let store = redis_store().await;
439 let session_key = generate_session_key();
440
441 match store.client {
442 RedisSessionConn::Single(ref conn) => conn
443 .clone()
444 .set::<_, _, ()>(session_key.as_ref(), "random-thing-which-is-not-json")
445 .await
446 .unwrap(),
447
448 #[cfg(feature = "redis-pool")]
449 RedisSessionConn::Pool(ref pool) => {
450 pool.get()
451 .await
452 .unwrap()
453 .set::<_, _, ()>(session_key.as_ref(), "random-thing-which-is-not-json")
454 .await
455 .unwrap();
456 }
457 }
458
459 assert!(matches!(
460 store.load(&session_key).await.unwrap_err(),
461 LoadError::Deserialization(_),
462 ));
463 }
464
465 #[actix_web::test]
466 async fn updating_of_an_expired_state_is_handled_gracefully() {
467 let store = redis_store().await;
468 let session_key = generate_session_key();
469 let initial_session_key = session_key.as_ref().to_owned();
470 let updated_session_key = store
471 .update(session_key, HashMap::new(), &time::Duration::seconds(1))
472 .await
473 .unwrap();
474 assert_ne!(initial_session_key, updated_session_key.as_ref());
475 }
476}