1use crate::errors::{Result, StorageError};
4use crate::tokens::AuthToken;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11#[async_trait]
13pub trait AuthStorage: Send + Sync {
14 async fn store_token(&self, token: &AuthToken) -> Result<()>;
16
17 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>>;
19
20 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>>;
22
23 async fn update_token(&self, token: &AuthToken) -> Result<()>;
25
26 async fn delete_token(&self, token_id: &str) -> Result<()>;
28
29 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>>;
31
32 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()>;
34
35 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>>;
37
38 async fn delete_session(&self, session_id: &str) -> Result<()>;
40
41 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()>;
43
44 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>>;
46
47 async fn delete_kv(&self, key: &str) -> Result<()>;
49
50 async fn cleanup_expired(&self) -> Result<()>;
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SessionData {
57 pub session_id: String,
59
60 pub user_id: String,
62
63 pub created_at: chrono::DateTime<chrono::Utc>,
65
66 pub expires_at: chrono::DateTime<chrono::Utc>,
68
69 pub last_activity: chrono::DateTime<chrono::Utc>,
71
72 pub ip_address: Option<String>,
74
75 pub user_agent: Option<String>,
77
78 pub data: HashMap<String, serde_json::Value>,
80}
81
82type KvValue = (Vec<u8>, Option<chrono::DateTime<chrono::Utc>>);
84
85#[derive(Debug, Clone)]
87pub struct MemoryStorage {
88 tokens: Arc<RwLock<HashMap<String, AuthToken>>>,
89 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
90 kv_store: Arc<RwLock<HashMap<String, KvValue>>>,
91}
92
93#[cfg(feature = "redis-storage")]
95#[derive(Debug, Clone)]
96pub struct RedisStorage {
97 client: redis::Client,
98 key_prefix: String,
99}
100
101impl Default for MemoryStorage {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl MemoryStorage {
108 pub fn new() -> Self {
110 Self {
111 tokens: Arc::new(RwLock::new(HashMap::new())),
112 sessions: Arc::new(RwLock::new(HashMap::new())),
113 kv_store: Arc::new(RwLock::new(HashMap::new())),
114 }
115 }
116}
117
118#[async_trait]
119impl AuthStorage for MemoryStorage {
120 async fn store_token(&self, token: &AuthToken) -> Result<()> {
121 let mut tokens = self.tokens.write().unwrap();
122 tokens.insert(token.token_id.clone(), token.clone());
123 Ok(())
124 }
125
126 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
127 let tokens = self.tokens.read().unwrap();
128 Ok(tokens.get(token_id).cloned())
129 }
130
131 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
132 let tokens = self.tokens.read().unwrap();
133 Ok(tokens.values()
134 .find(|token| token.access_token == access_token)
135 .cloned())
136 }
137
138 async fn update_token(&self, token: &AuthToken) -> Result<()> {
139 let mut tokens = self.tokens.write().unwrap();
140 tokens.insert(token.token_id.clone(), token.clone());
141 Ok(())
142 }
143
144 async fn delete_token(&self, token_id: &str) -> Result<()> {
145 let mut tokens = self.tokens.write().unwrap();
146 tokens.remove(token_id);
147 Ok(())
148 }
149
150 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
151 let tokens = self.tokens.read().unwrap();
152 Ok(tokens.values()
153 .filter(|token| token.user_id == user_id)
154 .cloned()
155 .collect())
156 }
157
158 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
159 let mut sessions = self.sessions.write().unwrap();
160 sessions.insert(session_id.to_string(), data.clone());
161 Ok(())
162 }
163
164 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
165 let sessions = self.sessions.read().unwrap();
166 let session = sessions.get(session_id).cloned();
167
168 if let Some(ref session) = session {
170 if chrono::Utc::now() > session.expires_at {
171 return Ok(None);
172 }
173 }
174
175 Ok(session)
176 }
177
178 async fn delete_session(&self, session_id: &str) -> Result<()> {
179 let mut sessions = self.sessions.write().unwrap();
180 sessions.remove(session_id);
181 Ok(())
182 }
183
184 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
185 let mut kv_store = self.kv_store.write().unwrap();
186 let expires_at = ttl.map(|ttl| {
187 chrono::Utc::now() + chrono::Duration::from_std(ttl).unwrap()
188 });
189 kv_store.insert(key.to_string(), (value.to_vec(), expires_at));
190 Ok(())
191 }
192
193 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
194 let kv_store = self.kv_store.read().unwrap();
195 if let Some((value, expires_at)) = kv_store.get(key) {
196 if let Some(expires_at) = expires_at {
198 if chrono::Utc::now() > *expires_at {
199 return Ok(None);
200 }
201 }
202 Ok(Some(value.clone()))
203 } else {
204 Ok(None)
205 }
206 }
207
208 async fn delete_kv(&self, key: &str) -> Result<()> {
209 let mut kv_store = self.kv_store.write().unwrap();
210 kv_store.remove(key);
211 Ok(())
212 }
213
214 async fn cleanup_expired(&self) -> Result<()> {
215 let now = chrono::Utc::now();
216
217 {
219 let mut tokens = self.tokens.write().unwrap();
220 tokens.retain(|_, token| !token.is_expired());
221 }
222
223 {
225 let mut sessions = self.sessions.write().unwrap();
226 sessions.retain(|_, session| now <= session.expires_at);
227 }
228
229 {
231 let mut kv_store = self.kv_store.write().unwrap();
232 kv_store.retain(|_, (_, expires_at)| {
233 expires_at.is_none_or(|exp| now <= exp)
234 });
235 }
236
237 Ok(())
238 }
239}
240
241#[cfg(feature = "redis-storage")]
242impl RedisStorage {
243 pub fn new(redis_url: &str, key_prefix: impl Into<String>) -> Result<Self> {
245 let client = redis::Client::open(redis_url)
246 .map_err(|e| StorageError::connection_failed(format!("Redis connection failed: {e}")))?;
247
248 Ok(Self {
249 client,
250 key_prefix: key_prefix.into(),
251 })
252 }
253
254 async fn get_connection(&self) -> Result<redis::aio::Connection> {
256 self.client.get_async_connection().await
257 .map_err(|e| StorageError::connection_failed(format!("Failed to get Redis connection: {e}")).into())
258 }
259
260 fn key(&self, suffix: &str) -> String {
262 format!("{}{}", self.key_prefix, suffix)
263 }
264}
265
266#[cfg(feature = "redis-storage")]
267#[async_trait]
268impl AuthStorage for RedisStorage {
269 async fn store_token(&self, token: &AuthToken) -> Result<()> {
270 let mut conn = self.get_connection().await?;
271 let token_json = serde_json::to_string(token)
272 .map_err(|e| StorageError::serialization(format!("Token serialization failed: {e}")))?;
273
274 let token_key = self.key(&format!("token:{}", token.token_id));
275 let access_token_key = self.key(&format!("access_token:{}", token.access_token));
276 let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
277
278 let ttl = token.time_until_expiry().as_secs().max(1);
280
281 redis::cmd("SETEX")
283 .arg(&token_key)
284 .arg(ttl)
285 .arg(&token_json)
286 .query_async::<_, ()>(&mut conn)
287 .await
288 .map_err(|e| StorageError::operation_failed(format!("Failed to store token: {e}")))?;
289
290 redis::cmd("SETEX")
292 .arg(&access_token_key)
293 .arg(ttl)
294 .arg(&token.token_id)
295 .query_async::<_, ()>(&mut conn)
296 .await
297 .map_err(|e| StorageError::operation_failed(format!("Failed to store access token mapping: {e}")))?;
298
299 redis::cmd("SADD")
301 .arg(&user_tokens_key)
302 .arg(&token.token_id)
303 .query_async::<_, ()>(&mut conn)
304 .await
305 .map_err(|e| StorageError::operation_failed(format!("Failed to add token to user set: {e}")))?;
306
307 Ok(())
308 }
309
310 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
311 let mut conn = self.get_connection().await?;
312 let token_key = self.key(&format!("token:{token_id}"));
313
314 let token_json: Option<String> = redis::cmd("GET")
315 .arg(&token_key)
316 .query_async(&mut conn)
317 .await
318 .map_err(|e| StorageError::operation_failed(format!("Failed to get token: {e}")))?;
319
320 if let Some(json) = token_json {
321 let token: AuthToken = serde_json::from_str(&json)
322 .map_err(|e| StorageError::serialization(format!("Token deserialization failed: {e}")))?;
323 Ok(Some(token))
324 } else {
325 Ok(None)
326 }
327 }
328
329 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
330 let mut conn = self.get_connection().await?;
331 let access_token_key = self.key(&format!("access_token:{access_token}"));
332
333 let token_id: Option<String> = redis::cmd("GET")
334 .arg(&access_token_key)
335 .query_async(&mut conn)
336 .await
337 .map_err(|e| StorageError::operation_failed(format!("Failed to get access token mapping: {e}")))?;
338
339 if let Some(token_id) = token_id {
340 self.get_token(&token_id).await
341 } else {
342 Ok(None)
343 }
344 }
345
346 async fn update_token(&self, token: &AuthToken) -> Result<()> {
347 self.store_token(token).await
349 }
350
351 async fn delete_token(&self, token_id: &str) -> Result<()> {
352 let mut conn = self.get_connection().await?;
353
354 if let Some(token) = self.get_token(token_id).await? {
356 let token_key = self.key(&format!("token:{token_id}"));
357 let access_token_key = self.key(&format!("access_token:{}", token.access_token));
358 let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
359
360 redis::cmd("DEL")
362 .arg(&token_key)
363 .query_async::<_, ()>(&mut conn)
364 .await
365 .map_err(|e| StorageError::operation_failed(format!("Failed to delete token: {e}")))?;
366
367 redis::cmd("DEL")
369 .arg(&access_token_key)
370 .query_async::<_, ()>(&mut conn)
371 .await
372 .map_err(|e| StorageError::operation_failed(format!("Failed to delete access token mapping: {e}")))?;
373
374 redis::cmd("SREM")
376 .arg(&user_tokens_key)
377 .arg(token_id)
378 .query_async::<_, ()>(&mut conn)
379 .await
380 .map_err(|e| StorageError::operation_failed(format!("Failed to remove token from user set: {e}")))?;
381 }
382
383 Ok(())
384 }
385
386 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
387 let mut conn = self.get_connection().await?;
388 let user_tokens_key = self.key(&format!("user_tokens:{user_id}"));
389
390 let token_ids: Vec<String> = redis::cmd("SMEMBERS")
391 .arg(&user_tokens_key)
392 .query_async(&mut conn)
393 .await
394 .map_err(|e| StorageError::operation_failed(format!("Failed to get user tokens: {e}")))?;
395
396 let mut tokens = Vec::new();
397 for token_id in token_ids {
398 if let Some(token) = self.get_token(&token_id).await? {
399 tokens.push(token);
400 }
401 }
402
403 Ok(tokens)
404 }
405
406 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
407 let mut conn = self.get_connection().await?;
408 let session_key = self.key(&format!("session:{session_id}"));
409
410 let session_json = serde_json::to_string(data)
411 .map_err(|e| StorageError::serialization(format!("Session serialization failed: {e}")))?;
412
413 let ttl = (data.expires_at - chrono::Utc::now()).num_seconds().max(1);
414
415 redis::cmd("SETEX")
416 .arg(&session_key)
417 .arg(ttl)
418 .arg(&session_json)
419 .query_async::<_, ()>(&mut conn)
420 .await
421 .map_err(|e| StorageError::operation_failed(format!("Failed to store session: {e}")))?;
422
423 Ok(())
424 }
425
426 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
427 let mut conn = self.get_connection().await?;
428 let session_key = self.key(&format!("session:{session_id}"));
429
430 let session_json: Option<String> = redis::cmd("GET")
431 .arg(&session_key)
432 .query_async(&mut conn)
433 .await
434 .map_err(|e| StorageError::operation_failed(format!("Failed to get session: {e}")))?;
435
436 if let Some(json) = session_json {
437 let session: SessionData = serde_json::from_str(&json)
438 .map_err(|e| StorageError::serialization(format!("Session deserialization failed: {e}")))?;
439 Ok(Some(session))
440 } else {
441 Ok(None)
442 }
443 }
444
445 async fn delete_session(&self, session_id: &str) -> Result<()> {
446 let mut conn = self.get_connection().await?;
447 let session_key = self.key(&format!("session:{session_id}"));
448
449 redis::cmd("DEL")
450 .arg(&session_key)
451 .query_async::<_, ()>(&mut conn)
452 .await
453 .map_err(|e| StorageError::operation_failed(format!("Failed to delete session: {e}")))?;
454
455 Ok(())
456 }
457
458 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
459 let mut conn = self.get_connection().await?;
460 let storage_key = self.key(&format!("kv:{key}"));
461
462 if let Some(ttl) = ttl {
463 redis::cmd("SETEX")
464 .arg(&storage_key)
465 .arg(ttl.as_secs())
466 .arg(value)
467 .query_async::<_, ()>(&mut conn)
468 .await
469 .map_err(|e| StorageError::operation_failed(format!("Failed to store KV with TTL: {e}")))?;
470 } else {
471 redis::cmd("SET")
472 .arg(&storage_key)
473 .arg(value)
474 .query_async::<_, ()>(&mut conn)
475 .await
476 .map_err(|e| StorageError::operation_failed(format!("Failed to store KV: {e}")))?;
477 }
478
479 Ok(())
480 }
481
482 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
483 let mut conn = self.get_connection().await?;
484 let storage_key = self.key(&format!("kv:{key}"));
485
486 let value: Option<Vec<u8>> = redis::cmd("GET")
487 .arg(&storage_key)
488 .query_async(&mut conn)
489 .await
490 .map_err(|e| StorageError::operation_failed(format!("Failed to get KV: {e}")))?;
491
492 Ok(value)
493 }
494
495 async fn delete_kv(&self, key: &str) -> Result<()> {
496 let mut conn = self.get_connection().await?;
497 let storage_key = self.key(&format!("kv:{key}"));
498
499 redis::cmd("DEL")
500 .arg(&storage_key)
501 .query_async::<_, ()>(&mut conn)
502 .await
503 .map_err(|e| StorageError::operation_failed(format!("Failed to delete KV: {e}")))?;
504
505 Ok(())
506 }
507
508 async fn cleanup_expired(&self) -> Result<()> {
509 Ok(())
511 }
512}
513
514impl SessionData {
515 pub fn new(
517 session_id: impl Into<String>,
518 user_id: impl Into<String>,
519 expires_in: Duration,
520 ) -> Self {
521 let now = chrono::Utc::now();
522
523 Self {
524 session_id: session_id.into(),
525 user_id: user_id.into(),
526 created_at: now,
527 expires_at: now + chrono::Duration::from_std(expires_in).unwrap(),
528 last_activity: now,
529 ip_address: None,
530 user_agent: None,
531 data: HashMap::new(),
532 }
533 }
534
535 pub fn is_expired(&self) -> bool {
537 chrono::Utc::now() > self.expires_at
538 }
539
540 pub fn update_activity(&mut self) {
542 self.last_activity = chrono::Utc::now();
543 }
544
545 pub fn with_metadata(
547 mut self,
548 ip_address: Option<String>,
549 user_agent: Option<String>,
550 ) -> Self {
551 self.ip_address = ip_address;
552 self.user_agent = user_agent;
553 self
554 }
555
556 pub fn set_data(&mut self, key: impl Into<String>, value: serde_json::Value) {
558 self.data.insert(key.into(), value);
559 }
560
561 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
563 self.data.get(key)
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::tokens::AuthToken;
571
572 #[tokio::test]
573 async fn test_memory_storage() {
574 let storage = MemoryStorage::new();
575
576 let token = AuthToken::new(
578 "user123",
579 "token123",
580 Duration::from_secs(3600),
581 "test",
582 );
583
584 storage.store_token(&token).await.unwrap();
586
587 let retrieved = storage.get_token(&token.token_id).await.unwrap().unwrap();
589 assert_eq!(retrieved.user_id, "user123");
590
591 let retrieved = storage.get_token_by_access_token(&token.access_token).await.unwrap().unwrap();
593 assert_eq!(retrieved.token_id, token.token_id);
594
595 let user_tokens = storage.list_user_tokens("user123").await.unwrap();
597 assert_eq!(user_tokens.len(), 1);
598
599 storage.delete_token(&token.token_id).await.unwrap();
601 let retrieved = storage.get_token(&token.token_id).await.unwrap();
602 assert!(retrieved.is_none());
603 }
604
605 #[tokio::test]
606 async fn test_session_storage() {
607 let storage = MemoryStorage::new();
608
609 let session = SessionData::new(
610 "session123",
611 "user123",
612 Duration::from_secs(3600),
613 ).with_metadata(
614 Some("192.168.1.1".to_string()),
615 Some("Test Agent".to_string()),
616 );
617
618 storage.store_session(&session.session_id, &session).await.unwrap();
620
621 let retrieved = storage.get_session(&session.session_id).await.unwrap().unwrap();
623 assert_eq!(retrieved.user_id, "user123");
624 assert_eq!(retrieved.ip_address, Some("192.168.1.1".to_string()));
625
626 storage.delete_session(&session.session_id).await.unwrap();
628 let retrieved = storage.get_session(&session.session_id).await.unwrap();
629 assert!(retrieved.is_none());
630 }
631
632 #[tokio::test]
633 async fn test_kv_storage() {
634 let storage = MemoryStorage::new();
635
636 let key = "test_key";
637 let value = b"test_value";
638
639 storage.store_kv(key, value, Some(Duration::from_secs(3600))).await.unwrap();
641
642 let retrieved = storage.get_kv(key).await.unwrap().unwrap();
644 assert_eq!(retrieved, value);
645
646 storage.delete_kv(key).await.unwrap();
648 let retrieved = storage.get_kv(key).await.unwrap();
649 assert!(retrieved.is_none());
650 }
651}