use std::collections::{HashMap, HashSet};
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use lru::LruCache;
use tokio::sync::RwLock;
use crate::core::constants::DEFAULT_LRU_SIZE;
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub client_pubkey: String,
pub original_request_id: serde_json::Value,
pub progress_token: Option<String>,
pub wrap_kind: Option<u16>,
pub registered_at: Instant,
}
struct Inner {
routes: LruCache<String, RouteEntry>,
progress_token_to_event: HashMap<String, String>,
client_event_ids: HashMap<String, HashSet<String>>,
}
impl Inner {
fn new(max_routes: usize) -> Self {
let routes =
LruCache::new(NonZeroUsize::new(max_routes).unwrap_or(NonZeroUsize::new(1).unwrap()));
Self {
routes,
progress_token_to_event: HashMap::new(),
client_event_ids: HashMap::new(),
}
}
fn cleanup_indexes(&mut self, event_id: &str, route: &RouteEntry) {
if let Some(ref token) = route.progress_token {
self.progress_token_to_event.remove(token);
}
if let Some(set) = self.client_event_ids.get_mut(&route.client_pubkey) {
set.remove(event_id);
if set.is_empty() {
self.client_event_ids.remove(&route.client_pubkey);
}
}
}
fn remove_route(&mut self, event_id: &str) -> Option<RouteEntry> {
let route = self.routes.pop(event_id)?;
self.cleanup_indexes(event_id, &route);
Some(route)
}
}
#[derive(Clone)]
pub struct ServerEventRouteStore {
inner: Arc<RwLock<Inner>>,
}
impl Default for ServerEventRouteStore {
fn default() -> Self {
Self::new()
}
}
impl ServerEventRouteStore {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(Inner::new(DEFAULT_LRU_SIZE))),
}
}
pub fn with_max_routes(max_routes: usize) -> Self {
Self {
inner: Arc::new(RwLock::new(Inner::new(max_routes))),
}
}
pub async fn register(
&self,
event_id: String,
client_pubkey: String,
original_request_id: serde_json::Value,
progress_token: Option<String>,
) {
let mut inner = self.inner.write().await;
inner
.client_event_ids
.entry(client_pubkey.clone())
.or_default()
.insert(event_id.clone());
if let Some(ref token) = progress_token {
inner
.progress_token_to_event
.insert(token.clone(), event_id.clone());
}
let evicted = inner.routes.push(
event_id.clone(),
RouteEntry {
client_pubkey,
original_request_id,
progress_token,
wrap_kind: None,
registered_at: Instant::now(),
},
);
if let Some((evicted_key, evicted_route)) = evicted {
if evicted_key != event_id {
inner.cleanup_indexes(&evicted_key, &evicted_route);
}
}
}
pub async fn get(&self, event_id: &str) -> Option<String> {
self.inner
.read()
.await
.routes
.peek(event_id)
.map(|r| r.client_pubkey.clone())
}
pub async fn get_route(&self, event_id: &str) -> Option<RouteEntry> {
self.inner.read().await.routes.peek(event_id).cloned()
}
pub async fn pop(&self, event_id: &str) -> Option<RouteEntry> {
self.inner.write().await.remove_route(event_id)
}
pub async fn remove_for_client(&self, client_pubkey: &str) -> usize {
let mut inner = self.inner.write().await;
let event_ids = match inner.client_event_ids.remove(client_pubkey) {
Some(ids) => ids,
None => return 0,
};
let count = event_ids.len();
for event_id in &event_ids {
if let Some(route) = inner.routes.pop(event_id.as_str()) {
if let Some(ref token) = route.progress_token {
inner.progress_token_to_event.remove(token);
}
}
}
count
}
pub async fn has_event_route(&self, event_id: &str) -> bool {
self.inner.read().await.routes.contains(event_id)
}
pub async fn has_active_routes_for_client(&self, client_pubkey: &str) -> bool {
self.inner
.read()
.await
.client_event_ids
.get(client_pubkey)
.is_some_and(|set| !set.is_empty())
}
pub async fn get_event_id_by_progress_token(&self, token: &str) -> Option<String> {
self.inner
.read()
.await
.progress_token_to_event
.get(token)
.cloned()
}
pub async fn has_progress_token(&self, token: &str) -> bool {
self.inner
.read()
.await
.progress_token_to_event
.contains_key(token)
}
pub async fn event_route_count(&self) -> usize {
self.inner.read().await.routes.len()
}
pub async fn progress_token_count(&self) -> usize {
self.inner.read().await.progress_token_to_event.len()
}
pub async fn sweep_stale_routes(&self, timeout: Duration) -> Vec<String> {
let now = Instant::now();
let mut inner = self.inner.write().await;
let mut expired_keys = Vec::new();
for (key, entry) in inner.routes.iter() {
if now.duration_since(entry.registered_at) >= timeout {
expired_keys.push(key.clone());
}
}
for key in &expired_keys {
inner.remove_route(key);
}
expired_keys
}
pub async fn clear(&self) {
let mut inner = self.inner.write().await;
inner.routes.clear();
inner.progress_token_to_event.clear();
inner.client_event_ids.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn pop_on_empty_returns_none() {
let store = ServerEventRouteStore::new();
assert!(store.pop("nonexistent").await.is_none());
}
#[tokio::test]
async fn get_returns_without_removing() {
let store = ServerEventRouteStore::new();
store
.register("e1".into(), "pk1".into(), json!("r1"), None)
.await;
assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
}
#[tokio::test]
async fn pop_removes_entry() {
let store = ServerEventRouteStore::new();
store
.register("e1".into(), "pk1".into(), json!("r1"), None)
.await;
let route = store.pop("e1").await.unwrap();
assert_eq!(route.client_pubkey, "pk1");
assert!(store.pop("e1").await.is_none());
}
#[tokio::test]
async fn remove_for_client_only_removes_matching() {
let store = ServerEventRouteStore::new();
store
.register("e1".into(), "pk1".into(), json!("r1"), None)
.await;
store
.register("e2".into(), "pk2".into(), json!("r2"), None)
.await;
store
.register("e3".into(), "pk1".into(), json!("r3"), None)
.await;
let removed = store.remove_for_client("pk1").await;
assert_eq!(removed, 2);
assert!(store.get("e1").await.is_none());
assert!(store.get("e3").await.is_none());
assert_eq!(store.get("e2").await.as_deref(), Some("pk2"));
}
#[tokio::test]
async fn remove_for_client_noop_when_no_match() {
let store = ServerEventRouteStore::new();
store
.register("e1".into(), "pk1".into(), json!("r1"), None)
.await;
let removed = store.remove_for_client("pk_other").await;
assert_eq!(removed, 0);
assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
}
#[tokio::test]
async fn clear_empties_store() {
let store = ServerEventRouteStore::new();
store
.register("e1".into(), "pk1".into(), json!("r1"), None)
.await;
store
.register("e2".into(), "pk2".into(), json!("r2"), None)
.await;
store.clear().await;
assert!(store.get("e1").await.is_none());
assert!(store.get("e2").await.is_none());
}
#[tokio::test]
async fn default_store_is_bounded() {
let store = ServerEventRouteStore::new();
for i in 0..=DEFAULT_LRU_SIZE {
store
.register(format!("e{i}"), "pk1".into(), json!(i), None)
.await;
}
assert_eq!(store.event_route_count().await, DEFAULT_LRU_SIZE);
assert!(!store.has_event_route("e0").await);
assert!(store.has_event_route(&format!("e{DEFAULT_LRU_SIZE}")).await);
}
#[tokio::test]
async fn sweep_stale_routes_removes_only_expired() {
let store = ServerEventRouteStore::new();
store
.register("old".into(), "pk1".into(), json!(1), Some("tok1".into()))
.await;
tokio::time::sleep(Duration::from_millis(20)).await;
store
.register("fresh".into(), "pk2".into(), json!(2), None)
.await;
let swept = store.sweep_stale_routes(Duration::from_millis(10)).await;
assert_eq!(swept.len(), 1);
assert_eq!(swept[0], "old");
assert!(!store.has_event_route("old").await);
assert!(store.has_event_route("fresh").await);
assert!(!store.has_progress_token("tok1").await);
assert!(!store.has_active_routes_for_client("pk1").await);
}
#[tokio::test]
async fn sweep_stale_routes_returns_zero_when_nothing_expired() {
let store = ServerEventRouteStore::new();
store
.register("e1".into(), "pk1".into(), json!(1), None)
.await;
let swept = store.sweep_stale_routes(Duration::from_secs(60)).await;
assert!(swept.is_empty());
assert!(store.has_event_route("e1").await);
}
}