Skip to main content

contextvm_sdk/transport/server/
correlation_store.rs

1//! Server-side event route store for mapping event IDs to client routes.
2
3use std::collections::{HashMap, HashSet};
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use lru::LruCache;
9use tokio::sync::RwLock;
10
11use crate::core::constants::DEFAULT_LRU_SIZE;
12
13/// A route entry for an in-flight request.
14#[derive(Debug, Clone)]
15pub struct RouteEntry {
16    /// The client's public key that originated this request.
17    pub client_pubkey: String,
18    /// The original JSON-RPC request ID (before replacement with event ID).
19    pub original_request_id: serde_json::Value,
20    /// Optional progress token for this request.
21    pub progress_token: Option<String>,
22    /// The outer gift-wrap event kind that carried this request (e.g. 1059 or 21059).
23    /// Populated from the inbound event in a later PR; `None` until then.
24    pub wrap_kind: Option<u16>,
25    /// When the route was registered.
26    pub registered_at: Instant,
27}
28
29/// Internal state behind the lock.
30struct Inner {
31    /// Primary index: event_id → route entry (LRU-ordered).
32    routes: LruCache<String, RouteEntry>,
33    /// Secondary index: progress_token → event_id.
34    progress_token_to_event: HashMap<String, String>,
35    /// Secondary index: client_pubkey → set of event_ids.
36    client_event_ids: HashMap<String, HashSet<String>>,
37}
38
39impl Inner {
40    fn new(max_routes: usize) -> Self {
41        let routes =
42            LruCache::new(NonZeroUsize::new(max_routes).unwrap_or(NonZeroUsize::new(1).unwrap()));
43        Self {
44            routes,
45            progress_token_to_event: HashMap::new(),
46            client_event_ids: HashMap::new(),
47        }
48    }
49
50    /// Clean up secondary indexes for a removed route.
51    fn cleanup_indexes(&mut self, event_id: &str, route: &RouteEntry) {
52        if let Some(ref token) = route.progress_token {
53            self.progress_token_to_event.remove(token);
54        }
55        if let Some(set) = self.client_event_ids.get_mut(&route.client_pubkey) {
56            set.remove(event_id);
57            if set.is_empty() {
58                self.client_event_ids.remove(&route.client_pubkey);
59            }
60        }
61    }
62
63    /// Remove a single route and clean up all secondary indexes.
64    fn remove_route(&mut self, event_id: &str) -> Option<RouteEntry> {
65        let route = self.routes.pop(event_id)?;
66        self.cleanup_indexes(event_id, &route);
67        Some(route)
68    }
69}
70
71/// Maps event IDs to full route entries for response routing on the server side.
72///
73/// An optional capacity limit enables LRU eviction; when the limit is reached
74/// the oldest entry is evicted and its secondary indexes are cleaned up.
75#[derive(Clone)]
76pub struct ServerEventRouteStore {
77    inner: Arc<RwLock<Inner>>,
78}
79
80impl Default for ServerEventRouteStore {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl ServerEventRouteStore {
87    pub fn new() -> Self {
88        Self {
89            inner: Arc::new(RwLock::new(Inner::new(DEFAULT_LRU_SIZE))),
90        }
91    }
92
93    /// Create a store with an upper bound on event routes.
94    /// When the limit is reached the oldest entry is evicted.
95    pub fn with_max_routes(max_routes: usize) -> Self {
96        Self {
97            inner: Arc::new(RwLock::new(Inner::new(max_routes))),
98        }
99    }
100
101    /// Register a route for an incoming request.
102    pub async fn register(
103        &self,
104        event_id: String,
105        client_pubkey: String,
106        original_request_id: serde_json::Value,
107        progress_token: Option<String>,
108    ) {
109        let mut inner = self.inner.write().await;
110
111        // Update client index.
112        inner
113            .client_event_ids
114            .entry(client_pubkey.clone())
115            .or_default()
116            .insert(event_id.clone());
117
118        // Update progress token index.
119        if let Some(ref token) = progress_token {
120            inner
121                .progress_token_to_event
122                .insert(token.clone(), event_id.clone());
123        }
124
125        // Insert into LRU; handle possible eviction.
126        let evicted = inner.routes.push(
127            event_id.clone(),
128            RouteEntry {
129                client_pubkey,
130                original_request_id,
131                progress_token,
132                wrap_kind: None,
133                registered_at: Instant::now(),
134            },
135        );
136
137        if let Some((evicted_key, evicted_route)) = evicted {
138            if evicted_key != event_id {
139                // A different entry was evicted due to capacity — clean up its indexes.
140                inner.cleanup_indexes(&evicted_key, &evicted_route);
141            }
142        }
143    }
144
145    /// Returns the client public key for the given event ID without removing it.
146    pub async fn get(&self, event_id: &str) -> Option<String> {
147        self.inner
148            .read()
149            .await
150            .routes
151            .peek(event_id)
152            .map(|r| r.client_pubkey.clone())
153    }
154
155    /// Returns the full route entry for the given event ID without removing it.
156    pub async fn get_route(&self, event_id: &str) -> Option<RouteEntry> {
157        self.inner.read().await.routes.peek(event_id).cloned()
158    }
159
160    /// Removes and returns the full route entry for the given event ID.
161    pub async fn pop(&self, event_id: &str) -> Option<RouteEntry> {
162        self.inner.write().await.remove_route(event_id)
163    }
164
165    /// Removes all routes for a given client public key. Returns the count removed.
166    pub async fn remove_for_client(&self, client_pubkey: &str) -> usize {
167        let mut inner = self.inner.write().await;
168
169        let event_ids = match inner.client_event_ids.remove(client_pubkey) {
170            Some(ids) => ids,
171            None => return 0,
172        };
173
174        let count = event_ids.len();
175        for event_id in &event_ids {
176            if let Some(route) = inner.routes.pop(event_id.as_str()) {
177                if let Some(ref token) = route.progress_token {
178                    inner.progress_token_to_event.remove(token);
179                }
180            }
181        }
182        count
183    }
184
185    /// Check whether a route exists for the given event ID.
186    pub async fn has_event_route(&self, event_id: &str) -> bool {
187        self.inner.read().await.routes.contains(event_id)
188    }
189
190    /// Check whether the given client has any active routes.
191    pub async fn has_active_routes_for_client(&self, client_pubkey: &str) -> bool {
192        self.inner
193            .read()
194            .await
195            .client_event_ids
196            .get(client_pubkey)
197            .is_some_and(|set| !set.is_empty())
198    }
199
200    /// Look up the event ID associated with a progress token.
201    pub async fn get_event_id_by_progress_token(&self, token: &str) -> Option<String> {
202        self.inner
203            .read()
204            .await
205            .progress_token_to_event
206            .get(token)
207            .cloned()
208    }
209
210    /// Check whether a progress token mapping exists.
211    pub async fn has_progress_token(&self, token: &str) -> bool {
212        self.inner
213            .read()
214            .await
215            .progress_token_to_event
216            .contains_key(token)
217    }
218
219    /// Number of event routes currently tracked.
220    pub async fn event_route_count(&self) -> usize {
221        self.inner.read().await.routes.len()
222    }
223
224    /// Number of progress token mappings currently tracked.
225    pub async fn progress_token_count(&self) -> usize {
226        self.inner.read().await.progress_token_to_event.len()
227    }
228
229    /// Remove all route entries older than `timeout`.
230    /// (Routes for expired sessions are already cleaned by `cleanup_sessions`.)
231    /// Returns the event IDs of the removed entries.
232    pub async fn sweep_stale_routes(&self, timeout: Duration) -> Vec<String> {
233        let now = Instant::now();
234        let mut inner = self.inner.write().await;
235        let mut expired_keys = Vec::new();
236
237        for (key, entry) in inner.routes.iter() {
238            if now.duration_since(entry.registered_at) >= timeout {
239                expired_keys.push(key.clone());
240            }
241        }
242
243        for key in &expired_keys {
244            inner.remove_route(key);
245        }
246        expired_keys
247    }
248
249    pub async fn clear(&self) {
250        let mut inner = self.inner.write().await;
251        inner.routes.clear();
252        inner.progress_token_to_event.clear();
253        inner.client_event_ids.clear();
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use serde_json::json;
261
262    #[tokio::test]
263    async fn pop_on_empty_returns_none() {
264        let store = ServerEventRouteStore::new();
265        assert!(store.pop("nonexistent").await.is_none());
266    }
267
268    #[tokio::test]
269    async fn get_returns_without_removing() {
270        let store = ServerEventRouteStore::new();
271        store
272            .register("e1".into(), "pk1".into(), json!("r1"), None)
273            .await;
274        assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
275        assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
276    }
277
278    #[tokio::test]
279    async fn pop_removes_entry() {
280        let store = ServerEventRouteStore::new();
281        store
282            .register("e1".into(), "pk1".into(), json!("r1"), None)
283            .await;
284        let route = store.pop("e1").await.unwrap();
285        assert_eq!(route.client_pubkey, "pk1");
286        assert!(store.pop("e1").await.is_none());
287    }
288
289    #[tokio::test]
290    async fn remove_for_client_only_removes_matching() {
291        let store = ServerEventRouteStore::new();
292        store
293            .register("e1".into(), "pk1".into(), json!("r1"), None)
294            .await;
295        store
296            .register("e2".into(), "pk2".into(), json!("r2"), None)
297            .await;
298        store
299            .register("e3".into(), "pk1".into(), json!("r3"), None)
300            .await;
301
302        let removed = store.remove_for_client("pk1").await;
303        assert_eq!(removed, 2);
304
305        assert!(store.get("e1").await.is_none());
306        assert!(store.get("e3").await.is_none());
307        assert_eq!(store.get("e2").await.as_deref(), Some("pk2"));
308    }
309
310    #[tokio::test]
311    async fn remove_for_client_noop_when_no_match() {
312        let store = ServerEventRouteStore::new();
313        store
314            .register("e1".into(), "pk1".into(), json!("r1"), None)
315            .await;
316        let removed = store.remove_for_client("pk_other").await;
317        assert_eq!(removed, 0);
318        assert_eq!(store.get("e1").await.as_deref(), Some("pk1"));
319    }
320
321    #[tokio::test]
322    async fn clear_empties_store() {
323        let store = ServerEventRouteStore::new();
324        store
325            .register("e1".into(), "pk1".into(), json!("r1"), None)
326            .await;
327        store
328            .register("e2".into(), "pk2".into(), json!("r2"), None)
329            .await;
330        store.clear().await;
331        assert!(store.get("e1").await.is_none());
332        assert!(store.get("e2").await.is_none());
333    }
334
335    #[tokio::test]
336    async fn default_store_is_bounded() {
337        let store = ServerEventRouteStore::new();
338        for i in 0..=DEFAULT_LRU_SIZE {
339            store
340                .register(format!("e{i}"), "pk1".into(), json!(i), None)
341                .await;
342        }
343
344        assert_eq!(store.event_route_count().await, DEFAULT_LRU_SIZE);
345        assert!(!store.has_event_route("e0").await);
346        assert!(store.has_event_route(&format!("e{DEFAULT_LRU_SIZE}")).await);
347    }
348
349    #[tokio::test]
350    async fn sweep_stale_routes_removes_only_expired() {
351        let store = ServerEventRouteStore::new();
352
353        // Insert a route that will age past the threshold.
354        store
355            .register("old".into(), "pk1".into(), json!(1), Some("tok1".into()))
356            .await;
357
358        tokio::time::sleep(Duration::from_millis(20)).await;
359
360        // Insert a fresh route.
361        store
362            .register("fresh".into(), "pk2".into(), json!(2), None)
363            .await;
364
365        // Sweep with 10ms timeout — "old" should be removed, "fresh" should remain.
366        let swept = store.sweep_stale_routes(Duration::from_millis(10)).await;
367        assert_eq!(swept.len(), 1);
368        assert_eq!(swept[0], "old");
369        assert!(!store.has_event_route("old").await);
370        assert!(store.has_event_route("fresh").await);
371        // Secondary indexes should also be cleaned.
372        assert!(!store.has_progress_token("tok1").await);
373        assert!(!store.has_active_routes_for_client("pk1").await);
374    }
375
376    #[tokio::test]
377    async fn sweep_stale_routes_returns_zero_when_nothing_expired() {
378        let store = ServerEventRouteStore::new();
379        store
380            .register("e1".into(), "pk1".into(), json!(1), None)
381            .await;
382
383        let swept = store.sweep_stale_routes(Duration::from_secs(60)).await;
384        assert!(swept.is_empty());
385        assert!(store.has_event_route("e1").await);
386    }
387}