use async_trait::async_trait;
use dashmap::DashMap;
use serde::{Serialize, de::DeserializeOwned};
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
#[async_trait]
pub trait SessionStore: Clone + Send + Sync + 'static {
async fn get(&self, session_id: &str) -> Option<String>;
async fn set(&self, session_id: &str, data: &str, ttl: Duration);
async fn remove(&self, session_id: &str);
async fn exists(&self, session_id: &str) -> bool;
async fn refresh(&self, session_id: &str, ttl: Duration);
async fn get_typed<T: DeserializeOwned + Send + Sync>(&self, session_id: &str) -> Option<T> {
let data = self.get(session_id).await?;
serde_json::from_str(&data).ok()
}
async fn set_typed<T: Serialize + Send + Sync>(&self, session_id: &str, data: &T, ttl: Duration) -> bool {
match serde_json::to_string(data) {
Ok(json) => {
self.set(session_id, &json, ttl).await;
true
}
Err(_) => false,
}
}
}
#[derive(Debug, Clone)]
struct SessionEntry {
data: String,
expires_at: Instant,
}
#[derive(Debug, Clone)]
pub struct MemorySessionStore {
storage: Arc<DashMap<String, SessionEntry>>,
cleanup_interval: Duration,
#[allow(dead_code)]
auto_cleanup: bool,
}
impl MemorySessionStore {
pub fn new() -> Self {
Self::with_config(Duration::from_secs(60), true)
}
pub fn with_config(cleanup_interval: Duration, auto_cleanup: bool) -> Self {
let store = Self { storage: Arc::new(DashMap::new()), cleanup_interval, auto_cleanup };
if auto_cleanup {
store.start_cleanup_task();
}
store
}
fn start_cleanup_task(&self) {
let storage = Arc::clone(&self.storage);
let interval = self.cleanup_interval;
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
storage.retain(|_, entry| entry.expires_at > now);
}
});
}
pub fn cleanup_expired(&self) {
let now = Instant::now();
self.storage.retain(|_, entry| entry.expires_at > now);
}
pub fn len(&self) -> usize {
self.storage.len()
}
pub fn is_empty(&self) -> bool {
self.storage.is_empty()
}
pub fn clear(&self) {
self.storage.clear();
}
}
impl Default for MemorySessionStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SessionStore for MemorySessionStore {
async fn get(&self, session_id: &str) -> Option<String> {
self.storage.get(session_id).filter(|entry| entry.expires_at > Instant::now()).map(|entry| entry.data.clone())
}
async fn set(&self, session_id: &str, data: &str, ttl: Duration) {
let entry = SessionEntry { data: data.to_string(), expires_at: Instant::now() + ttl };
self.storage.insert(session_id.to_string(), entry);
}
async fn remove(&self, session_id: &str) {
self.storage.remove(session_id);
}
async fn exists(&self, session_id: &str) -> bool {
self.storage.get(session_id).map(|entry| entry.expires_at > Instant::now()).unwrap_or(false)
}
async fn refresh(&self, session_id: &str, ttl: Duration) {
if let Some(mut entry) = self.storage.get_mut(session_id) {
entry.expires_at = Instant::now() + ttl;
}
}
}