actix_jwt_session/
redis_adapter.rs1use std::marker::PhantomData;
12use std::sync::Arc;
13
14pub use deadpool_redis;
15use deadpool_redis::Pool;
16use redis::AsyncCommands;
17
18use crate::*;
19
20#[derive(Clone)]
22struct RedisStorage<ClaimsType: Claims> {
23 pool: Pool,
24 _claims_type_marker: PhantomData<ClaimsType>,
25}
26
27impl<ClaimsType: Claims> RedisStorage<ClaimsType> {
28 pub fn new(pool: Pool) -> Self {
29 Self {
30 pool,
31 _claims_type_marker: Default::default(),
32 }
33 }
34}
35
36#[async_trait::async_trait(?Send)]
37impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
38where
39 ClaimsType: Claims,
40{
41 async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error> {
42 let pool = self.pool.clone();
43 let mut conn = pool.get().await.map_err(|e| {
44 #[cfg(feature = "use-tracing")]
45 tracing::error!("Unable to obtain redis connection: {e}");
46 Error::RedisConn
47 })?;
48 conn.get::<_, Vec<u8>>(jti).await.map_err(|e| {
49 #[cfg(feature = "use-tracing")]
50 tracing::error!("Session record not found in redis: {e}");
51 Error::NotFound
52 })
53 }
54
55 async fn set_by_jti(
56 self: Arc<Self>,
57 jwt_jti: &[u8],
58 refresh_jti: &[u8],
59 bytes: &[u8],
60 mut exp: Duration,
61 ) -> Result<(), Error> {
62 bad_ttl!(
63 exp,
64 Duration::seconds(1),
65 "Expiration time is bellow 1s. This is not allowed for redis server."
66 );
67 let pool = self.pool.clone();
68 let mut conn = pool.get().await.map_err(|e| {
69 #[cfg(feature = "use-tracing")]
70 tracing::error!("Unable to obtain redis connection: {e}");
71 Error::RedisConn
72 })?;
73 let mut pipeline = redis::Pipeline::new();
74 let _: () = pipeline
75 .set_ex(jwt_jti, bytes, exp.as_seconds_f32() as u64)
76 .set_ex(refresh_jti, bytes, exp.as_seconds_f32() as u64)
77 .query_async(&mut conn)
78 .await
79 .map_err(|e| {
80 #[cfg(feature = "use-tracing")]
81 tracing::error!("Failed to save session in redis: {e}");
82 Error::WriteFailed
83 })?;
84 Ok(())
85 }
86
87 async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error> {
88 let pool = self.pool.clone();
89 let mut conn = pool.get().await.map_err(|e| {
90 #[cfg(feature = "use-tracing")]
91 tracing::error!("Unable to obtain redis connection: {e}");
92 Error::RedisConn
93 })?;
94 let _: () = conn.del(jti).await.map_err(|e| {
95 #[cfg(feature = "use-tracing")]
96 tracing::error!("Session record can't be removed from redis: {e}");
97 Error::NotFound
98 })?;
99 Ok(())
100 }
101}
102
103impl<ClaimsType: Claims> SessionMiddlewareBuilder<ClaimsType> {
104 #[must_use]
105 pub fn with_redis_pool(mut self, pool: Pool) -> Self {
106 let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
107 let storage = SessionStorage::new(storage, self.jwt_encoding_key.clone(), self.algorithm);
108 self.storage = Some(storage);
109 self
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::ops::Add;
116
117 use actix_web::cookie::time::*;
118
119 use super::*;
120
121 #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
122 #[serde(rename_all = "snake_case")]
123 pub struct Claims {
124 #[serde(rename = "exp")]
125 pub expires_at: usize,
126 #[serde(rename = "iat")]
127 pub issues_at: usize,
128 #[serde(rename = "sub")]
130 pub subject: String,
131 #[serde(rename = "aud")]
132 pub audience: String,
133 #[serde(rename = "jti")]
134 pub jwt_id: uuid::Uuid,
135 #[serde(rename = "aci")]
136 pub account_id: i32,
137 }
138
139 impl crate::Claims for Claims {
140 fn jti(&self) -> uuid::Uuid {
141 self.jwt_id
142 }
143
144 fn subject(&self) -> &str {
145 &self.subject
146 }
147 }
148
149 async fn create_storage() -> (SessionStorage, SessionMiddlewareFactory<Claims>) {
150 use deadpool_redis::{Config, Runtime};
151
152 let redis = {
153 let cfg = Config::from_url("redis://localhost:6379");
154 let pool = cfg.create_pool(Some(Runtime::Tokio1)).unwrap();
155 pool
156 };
157 let jwt_signing_keys = JwtSigningKeys::generate(false).unwrap();
158 SessionMiddlewareFactory::<Claims>::build(
159 Arc::new(jwt_signing_keys.encoding_key),
160 Arc::new(jwt_signing_keys.decoding_key),
161 Algorithm::EdDSA,
162 )
163 .with_redis_pool(redis)
164 .with_extractors(
165 Extractors::default()
166 .with_refresh_cookie(REFRESH_COOKIE_NAME)
167 .with_refresh_header(REFRESH_HEADER_NAME)
168 .with_jwt_cookie(JWT_COOKIE_NAME)
169 .with_jwt_header(JWT_HEADER_NAME),
170 )
171 .finish()
172 }
173
174 #[tokio::test]
175 async fn check_encode() {
176 let (store, _) = create_storage().await;
177 let jwt_exp = JwtTtl(Duration::days(31));
178 let refresh_exp = RefreshTtl(Duration::days(31));
179
180 let original = Claims {
181 subject: "me".into(),
182 expires_at: OffsetDateTime::now_utc()
183 .add(Duration::days(31))
184 .unix_timestamp() as usize,
185 issues_at: OffsetDateTime::now_utc().unix_timestamp() as usize,
186 audience: "web".into(),
187 jwt_id: Uuid::new_v4(),
188 account_id: 24234,
189 };
190
191 store
192 .store(original.clone(), jwt_exp, refresh_exp)
193 .await
194 .unwrap();
195 let loaded = store.find_jwt(original.jwt_id).await.unwrap();
196 assert_eq!(original, loaded);
197 }
198}