actix_jwt_session/
redis_adapter.rs

1//! Default session storage which uses async redis requests
2//!
3//! Sessions are serialized to binary format and stored using [uuid::Uuid] key
4//! as bytes. All sessions must have expirations time after which they will be
5//! automatically removed by redis.
6//!
7//! [RedisStorage] is constructed by [RedisMiddlewareFactory] from
8//! [redis_async_pool::RedisPool] and shared between all [RedisMiddleware]
9//! instances.
10
11use std::marker::PhantomData;
12use std::sync::Arc;
13
14pub use deadpool_redis;
15use deadpool_redis::Pool;
16use redis::AsyncCommands;
17
18use crate::*;
19
20/// Redis implementation for [TokenStorage]
21#[derive(Clone)]
22struct RedisStorage<ClaimsType: Claims> {
23    pool: Pool,
24    _claims_type_marker: PhantomData<ClaimsType>,
25}
26
27impl<ClaimsType: Claims> RedisStorage<ClaimsType> {
28    pub fn new(pool: Pool) -> Self {
29        Self {
30            pool,
31            _claims_type_marker: Default::default(),
32        }
33    }
34}
35
36#[async_trait::async_trait(?Send)]
37impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
38where
39    ClaimsType: Claims,
40{
41    async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error> {
42        let pool = self.pool.clone();
43        let mut conn = pool.get().await.map_err(|e| {
44            #[cfg(feature = "use-tracing")]
45            tracing::error!("Unable to obtain redis connection: {e}");
46            Error::RedisConn
47        })?;
48        conn.get::<_, Vec<u8>>(jti).await.map_err(|e| {
49            #[cfg(feature = "use-tracing")]
50            tracing::error!("Session record not found in redis: {e}");
51            Error::NotFound
52        })
53    }
54
55    async fn set_by_jti(
56        self: Arc<Self>,
57        jwt_jti: &[u8],
58        refresh_jti: &[u8],
59        bytes: &[u8],
60        mut exp: Duration,
61    ) -> Result<(), Error> {
62        bad_ttl!(
63            exp,
64            Duration::seconds(1),
65            "Expiration time is bellow 1s. This is not allowed for redis server."
66        );
67        let pool = self.pool.clone();
68        let mut conn = pool.get().await.map_err(|e| {
69            #[cfg(feature = "use-tracing")]
70            tracing::error!("Unable to obtain redis connection: {e}");
71            Error::RedisConn
72        })?;
73        let mut pipeline = redis::Pipeline::new();
74        let _: () = pipeline
75            .set_ex(jwt_jti, bytes, exp.as_seconds_f32() as u64)
76            .set_ex(refresh_jti, bytes, exp.as_seconds_f32() as u64)
77            .query_async(&mut conn)
78            .await
79            .map_err(|e| {
80                #[cfg(feature = "use-tracing")]
81                tracing::error!("Failed to save session in redis: {e}");
82                Error::WriteFailed
83            })?;
84        Ok(())
85    }
86
87    async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error> {
88        let pool = self.pool.clone();
89        let mut conn = pool.get().await.map_err(|e| {
90            #[cfg(feature = "use-tracing")]
91            tracing::error!("Unable to obtain redis connection: {e}");
92            Error::RedisConn
93        })?;
94        let _: () = conn.del(jti).await.map_err(|e| {
95            #[cfg(feature = "use-tracing")]
96            tracing::error!("Session record can't be removed from redis: {e}");
97            Error::NotFound
98        })?;
99        Ok(())
100    }
101}
102
103impl<ClaimsType: Claims> SessionMiddlewareBuilder<ClaimsType> {
104    #[must_use]
105    pub fn with_redis_pool(mut self, pool: Pool) -> Self {
106        let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
107        let storage = SessionStorage::new(storage, self.jwt_encoding_key.clone(), self.algorithm);
108        self.storage = Some(storage);
109        self
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::ops::Add;
116
117    use actix_web::cookie::time::*;
118
119    use super::*;
120
121    #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
122    #[serde(rename_all = "snake_case")]
123    pub struct Claims {
124        #[serde(rename = "exp")]
125        pub expires_at: usize,
126        #[serde(rename = "iat")]
127        pub issues_at: usize,
128        /// Account login
129        #[serde(rename = "sub")]
130        pub subject: String,
131        #[serde(rename = "aud")]
132        pub audience: String,
133        #[serde(rename = "jti")]
134        pub jwt_id: uuid::Uuid,
135        #[serde(rename = "aci")]
136        pub account_id: i32,
137    }
138
139    impl crate::Claims for Claims {
140        fn jti(&self) -> uuid::Uuid {
141            self.jwt_id
142        }
143
144        fn subject(&self) -> &str {
145            &self.subject
146        }
147    }
148
149    async fn create_storage() -> (SessionStorage, SessionMiddlewareFactory<Claims>) {
150        use deadpool_redis::{Config, Runtime};
151
152        let redis = {
153            let cfg = Config::from_url("redis://localhost:6379");
154            let pool = cfg.create_pool(Some(Runtime::Tokio1)).unwrap();
155            pool
156        };
157        let jwt_signing_keys = JwtSigningKeys::generate(false).unwrap();
158        SessionMiddlewareFactory::<Claims>::build(
159            Arc::new(jwt_signing_keys.encoding_key),
160            Arc::new(jwt_signing_keys.decoding_key),
161            Algorithm::EdDSA,
162        )
163        .with_redis_pool(redis)
164        .with_extractors(
165            Extractors::default()
166                .with_refresh_cookie(REFRESH_COOKIE_NAME)
167                .with_refresh_header(REFRESH_HEADER_NAME)
168                .with_jwt_cookie(JWT_COOKIE_NAME)
169                .with_jwt_header(JWT_HEADER_NAME),
170        )
171        .finish()
172    }
173
174    #[tokio::test]
175    async fn check_encode() {
176        let (store, _) = create_storage().await;
177        let jwt_exp = JwtTtl(Duration::days(31));
178        let refresh_exp = RefreshTtl(Duration::days(31));
179
180        let original = Claims {
181            subject: "me".into(),
182            expires_at: OffsetDateTime::now_utc()
183                .add(Duration::days(31))
184                .unix_timestamp() as usize,
185            issues_at: OffsetDateTime::now_utc().unix_timestamp() as usize,
186            audience: "web".into(),
187            jwt_id: Uuid::new_v4(),
188            account_id: 24234,
189        };
190
191        store
192            .store(original.clone(), jwt_exp, refresh_exp)
193            .await
194            .unwrap();
195        let loaded = store.find_jwt(original.jwt_id).await.unwrap();
196        assert_eq!(original, loaded);
197    }
198}