use dashmap::DashMap;
use mcpkit_core::capability::ClientCapabilities;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
#[derive(Debug, Clone)]
pub struct Session {
pub id: String,
pub created_at: Instant,
pub last_active: Instant,
pub initialized: bool,
pub client_capabilities: Option<ClientCapabilities>,
}
impl Session {
#[must_use]
pub fn new(id: String) -> Self {
let now = Instant::now();
Self {
id,
created_at: now,
last_active: now,
initialized: false,
client_capabilities: None,
}
}
#[must_use]
pub fn is_expired(&self, timeout: Duration) -> bool {
self.last_active.elapsed() >= timeout
}
pub fn touch(&mut self) {
self.last_active = Instant::now();
}
pub fn mark_initialized(&mut self, capabilities: Option<ClientCapabilities>) {
self.initialized = true;
self.client_capabilities = capabilities;
}
}
pub struct SessionManager {
sessions: DashMap<String, broadcast::Sender<String>>,
}
impl Default for SessionManager {
fn default() -> Self {
Self::new()
}
}
impl SessionManager {
#[must_use]
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
}
}
#[must_use]
pub fn create_session(&self) -> (String, broadcast::Receiver<String>) {
let id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = broadcast::channel(100);
self.sessions.insert(id.clone(), tx);
(id, rx)
}
#[must_use]
pub fn get_receiver(&self, id: &str) -> Option<broadcast::Receiver<String>> {
self.sessions.get(id).map(|tx| tx.subscribe())
}
#[must_use]
pub fn send_to_session(&self, id: &str, message: String) -> bool {
if let Some(tx) = self.sessions.get(id) {
let _ = tx.send(message);
true
} else {
false
}
}
pub fn broadcast(&self, message: String) {
for entry in &self.sessions {
let _ = entry.value().send(message.clone());
}
}
pub fn remove_session(&self, id: &str) {
self.sessions.remove(id);
}
#[must_use]
pub fn session_count(&self) -> usize {
self.sessions.len()
}
}
pub struct SessionStore {
sessions: DashMap<String, Session>,
timeout: Duration,
}
impl SessionStore {
#[must_use]
pub fn new(timeout: Duration) -> Self {
Self {
sessions: DashMap::new(),
timeout,
}
}
#[must_use]
pub fn with_default_timeout() -> Self {
Self::new(Duration::from_secs(3600))
}
#[must_use]
pub fn create(&self) -> String {
let id = uuid::Uuid::new_v4().to_string();
self.sessions.insert(id.clone(), Session::new(id.clone()));
id
}
#[must_use]
pub fn get(&self, id: &str) -> Option<Session> {
self.sessions.get(id).map(|r| r.clone())
}
pub fn touch(&self, id: &str) {
if let Some(mut session) = self.sessions.get_mut(id) {
session.touch();
}
}
pub fn update<F>(&self, id: &str, f: F)
where
F: FnOnce(&mut Session),
{
if let Some(mut session) = self.sessions.get_mut(id) {
f(&mut session);
}
}
pub fn cleanup_expired(&self) {
let timeout = self.timeout;
self.sessions.retain(|_, s| !s.is_expired(timeout));
}
#[must_use]
pub fn remove(&self, id: &str) -> Option<Session> {
self.sessions.remove(id).map(|(_, s)| s)
}
#[must_use]
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
let store = Arc::clone(self);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
store.cleanup_expired();
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let session = Session::new("test-123".to_string());
assert_eq!(session.id, "test-123");
assert!(!session.initialized);
assert!(session.client_capabilities.is_none());
}
#[test]
fn test_session_expiry() {
let mut session = Session::new("test".to_string());
assert!(!session.is_expired(Duration::from_secs(60)));
session.last_active = Instant::now()
.checked_sub(Duration::from_secs(120))
.unwrap();
assert!(session.is_expired(Duration::from_secs(60)));
}
#[test]
fn test_session_store() {
let store = SessionStore::new(Duration::from_secs(60));
let id = store.create();
assert!(store.get(&id).is_some());
store.touch(&id);
let _ = store.remove(&id);
assert!(store.get(&id).is_none());
}
#[tokio::test]
async fn test_session_manager() {
let manager = SessionManager::new();
let (id, mut rx) = manager.create_session();
assert!(manager.send_to_session(&id, "test message".to_string()));
let msg = rx.recv().await.expect("Should receive message");
assert_eq!(msg, "test message");
manager.remove_session(&id);
assert!(!manager.send_to_session(&id, "another".to_string()));
}
}