1use crate::{
12 errors::Result,
13 storage::{AuthStorage, SessionData},
14 tokens::AuthToken,
15};
16use async_trait::async_trait;
17use std::{
18 collections::HashMap,
19 sync::{Arc, RwLock},
20 time::{Duration, Instant},
21};
22use tokio::time;
23
24#[derive(Clone)]
26pub struct InMemoryStorage {
27 tokens: Arc<RwLock<HashMap<String, TimestampedToken>>>,
28 access_tokens: Arc<RwLock<HashMap<String, String>>>, user_tokens: Arc<RwLock<HashMap<String, Vec<String>>>>, sessions: Arc<RwLock<HashMap<String, TimestampedSession>>>,
31 kv_store: Arc<RwLock<HashMap<String, TimestampedValue>>>,
32 cleanup_interval: Duration,
33 default_ttl: Duration,
34}
35
36#[derive(Clone)]
37struct TimestampedToken {
38 token: AuthToken,
39 expires_at: Instant,
40}
41
42#[derive(Clone)]
43struct TimestampedSession {
44 session: SessionData,
45 expires_at: Instant,
46}
47
48#[derive(Clone)]
49struct TimestampedValue {
50 value: Vec<u8>,
51 expires_at: Instant,
52}
53
54impl InMemoryStorage {
55 pub fn new() -> Self {
57 let storage = Self {
58 tokens: Arc::new(RwLock::new(HashMap::new())),
59 access_tokens: Arc::new(RwLock::new(HashMap::new())),
60 user_tokens: Arc::new(RwLock::new(HashMap::new())),
61 sessions: Arc::new(RwLock::new(HashMap::new())),
62 kv_store: Arc::new(RwLock::new(HashMap::new())),
63 cleanup_interval: Duration::from_secs(300), default_ttl: Duration::from_secs(3600), };
66
67 storage.start_cleanup_task();
69 storage
70 }
71
72 pub fn with_config(cleanup_interval: Duration, default_ttl: Duration) -> Self {
74 let mut storage = Self::new();
75 storage.cleanup_interval = cleanup_interval;
76 storage.default_ttl = default_ttl;
77 storage
78 }
79
80 fn start_cleanup_task(&self) {
81 let tokens = self.tokens.clone();
82 let access_tokens = self.access_tokens.clone();
83 let user_tokens = self.user_tokens.clone();
84 let sessions = self.sessions.clone();
85 let kv_store = self.kv_store.clone();
86 let interval = self.cleanup_interval;
87
88 tokio::spawn(async move {
89 let mut cleanup_timer = time::interval(interval);
90
91 loop {
92 cleanup_timer.tick().await;
93 Self::cleanup_expired_data(
94 &tokens,
95 &access_tokens,
96 &user_tokens,
97 &sessions,
98 &kv_store,
99 );
100 }
101 });
102 }
103
104 fn cleanup_expired_data(
105 tokens: &Arc<RwLock<HashMap<String, TimestampedToken>>>,
106 access_tokens: &Arc<RwLock<HashMap<String, String>>>,
107 user_tokens: &Arc<RwLock<HashMap<String, Vec<String>>>>,
108 sessions: &Arc<RwLock<HashMap<String, TimestampedSession>>>,
109 kv_store: &Arc<RwLock<HashMap<String, TimestampedValue>>>,
110 ) {
111 let now = Instant::now();
112
113 {
115 let mut tokens_guard = tokens.write().unwrap();
116 let mut access_tokens_guard = access_tokens.write().unwrap();
117 let mut user_tokens_guard = user_tokens.write().unwrap();
118
119 let expired_tokens: Vec<String> = tokens_guard
120 .iter()
121 .filter(|(_, timestamped)| timestamped.expires_at <= now)
122 .map(|(id, _)| id.clone())
123 .collect();
124
125 for token_id in expired_tokens {
126 if let Some(timestamped) = tokens_guard.remove(&token_id) {
127 access_tokens_guard.remove(×tamped.token.access_token);
129
130 if let Some(user_token_list) =
132 user_tokens_guard.get_mut(×tamped.token.user_id)
133 {
134 user_token_list.retain(|id| id != &token_id);
135 if user_token_list.is_empty() {
136 user_tokens_guard.remove(×tamped.token.user_id);
137 }
138 }
139 }
140 }
141 }
142
143 {
145 let mut sessions_guard = sessions.write().unwrap();
146 sessions_guard.retain(|_, timestamped| timestamped.expires_at > now);
147 }
148
149 {
151 let mut kv_guard = kv_store.write().unwrap();
152 kv_guard.retain(|_, timestamped| timestamped.expires_at > now);
153 }
154 }
155
156 fn calculate_expiry(&self, token: &AuthToken) -> Instant {
157 let now = chrono::Utc::now();
158 if token.expires_at > now {
159 let duration = (token.expires_at - now).num_seconds() as u64;
160 Instant::now() + Duration::from_secs(duration)
161 } else {
162 Instant::now() + self.default_ttl
163 }
164 }
165}
166
167impl Default for InMemoryStorage {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173#[async_trait]
174impl AuthStorage for InMemoryStorage {
175 async fn store_token(&self, token: &AuthToken) -> Result<()> {
176 let expires_at = self.calculate_expiry(token);
177 let timestamped_token = TimestampedToken {
178 token: token.clone(),
179 expires_at,
180 };
181
182 {
183 let mut tokens = self.tokens.write().unwrap();
184 tokens.insert(token.token_id.clone(), timestamped_token);
185 }
186
187 {
188 let mut access_tokens = self.access_tokens.write().unwrap();
189 access_tokens.insert(token.access_token.clone(), token.token_id.clone());
190 }
191
192 {
193 let mut user_tokens = self.user_tokens.write().unwrap();
194 user_tokens
195 .entry(token.user_id.clone())
196 .or_default()
197 .push(token.token_id.clone());
198 }
199
200 Ok(())
201 }
202
203 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
204 let tokens = self.tokens.read().unwrap();
205 if let Some(timestamped) = tokens.get(token_id) {
206 if timestamped.expires_at > Instant::now() {
207 Ok(Some(timestamped.token.clone()))
208 } else {
209 Ok(None)
210 }
211 } else {
212 Ok(None)
213 }
214 }
215
216 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
217 let token_id_opt = {
218 let access_tokens = self.access_tokens.read().unwrap();
219 access_tokens.get(access_token).cloned()
220 };
221 if let Some(token_id) = token_id_opt {
222 self.get_token(&token_id).await
223 } else {
224 Ok(None)
225 }
226 }
227
228 async fn update_token(&self, token: &AuthToken) -> Result<()> {
229 self.store_token(token).await
231 }
232
233 async fn delete_token(&self, token_id: &str) -> Result<()> {
234 let removed_token = {
235 let mut tokens = self.tokens.write().unwrap();
236 tokens.remove(token_id)
237 };
238
239 if let Some(timestamped) = removed_token {
240 {
242 let mut access_tokens = self.access_tokens.write().unwrap();
243 access_tokens.remove(×tamped.token.access_token);
244 }
245
246 {
248 let mut user_tokens = self.user_tokens.write().unwrap();
249 if let Some(user_token_list) = user_tokens.get_mut(×tamped.token.user_id) {
250 user_token_list.retain(|id| id != token_id);
251 if user_token_list.is_empty() {
252 user_tokens.remove(×tamped.token.user_id);
253 }
254 }
255 }
256 }
257
258 Ok(())
259 }
260
261 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
262 let user_tokens = self.user_tokens.read().unwrap();
263 let tokens = self.tokens.read().unwrap();
264 let now = Instant::now();
265
266 match user_tokens.get(user_id) {
267 Some(token_ids) => {
268 let mut result = Vec::new();
269 for token_id in token_ids {
270 if let Some(timestamped) = tokens.get(token_id)
271 && timestamped.expires_at > now
272 {
273 result.push(timestamped.token.clone());
274 }
275 }
276 Ok(result)
277 }
278 None => Ok(Vec::new()),
279 }
280 }
281
282 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
283 let expires_at = Instant::now() + self.default_ttl;
284 let timestamped_session = TimestampedSession {
285 session: data.clone(),
286 expires_at,
287 };
288
289 let mut sessions = self.sessions.write().unwrap();
290 sessions.insert(session_id.to_string(), timestamped_session);
291 Ok(())
292 }
293
294 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
295 let sessions = self.sessions.read().unwrap();
296 if let Some(timestamped) = sessions.get(session_id) {
297 if timestamped.expires_at > Instant::now() {
298 Ok(Some(timestamped.session.clone()))
299 } else {
300 Ok(None)
301 }
302 } else {
303 Ok(None)
304 }
305 }
306
307 async fn delete_session(&self, session_id: &str) -> Result<()> {
308 let mut sessions = self.sessions.write().unwrap();
309 sessions.remove(session_id);
310 Ok(())
311 }
312
313 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
314 let sessions = self.sessions.read().unwrap();
315 let now = Instant::now();
316
317 let user_sessions: Vec<SessionData> = sessions
318 .values()
319 .filter_map(|timestamped| {
320 if timestamped.session.user_id == user_id && timestamped.expires_at > now {
321 Some(timestamped.session.clone())
322 } else {
323 None
324 }
325 })
326 .collect();
327
328 Ok(user_sessions)
329 }
330
331 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
332 let expires_at = Instant::now() + ttl.unwrap_or(self.default_ttl);
333 let timestamped_value = TimestampedValue {
334 value: value.to_vec(),
335 expires_at,
336 };
337
338 let mut kv_store = self.kv_store.write().unwrap();
339 kv_store.insert(key.to_string(), timestamped_value);
340 Ok(())
341 }
342
343 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
344 let kv_store = self.kv_store.read().unwrap();
345 if let Some(timestamped) = kv_store.get(key) {
346 if timestamped.expires_at > Instant::now() {
347 Ok(Some(timestamped.value.clone()))
348 } else {
349 Ok(None)
350 }
351 } else {
352 Ok(None)
353 }
354 }
355
356 async fn delete_kv(&self, key: &str) -> Result<()> {
357 let mut kv_store = self.kv_store.write().unwrap();
358 kv_store.remove(key);
359 Ok(())
360 }
361
362 async fn cleanup_expired(&self) -> Result<()> {
363 Self::cleanup_expired_data(
364 &self.tokens,
365 &self.access_tokens,
366 &self.user_tokens,
367 &self.sessions,
368 &self.kv_store,
369 );
370 Ok(())
371 }
372
373 async fn count_active_sessions(&self) -> Result<u64> {
374 let sessions = self.sessions.read().unwrap();
375 let now = Instant::now();
376
377 let active_count = sessions
378 .values()
379 .filter(|timestamped| timestamped.expires_at > now)
380 .count() as u64;
381
382 Ok(active_count)
383 }
384}
385
386pub struct InMemoryConfig {
388 pub cleanup_interval: Duration,
389 pub default_ttl: Duration,
390}
391
392impl Default for InMemoryConfig {
393 fn default() -> Self {
394 Self {
395 cleanup_interval: Duration::from_secs(300), default_ttl: Duration::from_secs(3600), }
398 }
399}
400
401impl InMemoryConfig {
402 pub fn new() -> Self {
403 Self::default()
404 }
405
406 pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
407 self.cleanup_interval = interval;
408 self
409 }
410
411 pub fn with_default_ttl(mut self, ttl: Duration) -> Self {
412 self.default_ttl = ttl;
413 self
414 }
415
416 pub fn build(self) -> InMemoryStorage {
417 InMemoryStorage::with_config(self.cleanup_interval, self.default_ttl)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::testing::helpers::create_test_token;
425
426 #[tokio::test]
427 async fn test_in_memory_token_operations() {
428 let storage = InMemoryStorage::new();
429 let token = create_test_token("test_user");
430
431 storage.store_token(&token).await.unwrap();
433
434 let retrieved = storage.get_token(&token.token_id).await.unwrap();
436 assert!(retrieved.is_some());
437 assert_eq!(retrieved.unwrap().token_id, token.token_id);
438
439 let retrieved = storage
441 .get_token_by_access_token(&token.access_token)
442 .await
443 .unwrap();
444 assert!(retrieved.is_some());
445 assert_eq!(retrieved.unwrap().access_token, token.access_token);
446
447 let user_tokens = storage.list_user_tokens(&token.user_id).await.unwrap();
449 assert_eq!(user_tokens.len(), 1);
450
451 storage.delete_token(&token.token_id).await.unwrap();
453 let retrieved = storage.get_token(&token.token_id).await.unwrap();
454 assert!(retrieved.is_none());
455 }
456
457 #[tokio::test]
458 async fn test_in_memory_expiration() {
459 let storage = InMemoryStorage::with_config(
460 Duration::from_millis(100), Duration::from_millis(200), );
463
464 let key = "test_key";
465 let value = b"test_value";
466
467 storage
469 .store_kv(key, value, Some(Duration::from_millis(50)))
470 .await
471 .unwrap();
472
473 let retrieved = storage.get_kv(key).await.unwrap();
475 assert!(retrieved.is_some());
476
477 tokio::time::sleep(Duration::from_millis(100)).await;
479
480 let retrieved = storage.get_kv(key).await.unwrap();
482 assert!(retrieved.is_none());
483 }
484
485 #[tokio::test]
486 async fn test_in_memory_session_operations() {
487 let storage = InMemoryStorage::new();
488
489 let session_id = "test_session";
490 let session_data = SessionData {
491 session_id: session_id.to_string(),
492 user_id: "test_user".to_string(),
493 created_at: chrono::Utc::now(),
494 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
495 last_activity: chrono::Utc::now(),
496 ip_address: None,
497 user_agent: None,
498 data: [("key".to_string(), serde_json::json!("value"))]
499 .into_iter()
500 .collect(),
501 };
502
503 storage
505 .store_session(session_id, &session_data)
506 .await
507 .unwrap();
508
509 let retrieved = storage.get_session(session_id).await.unwrap();
511 assert!(retrieved.is_some());
512 assert_eq!(retrieved.unwrap().user_id, session_data.user_id);
513
514 storage.delete_session(session_id).await.unwrap();
516 let retrieved = storage.get_session(session_id).await.unwrap();
517 assert!(retrieved.is_none());
518 }
519}
520
521