fraiseql_server/auth/
session.rs1#[cfg(test)]
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10#[cfg(test)]
11use crate::auth::error::AuthError;
12use crate::auth::error::Result;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionData {
17 pub user_id: String,
19 pub issued_at: u64,
21 pub expires_at: u64,
23 pub refresh_token_hash: String,
25}
26
27impl SessionData {
28 pub fn is_expired(&self) -> bool {
30 let now = std::time::SystemTime::now()
31 .duration_since(std::time::UNIX_EPOCH)
32 .unwrap_or_default()
33 .as_secs();
34 self.expires_at <= now
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TokenPair {
41 pub access_token: String,
43 pub refresh_token: String,
45 pub expires_in: u64,
47}
48
49#[async_trait]
79pub trait SessionStore: Send + Sync {
80 async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair>;
92
93 async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData>;
104
105 async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()>;
113
114 async fn revoke_all_sessions(&self, user_id: &str) -> Result<()>;
122}
123
124pub fn hash_token(token: &str) -> String {
126 let mut hasher = Sha256::new();
127 hasher.update(token.as_bytes());
128 format!("{:x}", hasher.finalize())
129}
130
131pub fn generate_refresh_token() -> String {
133 use base64::Engine;
134 let mut rng = rand::thread_rng();
135 let random_bytes: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
136 base64::engine::general_purpose::STANDARD.encode(&random_bytes)
137}
138
139#[cfg(test)]
141pub struct InMemorySessionStore {
142 sessions: Arc<dashmap::DashMap<String, SessionData>>,
143}
144
145#[cfg(test)]
146impl InMemorySessionStore {
147 pub fn new() -> Self {
149 Self {
150 sessions: Arc::new(dashmap::DashMap::new()),
151 }
152 }
153
154 pub fn clear(&self) {
156 self.sessions.clear();
157 }
158
159 pub fn len(&self) -> usize {
161 self.sessions.len()
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.sessions.is_empty()
167 }
168}
169
170#[cfg(test)]
171impl Default for InMemorySessionStore {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(test)]
178#[async_trait]
179impl SessionStore for InMemorySessionStore {
180 async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
181 let refresh_token = generate_refresh_token();
182 let refresh_token_hash = hash_token(&refresh_token);
183
184 let now = std::time::SystemTime::now()
185 .duration_since(std::time::UNIX_EPOCH)
186 .unwrap_or_default()
187 .as_secs();
188
189 let session = SessionData {
190 user_id: user_id.to_string(),
191 issued_at: now,
192 expires_at,
193 refresh_token_hash: refresh_token_hash.clone(),
194 };
195
196 self.sessions.insert(refresh_token_hash, session);
197
198 let expires_in = expires_at.saturating_sub(now);
199
200 let access_token = format!("access_token_{}", refresh_token);
202
203 Ok(TokenPair {
204 access_token,
205 refresh_token,
206 expires_in,
207 })
208 }
209
210 async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
211 self.sessions
212 .get(refresh_token_hash)
213 .map(|entry| entry.clone())
214 .ok_or(AuthError::TokenNotFound)
215 }
216
217 async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
218 self.sessions.remove(refresh_token_hash).ok_or(AuthError::SessionError {
219 message: "Session not found".to_string(),
220 })?;
221 Ok(())
222 }
223
224 async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
225 let mut to_remove = Vec::new();
226 for entry in self.sessions.iter() {
227 if entry.user_id == user_id {
228 to_remove.push(entry.key().clone());
229 }
230 }
231
232 for key in to_remove {
233 self.sessions.remove(&key);
234 }
235
236 Ok(())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_hash_token() {
246 let token = "my_secret_token";
247 let hash1 = hash_token(token);
248 let hash2 = hash_token(token);
249
250 assert_eq!(hash1, hash2);
252
253 let different_hash = hash_token("different_token");
255 assert_ne!(hash1, different_hash);
256 }
257
258 #[test]
259 fn test_generate_refresh_token() {
260 let token1 = generate_refresh_token();
261 let token2 = generate_refresh_token();
262
263 assert_ne!(token1, token2);
265 assert!(!token1.is_empty());
267 assert!(!token2.is_empty());
268 }
269
270 #[test]
271 fn test_session_data_not_expired() {
272 let now = std::time::SystemTime::now()
273 .duration_since(std::time::UNIX_EPOCH)
274 .unwrap_or_default()
275 .as_secs();
276
277 let session = SessionData {
278 user_id: "user123".to_string(),
279 issued_at: now,
280 expires_at: now + 3600,
281 refresh_token_hash: "hash".to_string(),
282 };
283
284 assert!(!session.is_expired());
285 }
286
287 #[test]
288 fn test_session_data_expired() {
289 let now = std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .unwrap_or_default()
292 .as_secs();
293
294 let session = SessionData {
295 user_id: "user123".to_string(),
296 issued_at: now - 3600,
297 expires_at: now - 100,
298 refresh_token_hash: "hash".to_string(),
299 };
300
301 assert!(session.is_expired());
302 }
303
304 #[tokio::test]
305 async fn test_in_memory_store_create_session() {
306 let store = InMemorySessionStore::new();
307 let now = std::time::SystemTime::now()
308 .duration_since(std::time::UNIX_EPOCH)
309 .unwrap_or_default()
310 .as_secs();
311
312 let result = store.create_session("user123", now + 3600).await;
313 assert!(result.is_ok());
314
315 let tokens = result.unwrap();
316 assert!(!tokens.access_token.is_empty());
317 assert!(!tokens.refresh_token.is_empty());
318 assert!(tokens.expires_in > 0);
319 }
320
321 #[tokio::test]
322 async fn test_in_memory_store_get_session() {
323 let store = InMemorySessionStore::new();
324 let now = std::time::SystemTime::now()
325 .duration_since(std::time::UNIX_EPOCH)
326 .unwrap_or_default()
327 .as_secs();
328
329 let tokens = store.create_session("user123", now + 3600).await.unwrap();
330 let refresh_token_hash = hash_token(&tokens.refresh_token);
331
332 let session = store.get_session(&refresh_token_hash).await;
333 assert!(session.is_ok());
334 assert_eq!(session.unwrap().user_id, "user123");
335 }
336
337 #[tokio::test]
338 async fn test_in_memory_store_revoke_session() {
339 let store = InMemorySessionStore::new();
340 let now = std::time::SystemTime::now()
341 .duration_since(std::time::UNIX_EPOCH)
342 .unwrap_or_default()
343 .as_secs();
344
345 let tokens = store.create_session("user123", now + 3600).await.unwrap();
346 let refresh_token_hash = hash_token(&tokens.refresh_token);
347
348 assert!(store.revoke_session(&refresh_token_hash).await.is_ok());
349
350 let session = store.get_session(&refresh_token_hash).await;
351 assert!(session.is_err());
352 }
353
354 #[tokio::test]
355 async fn test_in_memory_store_revoke_all_sessions() {
356 let store = InMemorySessionStore::new();
357 let now = std::time::SystemTime::now()
358 .duration_since(std::time::UNIX_EPOCH)
359 .unwrap_or_default()
360 .as_secs();
361
362 let tokens1 = store.create_session("user123", now + 3600).await.unwrap();
364 let tokens2 = store.create_session("user123", now + 3600).await.unwrap();
365
366 let tokens3 = store.create_session("user456", now + 3600).await.unwrap();
368
369 assert_eq!(store.len(), 3);
370
371 assert!(store.revoke_all_sessions("user123").await.is_ok());
373
374 let hash3 = hash_token(&tokens3.refresh_token);
376 assert!(store.get_session(&hash3).await.is_ok());
377
378 let hash1 = hash_token(&tokens1.refresh_token);
380 let hash2 = hash_token(&tokens2.refresh_token);
381 assert!(store.get_session(&hash1).await.is_err());
382 assert!(store.get_session(&hash2).await.is_err());
383 }
384}