use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use lru::LruCache;
use tokio::sync::RwLock;
use crate::core::constants::DEFAULT_LRU_SIZE;
#[derive(Debug, Clone)]
pub struct PendingRequest {
pub original_id: serde_json::Value,
pub is_initialize: bool,
pub registered_at: Instant,
}
#[derive(Clone)]
pub struct ClientCorrelationStore {
pending_requests: Arc<RwLock<LruCache<String, PendingRequest>>>,
}
impl Default for ClientCorrelationStore {
fn default() -> Self {
Self::new()
}
}
impl ClientCorrelationStore {
pub fn new() -> Self {
Self::with_max_pending(DEFAULT_LRU_SIZE)
}
pub fn with_max_pending(max_pending: usize) -> Self {
Self {
pending_requests: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(max_pending).unwrap_or(NonZeroUsize::new(1).unwrap()),
))),
}
}
pub async fn register(
&self,
event_id: String,
original_id: serde_json::Value,
is_initialize: bool,
) {
self.pending_requests.write().await.push(
event_id,
PendingRequest {
original_id,
is_initialize,
registered_at: Instant::now(),
},
);
}
pub async fn is_initialize_request(&self, event_id: &str) -> bool {
self.pending_requests
.read()
.await
.peek(event_id)
.is_some_and(|r| r.is_initialize)
}
pub async fn contains(&self, event_id: &str) -> bool {
self.pending_requests.read().await.contains(event_id)
}
pub async fn remove(&self, event_id: &str) -> bool {
self.pending_requests.write().await.pop(event_id).is_some()
}
pub async fn get_original_id(&self, event_id: &str) -> Option<serde_json::Value> {
self.pending_requests
.read()
.await
.peek(event_id)
.map(|r| r.original_id.clone())
}
pub async fn count(&self) -> usize {
self.pending_requests.read().await.len()
}
pub async fn sweep_expired(&self, timeout: Duration) -> usize {
let now = Instant::now();
let mut cache = self.pending_requests.write().await;
let mut expired_keys = Vec::new();
for (key, entry) in cache.iter() {
if now.duration_since(entry.registered_at) >= timeout {
expired_keys.push(key.clone());
}
}
let count = expired_keys.len();
for key in expired_keys {
cache.pop(&key);
}
count
}
pub async fn clear(&self) {
self.pending_requests.write().await.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn remove_nonexistent_is_noop() {
let store = ClientCorrelationStore::new();
assert!(!store.remove("nonexistent").await);
assert!(!store.contains("nonexistent").await);
}
#[tokio::test]
async fn contains_after_clear() {
let store = ClientCorrelationStore::new();
store
.register("e1".into(), serde_json::Value::Null, false)
.await;
store
.register("e2".into(), serde_json::Value::Null, false)
.await;
assert!(store.contains("e1").await);
store.clear().await;
assert!(!store.contains("e1").await);
assert!(!store.contains("e2").await);
}
#[tokio::test]
async fn register_and_remove_roundtrip() {
let store = ClientCorrelationStore::new();
store
.register("e1".into(), serde_json::Value::Null, false)
.await;
assert!(store.contains("e1").await);
assert!(store.remove("e1").await);
assert!(!store.contains("e1").await);
}
#[tokio::test]
async fn default_store_is_bounded() {
let store = ClientCorrelationStore::new();
for i in 0..=DEFAULT_LRU_SIZE {
store
.register(format!("e{i}"), serde_json::Value::Null, false)
.await;
}
assert_eq!(store.count().await, DEFAULT_LRU_SIZE);
assert!(!store.contains("e0").await);
assert!(store.contains(&format!("e{DEFAULT_LRU_SIZE}")).await);
}
#[tokio::test]
async fn sweep_expired_removes_only_stale_entries() {
let store = ClientCorrelationStore::new();
store
.register("old".into(), serde_json::json!(1), false)
.await;
tokio::time::sleep(Duration::from_millis(20)).await;
store
.register("fresh".into(), serde_json::json!(2), false)
.await;
let swept = store.sweep_expired(Duration::from_millis(10)).await;
assert_eq!(swept, 1);
assert!(!store.contains("old").await);
assert!(store.contains("fresh").await);
}
#[tokio::test]
async fn sweep_expired_returns_zero_when_nothing_expired() {
let store = ClientCorrelationStore::new();
store
.register("e1".into(), serde_json::Value::Null, false)
.await;
let swept = store.sweep_expired(Duration::from_secs(60)).await;
assert_eq!(swept, 0);
assert!(store.contains("e1").await);
}
}