1use async_trait::async_trait;
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9use redis::{Client, AsyncCommands};
10use crate::{AuthError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Session {
15 pub id: String,
16 pub user_id: String,
17 pub created_at: u64,
18 pub expires_at: u64,
19 pub data: HashMap<String, serde_json::Value>,
20}
21
22impl Session {
23 pub fn new(user_id: String, ttl_secs: u64) -> Self {
24 let now = SystemTime::now()
25 .duration_since(UNIX_EPOCH)
26 .unwrap()
27 .as_secs();
28
29 Self {
30 id: Uuid::new_v4().to_string(),
31 user_id,
32 created_at: now,
33 expires_at: now + ttl_secs,
34 data: HashMap::new(),
35 }
36 }
37
38 pub fn is_expired(&self) -> bool {
39 let now = SystemTime::now()
40 .duration_since(UNIX_EPOCH)
41 .unwrap()
42 .as_secs();
43 now >= self.expires_at
44 }
45
46 pub fn renew(&mut self, ttl_secs: u64) {
47 let now = SystemTime::now()
48 .duration_since(UNIX_EPOCH)
49 .unwrap()
50 .as_secs();
51 self.expires_at = now + ttl_secs;
52 }
53
54 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
55 self.data.insert(key, value);
56 }
57
58 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
59 self.data.get(key)
60 }
61}
62
63#[async_trait]
65pub trait SessionStore: Send + Sync {
66 async fn create(&self, session: Session) -> Result<String>;
67 async fn get(&self, session_id: &str) -> Result<Option<Session>>;
68 async fn update(&self, session: Session) -> Result<()>;
69 async fn delete(&self, session_id: &str) -> Result<()>;
70 async fn cleanup_expired(&self) -> Result<usize>;
71}
72
73pub struct InMemorySessionStore {
75 sessions: Arc<RwLock<HashMap<String, Session>>>,
76}
77
78impl InMemorySessionStore {
79 pub fn new() -> Self {
80 Self {
81 sessions: Arc::new(RwLock::new(HashMap::new())),
82 }
83 }
84}
85
86impl Default for InMemorySessionStore {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[async_trait]
93impl SessionStore for InMemorySessionStore {
94 async fn create(&self, session: Session) -> Result<String> {
95 let session_id = session.id.clone();
96 let mut sessions = self.sessions.write().await;
97 sessions.insert(session_id.clone(), session);
98 Ok(session_id)
99 }
100
101 async fn get(&self, session_id: &str) -> Result<Option<Session>> {
102 let sessions = self.sessions.read().await;
103 Ok(sessions.get(session_id).cloned())
104 }
105
106 async fn update(&self, session: Session) -> Result<()> {
107 let mut sessions = self.sessions.write().await;
108 sessions.insert(session.id.clone(), session);
109 Ok(())
110 }
111
112 async fn delete(&self, session_id: &str) -> Result<()> {
113 let mut sessions = self.sessions.write().await;
114 sessions.remove(session_id);
115 Ok(())
116 }
117
118 async fn cleanup_expired(&self) -> Result<usize> {
119 let mut sessions = self.sessions.write().await;
120 let initial_count = sessions.len();
121 sessions.retain(|_, session| !session.is_expired());
122 Ok(initial_count - sessions.len())
123 }
124}
125
126pub struct RedisSessionStore {
128 client: Client,
129 prefix: String,
130}
131
132impl RedisSessionStore {
133 pub fn new(url: &str, prefix: &str) -> Result<Self> {
134 let client = Client::open(url)
135 .map_err(|e| AuthError::HashError(e.to_string()))?;
136
137 Ok(Self {
138 client,
139 prefix: prefix.to_string(),
140 })
141 }
142
143 fn session_key(&self, session_id: &str) -> String {
144 format!("{}:{}", self.prefix, session_id)
145 }
146}
147
148#[async_trait]
149impl SessionStore for RedisSessionStore {
150 async fn create(&self, session: Session) -> Result<String> {
151 let session_id = session.id.clone();
152 let key = self.session_key(&session_id);
153 let ttl = session.expires_at - session.created_at;
154
155 let mut conn = self.client.get_multiplexed_async_connection()
156 .await
157 .map_err(|e| AuthError::HashError(e.to_string()))?;
158
159 let data = serde_json::to_string(&session)
160 .map_err(|e| AuthError::HashError(e.to_string()))?;
161
162 let _: () = conn.set_ex(&key, data, ttl)
163 .await
164 .map_err(|e| AuthError::HashError(e.to_string()))?;
165
166 Ok(session_id)
167 }
168
169 async fn get(&self, session_id: &str) -> Result<Option<Session>> {
170 let key = self.session_key(session_id);
171
172 let mut conn = self.client.get_multiplexed_async_connection()
173 .await
174 .map_err(|e| AuthError::HashError(e.to_string()))?;
175
176 let result: Option<String> = conn.get(&key)
177 .await
178 .map_err(|e| AuthError::HashError(e.to_string()))?;
179
180 if let Some(data) = result {
181 let session: Session = serde_json::from_str(&data)
182 .map_err(|e| AuthError::HashError(e.to_string()))?;
183
184 if session.is_expired() {
185 self.delete(session_id).await?;
186 return Ok(None);
187 }
188
189 Ok(Some(session))
190 } else {
191 Ok(None)
192 }
193 }
194
195 async fn update(&self, session: Session) -> Result<()> {
196 self.create(session).await?;
197 Ok(())
198 }
199
200 async fn delete(&self, session_id: &str) -> Result<()> {
201 let key = self.session_key(session_id);
202
203 let mut conn = self.client.get_multiplexed_async_connection()
204 .await
205 .map_err(|e| AuthError::HashError(e.to_string()))?;
206
207 let _: () = conn.del(&key)
208 .await
209 .map_err(|e| AuthError::HashError(e.to_string()))?;
210
211 Ok(())
212 }
213
214 async fn cleanup_expired(&self) -> Result<usize> {
215 Ok(0)
217 }
218}
219
220#[derive(Clone)]
222pub struct SessionManager {
223 store: Arc<dyn SessionStore>,
224}
225
226impl SessionManager {
227 pub fn new(store: Arc<dyn SessionStore>) -> Self {
228 Self { store }
229 }
230
231 pub fn new_memory() -> Self {
232 Self::new(Arc::new(InMemorySessionStore::new()))
233 }
234
235 pub fn new_redis(url: &str, prefix: &str) -> Result<Self> {
236 Ok(Self::new(Arc::new(RedisSessionStore::new(url, prefix)?)))
237 }
238
239 pub async fn create(&self, session: Session) -> Result<String> {
240 self.store.create(session).await
241 }
242
243 pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
244 self.store.get(session_id).await
245 }
246
247 pub async fn update(&self, session: Session) -> Result<()> {
248 self.store.update(session).await
249 }
250
251 pub async fn delete(&self, session_id: &str) -> Result<()> {
252 self.store.delete(session_id).await
253 }
254}