#[cfg(test)]
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[cfg(test)]
use crate::error::AuthError;
use crate::error::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionData {
pub user_id: String,
pub issued_at: u64,
pub expires_at: u64,
pub refresh_token_hash: String,
}
impl SessionData {
pub fn is_expired(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.expires_at <= now
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
pub expires_in: u64,
}
#[async_trait]
pub trait SessionStore: Send + Sync {
async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair>;
async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData>;
async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()>;
async fn revoke_all_sessions(&self, user_id: &str) -> Result<()>;
}
pub fn hash_token(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn generate_refresh_token() -> String {
use base64::Engine;
use rand::{Rng, rngs::OsRng};
let random_bytes: Vec<u8> = (0..32).map(|_| OsRng.gen()).collect();
base64::engine::general_purpose::STANDARD.encode(&random_bytes)
}
#[cfg(test)]
pub struct InMemorySessionStore {
sessions: Arc<dashmap::DashMap<String, SessionData>>,
}
#[cfg(test)]
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(dashmap::DashMap::new()),
}
}
pub fn clear(&self) {
self.sessions.clear();
}
pub fn len(&self) -> usize {
self.sessions.len()
}
pub fn is_empty(&self) -> bool {
self.sessions.is_empty()
}
}
#[cfg(test)]
impl Default for InMemorySessionStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[async_trait]
impl SessionStore for InMemorySessionStore {
async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
let refresh_token = generate_refresh_token();
let refresh_token_hash = hash_token(&refresh_token);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let session = SessionData {
user_id: user_id.to_string(),
issued_at: now,
expires_at,
refresh_token_hash: refresh_token_hash.clone(),
};
self.sessions.insert(refresh_token_hash, session);
let expires_in = expires_at.saturating_sub(now);
let access_token = format!("access_token_{}", refresh_token);
Ok(TokenPair {
access_token,
refresh_token,
expires_in,
})
}
async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
self.sessions
.get(refresh_token_hash)
.map(|entry| entry.clone())
.ok_or(AuthError::TokenNotFound)
}
async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
self.sessions.remove(refresh_token_hash).ok_or(AuthError::SessionError {
message: "Session not found".to_string(),
})?;
Ok(())
}
async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
let mut to_remove = Vec::new();
for entry in self.sessions.iter() {
if entry.user_id == user_id {
to_remove.push(entry.key().clone());
}
}
for key in to_remove {
self.sessions.remove(&key);
}
Ok(())
}
}
#[allow(clippy::unwrap_used)] #[cfg(test)]
mod tests {
#[allow(clippy::wildcard_imports)]
use super::*;
#[test]
fn test_hash_token() {
let token = "my_secret_token";
let hash1 = hash_token(token);
let hash2 = hash_token(token);
assert_eq!(hash1, hash2);
let different_hash = hash_token("different_token");
assert_ne!(hash1, different_hash);
}
#[test]
fn test_generate_refresh_token() {
let token1 = generate_refresh_token();
let token2 = generate_refresh_token();
assert_ne!(token1, token2);
assert!(!token1.is_empty());
assert!(!token2.is_empty());
}
#[test]
fn test_session_data_not_expired() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let session = SessionData {
user_id: "user123".to_string(),
issued_at: now,
expires_at: now + 3600,
refresh_token_hash: "hash".to_string(),
};
assert!(!session.is_expired());
}
#[test]
fn test_session_data_expired() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let session = SessionData {
user_id: "user123".to_string(),
issued_at: now - 3600,
expires_at: now - 100,
refresh_token_hash: "hash".to_string(),
};
assert!(session.is_expired());
}
#[tokio::test]
async fn test_in_memory_store_create_session() {
let store = InMemorySessionStore::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let result = store.create_session("user123", now + 3600).await;
let tokens = result.unwrap_or_else(|e| panic!("expected Ok from create_session: {e}"));
assert!(!tokens.access_token.is_empty());
assert!(!tokens.refresh_token.is_empty());
assert!(tokens.expires_in > 0);
}
#[tokio::test]
async fn test_in_memory_store_get_session() {
let store = InMemorySessionStore::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let tokens = store.create_session("user123", now + 3600).await.unwrap();
let refresh_token_hash = hash_token(&tokens.refresh_token);
let session = store
.get_session(&refresh_token_hash)
.await
.unwrap_or_else(|e| panic!("expected Ok from get_session: {e}"));
assert_eq!(session.user_id, "user123");
}
#[tokio::test]
async fn test_in_memory_store_revoke_session() {
let store = InMemorySessionStore::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let tokens = store.create_session("user123", now + 3600).await.unwrap();
let refresh_token_hash = hash_token(&tokens.refresh_token);
store
.revoke_session(&refresh_token_hash)
.await
.unwrap_or_else(|e| panic!("expected Ok from revoke_session: {e}"));
let session = store.get_session(&refresh_token_hash).await;
assert!(
matches!(session, Err(AuthError::TokenNotFound)),
"expected TokenNotFound after revocation, got: {session:?}"
);
}
#[tokio::test]
async fn test_in_memory_store_revoke_all_sessions() {
let store = InMemorySessionStore::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let tokens1 = store.create_session("user123", now + 3600).await.unwrap();
let tokens2 = store.create_session("user123", now + 3600).await.unwrap();
let tokens3 = store.create_session("user456", now + 3600).await.unwrap();
assert_eq!(store.len(), 3);
store
.revoke_all_sessions("user123")
.await
.unwrap_or_else(|e| panic!("expected Ok from revoke_all_sessions: {e}"));
let hash3 = hash_token(&tokens3.refresh_token);
store
.get_session(&hash3)
.await
.unwrap_or_else(|e| panic!("expected user456 session to still exist: {e}"));
let hash1 = hash_token(&tokens1.refresh_token);
let hash2 = hash_token(&tokens2.refresh_token);
assert!(
matches!(store.get_session(&hash1).await, Err(AuthError::TokenNotFound)),
"expected user123 session 1 to be revoked"
);
assert!(
matches!(store.get_session(&hash2).await, Err(AuthError::TokenNotFound)),
"expected user123 session 2 to be revoked"
);
}
}