use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::Plugin;
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct TrackedSession {
token: String,
user_id: String,
created_at: u64,
last_active: u64,
expires_at: u64,
}
pub struct SessionExpiryPlugin {
max_lifetime: u64,
idle_timeout: u64,
sessions: Mutex<HashMap<String, TrackedSession>>,
}
impl SessionExpiryPlugin {
pub fn new() -> Self {
Self {
max_lifetime: 86400, idle_timeout: 7200, sessions: Mutex::new(HashMap::new()),
}
}
pub fn with_timeouts(max_lifetime: Duration, idle_timeout: Duration) -> Self {
Self {
max_lifetime: max_lifetime.as_secs(),
idle_timeout: idle_timeout.as_secs(),
sessions: Mutex::new(HashMap::new()),
}
}
pub fn track(&self, token: &str, user_id: &str) {
let now = now_secs();
self.sessions.lock().unwrap().insert(
token.to_string(),
TrackedSession {
token: token.to_string(),
user_id: user_id.to_string(),
created_at: now,
last_active: now,
expires_at: now + self.max_lifetime,
},
);
}
pub fn check(&self, token: &str) -> Result<String, String> {
let now = now_secs();
let mut sessions = self.sessions.lock().unwrap();
let session = sessions.get_mut(token).ok_or("Session not found")?;
if now > session.expires_at {
sessions.remove(token);
return Err("Session expired".into());
}
if now - session.last_active > self.idle_timeout {
sessions.remove(token);
return Err("Session timed out due to inactivity".into());
}
session.last_active = now;
Ok(session.user_id.clone())
}
pub fn expire(&self, token: &str) -> bool {
self.sessions.lock().unwrap().remove(token).is_some()
}
pub fn cleanup(&self) -> usize {
let now = now_secs();
let mut sessions = self.sessions.lock().unwrap();
let before = sessions.len();
sessions.retain(|_, s| s.expires_at > now && (now - s.last_active) <= self.idle_timeout);
before - sessions.len()
}
pub fn active_count(&self) -> usize {
self.sessions.lock().unwrap().len()
}
pub fn refresh(&self, token: &str) -> bool {
let now = now_secs();
let mut sessions = self.sessions.lock().unwrap();
if let Some(session) = sessions.get_mut(token) {
let hard_cap = session.created_at.saturating_add(self.max_lifetime);
if now >= hard_cap {
sessions.remove(token);
return false;
}
session.last_active = now;
let proposed = now.saturating_add(self.max_lifetime);
session.expires_at = proposed.min(hard_cap);
true
} else {
false
}
}
}
impl Plugin for SessionExpiryPlugin {
fn name(&self) -> &str {
"session-expiry"
}
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn track_and_check() {
let plugin = SessionExpiryPlugin::new();
plugin.track("token-1", "user-1");
let user_id = plugin.check("token-1").unwrap();
assert_eq!(user_id, "user-1");
}
#[test]
fn unknown_token_fails() {
let plugin = SessionExpiryPlugin::new();
assert!(plugin.check("unknown").is_err());
}
#[test]
fn expire_session() {
let plugin = SessionExpiryPlugin::new();
plugin.track("token-1", "user-1");
assert!(plugin.expire("token-1"));
assert!(plugin.check("token-1").is_err());
}
#[test]
fn active_count() {
let plugin = SessionExpiryPlugin::new();
assert_eq!(plugin.active_count(), 0);
plugin.track("t1", "u1");
plugin.track("t2", "u2");
assert_eq!(plugin.active_count(), 2);
}
#[test]
fn refresh_extends_lifetime() {
let plugin = SessionExpiryPlugin::new();
plugin.track("t1", "u1");
assert!(plugin.refresh("t1"));
assert!(plugin.check("t1").is_ok());
}
#[test]
fn refresh_unknown_returns_false() {
let plugin = SessionExpiryPlugin::new();
assert!(!plugin.refresh("unknown"));
}
#[test]
fn cleanup_removes_expired() {
let plugin = SessionExpiryPlugin::with_timeouts(
Duration::from_secs(86400),
Duration::from_secs(86400),
);
plugin.track("t1", "u1");
let removed = plugin.cleanup();
assert_eq!(removed, 0);
assert_eq!(plugin.active_count(), 1);
}
#[test]
fn custom_timeouts() {
let plugin =
SessionExpiryPlugin::with_timeouts(Duration::from_secs(3600), Duration::from_secs(600));
plugin.track("t1", "u1");
assert!(plugin.check("t1").is_ok());
}
}