use dashmap::DashMap;
use mcpkit_core::capability::ClientCapabilities;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::sync::broadcast;
#[derive(Debug, Clone)]
pub struct Session {
pub id: String,
pub created_at: Instant,
pub last_active: Instant,
pub initialized: bool,
pub client_capabilities: Option<ClientCapabilities>,
}
impl Session {
#[must_use]
pub fn new(id: String) -> Self {
let now = Instant::now();
Self {
id,
created_at: now,
last_active: now,
initialized: false,
client_capabilities: None,
}
}
#[must_use]
pub fn is_expired(&self, timeout: Duration) -> bool {
self.last_active.elapsed() >= timeout
}
pub fn touch(&mut self) {
self.last_active = Instant::now();
}
pub fn mark_initialized(&mut self, capabilities: Option<ClientCapabilities>) {
self.initialized = true;
self.client_capabilities = capabilities;
}
}
#[derive(Debug, Clone)]
pub struct StoredEvent {
pub id: String,
pub event_type: String,
pub data: String,
pub stored_at: Instant,
}
impl StoredEvent {
#[must_use]
pub fn new(id: String, event_type: impl Into<String>, data: impl Into<String>) -> Self {
Self {
id,
event_type: event_type.into(),
data: data.into(),
stored_at: Instant::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct EventStoreConfig {
pub max_events: usize,
pub max_age: Duration,
}
impl Default for EventStoreConfig {
fn default() -> Self {
Self {
max_events: 1000,
max_age: Duration::from_secs(300), }
}
}
impl EventStoreConfig {
#[must_use]
pub const fn new(max_events: usize, max_age: Duration) -> Self {
Self {
max_events,
max_age,
}
}
#[must_use]
pub const fn with_max_events(mut self, max_events: usize) -> Self {
self.max_events = max_events;
self
}
#[must_use]
pub const fn with_max_age(mut self, max_age: Duration) -> Self {
self.max_age = max_age;
self
}
}
#[derive(Debug)]
pub struct EventStore {
events: RwLock<VecDeque<StoredEvent>>,
config: EventStoreConfig,
next_id: AtomicU64,
}
impl EventStore {
#[must_use]
pub fn new(config: EventStoreConfig) -> Self {
Self {
events: RwLock::new(VecDeque::with_capacity(config.max_events)),
config,
next_id: AtomicU64::new(1),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(EventStoreConfig::default())
}
#[must_use]
pub fn next_event_id(&self) -> String {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
format!("evt-{id}")
}
pub fn store_auto_id(&self, event_type: impl Into<String>, data: impl Into<String>) -> String {
let id = self.next_event_id();
self.store(id.clone(), event_type, data);
id
}
pub fn store(
&self,
id: impl Into<String>,
event_type: impl Into<String>,
data: impl Into<String>,
) {
let event = StoredEvent::new(id.into(), event_type, data);
let mut events = futures::executor::block_on(self.events.write());
events.push_back(event);
while events.len() > self.config.max_events {
events.pop_front();
}
let now = Instant::now();
while let Some(front) = events.front() {
if now.duration_since(front.stored_at) > self.config.max_age {
events.pop_front();
} else {
break;
}
}
}
pub async fn store_async(
&self,
id: impl Into<String>,
event_type: impl Into<String>,
data: impl Into<String>,
) {
let event = StoredEvent::new(id.into(), event_type, data);
let mut events = self.events.write().await;
events.push_back(event);
while events.len() > self.config.max_events {
events.pop_front();
}
let now = Instant::now();
while let Some(front) = events.front() {
if now.duration_since(front.stored_at) > self.config.max_age {
events.pop_front();
} else {
break;
}
}
}
pub async fn get_events_after(&self, last_event_id: &str) -> Vec<StoredEvent> {
let events = self.events.read().await;
let start_idx = events
.iter()
.position(|e| e.id == last_event_id)
.map_or(0, |i| i + 1);
events.iter().skip(start_idx).cloned().collect()
}
pub async fn get_all_events(&self) -> Vec<StoredEvent> {
let events = self.events.read().await;
events.iter().cloned().collect()
}
pub async fn len(&self) -> usize {
self.events.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.events.read().await.is_empty()
}
pub async fn clear(&self) {
self.events.write().await.clear();
}
pub async fn cleanup_expired(&self) {
let mut events = self.events.write().await;
let now = Instant::now();
while let Some(front) = events.front() {
if now.duration_since(front.stored_at) > self.config.max_age {
events.pop_front();
} else {
break;
}
}
}
}
#[derive(Debug)]
pub struct SessionManager {
sessions: DashMap<String, broadcast::Sender<String>>,
event_stores: DashMap<String, Arc<EventStore>>,
event_store_config: EventStoreConfig,
}
impl Default for SessionManager {
fn default() -> Self {
Self::new()
}
}
impl SessionManager {
#[must_use]
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
event_stores: DashMap::new(),
event_store_config: EventStoreConfig::default(),
}
}
#[must_use]
pub fn with_event_store_config(config: EventStoreConfig) -> Self {
Self {
sessions: DashMap::new(),
event_stores: DashMap::new(),
event_store_config: config,
}
}
#[must_use]
pub fn create_session(&self) -> (String, broadcast::Receiver<String>) {
let id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = broadcast::channel(100);
self.sessions.insert(id.clone(), tx);
let event_store = Arc::new(EventStore::new(self.event_store_config.clone()));
self.event_stores.insert(id.clone(), event_store);
(id, rx)
}
#[must_use]
pub fn get_receiver(&self, id: &str) -> Option<broadcast::Receiver<String>> {
self.sessions.get(id).map(|tx| tx.subscribe())
}
#[must_use]
pub fn get_event_store(&self, id: &str) -> Option<Arc<EventStore>> {
self.event_stores.get(id).map(|store| Arc::clone(&store))
}
#[must_use]
pub fn send_to_session(&self, id: &str, message: String) -> bool {
if let Some(tx) = self.sessions.get(id) {
let _ = tx.send(message);
true
} else {
false
}
}
#[must_use]
pub fn send_to_session_with_storage(
&self,
session_id: &str,
event_type: impl Into<String>,
message: String,
) -> Option<String> {
if let Some(tx) = self.sessions.get(session_id) {
let event_id = if let Some(store) = self.event_stores.get(session_id) {
store.store_auto_id(event_type, message.clone())
} else {
let store = Arc::new(EventStore::new(self.event_store_config.clone()));
let event_id = store.store_auto_id(event_type, message.clone());
self.event_stores.insert(session_id.to_string(), store);
event_id
};
let _ = tx.send(message);
Some(event_id)
} else {
None
}
}
pub fn broadcast(&self, message: String) {
for entry in &self.sessions {
let _ = entry.value().send(message.clone());
}
}
pub fn broadcast_with_storage(&self, event_type: impl Into<String> + Clone, message: String) {
for entry in &self.sessions {
let session_id = entry.key();
if let Some(store) = self.event_stores.get(session_id) {
store.store_auto_id(event_type.clone(), message.clone());
}
let _ = entry.value().send(message.clone());
}
}
pub fn remove_session(&self, id: &str) {
self.sessions.remove(id);
self.event_stores.remove(id);
}
#[must_use]
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub async fn cleanup_expired_events(&self) {
for entry in &self.event_stores {
entry.value().cleanup_expired().await;
}
}
pub async fn get_events_for_replay(
&self,
session_id: &str,
last_event_id: &str,
) -> Option<Vec<StoredEvent>> {
if let Some(store) = self.event_stores.get(session_id) {
Some(store.get_events_after(last_event_id).await)
} else {
None
}
}
}
#[derive(Debug)]
pub struct SessionStore {
sessions: DashMap<String, Session>,
timeout: Duration,
}
impl SessionStore {
#[must_use]
pub fn new(timeout: Duration) -> Self {
Self {
sessions: DashMap::new(),
timeout,
}
}
#[must_use]
pub fn with_default_timeout() -> Self {
Self::new(Duration::from_secs(3600))
}
#[must_use]
pub fn create(&self) -> String {
let id = uuid::Uuid::new_v4().to_string();
self.sessions.insert(id.clone(), Session::new(id.clone()));
id
}
#[must_use]
pub fn get(&self, id: &str) -> Option<Session> {
self.sessions.get(id).map(|r| r.clone())
}
pub fn touch(&self, id: &str) {
if let Some(mut session) = self.sessions.get_mut(id) {
session.touch();
}
}
pub fn update<F>(&self, id: &str, f: F)
where
F: FnOnce(&mut Session),
{
if let Some(mut session) = self.sessions.get_mut(id) {
f(&mut session);
}
}
pub fn cleanup_expired(&self) {
let timeout = self.timeout;
self.sessions.retain(|_, s| !s.is_expired(timeout));
}
#[must_use]
pub fn remove(&self, id: &str) -> Option<Session> {
self.sessions.remove(id).map(|(_, s)| s)
}
#[must_use]
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
let store = Arc::clone(self);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
store.cleanup_expired();
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let session = Session::new("test-123".to_string());
assert_eq!(session.id, "test-123");
assert!(!session.initialized);
assert!(session.client_capabilities.is_none());
}
#[test]
fn test_session_expiry() -> Result<(), Box<dyn std::error::Error>> {
let mut session = Session::new("test".to_string());
assert!(!session.is_expired(Duration::from_secs(60)));
session.last_active = Instant::now()
.checked_sub(Duration::from_secs(120))
.ok_or("Failed to subtract duration")?;
assert!(session.is_expired(Duration::from_secs(60)));
Ok(())
}
#[test]
fn test_session_store() {
let store = SessionStore::new(Duration::from_secs(60));
let id = store.create();
assert!(store.get(&id).is_some());
store.touch(&id);
let _ = store.remove(&id);
assert!(store.get(&id).is_none());
}
#[tokio::test]
async fn test_session_manager() -> Result<(), Box<dyn std::error::Error>> {
let manager = SessionManager::new();
let (id, mut rx) = manager.create_session();
assert!(manager.send_to_session(&id, "test message".to_string()));
let msg = rx.recv().await?;
assert_eq!(msg, "test message");
manager.remove_session(&id);
assert!(!manager.send_to_session(&id, "another".to_string()));
Ok(())
}
#[tokio::test]
async fn test_event_store_creation() {
let store = EventStore::with_defaults();
assert!(store.is_empty().await);
assert_eq!(store.len().await, 0);
}
#[tokio::test]
async fn test_event_store_store_and_retrieve() {
let store = EventStore::with_defaults();
store.store_async("evt-1", "message", "data1").await;
store.store_async("evt-2", "message", "data2").await;
store.store_async("evt-3", "message", "data3").await;
assert_eq!(store.len().await, 3);
let all_events = store.get_all_events().await;
assert_eq!(all_events.len(), 3);
assert_eq!(all_events[0].id, "evt-1");
assert_eq!(all_events[1].id, "evt-2");
assert_eq!(all_events[2].id, "evt-3");
}
#[tokio::test]
async fn test_event_store_get_events_after() {
let store = EventStore::with_defaults();
store.store_async("evt-1", "message", "data1").await;
store.store_async("evt-2", "message", "data2").await;
store.store_async("evt-3", "message", "data3").await;
let events = store.get_events_after("evt-1").await;
assert_eq!(events.len(), 2);
assert_eq!(events[0].id, "evt-2");
assert_eq!(events[1].id, "evt-3");
let events = store.get_events_after("evt-2").await;
assert_eq!(events.len(), 1);
assert_eq!(events[0].id, "evt-3");
let events = store.get_events_after("evt-3").await;
assert_eq!(events.len(), 0);
let events = store.get_events_after("unknown").await;
assert_eq!(events.len(), 3);
}
#[tokio::test]
async fn test_event_store_auto_id() {
let store = EventStore::with_defaults();
let id1 = store.store_auto_id("message", "data1");
let id2 = store.store_auto_id("message", "data2");
assert!(id1.starts_with("evt-"));
assert!(id2.starts_with("evt-"));
assert_ne!(id1, id2);
assert_eq!(store.len().await, 2);
}
#[tokio::test]
async fn test_event_store_max_events_limit() {
let config = EventStoreConfig::new(3, Duration::from_secs(300));
let store = EventStore::new(config);
store.store_async("evt-1", "message", "data1").await;
store.store_async("evt-2", "message", "data2").await;
store.store_async("evt-3", "message", "data3").await;
store.store_async("evt-4", "message", "data4").await;
assert_eq!(store.len().await, 3);
let events = store.get_all_events().await;
assert_eq!(events[0].id, "evt-2"); assert_eq!(events[1].id, "evt-3");
assert_eq!(events[2].id, "evt-4");
}
#[tokio::test]
async fn test_event_store_clear() {
let store = EventStore::with_defaults();
store.store_async("evt-1", "message", "data1").await;
store.store_async("evt-2", "message", "data2").await;
assert_eq!(store.len().await, 2);
store.clear().await;
assert!(store.is_empty().await);
assert_eq!(store.len().await, 0);
}
#[tokio::test]
async fn test_session_manager_with_event_store() -> Result<(), Box<dyn std::error::Error>> {
let manager = SessionManager::new();
let (id, _rx) = manager.create_session();
let store = manager.get_event_store(&id);
assert!(store.is_some());
let store = store.ok_or("Event store not found")?;
assert!(store.is_empty().await);
Ok(())
}
#[tokio::test]
async fn test_session_manager_send_with_storage() -> Result<(), Box<dyn std::error::Error>> {
let manager = SessionManager::new();
let (id, mut rx) = manager.create_session();
let event_id =
manager.send_to_session_with_storage(&id, "message", "test data".to_string());
assert!(event_id.is_some());
let msg = rx.recv().await?;
assert_eq!(msg, "test data");
let store = manager
.get_event_store(&id)
.ok_or("Event store not found")?;
assert_eq!(store.len().await, 1);
let events = store.get_all_events().await;
assert_eq!(events[0].data, "test data");
assert_eq!(events[0].event_type, "message");
Ok(())
}
#[tokio::test]
async fn test_session_manager_replay() -> Result<(), Box<dyn std::error::Error>> {
let manager = SessionManager::new();
let (id, _rx) = manager.create_session();
let _ = manager.send_to_session_with_storage(&id, "message", "msg1".to_string());
let evt2 = manager.send_to_session_with_storage(&id, "message", "msg2".to_string());
let _ = manager.send_to_session_with_storage(&id, "message", "msg3".to_string());
let events = manager
.get_events_for_replay(&id, &evt2.ok_or("Failed to get event ID")?)
.await
.ok_or("Failed to get events for replay")?;
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "msg3");
Ok(())
}
#[test]
fn test_event_store_config() {
let config = EventStoreConfig::default();
assert_eq!(config.max_events, 1000);
assert_eq!(config.max_age, Duration::from_secs(300));
let config = EventStoreConfig::new(500, Duration::from_secs(120))
.with_max_events(600)
.with_max_age(Duration::from_secs(180));
assert_eq!(config.max_events, 600);
assert_eq!(config.max_age, Duration::from_secs(180));
}
#[test]
fn test_stored_event() {
let event = StoredEvent::new("evt-123".to_string(), "message", "test data");
assert_eq!(event.id, "evt-123");
assert_eq!(event.event_type, "message");
assert_eq!(event.data, "test data");
}
}