use super::backends::SessionBackend;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Clone)]
pub struct Session<B: SessionBackend> {
backend: B,
session_key: Option<String>,
data: HashMap<String, Value>,
is_modified: bool,
is_accessed: bool,
last_activity: Option<chrono::DateTime<chrono::Utc>>,
timeout: u64,
}
impl<B: SessionBackend> Session<B> {
pub fn new(backend: B) -> Self {
Self {
backend,
session_key: None,
data: HashMap::new(),
is_modified: false,
is_accessed: false,
last_activity: Some(chrono::Utc::now()),
timeout: 1800, }
}
pub async fn from_key(
backend: B,
session_key: String,
) -> Result<Self, super::backends::SessionError> {
let data: HashMap<String, Value> = backend
.load(&session_key)
.await?
.unwrap_or_else(HashMap::new);
Ok(Self {
backend,
session_key: Some(session_key),
data,
is_modified: false,
is_accessed: true,
last_activity: Some(chrono::Utc::now()),
timeout: 1800, })
}
pub fn get<T>(&mut self, key: &str) -> Result<Option<T>, serde_json::Error>
where
T: for<'de> Deserialize<'de>,
{
self.is_accessed = true;
self.last_activity = Some(chrono::Utc::now());
self.data
.get(key)
.map(|v| serde_json::from_value(v.clone()))
.transpose()
}
pub fn set<T>(&mut self, key: &str, value: T) -> Result<(), serde_json::Error>
where
T: Serialize,
{
let json_value = serde_json::to_value(value)?;
self.data.insert(key.to_string(), json_value);
self.is_modified = true;
self.is_accessed = true;
self.last_activity = Some(chrono::Utc::now());
Ok(())
}
pub fn delete(&mut self, key: &str) -> Option<Value> {
self.is_modified = true;
self.is_accessed = true;
self.data.remove(key)
}
pub fn contains_key(&self, key: &str) -> bool {
self.data.contains_key(key)
}
pub fn get_or_create_key(&mut self) -> &str {
if self.session_key.is_none() {
self.session_key = Some(Self::generate_key());
}
self.session_key.as_ref().unwrap()
}
pub fn generate_key() -> String {
Uuid::new_v4().to_string()
}
pub async fn flush(&mut self) -> Result<(), super::backends::SessionError> {
if let Some(old_key) = &self.session_key {
self.backend.delete(old_key).await?;
}
self.data.clear();
self.session_key = Some(Self::generate_key());
self.is_modified = true;
Ok(())
}
pub async fn cycle_key(&mut self) -> Result<(), super::backends::SessionError> {
if let Some(old_key) = &self.session_key {
self.backend.delete(old_key).await?;
}
self.session_key = Some(Self::generate_key());
self.is_modified = true;
Ok(())
}
pub async fn regenerate_id(&mut self) -> Result<(), super::backends::SessionError> {
self.cycle_key().await
}
pub async fn save(&mut self) -> Result<(), super::backends::SessionError> {
if !self.is_modified {
return Ok(());
}
let key = self.get_or_create_key().to_string();
self.backend
.save(&key, &self.data, Some(self.timeout))
.await?;
self.is_modified = false;
Ok(())
}
pub fn is_modified(&self) -> bool {
self.is_modified
}
pub fn is_accessed(&self) -> bool {
self.is_accessed
}
pub fn session_key(&self) -> Option<&str> {
self.session_key.as_deref()
}
pub fn keys(&mut self) -> Vec<String> {
self.is_accessed = true;
self.data.keys().cloned().collect()
}
pub fn values(&mut self) -> Vec<&Value> {
self.is_accessed = true;
self.data.values().collect()
}
pub fn items(&mut self) -> Vec<(&String, &Value)> {
self.is_accessed = true;
self.data.iter().collect()
}
pub fn clear(&mut self) {
self.data.clear();
self.is_modified = true;
self.is_accessed = true;
}
pub fn mark_modified(&mut self) {
self.is_modified = true;
}
pub fn mark_unmodified(&mut self) {
self.is_modified = false;
}
pub fn set_timeout(&mut self, timeout: u64) {
self.timeout = timeout;
}
pub fn get_timeout(&self) -> u64 {
self.timeout
}
pub fn update_activity(&mut self) {
self.last_activity = Some(chrono::Utc::now());
self.is_modified = true;
}
pub fn get_last_activity(&self) -> Option<chrono::DateTime<chrono::Utc>> {
self.last_activity
}
pub fn is_timed_out(&self) -> bool {
if let Some(last_activity) = self.last_activity {
let now = chrono::Utc::now();
let elapsed = now.signed_duration_since(last_activity);
elapsed.num_seconds() as u64 > self.timeout
} else {
false
}
}
pub fn validate_timeout(&self) -> Result<(), super::backends::SessionError> {
if self.is_timed_out() {
Err(super::backends::SessionError::SessionExpired)
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sessions::InMemorySessionBackend;
#[tokio::test]
async fn test_session_set_get() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("user_id", 42).unwrap();
let user_id: i32 = session.get("user_id").unwrap().unwrap();
assert_eq!(user_id, 42);
}
#[tokio::test]
async fn test_session_delete() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("key", "value").unwrap();
assert!(session.contains_key("key"));
session.delete("key");
assert!(!session.contains_key("key"));
}
#[tokio::test]
async fn test_session_flush() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("data", "value").unwrap();
let old_key = session.get_or_create_key().to_string();
session.flush().await.unwrap();
assert!(!session.contains_key("data"));
assert_ne!(session.get_or_create_key(), old_key);
}
#[tokio::test]
async fn test_session_cycle_key() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("user_id", 123).unwrap();
let old_key = session.get_or_create_key().to_string();
session.cycle_key().await.unwrap();
let user_id: i32 = session.get("user_id").unwrap().unwrap();
assert_eq!(user_id, 123);
assert_ne!(session.get_or_create_key(), old_key);
}
#[tokio::test]
async fn test_session_is_modified() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
assert!(!session.is_modified());
session.set("key", "value").unwrap();
assert!(session.is_modified());
}
#[tokio::test]
async fn test_session_is_accessed() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
assert!(!session.is_accessed());
session.get::<String>("key").unwrap();
assert!(session.is_accessed());
}
#[tokio::test]
async fn test_session_save() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("data", "test_value").unwrap();
session.save().await.unwrap();
let key = session.session_key().unwrap().to_string();
let loaded: Option<HashMap<String, Value>> = session.backend.load(&key).await.unwrap();
assert!(loaded.is_some());
}
#[tokio::test]
async fn test_session_keys() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("user_id", 123).unwrap();
session.set("username", "alice").unwrap();
session.set("role", "admin").unwrap();
let keys = session.keys();
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"user_id".to_string()));
assert!(keys.contains(&"username".to_string()));
assert!(keys.contains(&"role".to_string()));
}
#[tokio::test]
async fn test_session_keys_marks_accessed() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
assert!(!session.is_accessed());
session.set("key", "value").unwrap();
session.is_accessed = false;
let _keys = session.keys();
assert!(session.is_accessed());
}
#[tokio::test]
async fn test_session_values() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("count", 42).unwrap();
session.set("name", "test").unwrap();
let values = session.values();
assert_eq!(values.len(), 2);
let has_42 = values.iter().any(|v| v.as_i64() == Some(42));
let has_test = values.iter().any(|v| v.as_str() == Some("test"));
assert!(has_42);
assert!(has_test);
}
#[tokio::test]
async fn test_session_values_marks_accessed() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
assert!(!session.is_accessed());
session.set("key", "value").unwrap();
session.is_accessed = false;
let _values = session.values();
assert!(session.is_accessed());
}
#[tokio::test]
async fn test_session_items() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("user_id", 123).unwrap();
session.set("role", "admin").unwrap();
let items = session.items();
assert_eq!(items.len(), 2);
let user_id_item = items.iter().find(|(k, _)| k.as_str() == "user_id").unwrap();
assert_eq!(user_id_item.1.as_i64().unwrap(), 123);
let role_item = items.iter().find(|(k, _)| k.as_str() == "role").unwrap();
assert_eq!(role_item.1.as_str().unwrap(), "admin");
}
#[tokio::test]
async fn test_session_items_marks_accessed() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
assert!(!session.is_accessed());
session.set("key", "value").unwrap();
session.is_accessed = false;
let _items = session.items();
assert!(session.is_accessed());
}
#[tokio::test]
async fn test_session_clear() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("user_id", 123).unwrap();
session.set("username", "alice").unwrap();
assert_eq!(session.keys().len(), 2);
assert!(session.is_modified());
session.clear();
assert_eq!(session.keys().len(), 0);
assert!(session.is_modified());
assert!(session.is_accessed());
}
#[tokio::test]
async fn test_session_clear_preserves_session_key() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("data", "value").unwrap();
let session_key = session.get_or_create_key().to_string();
session.clear();
assert_eq!(session.get_or_create_key(), session_key);
}
#[tokio::test]
async fn test_session_mark_modified() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
assert!(!session.is_modified());
session.mark_modified();
assert!(session.is_modified());
}
#[tokio::test]
async fn test_session_mark_unmodified() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("key", "value").unwrap();
assert!(session.is_modified());
session.mark_unmodified();
assert!(!session.is_modified());
}
#[derive(Clone)]
struct TtlRecordingBackend {
inner: InMemorySessionBackend,
recorded_ttl: Arc<std::sync::Mutex<Option<Option<u64>>>>,
}
impl TtlRecordingBackend {
fn new() -> Self {
Self {
inner: InMemorySessionBackend::new(),
recorded_ttl: Arc::new(std::sync::Mutex::new(None)),
}
}
fn recorded_ttl(&self) -> Option<Option<u64>> {
*self.recorded_ttl.lock().unwrap()
}
}
use std::sync::Arc;
#[async_trait::async_trait]
impl super::super::backends::SessionBackend for TtlRecordingBackend {
async fn load<T>(
&self,
session_key: &str,
) -> Result<Option<T>, super::super::backends::SessionError>
where
T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
{
self.inner.load(session_key).await
}
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), super::super::backends::SessionError>
where
T: Serialize + Send + Sync,
{
*self.recorded_ttl.lock().unwrap() = Some(ttl);
self.inner.save(session_key, data, ttl).await
}
async fn delete(
&self,
session_key: &str,
) -> Result<(), super::super::backends::SessionError> {
self.inner.delete(session_key).await
}
async fn exists(
&self,
session_key: &str,
) -> Result<bool, super::super::backends::SessionError> {
self.inner.exists(session_key).await
}
}
#[rstest::rstest]
#[tokio::test]
async fn test_session_save_forwards_configured_ttl_to_backend() {
let backend = TtlRecordingBackend::new();
let mut session = Session::new(backend.clone());
session.set_timeout(7200);
session.set("key", "value").unwrap();
session.save().await.unwrap();
let recorded = backend.recorded_ttl();
assert_eq!(
recorded,
Some(Some(7200)),
"Session::save() should forward the configured timeout (7200) to backend.save()"
);
}
#[rstest::rstest]
#[tokio::test]
async fn test_session_save_forwards_default_ttl_to_backend() {
let backend = TtlRecordingBackend::new();
let mut session = Session::new(backend.clone());
session.set("key", "value").unwrap();
session.save().await.unwrap();
let recorded = backend.recorded_ttl();
assert_eq!(
recorded,
Some(Some(1800)),
"Session::save() should forward the default timeout (1800) to backend.save()"
);
}
#[tokio::test]
async fn test_session_mark_unmodified_prevents_save() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
session.set("data", "value").unwrap();
assert!(session.is_modified());
session.mark_unmodified();
session.save().await.unwrap();
assert!(session.session_key().is_none());
}
}