use async_trait::async_trait;
use chrono::{Duration, Utc};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::error::{Error, Result, StorageError};
use crate::random::generate_random_base64_url;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user_id: String,
pub created_at: i64,
pub last_accessed_at: i64,
pub expires_at: i64,
pub user_agent: Option<String>,
pub ip_address: Option<String>,
pub metadata: serde_json::Value,
}
impl Session {
fn new(user_id: impl Into<String>, expires_in: Duration) -> Result<Self> {
let now = Utc::now().timestamp();
let session_id = generate_random_base64_url(32)?;
Ok(Self {
id: session_id,
user_id: user_id.into(),
created_at: now,
last_accessed_at: now,
expires_at: now + expires_in.num_seconds(),
user_agent: None,
ip_address: None,
metadata: serde_json::Value::Object(serde_json::Map::new()),
})
}
#[inline]
pub fn is_expired(&self) -> bool {
Utc::now().timestamp() > self.expires_at
}
#[inline]
pub fn is_valid(&self) -> bool {
!self.is_expired()
}
pub fn time_to_live(&self) -> i64 {
(self.expires_at - Utc::now().timestamp()).max(0)
}
pub fn touch(&mut self) {
self.last_accessed_at = Utc::now().timestamp();
}
pub fn extend(&mut self, duration: Duration) {
self.expires_at = Utc::now().timestamp() + duration.num_seconds();
}
pub fn set_metadata<T: Serialize>(&mut self, key: impl Into<String>, value: T) {
if let Ok(json_value) = serde_json::to_value(value)
&& let Some(obj) = self.metadata.as_object_mut()
{
obj.insert(key.into(), json_value);
}
}
pub fn get_metadata<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.metadata
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn get_metadata_raw(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.get(key)
}
pub fn has_metadata(&self, key: &str) -> bool {
self.metadata.get(key).is_some()
}
pub fn remove_metadata(&mut self, key: &str) -> Option<serde_json::Value> {
self.metadata
.as_object_mut()
.and_then(|obj| obj.remove(key))
}
pub fn clear_metadata(&mut self) {
self.metadata = serde_json::Value::Object(serde_json::Map::new());
}
pub fn metadata_keys(&self) -> Vec<&str> {
self.metadata
.as_object()
.map(|obj| obj.keys().map(|k| k.as_str()).collect())
.unwrap_or_default()
}
pub fn merge_metadata(&mut self, other: serde_json::Value) {
if let (Some(current), Some(other_obj)) = (self.metadata.as_object_mut(), other.as_object())
{
for (key, value) in other_obj {
current.insert(key.clone(), value.clone());
}
}
}
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub expiration: Duration,
pub sliding_expiration: bool,
pub max_sessions_per_user: usize,
pub id_length: usize,
pub validate_ip: bool,
pub validate_user_agent: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
expiration: Duration::hours(24),
sliding_expiration: true,
max_sessions_per_user: 5,
id_length: 32,
validate_ip: false,
validate_user_agent: false,
}
}
}
impl SessionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_expiration(mut self, duration: Duration) -> Self {
self.expiration = duration;
self
}
pub fn with_sliding_expiration(mut self, enabled: bool) -> Self {
self.sliding_expiration = enabled;
self
}
pub fn with_max_sessions_per_user(mut self, max: usize) -> Self {
self.max_sessions_per_user = max;
self
}
pub fn with_ip_validation(mut self, enabled: bool) -> Self {
self.validate_ip = enabled;
self
}
pub fn with_user_agent_validation(mut self, enabled: bool) -> Self {
self.validate_user_agent = enabled;
self
}
pub fn short_lived() -> Self {
Self {
expiration: Duration::hours(1),
sliding_expiration: false,
max_sessions_per_user: 10,
..Default::default()
}
}
pub fn long_lived() -> Self {
Self {
expiration: Duration::days(30),
sliding_expiration: true,
max_sessions_per_user: 3,
..Default::default()
}
}
}
#[async_trait]
pub trait SessionStore: Send + Sync {
async fn save(&self, session: &Session) -> Result<()>;
async fn get(&self, session_id: &str) -> Result<Option<Session>>;
async fn update(&self, session: &Session) -> Result<()>;
async fn delete(&self, session_id: &str) -> Result<()>;
async fn get_by_user(&self, user_id: &str) -> Result<Vec<Session>>;
async fn delete_by_user(&self, user_id: &str) -> Result<usize>;
async fn cleanup_expired(&self) -> Result<usize>;
async fn count(&self) -> Result<usize>;
}
#[derive(Debug, Default)]
pub struct InMemorySessionStore {
sessions: RwLock<HashMap<String, Session>>,
}
impl InMemorySessionStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl SessionStore for InMemorySessionStore {
async fn save(&self, session: &Session) -> Result<()> {
let mut sessions = self
.sessions
.write()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
sessions.insert(session.id.clone(), session.clone());
Ok(())
}
async fn get(&self, session_id: &str) -> Result<Option<Session>> {
let sessions = self
.sessions
.read()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
Ok(sessions.get(session_id).cloned())
}
async fn update(&self, session: &Session) -> Result<()> {
let mut sessions = self
.sessions
.write()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
if sessions.contains_key(&session.id) {
sessions.insert(session.id.clone(), session.clone());
Ok(())
} else {
Err(Error::Storage(StorageError::NotFound(format!(
"session {}",
session.id
))))
}
}
async fn delete(&self, session_id: &str) -> Result<()> {
let mut sessions = self
.sessions
.write()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
sessions.remove(session_id);
Ok(())
}
async fn get_by_user(&self, user_id: &str) -> Result<Vec<Session>> {
let sessions = self
.sessions
.read()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
Ok(sessions
.values()
.filter(|s| s.user_id == user_id)
.cloned()
.collect())
}
async fn delete_by_user(&self, user_id: &str) -> Result<usize> {
let mut sessions = self
.sessions
.write()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
let to_delete: Vec<String> = sessions
.iter()
.filter(|(_, s)| s.user_id == user_id)
.map(|(id, _)| id.clone())
.collect();
let count = to_delete.len();
for id in to_delete {
sessions.remove(&id);
}
Ok(count)
}
async fn cleanup_expired(&self) -> Result<usize> {
let mut sessions = self
.sessions
.write()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
let now = Utc::now().timestamp();
let to_delete: Vec<String> = sessions
.iter()
.filter(|(_, s)| s.expires_at < now)
.map(|(id, _)| id.clone())
.collect();
let count = to_delete.len();
for id in to_delete {
sessions.remove(&id);
}
Ok(count)
}
async fn count(&self) -> Result<usize> {
let sessions = self
.sessions
.read()
.map_err(|_| Error::Storage(StorageError::OperationFailed("lock poisoned".into())))?;
Ok(sessions.len())
}
}
pub struct SessionManager {
store: Arc<dyn SessionStore>,
config: SessionConfig,
}
impl SessionManager {
pub fn new(config: SessionConfig) -> Self {
Self {
store: Arc::new(InMemorySessionStore::new()),
config,
}
}
pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
Self { store, config }
}
pub async fn create(&self, user_id: impl Into<String>) -> Result<Session> {
let user_id = user_id.into();
self.enforce_max_sessions(&user_id).await?;
let session = Session::new(&user_id, self.config.expiration)?;
self.store.save(&session).await?;
Ok(session)
}
pub async fn create_with_options(
&self,
user_id: impl Into<String>,
options: CreateSessionOptions,
) -> Result<Session> {
let user_id = user_id.into();
self.enforce_max_sessions(&user_id).await?;
let expiration = options.custom_expiration.unwrap_or(self.config.expiration);
let mut session = Session::new(&user_id, expiration)?;
session.user_agent = options.user_agent;
session.ip_address = options.ip_address;
if let Some(metadata) = options.metadata {
session.metadata = metadata;
}
self.store.save(&session).await?;
Ok(session)
}
pub async fn get(&self, session_id: &str) -> Option<Session> {
let session = self.store.get(session_id).await.ok()??;
if session.is_expired() {
let _ = self.store.delete(session_id).await;
return None;
}
if self.config.sliding_expiration {
let mut updated = session.clone();
updated.touch();
updated.extend(self.config.expiration);
let _ = self.store.update(&updated).await;
return Some(updated);
}
Some(session)
}
pub async fn validate(
&self,
session_id: &str,
ip_address: Option<&str>,
user_agent: Option<&str>,
) -> Result<Session> {
let session = self
.store
.get(session_id)
.await?
.ok_or_else(|| Error::Storage(StorageError::NotFound("session".into())))?;
if session.is_expired() {
self.store.delete(session_id).await?;
return Err(Error::validation("session expired"));
}
if self.config.validate_ip
&& let (Some(stored_ip), Some(request_ip)) = (&session.ip_address, ip_address)
&& stored_ip != request_ip
{
return Err(Error::validation("IP address mismatch"));
}
if self.config.validate_user_agent
&& let (Some(stored_ua), Some(request_ua)) = (&session.user_agent, user_agent)
&& stored_ua != request_ua
{
return Err(Error::validation("User-Agent mismatch"));
}
Ok(session)
}
pub async fn update(&self, session: &Session) -> Result<()> {
self.store.update(session).await
}
pub async fn destroy(&self, session_id: &str) -> Result<()> {
self.store.delete(session_id).await
}
pub async fn destroy_all_for_user(&self, user_id: &str) -> Result<usize> {
self.store.delete_by_user(user_id).await
}
pub async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>> {
self.store.get_by_user(user_id).await
}
pub async fn cleanup(&self) -> Result<usize> {
self.store.cleanup_expired().await
}
pub async fn count(&self) -> Result<usize> {
self.store.count().await
}
pub async fn refresh(&self, session_id: &str) -> Result<Session> {
let mut session = self
.store
.get(session_id)
.await?
.ok_or_else(|| Error::Storage(StorageError::NotFound("session".into())))?;
if session.is_expired() {
self.store.delete(session_id).await?;
return Err(Error::validation("session expired"));
}
session.touch();
session.extend(self.config.expiration);
self.store.update(&session).await?;
Ok(session)
}
pub fn config(&self) -> &SessionConfig {
&self.config
}
async fn enforce_max_sessions(&self, user_id: &str) -> Result<()> {
if self.config.max_sessions_per_user == 0 {
return Ok(());
}
let sessions = self.store.get_by_user(user_id).await?;
if sessions.len() >= self.config.max_sessions_per_user {
if let Some(oldest) = sessions.iter().min_by_key(|s| s.created_at) {
self.store.delete(&oldest.id).await?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct CreateSessionOptions {
pub user_agent: Option<String>,
pub ip_address: Option<String>,
pub metadata: Option<serde_json::Value>,
pub custom_expiration: Option<Duration>,
}
impl CreateSessionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = Some(user_agent.into());
self
}
pub fn with_ip_address(mut self, ip: impl Into<String>) -> Self {
self.ip_address = Some(ip.into());
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_metadata_from<T: Serialize>(mut self, data: T) -> Self {
self.metadata = serde_json::to_value(data).ok();
self
}
pub fn with_expiration(mut self, duration: Duration) -> Self {
self.custom_expiration = Some(duration);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let session = Session::new("user123", Duration::hours(1)).unwrap();
assert_eq!(session.user_id, "user123");
assert!(!session.is_expired());
assert!(session.is_valid());
}
#[test]
fn test_session_expiration() {
let mut session = Session::new("user123", Duration::seconds(-1)).unwrap();
session.expires_at = Utc::now().timestamp() - 1;
assert!(session.is_expired());
assert!(!session.is_valid());
}
#[test]
fn test_session_metadata() {
let mut session = Session::new("user123", Duration::hours(1)).unwrap();
session.set_metadata("role", "admin");
let role: Option<String> = session.get_metadata("role");
assert_eq!(role, Some("admin".to_string()));
session.set_metadata("count", 42);
let count: Option<i32> = session.get_metadata("count");
assert_eq!(count, Some(42));
session.set_metadata("tags", vec!["a", "b", "c"]);
let tags: Option<Vec<String>> = session.get_metadata("tags");
assert_eq!(
tags,
Some(vec!["a".to_string(), "b".to_string(), "c".to_string()])
);
assert!(session.has_metadata("role"));
assert!(!session.has_metadata("nonexistent"));
session.remove_metadata("role");
assert!(!session.has_metadata("role"));
let keys = session.metadata_keys();
assert!(keys.contains(&"count"));
assert!(keys.contains(&"tags"));
}
#[test]
fn test_session_metadata_complex_types() {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct UserProfile {
name: String,
age: u32,
}
let mut session = Session::new("user123", Duration::hours(1)).unwrap();
let profile = UserProfile {
name: "张三".to_string(),
age: 30,
};
session.set_metadata("profile", profile.clone());
let retrieved: Option<UserProfile> = session.get_metadata("profile");
assert_eq!(retrieved, Some(profile));
}
#[tokio::test]
async fn test_session_manager_create() {
let manager = SessionManager::new(SessionConfig::default());
let session = manager.create("user123").await.unwrap();
assert_eq!(session.user_id, "user123");
}
#[tokio::test]
async fn test_session_manager_get() {
let manager = SessionManager::new(SessionConfig::default());
let session = manager.create("user123").await.unwrap();
let retrieved = manager.get(&session.id).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().user_id, "user123");
}
#[tokio::test]
async fn test_session_manager_destroy() {
let manager = SessionManager::new(SessionConfig::default());
let session = manager.create("user123").await.unwrap();
manager.destroy(&session.id).await.unwrap();
assert!(manager.get(&session.id).await.is_none());
}
#[tokio::test]
async fn test_session_manager_destroy_all_for_user() {
let config = SessionConfig::default().with_max_sessions_per_user(0);
let manager = SessionManager::new(config);
manager.create("user123").await.unwrap();
manager.create("user123").await.unwrap();
manager.create("user456").await.unwrap();
let count = manager.destroy_all_for_user("user123").await.unwrap();
assert_eq!(count, 2);
let remaining = manager.get_user_sessions("user123").await.unwrap();
assert!(remaining.is_empty());
}
#[tokio::test]
async fn test_session_manager_max_sessions() {
let config = SessionConfig::default().with_max_sessions_per_user(2);
let manager = SessionManager::new(config);
let s1 = manager.create("user123").await.unwrap();
let s2 = manager.create("user123").await.unwrap();
let s3 = manager.create("user123").await.unwrap();
let sessions = manager.get_user_sessions("user123").await.unwrap();
assert_eq!(sessions.len(), 2);
let session_ids: Vec<_> = sessions.iter().map(|s| s.id.clone()).collect();
assert!(session_ids.contains(&s3.id));
let deleted_count = [&s1.id, &s2.id]
.iter()
.filter(|id| !session_ids.contains(*id))
.count();
assert_eq!(deleted_count, 1);
}
#[tokio::test]
async fn test_session_manager_sliding_expiration() {
let config = SessionConfig::default()
.with_expiration(Duration::hours(1))
.with_sliding_expiration(true);
let manager = SessionManager::new(config);
let session = manager.create("user123").await.unwrap();
let original_expires = session.expires_at;
std::thread::sleep(std::time::Duration::from_millis(10));
let retrieved = manager.get(&session.id).await.unwrap();
assert!(retrieved.expires_at >= original_expires);
}
#[tokio::test]
async fn test_session_manager_validate() {
let config = SessionConfig::default()
.with_ip_validation(true)
.with_user_agent_validation(true);
let manager = SessionManager::new(config);
let options = CreateSessionOptions::new()
.with_ip_address("192.168.1.1")
.with_user_agent("TestBrowser");
let session = manager
.create_with_options("user123", options)
.await
.unwrap();
assert!(
manager
.validate(&session.id, Some("192.168.1.1"), Some("TestBrowser"))
.await
.is_ok()
);
assert!(
manager
.validate(&session.id, Some("10.0.0.1"), Some("TestBrowser"))
.await
.is_err()
);
}
#[tokio::test]
async fn test_session_cleanup() {
let manager = SessionManager::new(SessionConfig::default());
let mut session = Session::new("user123", Duration::hours(1)).unwrap();
session.expires_at = Utc::now().timestamp() - 100;
manager.store.save(&session).await.unwrap();
manager.create("user456").await.unwrap();
let cleaned = manager.cleanup().await.unwrap();
assert_eq!(cleaned, 1);
assert_eq!(manager.count().await.unwrap(), 1);
}
#[tokio::test]
async fn test_in_memory_store() {
let store = InMemorySessionStore::new();
let session = Session::new("user123", Duration::hours(1)).unwrap();
store.save(&session).await.unwrap();
let retrieved = store.get(&session.id).await.unwrap();
assert!(retrieved.is_some());
store.delete(&session.id).await.unwrap();
let retrieved = store.get(&session.id).await.unwrap();
assert!(retrieved.is_none());
}
#[test]
fn test_session_time_to_live() {
let session = Session::new("user123", Duration::hours(1)).unwrap();
let ttl = session.time_to_live();
assert!(ttl > 3500 && ttl <= 3600);
}
#[test]
fn test_session_touch() {
let mut session = Session::new("user123", Duration::hours(1)).unwrap();
let original_accessed = session.last_accessed_at;
std::thread::sleep(std::time::Duration::from_millis(10));
session.touch();
assert!(session.last_accessed_at >= original_accessed);
}
#[test]
fn test_create_session_options() {
let options = CreateSessionOptions::new()
.with_user_agent("Mozilla/5.0")
.with_ip_address("192.168.1.1")
.with_expiration(Duration::hours(2));
assert_eq!(options.user_agent, Some("Mozilla/5.0".to_string()));
assert_eq!(options.ip_address, Some("192.168.1.1".to_string()));
assert!(options.custom_expiration.is_some());
}
#[test]
fn test_session_config_presets() {
let short = SessionConfig::short_lived();
assert_eq!(short.expiration, Duration::hours(1));
assert!(!short.sliding_expiration);
let long = SessionConfig::long_lived();
assert_eq!(long.expiration, Duration::days(30));
assert!(long.sliding_expiration);
}
#[tokio::test]
async fn test_refresh_session() {
let manager = SessionManager::new(SessionConfig::default());
let session = manager.create("user123").await.unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
let refreshed = manager.refresh(&session.id).await.unwrap();
assert!(refreshed.last_accessed_at >= session.last_accessed_at);
}
#[test]
fn test_session_clear_metadata() {
let mut session = Session::new("user123", Duration::hours(1)).unwrap();
session.set_metadata("key1", "value1");
session.set_metadata("key2", "value2");
assert!(session.has_metadata("key1"));
session.clear_metadata();
assert!(!session.has_metadata("key1"));
assert!(session.metadata_keys().is_empty());
}
#[test]
fn test_session_merge_metadata() {
let mut session = Session::new("user123", Duration::hours(1)).unwrap();
session.set_metadata("existing", "value");
let additional = serde_json::json!({
"new_key": "new_value",
"another": 123
});
session.merge_metadata(additional);
assert_eq!(
session.get_metadata::<String>("existing"),
Some("value".to_string())
);
assert_eq!(
session.get_metadata::<String>("new_key"),
Some("new_value".to_string())
);
assert_eq!(session.get_metadata::<i32>("another"), Some(123));
}
}