contextvm_sdk/transport/client/
correlation_store.rs1use std::num::NonZeroUsize;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use lru::LruCache;
8use tokio::sync::RwLock;
9
10use crate::core::constants::DEFAULT_LRU_SIZE;
11
12#[derive(Debug, Clone)]
14pub struct PendingRequest {
15 pub original_id: serde_json::Value,
17 pub is_initialize: bool,
19 pub registered_at: Instant,
21}
22
23#[derive(Clone)]
28pub struct ClientCorrelationStore {
29 pending_requests: Arc<RwLock<LruCache<String, PendingRequest>>>,
30}
31
32impl Default for ClientCorrelationStore {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl ClientCorrelationStore {
39 pub fn new() -> Self {
40 Self::with_max_pending(DEFAULT_LRU_SIZE)
41 }
42
43 pub fn with_max_pending(max_pending: usize) -> Self {
46 Self {
47 pending_requests: Arc::new(RwLock::new(LruCache::new(
48 NonZeroUsize::new(max_pending).unwrap_or(NonZeroUsize::new(1).unwrap()),
49 ))),
50 }
51 }
52
53 pub async fn register(
55 &self,
56 event_id: String,
57 original_id: serde_json::Value,
58 is_initialize: bool,
59 ) {
60 self.pending_requests.write().await.push(
61 event_id,
62 PendingRequest {
63 original_id,
64 is_initialize,
65 registered_at: Instant::now(),
66 },
67 );
68 }
69
70 pub async fn is_initialize_request(&self, event_id: &str) -> bool {
72 self.pending_requests
73 .read()
74 .await
75 .peek(event_id)
76 .is_some_and(|r| r.is_initialize)
77 }
78
79 pub async fn contains(&self, event_id: &str) -> bool {
80 self.pending_requests.read().await.contains(event_id)
81 }
82
83 pub async fn remove(&self, event_id: &str) -> bool {
85 self.pending_requests.write().await.pop(event_id).is_some()
86 }
87
88 pub async fn get_original_id(&self, event_id: &str) -> Option<serde_json::Value> {
90 self.pending_requests
91 .read()
92 .await
93 .peek(event_id)
94 .map(|r| r.original_id.clone())
95 }
96
97 pub async fn count(&self) -> usize {
99 self.pending_requests.read().await.len()
100 }
101
102 pub async fn sweep_expired(&self, timeout: Duration) -> usize {
104 let now = Instant::now();
105 let mut cache = self.pending_requests.write().await;
106 let mut expired_keys = Vec::new();
107
108 for (key, entry) in cache.iter() {
109 if now.duration_since(entry.registered_at) >= timeout {
110 expired_keys.push(key.clone());
111 }
112 }
113
114 let count = expired_keys.len();
115 for key in expired_keys {
116 cache.pop(&key);
117 }
118 count
119 }
120
121 pub async fn clear(&self) {
122 self.pending_requests.write().await.clear();
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[tokio::test]
131 async fn remove_nonexistent_is_noop() {
132 let store = ClientCorrelationStore::new();
133 assert!(!store.remove("nonexistent").await);
134 assert!(!store.contains("nonexistent").await);
135 }
136
137 #[tokio::test]
138 async fn contains_after_clear() {
139 let store = ClientCorrelationStore::new();
140 store
141 .register("e1".into(), serde_json::Value::Null, false)
142 .await;
143 store
144 .register("e2".into(), serde_json::Value::Null, false)
145 .await;
146 assert!(store.contains("e1").await);
147 store.clear().await;
148 assert!(!store.contains("e1").await);
149 assert!(!store.contains("e2").await);
150 }
151
152 #[tokio::test]
153 async fn register_and_remove_roundtrip() {
154 let store = ClientCorrelationStore::new();
155 store
156 .register("e1".into(), serde_json::Value::Null, false)
157 .await;
158 assert!(store.contains("e1").await);
159 assert!(store.remove("e1").await);
160 assert!(!store.contains("e1").await);
161 }
162
163 #[tokio::test]
164 async fn default_store_is_bounded() {
165 let store = ClientCorrelationStore::new();
166 for i in 0..=DEFAULT_LRU_SIZE {
167 store
168 .register(format!("e{i}"), serde_json::Value::Null, false)
169 .await;
170 }
171
172 assert_eq!(store.count().await, DEFAULT_LRU_SIZE);
173 assert!(!store.contains("e0").await);
174 assert!(store.contains(&format!("e{DEFAULT_LRU_SIZE}")).await);
175 }
176
177 #[tokio::test]
178 async fn sweep_expired_removes_only_stale_entries() {
179 let store = ClientCorrelationStore::new();
180
181 store
183 .register("old".into(), serde_json::json!(1), false)
184 .await;
185
186 tokio::time::sleep(Duration::from_millis(20)).await;
188
189 store
191 .register("fresh".into(), serde_json::json!(2), false)
192 .await;
193
194 let swept = store.sweep_expired(Duration::from_millis(10)).await;
196 assert_eq!(swept, 1);
197 assert!(!store.contains("old").await);
198 assert!(store.contains("fresh").await);
199 }
200
201 #[tokio::test]
202 async fn sweep_expired_returns_zero_when_nothing_expired() {
203 let store = ClientCorrelationStore::new();
204 store
205 .register("e1".into(), serde_json::Value::Null, false)
206 .await;
207
208 let swept = store.sweep_expired(Duration::from_secs(60)).await;
209 assert_eq!(swept, 0);
210 assert!(store.contains("e1").await);
211 }
212}