1#[cfg(test)]
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8
9#[cfg(test)]
10use crate::error::AuthError;
11use crate::error::Result;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SessionData {
16 pub user_id: String,
18 pub issued_at: u64,
20 pub expires_at: u64,
22 pub refresh_token_hash: String,
24}
25
26impl SessionData {
27 pub fn is_expired(&self) -> bool {
29 let now = std::time::SystemTime::now()
30 .duration_since(std::time::UNIX_EPOCH)
31 .unwrap_or_default()
32 .as_secs();
33 self.expires_at <= now
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct TokenPair {
40 pub access_token: String,
42 pub refresh_token: String,
44 pub expires_in: u64,
46}
47
48#[async_trait]
110pub trait SessionStore: Send + Sync {
111 async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair>;
123
124 async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData>;
135
136 async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()>;
144
145 async fn revoke_all_sessions(&self, user_id: &str) -> Result<()>;
153}
154
155pub fn hash_token(token: &str) -> String {
161 let mut hasher = Sha256::new();
162 hasher.update(token.as_bytes());
163 format!("{:x}", hasher.finalize())
164}
165
166pub fn generate_refresh_token() -> String {
171 use base64::Engine;
172 use rand::{Rng, rngs::OsRng};
173 let random_bytes: Vec<u8> = (0..32).map(|_| OsRng.gen()).collect();
175 base64::engine::general_purpose::STANDARD.encode(&random_bytes)
176}
177
178#[cfg(test)]
180pub struct InMemorySessionStore {
181 sessions: Arc<dashmap::DashMap<String, SessionData>>,
182}
183
184#[cfg(test)]
185impl InMemorySessionStore {
186 pub fn new() -> Self {
188 Self {
189 sessions: Arc::new(dashmap::DashMap::new()),
190 }
191 }
192
193 pub fn clear(&self) {
195 self.sessions.clear();
196 }
197
198 pub fn len(&self) -> usize {
200 self.sessions.len()
201 }
202
203 pub fn is_empty(&self) -> bool {
205 self.sessions.is_empty()
206 }
207}
208
209#[cfg(test)]
210impl Default for InMemorySessionStore {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[cfg(test)]
217#[async_trait]
221impl SessionStore for InMemorySessionStore {
222 async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
223 let refresh_token = generate_refresh_token();
224 let refresh_token_hash = hash_token(&refresh_token);
225
226 let now = std::time::SystemTime::now()
227 .duration_since(std::time::UNIX_EPOCH)
228 .unwrap_or_default()
229 .as_secs();
230
231 let session = SessionData {
232 user_id: user_id.to_string(),
233 issued_at: now,
234 expires_at,
235 refresh_token_hash: refresh_token_hash.clone(),
236 };
237
238 self.sessions.insert(refresh_token_hash, session);
239
240 let expires_in = expires_at.saturating_sub(now);
241
242 let access_token = format!("access_token_{}", refresh_token);
244
245 Ok(TokenPair {
246 access_token,
247 refresh_token,
248 expires_in,
249 })
250 }
251
252 async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
253 self.sessions
254 .get(refresh_token_hash)
255 .map(|entry| entry.clone())
256 .ok_or(AuthError::TokenNotFound)
257 }
258
259 async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
260 self.sessions.remove(refresh_token_hash).ok_or(AuthError::SessionError {
261 message: "Session not found".to_string(),
262 })?;
263 Ok(())
264 }
265
266 async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
267 let mut to_remove = Vec::new();
268 for entry in self.sessions.iter() {
269 if entry.user_id == user_id {
270 to_remove.push(entry.key().clone());
271 }
272 }
273
274 for key in to_remove {
275 self.sessions.remove(&key);
276 }
277
278 Ok(())
279 }
280}
281
282#[allow(clippy::unwrap_used)] #[cfg(test)]
284mod tests {
285 #[allow(clippy::wildcard_imports)]
286 use super::*;
288
289 #[test]
290 fn test_hash_token() {
291 let token = "my_secret_token";
292 let hash1 = hash_token(token);
293 let hash2 = hash_token(token);
294
295 assert_eq!(hash1, hash2);
297
298 let different_hash = hash_token("different_token");
300 assert_ne!(hash1, different_hash);
301 }
302
303 #[test]
304 fn test_generate_refresh_token() {
305 let token1 = generate_refresh_token();
306 let token2 = generate_refresh_token();
307
308 assert_ne!(token1, token2);
310 assert!(!token1.is_empty());
312 assert!(!token2.is_empty());
313 }
314
315 #[test]
316 fn test_session_data_not_expired() {
317 let now = std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .unwrap_or_default()
320 .as_secs();
321
322 let session = SessionData {
323 user_id: "user123".to_string(),
324 issued_at: now,
325 expires_at: now + 3600,
326 refresh_token_hash: "hash".to_string(),
327 };
328
329 assert!(!session.is_expired());
330 }
331
332 #[test]
333 fn test_session_data_expired() {
334 let now = std::time::SystemTime::now()
335 .duration_since(std::time::UNIX_EPOCH)
336 .unwrap_or_default()
337 .as_secs();
338
339 let session = SessionData {
340 user_id: "user123".to_string(),
341 issued_at: now - 3600,
342 expires_at: now - 100,
343 refresh_token_hash: "hash".to_string(),
344 };
345
346 assert!(session.is_expired());
347 }
348
349 #[tokio::test]
350 async fn test_in_memory_store_create_session() {
351 let store = InMemorySessionStore::new();
352 let now = std::time::SystemTime::now()
353 .duration_since(std::time::UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_secs();
356
357 let result = store.create_session("user123", now + 3600).await;
358 let tokens = result.unwrap_or_else(|e| panic!("expected Ok from create_session: {e}"));
359 assert!(!tokens.access_token.is_empty());
360 assert!(!tokens.refresh_token.is_empty());
361 assert!(tokens.expires_in > 0);
362 }
363
364 #[tokio::test]
365 async fn test_in_memory_store_get_session() {
366 let store = InMemorySessionStore::new();
367 let now = std::time::SystemTime::now()
368 .duration_since(std::time::UNIX_EPOCH)
369 .unwrap_or_default()
370 .as_secs();
371
372 let tokens = store.create_session("user123", now + 3600).await.unwrap();
373 let refresh_token_hash = hash_token(&tokens.refresh_token);
374
375 let session = store
376 .get_session(&refresh_token_hash)
377 .await
378 .unwrap_or_else(|e| panic!("expected Ok from get_session: {e}"));
379 assert_eq!(session.user_id, "user123");
380 }
381
382 #[tokio::test]
383 async fn test_in_memory_store_revoke_session() {
384 let store = InMemorySessionStore::new();
385 let now = std::time::SystemTime::now()
386 .duration_since(std::time::UNIX_EPOCH)
387 .unwrap_or_default()
388 .as_secs();
389
390 let tokens = store.create_session("user123", now + 3600).await.unwrap();
391 let refresh_token_hash = hash_token(&tokens.refresh_token);
392
393 store
394 .revoke_session(&refresh_token_hash)
395 .await
396 .unwrap_or_else(|e| panic!("expected Ok from revoke_session: {e}"));
397
398 let session = store.get_session(&refresh_token_hash).await;
399 assert!(
400 matches!(session, Err(AuthError::TokenNotFound)),
401 "expected TokenNotFound after revocation, got: {session:?}"
402 );
403 }
404
405 #[tokio::test]
406 async fn test_in_memory_store_revoke_all_sessions() {
407 let store = InMemorySessionStore::new();
408 let now = std::time::SystemTime::now()
409 .duration_since(std::time::UNIX_EPOCH)
410 .unwrap_or_default()
411 .as_secs();
412
413 let tokens1 = store.create_session("user123", now + 3600).await.unwrap();
415 let tokens2 = store.create_session("user123", now + 3600).await.unwrap();
416
417 let tokens3 = store.create_session("user456", now + 3600).await.unwrap();
419
420 assert_eq!(store.len(), 3);
421
422 store
424 .revoke_all_sessions("user123")
425 .await
426 .unwrap_or_else(|e| panic!("expected Ok from revoke_all_sessions: {e}"));
427
428 let hash3 = hash_token(&tokens3.refresh_token);
430 store
431 .get_session(&hash3)
432 .await
433 .unwrap_or_else(|e| panic!("expected user456 session to still exist: {e}"));
434
435 let hash1 = hash_token(&tokens1.refresh_token);
437 let hash2 = hash_token(&tokens2.refresh_token);
438 assert!(
439 matches!(store.get_session(&hash1).await, Err(AuthError::TokenNotFound)),
440 "expected user123 session 1 to be revoked"
441 );
442 assert!(
443 matches!(store.get_session(&hash2).await, Err(AuthError::TokenNotFound)),
444 "expected user123 session 2 to be revoked"
445 );
446 }
447}