Skip to main content

actix_jwt/store/
memory.rs

1//! Thread-safe, in-memory refresh-token store.
2//!
3//! Mirrors
4//! [`store/memory.go`](https://github.com/LdDl/echo-jwt/blob/master/store/memory.go)
5//! from the Go implementation.
6
7use std::collections::HashMap;
8
9use async_trait::async_trait;
10use chrono::Utc;
11use tokio::sync::RwLock;
12
13use crate::core::{RefreshTokenData, TokenStore};
14use crate::errors::JwtError;
15
16/// Thread-safe, in-memory implementation of [`TokenStore`].
17///
18/// Internally keeps a `HashMap<String, RefreshTokenData>` behind a
19/// [`tokio::sync::RwLock`], allowing concurrent reads while serializing
20/// writes.  Suitable for development, testing and single-instance
21/// deployments.
22///
23/// # Examples
24///
25/// ```
26/// use chrono::{Duration, Utc};
27/// use actix_jwt::core::TokenStore;
28/// use actix_jwt::store::InMemoryRefreshTokenStore;
29///
30/// # #[tokio::main]
31/// # async fn main() {
32/// let store = InMemoryRefreshTokenStore::new();
33///
34/// let expiry = Utc::now() + Duration::hours(1);
35/// store.set("tok-1", serde_json::json!({"uid": 1}), expiry).await.unwrap();
36///
37/// let data = store.get("tok-1").await.unwrap();
38/// assert_eq!(data["uid"], 1);
39///
40/// store.delete("tok-1").await.unwrap();
41/// assert!(store.get("tok-1").await.is_err());
42/// # }
43/// ```
44pub struct InMemoryRefreshTokenStore {
45    tokens: RwLock<HashMap<String, RefreshTokenData>>,
46}
47
48impl InMemoryRefreshTokenStore {
49    /// Creates a new empty store.
50    pub fn new() -> Self {
51        Self {
52            tokens: RwLock::new(HashMap::new()),
53        }
54    }
55
56    /// Returns a clone of all **non-expired** tokens in the store.
57    ///
58    /// Expired tokens are filtered out but **not** removed from the
59    /// underlying map.  Call [`TokenStore::cleanup`] to actually evict them.
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// use chrono::{Duration, Utc};
65    /// use actix_jwt::store::InMemoryRefreshTokenStore;
66    /// use actix_jwt::core::TokenStore;
67    ///
68    /// # #[tokio::main]
69    /// # async fn main() {
70    /// let store = InMemoryRefreshTokenStore::new();
71    /// let expiry = Utc::now() + Duration::hours(1);
72    /// store.set("a", serde_json::json!(1), expiry).await.unwrap();
73    ///
74    /// let all = store.get_all().await;
75    /// assert_eq!(all.len(), 1);
76    /// # }
77    /// ```
78    pub async fn get_all(&self) -> HashMap<String, RefreshTokenData> {
79        let tokens = self.tokens.read().await;
80        let now = Utc::now();
81        tokens
82            .iter()
83            .filter(|(_, data)| data.expiry > now)
84            .map(|(k, v)| (k.clone(), v.clone()))
85            .collect()
86    }
87
88    /// Removes **all** tokens from the store (including non-expired ones).
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use chrono::{Duration, Utc};
94    /// use actix_jwt::store::InMemoryRefreshTokenStore;
95    /// use actix_jwt::core::TokenStore;
96    ///
97    /// # #[tokio::main]
98    /// # async fn main() {
99    /// let store = InMemoryRefreshTokenStore::new();
100    /// store.set("x", serde_json::json!(1), Utc::now() + Duration::hours(1)).await.unwrap();
101    /// store.clear().await;
102    /// assert_eq!(store.count().await.unwrap(), 0);
103    /// # }
104    /// ```
105    pub async fn clear(&self) {
106        let mut tokens = self.tokens.write().await;
107        tokens.clear();
108    }
109}
110
111impl Default for InMemoryRefreshTokenStore {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117#[async_trait]
118impl TokenStore for InMemoryRefreshTokenStore {
119    /// Stores a refresh token with associated user data and expiration.
120    ///
121    /// # Errors
122    ///
123    /// Returns [`JwtError::TokenEmpty`] if `token` is an empty string.
124    async fn set(
125        &self,
126        token: &str,
127        user_data: serde_json::Value,
128        expiry: chrono::DateTime<Utc>,
129    ) -> Result<(), JwtError> {
130        if token.is_empty() {
131            return Err(JwtError::TokenEmpty);
132        }
133
134        let data = RefreshTokenData {
135            user_data,
136            expiry,
137            created: Utc::now(),
138        };
139
140        let mut tokens = self.tokens.write().await;
141        tokens.insert(token.to_string(), data);
142        Ok(())
143    }
144
145    /// Retrieves user data for the given refresh token.
146    ///
147    /// Performs lazy cleanup: if the token exists but is expired it is removed
148    /// from the store and [`JwtError::RefreshTokenNotFound`] is returned.
149    ///
150    /// # Errors
151    ///
152    /// * [`JwtError::TokenEmpty`] - empty token string.
153    /// * [`JwtError::RefreshTokenNotFound`] - token absent or expired.
154    async fn get(&self, token: &str) -> Result<serde_json::Value, JwtError> {
155        if token.is_empty() {
156            return Err(JwtError::TokenEmpty);
157        }
158
159        let mut tokens = self.tokens.write().await;
160        match tokens.get(token) {
161            Some(data) => {
162                if data.is_expired() {
163                    tokens.remove(token);
164                    Err(JwtError::RefreshTokenNotFound)
165                } else {
166                    Ok(data.user_data.clone())
167                }
168            }
169            None => Err(JwtError::RefreshTokenNotFound),
170        }
171    }
172
173    /// Removes a refresh token from storage.
174    ///
175    /// Silently succeeds if the token is empty or does not exist.
176    async fn delete(&self, token: &str) -> Result<(), JwtError> {
177        if token.is_empty() {
178            return Ok(());
179        }
180
181        let mut tokens = self.tokens.write().await;
182        tokens.remove(token);
183        Ok(())
184    }
185
186    /// Removes all expired tokens from the store.
187    ///
188    /// Returns the number of entries that were evicted.
189    async fn cleanup(&self) -> Result<usize, JwtError> {
190        let mut tokens = self.tokens.write().await;
191        let now = Utc::now();
192        let before = tokens.len();
193        tokens.retain(|_, data| data.expiry > now);
194        let after = tokens.len();
195        Ok(before - after)
196    }
197
198    /// Returns the total number of tokens in the store **including expired
199    /// ones**.
200    ///
201    /// To get only valid tokens use [`get_all`](Self::get_all).
202    async fn count(&self) -> Result<usize, JwtError> {
203        let tokens = self.tokens.read().await;
204        Ok(tokens.len())
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use chrono::Duration;
212
213    #[tokio::test]
214    async fn test_set() {
215        let store = InMemoryRefreshTokenStore::new();
216        let user_data =
217            serde_json::json!({"id": "123", "username": "testuser", "email": "test@example.com"});
218        let expiry = Utc::now() + Duration::hours(1);
219
220        store.set("token123", user_data, expiry).await.unwrap();
221
222        let count = store.count().await.unwrap();
223        assert_eq!(count, 1);
224    }
225
226    #[tokio::test]
227    async fn test_get() {
228        let store = InMemoryRefreshTokenStore::new();
229        let user_data =
230            serde_json::json!({"id": "123", "username": "testuser", "email": "test@example.com"});
231        let expiry = Utc::now() + Duration::hours(1);
232
233        store
234            .set("token123", user_data.clone(), expiry)
235            .await
236            .unwrap();
237
238        let result = store.get("token123").await.unwrap();
239        assert_eq!(result["id"], "123");
240        assert_eq!(result["username"], "testuser");
241        assert_eq!(result["email"], "test@example.com");
242    }
243
244    #[tokio::test]
245    async fn test_set_empty_token() {
246        let store = InMemoryRefreshTokenStore::new();
247        let expiry = Utc::now() + Duration::hours(1);
248
249        let result = store.set("", serde_json::json!({}), expiry).await;
250        assert!(result.is_err());
251    }
252
253    #[tokio::test]
254    async fn test_get_empty_token() {
255        let store = InMemoryRefreshTokenStore::new();
256
257        let result = store.get("").await;
258        assert!(result.is_err());
259    }
260
261    #[tokio::test]
262    async fn test_get_nonexistent() {
263        let store = InMemoryRefreshTokenStore::new();
264
265        let result = store.get("nonexistent").await;
266        assert!(result.is_err());
267    }
268
269    #[tokio::test]
270    async fn test_get_expired_auto_cleanup() {
271        let store = InMemoryRefreshTokenStore::new();
272        let expiry = Utc::now() - Duration::seconds(1);
273
274        // Insert an already-expired token
275        {
276            let mut tokens = store.tokens.write().await;
277            tokens.insert(
278                "expired".to_string(),
279                RefreshTokenData {
280                    user_data: serde_json::json!({"user_id": "123"}),
281                    expiry,
282                    created: Utc::now() - Duration::hours(1),
283                },
284            );
285        }
286
287        let result = store.get("expired").await;
288        assert!(result.is_err());
289
290        // Token should have been removed
291        let count = store.count().await.unwrap();
292        assert_eq!(count, 0);
293    }
294
295    #[tokio::test]
296    async fn test_delete() {
297        let store = InMemoryRefreshTokenStore::new();
298        let expiry = Utc::now() + Duration::hours(1);
299
300        store
301            .set("token1", serde_json::json!({}), expiry)
302            .await
303            .unwrap();
304
305        store.delete("token1").await.unwrap();
306
307        let result = store.get("token1").await;
308        assert!(result.is_err());
309    }
310
311    #[tokio::test]
312    async fn test_delete_empty_token() {
313        let store = InMemoryRefreshTokenStore::new();
314        // Should not error
315        store.delete("").await.unwrap();
316    }
317
318    #[tokio::test]
319    async fn test_cleanup() {
320        let store = InMemoryRefreshTokenStore::new();
321        let valid_expiry = Utc::now() + Duration::hours(1);
322        let expired_expiry = Utc::now() - Duration::seconds(1);
323
324        store
325            .set("valid", serde_json::json!({}), valid_expiry)
326            .await
327            .unwrap();
328
329        // Insert an expired token directly
330        {
331            let mut tokens = store.tokens.write().await;
332            tokens.insert(
333                "expired".to_string(),
334                RefreshTokenData {
335                    user_data: serde_json::json!({}),
336                    expiry: expired_expiry,
337                    created: Utc::now() - Duration::hours(1),
338                },
339            );
340        }
341
342        let cleaned = store.cleanup().await.unwrap();
343        assert_eq!(cleaned, 1);
344
345        let count = store.count().await.unwrap();
346        assert_eq!(count, 1);
347    }
348
349    #[tokio::test]
350    async fn test_get_all_filters_expired() {
351        let store = InMemoryRefreshTokenStore::new();
352        let valid_expiry = Utc::now() + Duration::hours(1);
353
354        store
355            .set("valid", serde_json::json!({"id": 1}), valid_expiry)
356            .await
357            .unwrap();
358
359        // Insert an expired token directly
360        {
361            let mut tokens = store.tokens.write().await;
362            tokens.insert(
363                "expired".to_string(),
364                RefreshTokenData {
365                    user_data: serde_json::json!({"id": 2}),
366                    expiry: Utc::now() - Duration::seconds(1),
367                    created: Utc::now() - Duration::hours(1),
368                },
369            );
370        }
371
372        let all = store.get_all().await;
373        assert_eq!(all.len(), 1);
374        assert!(all.contains_key("valid"));
375    }
376
377    #[tokio::test]
378    async fn test_clear() {
379        let store = InMemoryRefreshTokenStore::new();
380        let expiry = Utc::now() + Duration::hours(1);
381
382        store
383            .set("t1", serde_json::json!({}), expiry)
384            .await
385            .unwrap();
386        store
387            .set("t2", serde_json::json!({}), expiry)
388            .await
389            .unwrap();
390
391        store.clear().await;
392
393        let count = store.count().await.unwrap();
394        assert_eq!(count, 0);
395    }
396
397    #[tokio::test]
398    async fn test_new_store() {
399        let store = InMemoryRefreshTokenStore::new();
400        let count = store.count().await.unwrap();
401        assert_eq!(count, 0, "New store should be empty");
402    }
403
404    #[tokio::test]
405    async fn test_delete_nonexistent() {
406        let store = InMemoryRefreshTokenStore::new();
407        // Deleting a token that does not exist should succeed without error
408        let result = store.delete("nonexistent_token").await;
409        assert!(result.is_ok());
410    }
411
412    #[tokio::test]
413    async fn test_count() {
414        let store = InMemoryRefreshTokenStore::new();
415        let valid_expiry = Utc::now() + Duration::hours(1);
416        let expired_expiry = Utc::now() - Duration::seconds(1);
417
418        // Add 3 valid tokens
419        for i in 0..3 {
420            store
421                .set(
422                    &format!("valid{}", i),
423                    serde_json::json!({"id": i}),
424                    valid_expiry,
425                )
426                .await
427                .unwrap();
428        }
429
430        // Add 2 expired tokens directly
431        {
432            let mut tokens = store.tokens.write().await;
433            for i in 0..2 {
434                tokens.insert(
435                    format!("expired{}", i),
436                    RefreshTokenData {
437                        user_data: serde_json::json!({"id": i}),
438                        expiry: expired_expiry,
439                        created: Utc::now() - Duration::hours(1),
440                    },
441                );
442            }
443        }
444
445        // count() returns total (including expired)
446        let count = store.count().await.unwrap();
447        assert_eq!(
448            count, 5,
449            "Count should include both valid and expired tokens"
450        );
451
452        // After cleanup, only valid tokens remain
453        let cleaned = store.cleanup().await.unwrap();
454        assert_eq!(cleaned, 2);
455
456        let count = store.count().await.unwrap();
457        assert_eq!(count, 3, "Count after cleanup should be 3");
458    }
459
460    #[tokio::test]
461    async fn test_concurrent_access() {
462        use std::sync::Arc;
463
464        let store = Arc::new(InMemoryRefreshTokenStore::new());
465        let num_tasks = 100usize;
466
467        // Concurrent writes
468        let mut handles = Vec::new();
469        for i in 0..num_tasks {
470            let store = Arc::clone(&store);
471            handles.push(tokio::spawn(async move {
472                let token = format!("token{}", i);
473                let user_data = serde_json::json!({"id": i});
474                let expiry = Utc::now() + Duration::hours(1);
475                store.set(&token, user_data, expiry).await.unwrap();
476            }));
477        }
478        for h in handles {
479            h.await.unwrap();
480        }
481
482        let count = store.count().await.unwrap();
483        assert_eq!(count, num_tasks);
484
485        // Concurrent reads
486        let mut handles = Vec::new();
487        for i in 0..num_tasks {
488            let store = Arc::clone(&store);
489            handles.push(tokio::spawn(async move {
490                let token = format!("token{}", i);
491                let result = store.get(&token).await;
492                assert!(result.is_ok(), "Failed to get token{}", i);
493            }));
494        }
495        for h in handles {
496            h.await.unwrap();
497        }
498
499        // Concurrent deletes
500        let mut handles = Vec::new();
501        for i in 0..num_tasks {
502            let store = Arc::clone(&store);
503            handles.push(tokio::spawn(async move {
504                let token = format!("token{}", i);
505                store.delete(&token).await.unwrap();
506            }));
507        }
508        for h in handles {
509            h.await.unwrap();
510        }
511
512        let count = store.count().await.unwrap();
513        assert_eq!(
514            count, 0,
515            "All tokens should be deleted after concurrent deletes"
516        );
517    }
518
519    #[tokio::test]
520    async fn test_is_expired() {
521        // Non-expired token
522        let data = RefreshTokenData {
523            user_data: serde_json::json!({"user_id": "123"}),
524            expiry: Utc::now() + Duration::hours(1),
525            created: Utc::now(),
526        };
527        assert!(
528            !data.is_expired(),
529            "Token with future expiry should not be expired"
530        );
531
532        // Expired token
533        let data = RefreshTokenData {
534            user_data: serde_json::json!({"user_id": "123"}),
535            expiry: Utc::now() - Duration::hours(1),
536            created: Utc::now() - Duration::hours(2),
537        };
538        assert!(
539            data.is_expired(),
540            "Token with past expiry should be expired"
541        );
542    }
543}