use crate::multitenancy::TenantId;
use std::collections::HashMap;
use std::time::{Duration, Instant, SystemTime};
#[derive(Debug, Clone)]
pub struct Session {
pub session_id: String,
pub user_id: String,
pub tenant_id: TenantId,
pub created_at: SystemTime,
pub last_activity: Instant,
pub timeout: Duration,
pub attributes: HashMap<String, String>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub active: bool,
}
impl Session {
pub fn new(user_id: impl Into<String>, tenant_id: impl Into<String>) -> Self {
Session {
session_id: generate_session_id(),
user_id: user_id.into(),
tenant_id: tenant_id.into(),
created_at: SystemTime::now(),
last_activity: Instant::now(),
timeout: Duration::from_secs(3600), attributes: HashMap::new(),
ip_address: None,
user_agent: None,
active: true,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_ip_address(mut self, ip: impl Into<String>) -> Self {
self.ip_address = Some(ip.into());
self
}
pub fn with_user_agent(mut self, ua: impl Into<String>) -> Self {
self.user_agent = Some(ua.into());
self
}
pub fn set_attribute(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.attributes.insert(key.into(), value.into());
}
pub fn get_attribute(&self, key: &str) -> Option<&String> {
self.attributes.get(key)
}
pub fn remove_attribute(&mut self, key: &str) -> Option<String> {
self.attributes.remove(key)
}
pub fn is_expired(&self) -> bool {
!self.active || self.last_activity.elapsed() > self.timeout
}
pub fn refresh(&mut self) {
self.last_activity = Instant::now();
}
pub fn time_remaining(&self) -> Duration {
if self.is_expired() {
Duration::ZERO
} else {
self.timeout.saturating_sub(self.last_activity.elapsed())
}
}
pub fn invalidate(&mut self) {
self.active = false;
}
pub fn duration(&self) -> Duration {
self.created_at.elapsed().unwrap_or(Duration::ZERO)
}
}
#[derive(Debug)]
pub struct SessionStore {
sessions: HashMap<String, Session>,
user_sessions: HashMap<String, Vec<String>>,
max_sessions_per_user: usize,
default_timeout: Duration,
allow_concurrent: bool,
}
impl SessionStore {
pub fn new() -> Self {
SessionStore {
sessions: HashMap::new(),
user_sessions: HashMap::new(),
max_sessions_per_user: 5,
default_timeout: Duration::from_secs(3600),
allow_concurrent: true,
}
}
pub fn with_max_sessions(mut self, max: usize) -> Self {
self.max_sessions_per_user = max;
self
}
pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = timeout;
self
}
pub fn without_concurrent_sessions(mut self) -> Self {
self.allow_concurrent = false;
self
}
pub fn create_session(&mut self, user_id: &str, tenant_id: &str) -> Session {
if !self.allow_concurrent {
self.invalidate_user_sessions(user_id);
}
if let Some(session_ids) = self.user_sessions.get(user_id) {
if session_ids.len() >= self.max_sessions_per_user {
if let Some(oldest_id) = session_ids.first().cloned() {
self.remove_session(&oldest_id);
}
}
}
let session = Session::new(user_id, tenant_id).with_timeout(self.default_timeout);
let session_id = session.session_id.clone();
self.sessions.insert(session_id.clone(), session.clone());
self.user_sessions
.entry(user_id.to_string())
.or_insert_with(Vec::new)
.push(session_id);
session
}
pub fn get_session(&self, session_id: &str) -> Option<&Session> {
self.sessions.get(session_id)
}
pub fn get_session_mut(&mut self, session_id: &str) -> Option<&mut Session> {
self.sessions.get_mut(session_id)
}
pub fn validate_session(&mut self, session_id: &str) -> Option<&Session> {
if let Some(session) = self.sessions.get_mut(session_id) {
if session.is_expired() {
return None;
}
session.refresh();
return self.sessions.get(session_id);
}
None
}
pub fn remove_session(&mut self, session_id: &str) -> Option<Session> {
if let Some(session) = self.sessions.remove(session_id) {
if let Some(session_ids) = self.user_sessions.get_mut(&session.user_id) {
session_ids.retain(|id| id != session_id);
}
return Some(session);
}
None
}
pub fn invalidate_user_sessions(&mut self, user_id: &str) {
if let Some(session_ids) = self.user_sessions.get(user_id) {
for session_id in session_ids.clone() {
if let Some(session) = self.sessions.get_mut(&session_id) {
session.invalidate();
}
}
}
}
pub fn get_user_sessions(&self, user_id: &str) -> Vec<&Session> {
self.user_sessions
.get(user_id)
.map(|session_ids| {
session_ids
.iter()
.filter_map(|id| self.sessions.get(id))
.collect()
})
.unwrap_or_default()
}
pub fn get_active_session_count(&self, user_id: &str) -> usize {
self.get_user_sessions(user_id)
.iter()
.filter(|s| !s.is_expired())
.count()
}
pub fn cleanup_expired(&mut self) {
let expired_ids: Vec<String> = self
.sessions
.iter()
.filter(|(_, session)| session.is_expired())
.map(|(id, _)| id.clone())
.collect();
for session_id in expired_ids {
self.remove_session(&session_id);
}
}
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn active_session_count(&self) -> usize {
self.sessions.values().filter(|s| !s.is_expired()).count()
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SessionContext {
pub session: Session,
pub user_id: String,
pub tenant_id: TenantId,
pub request_data: HashMap<String, String>,
}
impl SessionContext {
pub fn from_session(session: Session) -> Self {
SessionContext {
user_id: session.user_id.clone(),
tenant_id: session.tenant_id.clone(),
session,
request_data: HashMap::new(),
}
}
pub fn set_request_data(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.request_data.insert(key.into(), value.into());
}
pub fn get_request_data(&self, key: &str) -> Option<&String> {
self.request_data.get(key)
}
}
fn generate_session_id() -> String {
use rand::Rng;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
format!(
"sess_{}",
bytes
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>()
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let session = Session::new("user1", "tenant_a");
assert!(session.session_id.starts_with("sess_"));
assert_eq!(session.user_id, "user1");
assert_eq!(session.tenant_id, "tenant_a");
assert!(session.active);
assert!(!session.is_expired());
}
#[test]
fn test_session_expiration() {
let session = Session::new("user1", "tenant_a").with_timeout(Duration::from_millis(1));
std::thread::sleep(Duration::from_millis(10));
assert!(session.is_expired());
}
#[test]
fn test_session_refresh() {
let mut session = Session::new("user1", "tenant_a").with_timeout(Duration::from_secs(1));
std::thread::sleep(Duration::from_millis(100));
let time_before_refresh = session.time_remaining();
session.refresh();
let time_after_refresh = session.time_remaining();
assert!(time_after_refresh > time_before_refresh);
}
#[test]
fn test_session_attributes() {
let mut session = Session::new("user1", "tenant_a");
session.set_attribute("theme", "dark");
session.set_attribute("language", "en");
assert_eq!(session.get_attribute("theme"), Some(&"dark".to_string()));
assert_eq!(session.get_attribute("language"), Some(&"en".to_string()));
assert_eq!(session.get_attribute("missing"), None);
let removed = session.remove_attribute("theme");
assert_eq!(removed, Some("dark".to_string()));
assert_eq!(session.get_attribute("theme"), None);
}
#[test]
fn test_session_store() {
let mut store = SessionStore::new();
let session = store.create_session("user1", "tenant_a");
let session_id = session.session_id.clone();
assert!(store.get_session(&session_id).is_some());
assert!(store.validate_session(&session_id).is_some());
let removed = store.remove_session(&session_id);
assert!(removed.is_some());
assert!(store.get_session(&session_id).is_none());
}
#[test]
fn test_session_store_max_sessions() {
let mut store = SessionStore::new().with_max_sessions(2);
let s1 = store.create_session("user1", "tenant_a");
let s2 = store.create_session("user1", "tenant_a");
let s3 = store.create_session("user1", "tenant_a");
assert!(store.get_session(&s1.session_id).is_none());
assert!(store.get_session(&s2.session_id).is_some());
assert!(store.get_session(&s3.session_id).is_some());
}
#[test]
fn test_session_store_no_concurrent() {
let mut store = SessionStore::new().without_concurrent_sessions();
let s1 = store.create_session("user1", "tenant_a");
let s2 = store.create_session("user1", "tenant_a");
let session1 = store.get_session(&s1.session_id);
assert!(session1.is_none() || !session1.expect("operation should succeed").active);
assert!(store.get_session(&s2.session_id).is_some());
}
#[test]
fn test_session_store_cleanup() {
let mut store = SessionStore::new().with_default_timeout(Duration::from_millis(1));
store.create_session("user1", "tenant_a");
store.create_session("user2", "tenant_b");
std::thread::sleep(Duration::from_millis(10));
store.cleanup_expired();
assert_eq!(store.session_count(), 0);
}
#[test]
fn test_user_sessions() {
let mut store = SessionStore::new();
store.create_session("user1", "tenant_a");
store.create_session("user1", "tenant_a");
store.create_session("user2", "tenant_b");
assert_eq!(store.get_user_sessions("user1").len(), 2);
assert_eq!(store.get_user_sessions("user2").len(), 1);
assert_eq!(store.get_active_session_count("user1"), 2);
}
#[test]
fn test_session_context() {
let session = Session::new("user1", "tenant_a");
let mut ctx = SessionContext::from_session(session);
assert_eq!(ctx.user_id, "user1");
assert_eq!(ctx.tenant_id, "tenant_a");
ctx.set_request_data("correlation_id", "abc123");
assert_eq!(
ctx.get_request_data("correlation_id"),
Some(&"abc123".to_string())
);
}
}