Skip to main content

contextvm_sdk/transport/client/
correlation_store.rs

1//! Client-side correlation store for tracking pending request event IDs.
2
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use lru::LruCache;
8use tokio::sync::RwLock;
9
10use crate::core::constants::DEFAULT_LRU_SIZE;
11
12/// A pending request tracked by the correlation store.
13#[derive(Debug, Clone)]
14pub struct PendingRequest {
15    /// The original JSON-RPC request ID before event-ID replacement.
16    pub original_id: serde_json::Value,
17    /// Whether this request is an `initialize` handshake.
18    pub is_initialize: bool,
19    /// When the request was registered.
20    pub registered_at: Instant,
21}
22
23/// Tracks pending request event IDs and their original request IDs on the client side.
24///
25/// An optional capacity limit enables LRU eviction of the oldest entry when the
26/// store is full.
27#[derive(Clone)]
28pub struct ClientCorrelationStore {
29    pending_requests: Arc<RwLock<LruCache<String, PendingRequest>>>,
30}
31
32impl Default for ClientCorrelationStore {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl ClientCorrelationStore {
39    pub fn new() -> Self {
40        Self::with_max_pending(DEFAULT_LRU_SIZE)
41    }
42
43    /// Create a store with an upper bound on pending requests.
44    /// When the limit is reached the oldest entry is evicted.
45    pub fn with_max_pending(max_pending: usize) -> Self {
46        Self {
47            pending_requests: Arc::new(RwLock::new(LruCache::new(
48                NonZeroUsize::new(max_pending).unwrap_or(NonZeroUsize::new(1).unwrap()),
49            ))),
50        }
51    }
52
53    /// Register a pending request with its original JSON-RPC request ID.
54    pub async fn register(
55        &self,
56        event_id: String,
57        original_id: serde_json::Value,
58        is_initialize: bool,
59    ) {
60        self.pending_requests.write().await.push(
61            event_id,
62            PendingRequest {
63                original_id,
64                is_initialize,
65                registered_at: Instant::now(),
66            },
67        );
68    }
69
70    /// Check whether a given event ID corresponds to an `initialize` request.
71    pub async fn is_initialize_request(&self, event_id: &str) -> bool {
72        self.pending_requests
73            .read()
74            .await
75            .peek(event_id)
76            .is_some_and(|r| r.is_initialize)
77    }
78
79    pub async fn contains(&self, event_id: &str) -> bool {
80        self.pending_requests.read().await.contains(event_id)
81    }
82
83    /// Remove a pending request. Returns `true` if the key existed.
84    pub async fn remove(&self, event_id: &str) -> bool {
85        self.pending_requests.write().await.pop(event_id).is_some()
86    }
87
88    /// Retrieve the original request ID for a given event ID without removing it.
89    pub async fn get_original_id(&self, event_id: &str) -> Option<serde_json::Value> {
90        self.pending_requests
91            .read()
92            .await
93            .peek(event_id)
94            .map(|r| r.original_id.clone())
95    }
96
97    /// Number of pending requests currently tracked.
98    pub async fn count(&self) -> usize {
99        self.pending_requests.read().await.len()
100    }
101
102    /// Remove all entries older than `timeout`. Returns the number of entries removed.
103    pub async fn sweep_expired(&self, timeout: Duration) -> usize {
104        let now = Instant::now();
105        let mut cache = self.pending_requests.write().await;
106        let mut expired_keys = Vec::new();
107
108        for (key, entry) in cache.iter() {
109            if now.duration_since(entry.registered_at) >= timeout {
110                expired_keys.push(key.clone());
111            }
112        }
113
114        let count = expired_keys.len();
115        for key in expired_keys {
116            cache.pop(&key);
117        }
118        count
119    }
120
121    pub async fn clear(&self) {
122        self.pending_requests.write().await.clear();
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[tokio::test]
131    async fn remove_nonexistent_is_noop() {
132        let store = ClientCorrelationStore::new();
133        assert!(!store.remove("nonexistent").await);
134        assert!(!store.contains("nonexistent").await);
135    }
136
137    #[tokio::test]
138    async fn contains_after_clear() {
139        let store = ClientCorrelationStore::new();
140        store
141            .register("e1".into(), serde_json::Value::Null, false)
142            .await;
143        store
144            .register("e2".into(), serde_json::Value::Null, false)
145            .await;
146        assert!(store.contains("e1").await);
147        store.clear().await;
148        assert!(!store.contains("e1").await);
149        assert!(!store.contains("e2").await);
150    }
151
152    #[tokio::test]
153    async fn register_and_remove_roundtrip() {
154        let store = ClientCorrelationStore::new();
155        store
156            .register("e1".into(), serde_json::Value::Null, false)
157            .await;
158        assert!(store.contains("e1").await);
159        assert!(store.remove("e1").await);
160        assert!(!store.contains("e1").await);
161    }
162
163    #[tokio::test]
164    async fn default_store_is_bounded() {
165        let store = ClientCorrelationStore::new();
166        for i in 0..=DEFAULT_LRU_SIZE {
167            store
168                .register(format!("e{i}"), serde_json::Value::Null, false)
169                .await;
170        }
171
172        assert_eq!(store.count().await, DEFAULT_LRU_SIZE);
173        assert!(!store.contains("e0").await);
174        assert!(store.contains(&format!("e{DEFAULT_LRU_SIZE}")).await);
175    }
176
177    #[tokio::test]
178    async fn sweep_expired_removes_only_stale_entries() {
179        let store = ClientCorrelationStore::new();
180
181        // Insert an entry that will be "old" by the time we sweep.
182        store
183            .register("old".into(), serde_json::json!(1), false)
184            .await;
185
186        // Sleep so "old" entry ages past the threshold.
187        tokio::time::sleep(Duration::from_millis(20)).await;
188
189        // Insert a fresh entry.
190        store
191            .register("fresh".into(), serde_json::json!(2), false)
192            .await;
193
194        // Sweep with a 10ms timeout — "old" should be removed, "fresh" should remain.
195        let swept = store.sweep_expired(Duration::from_millis(10)).await;
196        assert_eq!(swept, 1);
197        assert!(!store.contains("old").await);
198        assert!(store.contains("fresh").await);
199    }
200
201    #[tokio::test]
202    async fn sweep_expired_returns_zero_when_nothing_expired() {
203        let store = ClientCorrelationStore::new();
204        store
205            .register("e1".into(), serde_json::Value::Null, false)
206            .await;
207
208        let swept = store.sweep_expired(Duration::from_secs(60)).await;
209        assert_eq!(swept, 0);
210        assert!(store.contains("e1").await);
211    }
212}