use crate::error::{Error, Result};
use crate::shared::TransportMessage;
use crate::types::RequestId;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use uuid::Uuid;
#[async_trait]
pub trait EventStore: Send + Sync {
async fn store_event(&self, event: StoredEvent) -> Result<()>;
async fn get_events_since(
&self,
event_id: &str,
limit: Option<usize>,
) -> Result<Vec<StoredEvent>>;
async fn get_latest_event_id(&self) -> Result<Option<String>>;
async fn clear_events_before(&self, timestamp: DateTime<Utc>) -> Result<usize>;
async fn create_resumption_token(&self) -> Result<ResumptionToken>;
async fn validate_resumption_token(&self, token: &str) -> Result<Option<ResumptionState>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredEvent {
pub id: String,
pub timestamp: DateTime<Utc>,
pub message: TransportMessage,
pub direction: MessageDirection,
pub session_id: String,
pub sequence: u64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum MessageDirection {
Inbound,
Outbound,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResumptionToken {
pub token: String,
pub session_id: String,
pub last_event_id: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResumptionState {
pub session_id: String,
pub last_event_id: String,
pub pending_requests: Vec<RequestId>,
pub next_sequence: u64,
}
#[derive(Debug)]
pub struct InMemoryEventStore {
events: Arc<RwLock<VecDeque<StoredEvent>>>,
tokens: Arc<RwLock<HashMap<String, ResumptionState>>>,
max_events: usize,
max_age: chrono::Duration,
}
impl InMemoryEventStore {
pub fn new(max_events: usize, max_age: chrono::Duration) -> Self {
Self {
events: Arc::new(RwLock::new(VecDeque::new())),
tokens: Arc::new(RwLock::new(HashMap::new())),
max_events,
max_age,
}
}
fn cleanup(&self) {
let mut events = self.events.write();
while events.len() > self.max_events {
events.pop_front();
}
let cutoff = Utc::now() - self.max_age;
while let Some(event) = events.front() {
if event.timestamp < cutoff {
events.pop_front();
} else {
break;
}
}
}
}
#[async_trait]
impl EventStore for InMemoryEventStore {
async fn store_event(&self, event: StoredEvent) -> Result<()> {
{
let mut events = self.events.write();
events.push_back(event);
}
self.cleanup();
Ok(())
}
async fn get_events_since(
&self,
event_id: &str,
limit: Option<usize>,
) -> Result<Vec<StoredEvent>> {
let events = self.events.read();
let start_idx = events
.iter()
.position(|e| e.id == event_id)
.map_or(0, |idx| idx + 1);
let limit = limit.unwrap_or(usize::MAX);
Ok(events.iter().skip(start_idx).take(limit).cloned().collect())
}
async fn get_latest_event_id(&self) -> Result<Option<String>> {
let events = self.events.read();
Ok(events.back().map(|e| e.id.clone()))
}
async fn clear_events_before(&self, timestamp: DateTime<Utc>) -> Result<usize> {
let mut events = self.events.write();
let initial_len = events.len();
while let Some(event) = events.front() {
if event.timestamp < timestamp {
events.pop_front();
} else {
break;
}
}
Ok(initial_len - events.len())
}
async fn create_resumption_token(&self) -> Result<ResumptionToken> {
let token_id = Uuid::new_v4().to_string();
let events = self.events.read();
let last_event = events
.back()
.ok_or_else(|| Error::Internal("No events to create resumption token".into()))?;
let state = ResumptionState {
session_id: last_event.session_id.clone(),
last_event_id: last_event.id.clone(),
pending_requests: Vec::new(),
next_sequence: last_event.sequence + 1,
};
self.tokens.write().insert(token_id.clone(), state);
Ok(ResumptionToken {
token: token_id,
session_id: last_event.session_id.clone(),
last_event_id: last_event.id.clone(),
created_at: Utc::now(),
expires_at: Utc::now() + chrono::Duration::hours(24),
metadata: HashMap::new(),
})
}
async fn validate_resumption_token(&self, token: &str) -> Result<Option<ResumptionState>> {
let tokens = self.tokens.read();
Ok(tokens.get(token).cloned())
}
}
pub struct ResumptionManager {
event_store: Arc<dyn EventStore>,
session_id: String,
sequence_counter: Arc<RwLock<u64>>,
}
impl std::fmt::Debug for ResumptionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResumptionManager")
.field("session_id", &self.session_id)
.field("sequence_counter", &self.sequence_counter)
.finish_non_exhaustive()
}
}
impl ResumptionManager {
pub fn new(event_store: Arc<dyn EventStore>, session_id: String) -> Self {
Self {
event_store,
session_id,
sequence_counter: Arc::new(RwLock::new(0)),
}
}
pub async fn record_outbound(&self, message: TransportMessage) -> Result<()> {
let sequence = {
let mut counter = self.sequence_counter.write();
let seq = *counter;
*counter += 1;
seq
};
let event = StoredEvent {
id: Uuid::new_v4().to_string(),
timestamp: Utc::now(),
message,
direction: MessageDirection::Outbound,
session_id: self.session_id.clone(),
sequence,
};
self.event_store.store_event(event).await
}
pub async fn record_inbound(&self, message: TransportMessage) -> Result<()> {
let sequence = {
let mut counter = self.sequence_counter.write();
let seq = *counter;
*counter += 1;
seq
};
let event = StoredEvent {
id: Uuid::new_v4().to_string(),
timestamp: Utc::now(),
message,
direction: MessageDirection::Inbound,
session_id: self.session_id.clone(),
sequence,
};
self.event_store.store_event(event).await
}
pub async fn create_token(&self) -> Result<ResumptionToken> {
self.event_store.create_resumption_token().await
}
pub async fn resume_from_token(&self, token: &str) -> Result<Vec<StoredEvent>> {
let state = self
.event_store
.validate_resumption_token(token)
.await?
.ok_or_else(|| Error::Internal("Invalid or expired resumption token".into()))?;
*self.sequence_counter.write() = state.next_sequence;
self.event_store
.get_events_since(&state.last_event_id, None)
.await
}
pub async fn get_pending_events(&self, since_event_id: &str) -> Result<Vec<StoredEvent>> {
let events = self
.event_store
.get_events_since(since_event_id, None)
.await?;
Ok(events
.into_iter()
.filter(|e| e.direction == MessageDirection::Outbound)
.collect())
}
}
#[derive(Debug, Clone)]
pub struct EventStoreConfig {
pub max_events: usize,
pub max_age: chrono::Duration,
pub auto_cleanup: bool,
pub cleanup_interval: std::time::Duration,
}
impl Default for EventStoreConfig {
fn default() -> Self {
Self {
max_events: 10000,
max_age: chrono::Duration::hours(24),
auto_cleanup: true,
cleanup_interval: std::time::Duration::from_secs(300),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_event_store() {
let store = InMemoryEventStore::new(100, chrono::Duration::hours(1));
let event = StoredEvent {
id: "test-1".to_string(),
timestamp: Utc::now(),
message: TransportMessage::Request {
id: RequestId::String("req-1".to_string()),
request: crate::types::Request::Client(Box::new(
crate::types::ClientRequest::Initialize(crate::types::InitializeRequest {
protocol_version: crate::DEFAULT_PROTOCOL_VERSION.to_string(),
capabilities: crate::types::ClientCapabilities::default(),
client_info: crate::types::Implementation::new("test", "1.0.0"),
}),
)),
},
direction: MessageDirection::Outbound,
session_id: "session-1".to_string(),
sequence: 0,
};
store.store_event(event.clone()).await.unwrap();
let events = store
.get_events_since("non-existent", Some(10))
.await
.unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].id, "test-1");
let latest_id = store.get_latest_event_id().await.unwrap();
assert_eq!(latest_id, Some("test-1".to_string()));
}
#[tokio::test]
async fn test_resumption_manager() {
let store = Arc::new(InMemoryEventStore::new(100, chrono::Duration::hours(1)));
let manager = ResumptionManager::new(store.clone(), "session-1".to_string());
let msg = TransportMessage::Request {
id: RequestId::String("req-1".to_string()),
request: crate::types::Request::Client(Box::new(
crate::types::ClientRequest::Initialize(crate::types::InitializeRequest {
protocol_version: "2024-11-05".to_string(),
capabilities: crate::types::ClientCapabilities::default(),
client_info: crate::types::Implementation::new("test", "1.0.0"),
}),
)),
};
manager.record_outbound(msg.clone()).await.unwrap();
manager.record_inbound(msg).await.unwrap();
let token = manager.create_token().await.unwrap();
assert!(!token.token.is_empty());
let events = manager.resume_from_token(&token.token).await.unwrap();
assert!(events.is_empty()); }
}