contextvm_sdk/transport/server/
correlation_store.rs1use std::collections::{HashMap, HashSet};
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use lru::LruCache;
9use tokio::sync::RwLock;
10
11use crate::core::constants::DEFAULT_LRU_SIZE;
12
13#[derive(Debug, Clone)]
15pub struct RouteEntry {
16 pub client_pubkey: String,
18 pub original_request_id: serde_json::Value,
20 pub progress_token: Option<String>,
22 pub wrap_kind: Option<u16>,
25 pub registered_at: Instant,
27}
28
29struct Inner {
31 routes: LruCache<String, RouteEntry>,
33 progress_token_to_event: HashMap<String, String>,
35 client_event_ids: HashMap<String, HashSet<String>>,
37}
38
39impl Inner {
40 fn new(max_routes: usize) -> Self {
41 let routes =
42 LruCache::new(NonZeroUsize::new(max_routes).unwrap_or(NonZeroUsize::new(1).unwrap()));
43 Self {
44 routes,
45 progress_token_to_event: HashMap::new(),
46 client_event_ids: HashMap::new(),
47 }
48 }
49
50 fn cleanup_indexes(&mut self, event_id: &str, route: &RouteEntry) {
52 if let Some(ref token) = route.progress_token {
53 self.progress_token_to_event.remove(token);
54 }
55 if let Some(set) = self.client_event_ids.get_mut(&route.client_pubkey) {
56 set.remove(event_id);
57 if set.is_empty() {
58 self.client_event_ids.remove(&route.client_pubkey);
59 }
60 }
61 }
62
63 fn remove_route(&mut self, event_id: &str) -> Option<RouteEntry> {
65 let route = self.routes.pop(event_id)?;
66 self.cleanup_indexes(event_id, &route);
67 Some(route)
68 }
69}
70
71#[derive(Clone)]
76pub struct ServerEventRouteStore {
77 inner: Arc<RwLock<Inner>>,
78}
79
80impl Default for ServerEventRouteStore {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl ServerEventRouteStore {
87 pub fn new() -> Self {
88 Self {
89 inner: Arc::new(RwLock::new(Inner::new(DEFAULT_LRU_SIZE))),
90 }
91 }
92
93 pub fn with_max_routes(max_routes: usize) -> Self {
96 Self {
97 inner: Arc::new(RwLock::new(Inner::new(max_routes))),
98 }
99 }
100
101 pub async fn register(
103 &self,
104 event_id: String,
105 client_pubkey: String,
106 original_request_id: serde_json::Value,
107 progress_token: Option<String>,
108 ) {
109 let mut inner = self.inner.write().await;
110
111 inner
113 .client_event_ids
114 .entry(client_pubkey.clone())
115 .or_default()
116 .insert(event_id.clone());
117
118 if let Some(ref token) = progress_token {
120 inner
121 .progress_token_to_event
122 .insert(token.clone(), event_id.clone());
123 }
124
125 let evicted = inner.routes.push(
127 event_id.clone(),
128 RouteEntry {
129 client_pubkey,
130 original_request_id,
131 progress_token,
132 wrap_kind: None,
133 registered_at: Instant::now(),
134 },
135 );
136
137 if let Some((evicted_key, evicted_route)) = evicted {
138 if evicted_key != event_id {
139 inner.cleanup_indexes(&evicted_key, &evicted_route);
141 }
142 }
143 }
144
145 pub async fn get(&self, event_id: &str) -> Option<String> {
147 self.inner
148 .read()
149 .await
150 .routes
151 .peek(event_id)
152 .map(|r| r.client_pubkey.clone())
153 }
154
155 pub async fn get_route(&self, event_id: &str) -> Option<RouteEntry> {
157 self.inner.read().await.routes.peek(event_id).cloned()
158 }
159
160 pub async fn pop(&self, event_id: &str) -> Option<RouteEntry> {
162 self.inner.write().await.remove_route(event_id)
163 }
164
165 pub async fn remove_for_client(&self, client_pubkey: &str) -> usize {
167 let mut inner = self.inner.write().await;
168
169 let event_ids = match inner.client_event_ids.remove(client_pubkey) {
170 Some(ids) => ids,
171 None => return 0,
172 };
173
174 let count = event_ids.len();
175 for event_id in &event_ids {
176 if let Some(route) = inner.routes.pop(event_id.as_str()) {
177 if let Some(ref token) = route.progress_token {
178 inner.progress_token_to_event.remove(token);
179 }
180 }
181 }
182 count
183 }
184
185 pub async fn has_event_route(&self, event_id: &str) -> bool {
187 self.inner.read().await.routes.contains(event_id)
188 }
189
190 pub async fn has_active_routes_for_client(&self, client_pubkey: &str) -> bool {
192 self.inner
193 .read()
194 .await
195 .client_event_ids
196 .get(client_pubkey)
197 .is_some_and(|set| !set.is_empty())
198 }
199
200 pub async fn get_event_id_by_progress_token(&self, token: &str) -> Option<String> {
202 self.inner
203 .read()
204 .await
205 .progress_token_to_event
206 .get(token)
207 .cloned()
208 }
209
210 pub async fn has_progress_token(&self, token: &str) -> bool {
212 self.inner
213 .read()
214 .await
215 .progress_token_to_event
216 .contains_key(token)
217 }
218
219 pub async fn event_route_count(&self) -> usize {
221 self.inner.read().await.routes.len()
222 }
223
224 pub async fn progress_token_count(&self) -> usize {
226 self.inner.read().await.progress_token_to_event.len()
227 }
228
229 pub async fn sweep_stale_routes(&self, timeout: Duration) -> Vec<String> {
233 let now = Instant::now();
234 let mut inner = self.inner.write().await;
235 let mut expired_keys = Vec::new();
236
237 for (key, entry) in inner.routes.iter() {
238 if now.duration_since(entry.registered_at) >= timeout {
239 expired_keys.push(key.clone());
240 }
241 }
242
243 for key in &expired_keys {
244 inner.remove_route(key);
245 }
246 expired_keys
247 }
248
249 pub async fn clear(&self) {
250 let mut inner = self.inner.write().await;
251 inner.routes.clear();
252 inner.progress_token_to_event.clear();
253 inner.client_event_ids.clear();
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use serde_json::json;
261
262 #[tokio::test]
263 async fn pop_on_empty_returns_none() {
264 let store = ServerEventRouteStore::new();
265 assert!(store.pop("nonexistent").await.is_none());
266 }
267
268 #[tokio::test]
269 async fn get_returns_without_removing() {
270 let store = ServerEventRouteStore::new();
271 store
272 .register("e1".into(), "pk1".into(), json!("r1"), None)
273 .await;
274 assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
275 assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
276 }
277
278 #[tokio::test]
279 async fn pop_removes_entry() {
280 let store = ServerEventRouteStore::new();
281 store
282 .register("e1".into(), "pk1".into(), json!("r1"), None)
283 .await;
284 let route = store.pop("e1").await.unwrap();
285 assert_eq!(route.client_pubkey, "pk1");
286 assert!(store.pop("e1").await.is_none());
287 }
288
289 #[tokio::test]
290 async fn remove_for_client_only_removes_matching() {
291 let store = ServerEventRouteStore::new();
292 store
293 .register("e1".into(), "pk1".into(), json!("r1"), None)
294 .await;
295 store
296 .register("e2".into(), "pk2".into(), json!("r2"), None)
297 .await;
298 store
299 .register("e3".into(), "pk1".into(), json!("r3"), None)
300 .await;
301
302 let removed = store.remove_for_client("pk1").await;
303 assert_eq!(removed, 2);
304
305 assert!(store.get("e1").await.is_none());
306 assert!(store.get("e3").await.is_none());
307 assert_eq!(store.get("e2").await.as_deref(), Some("pk2"));
308 }
309
310 #[tokio::test]
311 async fn remove_for_client_noop_when_no_match() {
312 let store = ServerEventRouteStore::new();
313 store
314 .register("e1".into(), "pk1".into(), json!("r1"), None)
315 .await;
316 let removed = store.remove_for_client("pk_other").await;
317 assert_eq!(removed, 0);
318 assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
319 }
320
321 #[tokio::test]
322 async fn clear_empties_store() {
323 let store = ServerEventRouteStore::new();
324 store
325 .register("e1".into(), "pk1".into(), json!("r1"), None)
326 .await;
327 store
328 .register("e2".into(), "pk2".into(), json!("r2"), None)
329 .await;
330 store.clear().await;
331 assert!(store.get("e1").await.is_none());
332 assert!(store.get("e2").await.is_none());
333 }
334
335 #[tokio::test]
336 async fn default_store_is_bounded() {
337 let store = ServerEventRouteStore::new();
338 for i in 0..=DEFAULT_LRU_SIZE {
339 store
340 .register(format!("e{i}"), "pk1".into(), json!(i), None)
341 .await;
342 }
343
344 assert_eq!(store.event_route_count().await, DEFAULT_LRU_SIZE);
345 assert!(!store.has_event_route("e0").await);
346 assert!(store.has_event_route(&format!("e{DEFAULT_LRU_SIZE}")).await);
347 }
348
349 #[tokio::test]
350 async fn sweep_stale_routes_removes_only_expired() {
351 let store = ServerEventRouteStore::new();
352
353 store
355 .register("old".into(), "pk1".into(), json!(1), Some("tok1".into()))
356 .await;
357
358 tokio::time::sleep(Duration::from_millis(20)).await;
359
360 store
362 .register("fresh".into(), "pk2".into(), json!(2), None)
363 .await;
364
365 let swept = store.sweep_stale_routes(Duration::from_millis(10)).await;
367 assert_eq!(swept.len(), 1);
368 assert_eq!(swept[0], "old");
369 assert!(!store.has_event_route("old").await);
370 assert!(store.has_event_route("fresh").await);
371 assert!(!store.has_progress_token("tok1").await);
373 assert!(!store.has_active_routes_for_client("pk1").await);
374 }
375
376 #[tokio::test]
377 async fn sweep_stale_routes_returns_zero_when_nothing_expired() {
378 let store = ServerEventRouteStore::new();
379 store
380 .register("e1".into(), "pk1".into(), json!(1), None)
381 .await;
382
383 let swept = store.sweep_stale_routes(Duration::from_secs(60)).await;
384 assert!(swept.is_empty());
385 assert!(store.has_event_route("e1").await);
386 }
387}