use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::RwLock;
use uvb_core::TenantId;
use uvb_storage_api::{SessionError, SessionRecord, SessionStore};
pub struct InMemorySessionStore {
sessions: Arc<RwLock<HashMap<String, SessionRecord>>>,
}
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemorySessionStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SessionStore for InMemorySessionStore {
async fn create(&self, record: SessionRecord) -> Result<String, SessionError> {
let id = record.id.clone();
self.sessions.write().await.insert(id.clone(), record);
Ok(id)
}
async fn get(&self, id: &str) -> Result<Option<SessionRecord>, SessionError> {
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(id) {
if session.expires_at <= SystemTime::now() {
return Err(SessionError::Expired);
}
Ok(Some(session.clone()))
} else {
Ok(None)
}
}
async fn update(&self, record: SessionRecord) -> Result<(), SessionError> {
let mut sessions = self.sessions.write().await;
if !sessions.contains_key(&record.id) {
return Err(SessionError::NotFound);
}
sessions.insert(record.id.clone(), record);
Ok(())
}
async fn delete(&self, id: &str) -> Result<(), SessionError> {
self.sessions
.write()
.await
.remove(id)
.ok_or(SessionError::NotFound)?;
Ok(())
}
async fn delete_by_user(
&self,
user_id: &str,
tenant_id: &TenantId,
) -> Result<usize, SessionError> {
let mut sessions = self.sessions.write().await;
let to_delete: Vec<String> = sessions
.iter()
.filter(|(_, s)| s.user_id == user_id && &s.tenant_id == tenant_id)
.map(|(id, _)| id.clone())
.collect();
let count = to_delete.len();
for id in to_delete {
sessions.remove(&id);
}
Ok(count)
}
async fn extend(&self, id: &str, duration: Duration) -> Result<(), SessionError> {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(id) {
session.expires_at = session
.expires_at
.checked_add(duration)
.unwrap_or(session.expires_at);
Ok(())
} else {
Err(SessionError::NotFound)
}
}
async fn touch(&self, id: &str) -> Result<(), SessionError> {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(id) {
session.last_activity_at = SystemTime::now();
Ok(())
} else {
Err(SessionError::NotFound)
}
}
async fn cleanup_expired(&self) -> Result<usize, SessionError> {
let mut sessions = self.sessions.write().await;
let now = SystemTime::now();
let expired: Vec<String> = sessions
.iter()
.filter(|(_, s)| s.expires_at <= now)
.map(|(id, _)| id.clone())
.collect();
let count = expired.len();
for id in expired {
sessions.remove(&id);
}
Ok(count)
}
async fn list_by_user(
&self,
user_id: &str,
tenant_id: &TenantId,
) -> Result<Vec<SessionRecord>, SessionError> {
let sessions = self.sessions.read().await;
let now = SystemTime::now();
Ok(sessions
.values()
.filter(|s| s.user_id == user_id && &s.tenant_id == tenant_id && s.expires_at > now)
.cloned()
.collect())
}
}