use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
use crate::error::FrameworkError;
#[derive(Clone, Debug, Default)]
pub struct SessionData {
pub id: String,
pub data: HashMap<String, serde_json::Value>,
pub user_id: Option<i64>,
pub csrf_token: String,
pub dirty: bool,
}
impl SessionData {
pub fn new(id: String, csrf_token: String) -> Self {
Self {
id,
data: HashMap::new(),
user_id: None,
csrf_token,
dirty: false,
}
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.data
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn put<T: Serialize>(&mut self, key: &str, value: T) {
if let Ok(v) = serde_json::to_value(value) {
self.data.insert(key.to_string(), v);
self.dirty = true;
}
}
pub fn forget(&mut self, key: &str) -> Option<serde_json::Value> {
self.dirty = true;
self.data.remove(key)
}
pub fn has(&self, key: &str) -> bool {
self.data.contains_key(key)
}
pub fn flash<T: Serialize>(&mut self, key: &str, value: T) {
self.put(&format!("_flash.new.{key}"), value);
}
pub fn get_flash<T: DeserializeOwned>(&mut self, key: &str) -> Option<T> {
let flash_key = format!("_flash.old.{key}");
let value = self.get(&flash_key);
if value.is_some() {
self.forget(&flash_key);
}
value
}
pub fn age_flash_data(&mut self) {
let old_keys: Vec<String> = self
.data
.keys()
.filter(|k| k.starts_with("_flash.old."))
.cloned()
.collect();
let had_old = !old_keys.is_empty();
for key in old_keys {
self.data.remove(&key);
}
let new_keys: Vec<String> = self
.data
.keys()
.filter(|k| k.starts_with("_flash.new."))
.cloned()
.collect();
let had_new = !new_keys.is_empty();
for key in new_keys {
if let Some(value) = self.data.remove(&key) {
let old_key = key.replace("_flash.new.", "_flash.old.");
self.data.insert(old_key, value);
}
}
if had_new || had_old {
self.dirty = true;
}
}
pub fn flush(&mut self) {
self.data.clear();
self.user_id = None;
self.dirty = true;
}
pub fn is_dirty(&self) -> bool {
self.dirty
}
pub fn mark_clean(&mut self) {
self.dirty = false;
}
}
#[async_trait]
pub trait SessionStore: Send + Sync {
async fn read(&self, id: &str) -> Result<Option<SessionData>, FrameworkError>;
async fn write(&self, session: &SessionData) -> Result<(), FrameworkError>;
async fn destroy(&self, id: &str) -> Result<(), FrameworkError>;
async fn gc(&self) -> Result<u64, FrameworkError>;
async fn destroy_for_user(
&self,
_user_id: i64,
_except_session_id: Option<&str>,
) -> Result<u64, FrameworkError> {
Err(FrameworkError::internal(
"destroy_for_user not supported by this session driver".to_string(),
))
}
}