actix_session/storage/
redis_rs.rs

1use std::sync::Arc;
2
3use actix_web::cookie::time::Duration;
4use anyhow::Error;
5use redis::{aio::ConnectionManager, AsyncCommands, Client, Cmd, FromRedisValue, Value};
6
7use super::SessionKey;
8use crate::storage::{
9    interface::{LoadError, SaveError, SessionState, UpdateError},
10    utils::generate_session_key,
11    SessionStore,
12};
13
14/// Use Redis as session storage backend.
15///
16/// ```no_run
17/// use actix_web::{web, App, HttpServer, HttpResponse, Error};
18/// use actix_session::{SessionMiddleware, storage::RedisSessionStore};
19/// use actix_web::cookie::Key;
20///
21/// // The secret key would usually be read from a configuration file/environment variables.
22/// fn get_secret_key() -> Key {
23///     # todo!()
24///     // [...]
25/// }
26///
27/// #[actix_web::main]
28/// async fn main() -> std::io::Result<()> {
29///     let secret_key = get_secret_key();
30///     let redis_connection_string = "redis://127.0.0.1:6379";
31///     let store = RedisSessionStore::new(redis_connection_string).await.unwrap();
32///
33///     HttpServer::new(move ||
34///             App::new()
35///             .wrap(SessionMiddleware::new(
36///                 store.clone(),
37///                 secret_key.clone()
38///             ))
39///             .default_service(web::to(|| HttpResponse::Ok())))
40///         .bind(("127.0.0.1", 8080))?
41///         .run()
42///         .await
43/// }
44/// ```
45///
46/// # TLS support
47/// Add the `redis-rs-tls-session` or `redis-rs-tls-session-rustls` feature flag to enable TLS support. You can then establish a TLS
48/// connection to Redis using the `rediss://` URL scheme:
49///
50/// ```no_run
51/// use actix_session::{storage::RedisSessionStore};
52///
53/// # actix_web::rt::System::new().block_on(async {
54/// let redis_connection_string = "rediss://127.0.0.1:6379";
55/// let store = RedisSessionStore::new(redis_connection_string).await.unwrap();
56/// # })
57/// ```
58///
59/// # Pooled Redis Connections
60///
61/// When the `redis-pool` crate feature is enabled, a pre-existing pool from [`deadpool_redis`] can
62/// be provided.
63///
64/// ```no_run
65/// use actix_session::storage::RedisSessionStore;
66/// use deadpool_redis::{Config, Runtime};
67///
68/// let redis_cfg = Config::from_url("redis://127.0.0.1:6379");
69/// let redis_pool = redis_cfg.create_pool(Some(Runtime::Tokio1)).unwrap();
70///
71/// let store = RedisSessionStore::new_pooled(redis_pool);
72/// ```
73///
74/// # Implementation notes
75///
76/// `RedisSessionStore` leverages the [`redis`] crate as the underlying Redis client.
77#[derive(Clone)]
78pub struct RedisSessionStore {
79    configuration: CacheConfiguration,
80    client: RedisSessionConn,
81}
82
83#[derive(Clone)]
84enum RedisSessionConn {
85    /// Single connection.
86    Single(ConnectionManager),
87
88    /// Connection pool.
89    #[cfg(feature = "redis-pool")]
90    Pool(deadpool_redis::Pool),
91}
92
93#[derive(Clone)]
94struct CacheConfiguration {
95    cache_keygen: Arc<dyn Fn(&str) -> String + Send + Sync>,
96}
97
98impl Default for CacheConfiguration {
99    fn default() -> Self {
100        Self {
101            cache_keygen: Arc::new(str::to_owned),
102        }
103    }
104}
105
106impl RedisSessionStore {
107    /// Returns a fluent API builder to configure [`RedisSessionStore`].
108    ///
109    /// It takes as input the only required input to create a new instance of [`RedisSessionStore`]
110    /// - a connection string for Redis.
111    pub fn builder(connection_string: impl Into<String>) -> RedisSessionStoreBuilder {
112        RedisSessionStoreBuilder {
113            configuration: CacheConfiguration::default(),
114            conn_builder: RedisSessionConnBuilder::Single(connection_string.into()),
115        }
116    }
117
118    /// Returns a fluent API builder to configure [`RedisSessionStore`].
119    ///
120    /// It takes as input the only required input to create a new instance of [`RedisSessionStore`]
121    /// - a pool object for Redis.
122    #[cfg(feature = "redis-pool")]
123    pub fn builder_pooled(pool: impl Into<deadpool_redis::Pool>) -> RedisSessionStoreBuilder {
124        RedisSessionStoreBuilder {
125            configuration: CacheConfiguration::default(),
126            conn_builder: RedisSessionConnBuilder::Pool(pool.into()),
127        }
128    }
129
130    /// Creates a new instance of [`RedisSessionStore`] using the default configuration.
131    ///
132    /// It takes as input the only required input to create a new instance of [`RedisSessionStore`]
133    /// - a connection string for Redis.
134    pub async fn new(connection_string: impl Into<String>) -> Result<RedisSessionStore, Error> {
135        Self::builder(connection_string).build().await
136    }
137
138    /// Creates a new instance of [`RedisSessionStore`] using the default configuration.
139    ///
140    /// It takes as input the only required input to create a new instance of [`RedisSessionStore`]
141    /// - a pool object for Redis.
142    #[cfg(feature = "redis-pool")]
143    pub async fn new_pooled(
144        pool: impl Into<deadpool_redis::Pool>,
145    ) -> anyhow::Result<RedisSessionStore> {
146        Self::builder_pooled(pool).build().await
147    }
148}
149
150/// A fluent builder to construct a [`RedisSessionStore`] instance with custom configuration
151/// parameters.
152#[must_use]
153pub struct RedisSessionStoreBuilder {
154    configuration: CacheConfiguration,
155    conn_builder: RedisSessionConnBuilder,
156}
157
158enum RedisSessionConnBuilder {
159    /// Single connection string.
160    Single(String),
161
162    /// Pre-built connection pool.
163    #[cfg(feature = "redis-pool")]
164    Pool(deadpool_redis::Pool),
165}
166
167impl RedisSessionConnBuilder {
168    async fn into_client(self) -> anyhow::Result<RedisSessionConn> {
169        Ok(match self {
170            RedisSessionConnBuilder::Single(conn_string) => {
171                RedisSessionConn::Single(ConnectionManager::new(Client::open(conn_string)?).await?)
172            }
173
174            #[cfg(feature = "redis-pool")]
175            RedisSessionConnBuilder::Pool(pool) => RedisSessionConn::Pool(pool),
176        })
177    }
178}
179
180impl RedisSessionStoreBuilder {
181    /// Set a custom cache key generation strategy, expecting a session key as input.
182    pub fn cache_keygen<F>(mut self, keygen: F) -> Self
183    where
184        F: Fn(&str) -> String + 'static + Send + Sync,
185    {
186        self.configuration.cache_keygen = Arc::new(keygen);
187        self
188    }
189
190    /// Finalises builder and returns a [`RedisSessionStore`] instance.
191    pub async fn build(self) -> anyhow::Result<RedisSessionStore> {
192        let client = self.conn_builder.into_client().await?;
193
194        Ok(RedisSessionStore {
195            configuration: self.configuration,
196            client,
197        })
198    }
199}
200
201impl SessionStore for RedisSessionStore {
202    async fn load(&self, session_key: &SessionKey) -> Result<Option<SessionState>, LoadError> {
203        let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
204
205        let value: Option<String> = self
206            .execute_command(redis::cmd("GET").arg(&[&cache_key]))
207            .await
208            .map_err(Into::into)
209            .map_err(LoadError::Other)?;
210
211        match value {
212            None => Ok(None),
213            Some(value) => Ok(serde_json::from_str(&value)
214                .map_err(Into::into)
215                .map_err(LoadError::Deserialization)?),
216        }
217    }
218
219    async fn save(
220        &self,
221        session_state: SessionState,
222        ttl: &Duration,
223    ) -> Result<SessionKey, SaveError> {
224        let body = serde_json::to_string(&session_state)
225            .map_err(Into::into)
226            .map_err(SaveError::Serialization)?;
227        let session_key = generate_session_key();
228        let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
229
230        self.execute_command::<()>(
231            redis::cmd("SET")
232                .arg(&[
233                    &cache_key, // key
234                    &body,      // value
235                    "NX",       // only set the key if it does not already exist
236                    "EX",       // set expiry / TTL
237                ])
238                .arg(
239                    ttl.whole_seconds(), // EXpiry in seconds
240                ),
241        )
242        .await
243        .map_err(Into::into)
244        .map_err(SaveError::Other)?;
245
246        Ok(session_key)
247    }
248
249    async fn update(
250        &self,
251        session_key: SessionKey,
252        session_state: SessionState,
253        ttl: &Duration,
254    ) -> Result<SessionKey, UpdateError> {
255        let body = serde_json::to_string(&session_state)
256            .map_err(Into::into)
257            .map_err(UpdateError::Serialization)?;
258
259        let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
260
261        let v: Value = self
262            .execute_command(redis::cmd("SET").arg(&[
263                &cache_key,
264                &body,
265                "XX", // XX: Only set the key if it already exist.
266                "EX", // EX: set expiry
267                &format!("{}", ttl.whole_seconds()),
268            ]))
269            .await
270            .map_err(Into::into)
271            .map_err(UpdateError::Other)?;
272
273        match v {
274            Value::Nil => {
275                // The SET operation was not performed because the XX condition was not verified.
276                // This can happen if the session state expired between the load operation and the
277                // update operation. Unlucky, to say the least. We fall back to the `save` routine
278                // to ensure that the new key is unique.
279                self.save(session_state, ttl)
280                    .await
281                    .map_err(|err| match err {
282                        SaveError::Serialization(err) => UpdateError::Serialization(err),
283                        SaveError::Other(err) => UpdateError::Other(err),
284                    })
285            }
286            Value::Int(_) | Value::Okay | Value::SimpleString(_) => Ok(session_key),
287            val => Err(UpdateError::Other(anyhow::anyhow!(
288                "Failed to update session state. {:?}",
289                val
290            ))),
291        }
292    }
293
294    async fn update_ttl(&self, session_key: &SessionKey, ttl: &Duration) -> anyhow::Result<()> {
295        let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
296
297        match self.client {
298            RedisSessionConn::Single(ref conn) => {
299                conn.clone()
300                    .expire::<_, ()>(&cache_key, ttl.whole_seconds())
301                    .await?;
302            }
303
304            #[cfg(feature = "redis-pool")]
305            RedisSessionConn::Pool(ref pool) => {
306                pool.get()
307                    .await?
308                    .expire::<_, ()>(&cache_key, ttl.whole_seconds())
309                    .await?;
310            }
311        }
312
313        Ok(())
314    }
315
316    async fn delete(&self, session_key: &SessionKey) -> Result<(), Error> {
317        let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
318
319        self.execute_command::<()>(redis::cmd("DEL").arg(&[&cache_key]))
320            .await
321            .map_err(Into::into)
322            .map_err(UpdateError::Other)?;
323
324        Ok(())
325    }
326}
327
328impl RedisSessionStore {
329    /// Execute Redis command and retry once in certain cases.
330    ///
331    /// `ConnectionManager` automatically reconnects when it encounters an error talking to Redis.
332    /// The request that bumped into the error, though, fails.
333    ///
334    /// This is generally OK, but there is an unpleasant edge case: Redis client timeouts. The
335    /// server is configured to drop connections who have been active longer than a pre-determined
336    /// threshold. `redis-rs` does not proactively detect that the connection has been dropped - you
337    /// only find out when you try to use it.
338    ///
339    /// This helper method catches this case (`.is_connection_dropped`) to execute a retry. The
340    /// retry will be executed on a fresh connection, therefore it is likely to succeed (or fail for
341    /// a different more meaningful reason).
342    #[allow(clippy::needless_pass_by_ref_mut)]
343    async fn execute_command<T: FromRedisValue>(&self, cmd: &mut Cmd) -> anyhow::Result<T> {
344        let mut can_retry = true;
345
346        match self.client {
347            RedisSessionConn::Single(ref conn) => {
348                let mut conn = conn.clone();
349
350                loop {
351                    match cmd.query_async(&mut conn).await {
352                        Ok(value) => return Ok(value),
353                        Err(err) => {
354                            if can_retry && err.is_connection_dropped() {
355                                tracing::debug!(
356                                    "Connection dropped while trying to talk to Redis. Retrying."
357                                );
358
359                                // Retry at most once
360                                can_retry = false;
361
362                                continue;
363                            } else {
364                                return Err(err.into());
365                            }
366                        }
367                    }
368                }
369            }
370
371            #[cfg(feature = "redis-pool")]
372            RedisSessionConn::Pool(ref pool) => {
373                let mut conn = pool.get().await?;
374
375                loop {
376                    match cmd.query_async(&mut conn).await {
377                        Ok(value) => return Ok(value),
378                        Err(err) => {
379                            if can_retry && err.is_connection_dropped() {
380                                tracing::debug!(
381                                    "Connection dropped while trying to talk to Redis. Retrying."
382                                );
383
384                                // Retry at most once
385                                can_retry = false;
386
387                                continue;
388                            } else {
389                                return Err(err.into());
390                            }
391                        }
392                    }
393                }
394            }
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use std::collections::HashMap;
402
403    use actix_web::cookie::time;
404    #[cfg(not(feature = "redis-session"))]
405    use deadpool_redis::{Config, Runtime};
406
407    use super::*;
408    use crate::test_helpers::acceptance_test_suite;
409
410    async fn redis_store() -> RedisSessionStore {
411        #[cfg(feature = "redis-session")]
412        {
413            RedisSessionStore::new("redis://127.0.0.1:6379")
414                .await
415                .unwrap()
416        }
417
418        #[cfg(not(feature = "redis-session"))]
419        {
420            let redis_pool = Config::from_url("redis://127.0.0.1:6379")
421                .create_pool(Some(Runtime::Tokio1))
422                .unwrap();
423            RedisSessionStore::new(redis_pool.clone())
424        }
425    }
426
427    #[actix_web::test]
428    async fn test_session_workflow() {
429        let redis_store = redis_store().await;
430        acceptance_test_suite(move || redis_store.clone(), true).await;
431    }
432
433    #[actix_web::test]
434    async fn loading_a_missing_session_returns_none() {
435        let store = redis_store().await;
436        let session_key = generate_session_key();
437        assert!(store.load(&session_key).await.unwrap().is_none());
438    }
439
440    #[actix_web::test]
441    async fn loading_an_invalid_session_state_returns_deserialization_error() {
442        let store = redis_store().await;
443        let session_key = generate_session_key();
444
445        match store.client {
446            RedisSessionConn::Single(ref conn) => conn
447                .clone()
448                .set::<_, _, ()>(session_key.as_ref(), "random-thing-which-is-not-json")
449                .await
450                .unwrap(),
451
452            #[cfg(feature = "redis-pool")]
453            RedisSessionConn::Pool(ref pool) => {
454                pool.get()
455                    .await
456                    .unwrap()
457                    .set::<_, _, ()>(session_key.as_ref(), "random-thing-which-is-not-json")
458                    .await
459                    .unwrap();
460            }
461        }
462
463        assert!(matches!(
464            store.load(&session_key).await.unwrap_err(),
465            LoadError::Deserialization(_),
466        ));
467    }
468
469    #[actix_web::test]
470    async fn updating_of_an_expired_state_is_handled_gracefully() {
471        let store = redis_store().await;
472        let session_key = generate_session_key();
473        let initial_session_key = session_key.as_ref().to_owned();
474        let updated_session_key = store
475            .update(session_key, HashMap::new(), &time::Duration::seconds(1))
476            .await
477            .unwrap();
478        assert_ne!(initial_session_key, updated_session_key.as_ref());
479    }
480}