oxidite_auth/
session.rs

1use async_trait::async_trait;
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9use redis::{Client, AsyncCommands};
10use crate::{AuthError, Result};
11
12/// Session data
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Session {
15    pub id: String,
16    pub user_id: String,
17    pub created_at: u64,
18    pub expires_at: u64,
19    pub data: HashMap<String, serde_json::Value>,
20}
21
22impl Session {
23    pub fn new(user_id: String, ttl_secs: u64) -> Self {
24        let now = SystemTime::now()
25            .duration_since(UNIX_EPOCH)
26            .unwrap()
27            .as_secs();
28
29        Self {
30            id: Uuid::new_v4().to_string(),
31            user_id,
32            created_at: now,
33            expires_at: now + ttl_secs,
34            data: HashMap::new(),
35        }
36    }
37
38    pub fn is_expired(&self) -> bool {
39        let now = SystemTime::now()
40            .duration_since(UNIX_EPOCH)
41            .unwrap()
42            .as_secs();
43        now >= self.expires_at
44    }
45
46    pub fn renew(&mut self, ttl_secs: u64) {
47        let now = SystemTime::now()
48            .duration_since(UNIX_EPOCH)
49            .unwrap()
50            .as_secs();
51        self.expires_at = now + ttl_secs;
52    }
53
54    pub fn set_data(&mut self, key: String, value: serde_json::Value) {
55        self.data.insert(key, value);
56    }
57
58    pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
59        self.data.get(key)
60    }
61}
62
63/// Session storage trait
64#[async_trait]
65pub trait SessionStore: Send + Sync {
66    async fn create(&self, session: Session) -> Result<String>;
67    async fn get(&self, session_id: &str) -> Result<Option<Session>>;
68    async fn update(&self, session: Session) -> Result<()>;
69    async fn delete(&self, session_id: &str) -> Result<()>;
70    async fn cleanup_expired(&self) -> Result<usize>;
71}
72
73/// In-memory session store
74pub struct InMemorySessionStore {
75    sessions: Arc<RwLock<HashMap<String, Session>>>,
76}
77
78impl InMemorySessionStore {
79    pub fn new() -> Self {
80        Self {
81            sessions: Arc::new(RwLock::new(HashMap::new())),
82        }
83    }
84}
85
86impl Default for InMemorySessionStore {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92#[async_trait]
93impl SessionStore for InMemorySessionStore {
94    async fn create(&self, session: Session) -> Result<String> {
95        let session_id = session.id.clone();
96        let mut sessions = self.sessions.write().await;
97        sessions.insert(session_id.clone(), session);
98        Ok(session_id)
99    }
100
101    async fn get(&self, session_id: &str) -> Result<Option<Session>> {
102        let sessions = self.sessions.read().await;
103        Ok(sessions.get(session_id).cloned())
104    }
105
106    async fn update(&self, session: Session) -> Result<()> {
107        let mut sessions = self.sessions.write().await;
108        sessions.insert(session.id.clone(), session);
109        Ok(())
110    }
111
112    async fn delete(&self, session_id: &str) -> Result<()> {
113        let mut sessions = self.sessions.write().await;
114        sessions.remove(session_id);
115        Ok(())
116    }
117
118    async fn cleanup_expired(&self) -> Result<usize> {
119        let mut sessions = self.sessions.write().await;
120        let initial_count = sessions.len();
121        sessions.retain(|_, session| !session.is_expired());
122        Ok(initial_count - sessions.len())
123    }
124}
125
126/// Redis session store
127pub struct RedisSessionStore {
128    client: Client,
129    prefix: String,
130}
131
132impl RedisSessionStore {
133    pub fn new(url: &str, prefix: &str) -> Result<Self> {
134        let client = Client::open(url)
135            .map_err(|e| AuthError::HashError(e.to_string()))?;
136        
137        Ok(Self {
138            client,
139            prefix: prefix.to_string(),
140        })
141    }
142
143    fn session_key(&self, session_id: &str) -> String {
144        format!("{}:{}", self.prefix, session_id)
145    }
146}
147
148#[async_trait]
149impl SessionStore for RedisSessionStore {
150    async fn create(&self, session: Session) -> Result<String> {
151        let session_id = session.id.clone();
152        let key = self.session_key(&session_id);
153        let ttl = session.expires_at - session.created_at;
154        
155        let mut conn = self.client.get_multiplexed_async_connection()
156            .await
157            .map_err(|e| AuthError::HashError(e.to_string()))?;
158        
159        let data = serde_json::to_string(&session)
160            .map_err(|e| AuthError::HashError(e.to_string()))?;
161        
162        let _: () = conn.set_ex(&key, data, ttl)
163            .await
164            .map_err(|e| AuthError::HashError(e.to_string()))?;
165        
166        Ok(session_id)
167    }
168
169    async fn get(&self, session_id: &str) -> Result<Option<Session>> {
170        let key = self.session_key(session_id);
171        
172        let mut conn = self.client.get_multiplexed_async_connection()
173            .await
174            .map_err(|e| AuthError::HashError(e.to_string()))?;
175        
176        let result: Option<String> = conn.get(&key)
177            .await
178            .map_err(|e| AuthError::HashError(e.to_string()))?;
179        
180        if let Some(data) = result {
181            let session: Session = serde_json::from_str(&data)
182                .map_err(|e| AuthError::HashError(e.to_string()))?;
183            
184            if session.is_expired() {
185                self.delete(session_id).await?;
186                return Ok(None);
187            }
188            
189            Ok(Some(session))
190        } else {
191            Ok(None)
192        }
193    }
194
195    async fn update(&self, session: Session) -> Result<()> {
196        self.create(session).await?;
197        Ok(())
198    }
199
200    async fn delete(&self, session_id: &str) -> Result<()> {
201        let key = self.session_key(session_id);
202        
203        let mut conn = self.client.get_multiplexed_async_connection()
204            .await
205            .map_err(|e| AuthError::HashError(e.to_string()))?;
206        
207        let _: () = conn.del(&key)
208            .await
209            .map_err(|e| AuthError::HashError(e.to_string()))?;
210        
211        Ok(())
212    }
213
214    async fn cleanup_expired(&self) -> Result<usize> {
215        // Redis automatically expires keys with TTL, so no cleanup needed
216        Ok(0)
217    }
218}
219
220/// Session Manager
221#[derive(Clone)]
222pub struct SessionManager {
223    store: Arc<dyn SessionStore>,
224}
225
226impl SessionManager {
227    pub fn new(store: Arc<dyn SessionStore>) -> Self {
228        Self { store }
229    }
230    
231    pub fn new_memory() -> Self {
232        Self::new(Arc::new(InMemorySessionStore::new()))
233    }
234    
235    pub fn new_redis(url: &str, prefix: &str) -> Result<Self> {
236        Ok(Self::new(Arc::new(RedisSessionStore::new(url, prefix)?)))
237    }
238    
239    pub async fn create(&self, session: Session) -> Result<String> {
240        self.store.create(session).await
241    }
242    
243    pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
244        self.store.get(session_id).await
245    }
246    
247    pub async fn update(&self, session: Session) -> Result<()> {
248        self.store.update(session).await
249    }
250    
251    pub async fn delete(&self, session_id: &str) -> Result<()> {
252        self.store.delete(session_id).await
253    }
254}