1use std::collections::HashMap;
8
9use async_trait::async_trait;
10use chrono::Utc;
11use tokio::sync::RwLock;
12
13use crate::core::{RefreshTokenData, TokenStore};
14use crate::errors::JwtError;
15
16pub struct InMemoryRefreshTokenStore {
45 tokens: RwLock<HashMap<String, RefreshTokenData>>,
46}
47
48impl InMemoryRefreshTokenStore {
49 pub fn new() -> Self {
51 Self {
52 tokens: RwLock::new(HashMap::new()),
53 }
54 }
55
56 pub async fn get_all(&self) -> HashMap<String, RefreshTokenData> {
79 let tokens = self.tokens.read().await;
80 let now = Utc::now();
81 tokens
82 .iter()
83 .filter(|(_, data)| data.expiry > now)
84 .map(|(k, v)| (k.clone(), v.clone()))
85 .collect()
86 }
87
88 pub async fn clear(&self) {
106 let mut tokens = self.tokens.write().await;
107 tokens.clear();
108 }
109}
110
111impl Default for InMemoryRefreshTokenStore {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117#[async_trait]
118impl TokenStore for InMemoryRefreshTokenStore {
119 async fn set(
125 &self,
126 token: &str,
127 user_data: serde_json::Value,
128 expiry: chrono::DateTime<Utc>,
129 ) -> Result<(), JwtError> {
130 if token.is_empty() {
131 return Err(JwtError::TokenEmpty);
132 }
133
134 let data = RefreshTokenData {
135 user_data,
136 expiry,
137 created: Utc::now(),
138 };
139
140 let mut tokens = self.tokens.write().await;
141 tokens.insert(token.to_string(), data);
142 Ok(())
143 }
144
145 async fn get(&self, token: &str) -> Result<serde_json::Value, JwtError> {
155 if token.is_empty() {
156 return Err(JwtError::TokenEmpty);
157 }
158
159 let mut tokens = self.tokens.write().await;
160 match tokens.get(token) {
161 Some(data) => {
162 if data.is_expired() {
163 tokens.remove(token);
164 Err(JwtError::RefreshTokenNotFound)
165 } else {
166 Ok(data.user_data.clone())
167 }
168 }
169 None => Err(JwtError::RefreshTokenNotFound),
170 }
171 }
172
173 async fn delete(&self, token: &str) -> Result<(), JwtError> {
177 if token.is_empty() {
178 return Ok(());
179 }
180
181 let mut tokens = self.tokens.write().await;
182 tokens.remove(token);
183 Ok(())
184 }
185
186 async fn cleanup(&self) -> Result<usize, JwtError> {
190 let mut tokens = self.tokens.write().await;
191 let now = Utc::now();
192 let before = tokens.len();
193 tokens.retain(|_, data| data.expiry > now);
194 let after = tokens.len();
195 Ok(before - after)
196 }
197
198 async fn count(&self) -> Result<usize, JwtError> {
203 let tokens = self.tokens.read().await;
204 Ok(tokens.len())
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use chrono::Duration;
212
213 #[tokio::test]
214 async fn test_set() {
215 let store = InMemoryRefreshTokenStore::new();
216 let user_data =
217 serde_json::json!({"id": "123", "username": "testuser", "email": "test@example.com"});
218 let expiry = Utc::now() + Duration::hours(1);
219
220 store.set("token123", user_data, expiry).await.unwrap();
221
222 let count = store.count().await.unwrap();
223 assert_eq!(count, 1);
224 }
225
226 #[tokio::test]
227 async fn test_get() {
228 let store = InMemoryRefreshTokenStore::new();
229 let user_data =
230 serde_json::json!({"id": "123", "username": "testuser", "email": "test@example.com"});
231 let expiry = Utc::now() + Duration::hours(1);
232
233 store
234 .set("token123", user_data.clone(), expiry)
235 .await
236 .unwrap();
237
238 let result = store.get("token123").await.unwrap();
239 assert_eq!(result["id"], "123");
240 assert_eq!(result["username"], "testuser");
241 assert_eq!(result["email"], "test@example.com");
242 }
243
244 #[tokio::test]
245 async fn test_set_empty_token() {
246 let store = InMemoryRefreshTokenStore::new();
247 let expiry = Utc::now() + Duration::hours(1);
248
249 let result = store.set("", serde_json::json!({}), expiry).await;
250 assert!(result.is_err());
251 }
252
253 #[tokio::test]
254 async fn test_get_empty_token() {
255 let store = InMemoryRefreshTokenStore::new();
256
257 let result = store.get("").await;
258 assert!(result.is_err());
259 }
260
261 #[tokio::test]
262 async fn test_get_nonexistent() {
263 let store = InMemoryRefreshTokenStore::new();
264
265 let result = store.get("nonexistent").await;
266 assert!(result.is_err());
267 }
268
269 #[tokio::test]
270 async fn test_get_expired_auto_cleanup() {
271 let store = InMemoryRefreshTokenStore::new();
272 let expiry = Utc::now() - Duration::seconds(1);
273
274 {
276 let mut tokens = store.tokens.write().await;
277 tokens.insert(
278 "expired".to_string(),
279 RefreshTokenData {
280 user_data: serde_json::json!({"user_id": "123"}),
281 expiry,
282 created: Utc::now() - Duration::hours(1),
283 },
284 );
285 }
286
287 let result = store.get("expired").await;
288 assert!(result.is_err());
289
290 let count = store.count().await.unwrap();
292 assert_eq!(count, 0);
293 }
294
295 #[tokio::test]
296 async fn test_delete() {
297 let store = InMemoryRefreshTokenStore::new();
298 let expiry = Utc::now() + Duration::hours(1);
299
300 store
301 .set("token1", serde_json::json!({}), expiry)
302 .await
303 .unwrap();
304
305 store.delete("token1").await.unwrap();
306
307 let result = store.get("token1").await;
308 assert!(result.is_err());
309 }
310
311 #[tokio::test]
312 async fn test_delete_empty_token() {
313 let store = InMemoryRefreshTokenStore::new();
314 store.delete("").await.unwrap();
316 }
317
318 #[tokio::test]
319 async fn test_cleanup() {
320 let store = InMemoryRefreshTokenStore::new();
321 let valid_expiry = Utc::now() + Duration::hours(1);
322 let expired_expiry = Utc::now() - Duration::seconds(1);
323
324 store
325 .set("valid", serde_json::json!({}), valid_expiry)
326 .await
327 .unwrap();
328
329 {
331 let mut tokens = store.tokens.write().await;
332 tokens.insert(
333 "expired".to_string(),
334 RefreshTokenData {
335 user_data: serde_json::json!({}),
336 expiry: expired_expiry,
337 created: Utc::now() - Duration::hours(1),
338 },
339 );
340 }
341
342 let cleaned = store.cleanup().await.unwrap();
343 assert_eq!(cleaned, 1);
344
345 let count = store.count().await.unwrap();
346 assert_eq!(count, 1);
347 }
348
349 #[tokio::test]
350 async fn test_get_all_filters_expired() {
351 let store = InMemoryRefreshTokenStore::new();
352 let valid_expiry = Utc::now() + Duration::hours(1);
353
354 store
355 .set("valid", serde_json::json!({"id": 1}), valid_expiry)
356 .await
357 .unwrap();
358
359 {
361 let mut tokens = store.tokens.write().await;
362 tokens.insert(
363 "expired".to_string(),
364 RefreshTokenData {
365 user_data: serde_json::json!({"id": 2}),
366 expiry: Utc::now() - Duration::seconds(1),
367 created: Utc::now() - Duration::hours(1),
368 },
369 );
370 }
371
372 let all = store.get_all().await;
373 assert_eq!(all.len(), 1);
374 assert!(all.contains_key("valid"));
375 }
376
377 #[tokio::test]
378 async fn test_clear() {
379 let store = InMemoryRefreshTokenStore::new();
380 let expiry = Utc::now() + Duration::hours(1);
381
382 store
383 .set("t1", serde_json::json!({}), expiry)
384 .await
385 .unwrap();
386 store
387 .set("t2", serde_json::json!({}), expiry)
388 .await
389 .unwrap();
390
391 store.clear().await;
392
393 let count = store.count().await.unwrap();
394 assert_eq!(count, 0);
395 }
396
397 #[tokio::test]
398 async fn test_new_store() {
399 let store = InMemoryRefreshTokenStore::new();
400 let count = store.count().await.unwrap();
401 assert_eq!(count, 0, "New store should be empty");
402 }
403
404 #[tokio::test]
405 async fn test_delete_nonexistent() {
406 let store = InMemoryRefreshTokenStore::new();
407 let result = store.delete("nonexistent_token").await;
409 assert!(result.is_ok());
410 }
411
412 #[tokio::test]
413 async fn test_count() {
414 let store = InMemoryRefreshTokenStore::new();
415 let valid_expiry = Utc::now() + Duration::hours(1);
416 let expired_expiry = Utc::now() - Duration::seconds(1);
417
418 for i in 0..3 {
420 store
421 .set(
422 &format!("valid{}", i),
423 serde_json::json!({"id": i}),
424 valid_expiry,
425 )
426 .await
427 .unwrap();
428 }
429
430 {
432 let mut tokens = store.tokens.write().await;
433 for i in 0..2 {
434 tokens.insert(
435 format!("expired{}", i),
436 RefreshTokenData {
437 user_data: serde_json::json!({"id": i}),
438 expiry: expired_expiry,
439 created: Utc::now() - Duration::hours(1),
440 },
441 );
442 }
443 }
444
445 let count = store.count().await.unwrap();
447 assert_eq!(
448 count, 5,
449 "Count should include both valid and expired tokens"
450 );
451
452 let cleaned = store.cleanup().await.unwrap();
454 assert_eq!(cleaned, 2);
455
456 let count = store.count().await.unwrap();
457 assert_eq!(count, 3, "Count after cleanup should be 3");
458 }
459
460 #[tokio::test]
461 async fn test_concurrent_access() {
462 use std::sync::Arc;
463
464 let store = Arc::new(InMemoryRefreshTokenStore::new());
465 let num_tasks = 100usize;
466
467 let mut handles = Vec::new();
469 for i in 0..num_tasks {
470 let store = Arc::clone(&store);
471 handles.push(tokio::spawn(async move {
472 let token = format!("token{}", i);
473 let user_data = serde_json::json!({"id": i});
474 let expiry = Utc::now() + Duration::hours(1);
475 store.set(&token, user_data, expiry).await.unwrap();
476 }));
477 }
478 for h in handles {
479 h.await.unwrap();
480 }
481
482 let count = store.count().await.unwrap();
483 assert_eq!(count, num_tasks);
484
485 let mut handles = Vec::new();
487 for i in 0..num_tasks {
488 let store = Arc::clone(&store);
489 handles.push(tokio::spawn(async move {
490 let token = format!("token{}", i);
491 let result = store.get(&token).await;
492 assert!(result.is_ok(), "Failed to get token{}", i);
493 }));
494 }
495 for h in handles {
496 h.await.unwrap();
497 }
498
499 let mut handles = Vec::new();
501 for i in 0..num_tasks {
502 let store = Arc::clone(&store);
503 handles.push(tokio::spawn(async move {
504 let token = format!("token{}", i);
505 store.delete(&token).await.unwrap();
506 }));
507 }
508 for h in handles {
509 h.await.unwrap();
510 }
511
512 let count = store.count().await.unwrap();
513 assert_eq!(
514 count, 0,
515 "All tokens should be deleted after concurrent deletes"
516 );
517 }
518
519 #[tokio::test]
520 async fn test_is_expired() {
521 let data = RefreshTokenData {
523 user_data: serde_json::json!({"user_id": "123"}),
524 expiry: Utc::now() + Duration::hours(1),
525 created: Utc::now(),
526 };
527 assert!(
528 !data.is_expired(),
529 "Token with future expiry should not be expired"
530 );
531
532 let data = RefreshTokenData {
534 user_data: serde_json::json!({"user_id": "123"}),
535 expiry: Utc::now() - Duration::hours(1),
536 created: Utc::now() - Duration::hours(2),
537 };
538 assert!(
539 data.is_expired(),
540 "Token with past expiry should be expired"
541 );
542 }
543}