1use crate::rpc::{RpcError, RpcResponse};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use tokio::sync::oneshot;
10use serde::{Deserialize, Serialize};
11
12struct PendingRequest {
14 response_tx: oneshot::Sender<Result<RpcResponse<serde_json::Value>, RpcError>>,
16 timeout_at: Instant,
18 method: String,
20}
21
22#[derive(Clone)]
24pub struct RpcCorrelationManager {
25 pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
27 default_timeout: Duration,
29}
30
31impl RpcCorrelationManager {
32 pub fn new() -> Self {
34 Self::with_timeout(Duration::from_secs(30))
35 }
36
37 pub fn with_timeout(timeout: Duration) -> Self {
39 Self {
40 pending_requests: Arc::new(Mutex::new(HashMap::new())),
41 default_timeout: timeout,
42 }
43 }
44
45 pub fn register_request(
48 &self,
49 request_id: String,
50 method: String,
51 ) -> oneshot::Receiver<Result<RpcResponse<serde_json::Value>, RpcError>> {
52 let (response_tx, response_rx) = oneshot::channel();
53
54 let pending_request = PendingRequest {
55 response_tx,
56 timeout_at: Instant::now() + self.default_timeout,
57 method,
58 };
59
60 {
61 let mut pending = self.pending_requests.lock().unwrap();
62 pending.insert(request_id, pending_request);
63 }
64
65 response_rx
66 }
67
68 pub fn handle_response(&self, response: RpcResponse<serde_json::Value>) -> Result<(), RpcError> {
70 let mut pending = self.pending_requests.lock().unwrap();
71
72 if let Some(pending_request) = pending.remove(&response.id) {
73 if Instant::now() > pending_request.timeout_at {
75 return Err(RpcError {
76 code: -32603,
77 message: format!("Request {} timed out", response.id),
78 data: None,
79 });
80 }
81
82 match pending_request.response_tx.send(Ok(response)) {
84 Ok(_) => Ok(()),
85 Err(_) => Err(RpcError {
86 code: -32603,
87 message: "Caller dropped request before response arrived".to_string(),
88 data: None,
89 }),
90 }
91 } else {
92 Err(RpcError {
93 code: -32603,
94 message: format!("No pending request found for ID: {}", response.id),
95 data: None,
96 })
97 }
98 }
99
100 pub fn handle_error_response(&self, request_id: String, error: RpcError) -> Result<(), RpcError> {
102 let mut pending = self.pending_requests.lock().unwrap();
103
104 if let Some(pending_request) = pending.remove(&request_id) {
105 match pending_request.response_tx.send(Err(error)) {
107 Ok(_) => Ok(()),
108 Err(_) => Err(RpcError {
109 code: -32603,
110 message: "Caller dropped request before error response arrived".to_string(),
111 data: None,
112 }),
113 }
114 } else {
115 Err(RpcError {
116 code: -32603,
117 message: format!("No pending request found for error response ID: {}", request_id),
118 data: None,
119 })
120 }
121 }
122
123 pub fn cleanup_expired(&self) -> usize {
126 let mut pending = self.pending_requests.lock().unwrap();
127 let now = Instant::now();
128
129 let expired_ids: Vec<String> = pending
130 .iter()
131 .filter(|(_, request)| now > request.timeout_at)
132 .map(|(id, _)| id.clone())
133 .collect();
134
135 let cleanup_count = expired_ids.len();
136
137 for id in expired_ids {
138 if let Some(expired_request) = pending.remove(&id) {
139 let timeout_error = RpcError {
140 code: -32603,
141 message: format!("Request {} timed out after {:?}", id, self.default_timeout),
142 data: Some(serde_json::json!({
143 "method": expired_request.method,
144 "timeout_duration_secs": self.default_timeout.as_secs()
145 })),
146 };
147
148 let _ = expired_request.response_tx.send(Err(timeout_error));
150 }
151 }
152
153 cleanup_count
154 }
155
156 pub fn pending_count(&self) -> usize {
158 self.pending_requests.lock().unwrap().len()
159 }
160
161 pub fn pending_request_ids(&self) -> Vec<String> {
163 self.pending_requests.lock().unwrap().keys().cloned().collect()
164 }
165
166 pub fn cancel_request(&self, request_id: &str) -> bool {
168 let mut pending = self.pending_requests.lock().unwrap();
169
170 if let Some(cancelled_request) = pending.remove(request_id) {
171 let cancel_error = RpcError {
172 code: -32603,
173 message: format!("Request {} was cancelled", request_id),
174 data: None,
175 };
176
177 let _ = cancelled_request.response_tx.send(Err(cancel_error));
179 true
180 } else {
181 false
182 }
183 }
184
185 pub fn cancel_all(&self) -> usize {
187 let mut pending = self.pending_requests.lock().unwrap();
188 let count = pending.len();
189
190 for (request_id, cancelled_request) in pending.drain() {
191 let cancel_error = RpcError {
192 code: -32603,
193 message: format!("Request {} was cancelled due to shutdown", request_id),
194 data: None,
195 };
196
197 let _ = cancelled_request.response_tx.send(Err(cancel_error));
198 }
199
200 count
201 }
202}
203
204impl Default for RpcCorrelationManager {
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210pub struct CorrelationCleanupTask {
212 manager: RpcCorrelationManager,
213 cleanup_interval: Duration,
214}
215
216impl CorrelationCleanupTask {
217 pub fn new(manager: RpcCorrelationManager) -> Self {
219 Self {
220 manager,
221 cleanup_interval: Duration::from_secs(10), }
223 }
224
225 pub fn with_interval(manager: RpcCorrelationManager, interval: Duration) -> Self {
227 Self {
228 manager,
229 cleanup_interval: interval,
230 }
231 }
232
233 pub async fn run(&self) {
235 let mut interval = tokio::time::interval(self.cleanup_interval);
236
237 loop {
238 interval.tick().await;
239 let cleaned_up = self.manager.cleanup_expired();
240
241 if cleaned_up > 0 {
242 tracing::debug!("Cleaned up {} expired RPC requests", cleaned_up);
243 }
244 }
245 }
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct CorrelationStats {
251 pub pending_requests: usize,
252 pub total_requests_processed: u64,
253 pub total_timeouts: u64,
254 pub total_cancellations: u64,
255 pub average_response_time_ms: f64,
256}
257
258impl CorrelationStats {
259 pub fn new() -> Self {
260 Self {
261 pending_requests: 0,
262 total_requests_processed: 0,
263 total_timeouts: 0,
264 total_cancellations: 0,
265 average_response_time_ms: 0.0,
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use tokio::time::{sleep, Duration};
274
275 #[tokio::test]
276 async fn test_correlation_manager_basic() {
277 let manager = RpcCorrelationManager::new();
278
279 let request_id = "test_123".to_string();
281 let method = "test_method".to_string();
282 let response_rx = manager.register_request(request_id.clone(), method.clone());
283
284 assert_eq!(manager.pending_count(), 1);
285
286 let response = RpcResponse {
288 id: request_id.clone(),
289 result: Some(serde_json::json!({"success": true})),
290 error: None,
291 };
292
293 assert!(manager.handle_response(response).is_ok());
295
296 let result = response_rx.await.unwrap();
298 assert!(result.is_ok());
299 let rpc_response = result.unwrap();
300 assert_eq!(rpc_response.id, request_id);
301
302 assert_eq!(manager.pending_count(), 0);
304 }
305
306 #[tokio::test]
307 async fn test_correlation_manager_timeout() {
308 let manager = RpcCorrelationManager::with_timeout(Duration::from_millis(100));
309
310 let request_id = "timeout_test".to_string();
312 let method = "timeout_method".to_string();
313 let response_rx = manager.register_request(request_id.clone(), method);
314
315 sleep(Duration::from_millis(200)).await;
317
318 let cleaned_up = manager.cleanup_expired();
320 assert_eq!(cleaned_up, 1);
321
322 let result = response_rx.await.unwrap();
324 assert!(result.is_err());
325 let error = result.unwrap_err();
326 assert!(error.message.contains("timed out"));
327 }
328
329 #[tokio::test]
330 async fn test_correlation_manager_error_response() {
331 let manager = RpcCorrelationManager::new();
332
333 let request_id = "error_test".to_string();
335 let method = "error_method".to_string();
336 let response_rx = manager.register_request(request_id.clone(), method);
337
338 let error = RpcError {
340 code: 404,
341 message: "Method not found".to_string(),
342 data: None,
343 };
344
345 assert!(manager.handle_error_response(request_id, error.clone()).is_ok());
346
347 let result = response_rx.await.unwrap();
349 assert!(result.is_err());
350 let received_error = result.unwrap_err();
351 assert_eq!(received_error.code, 404);
352 assert_eq!(received_error.message, "Method not found");
353 }
354
355 #[tokio::test]
356 async fn test_correlation_manager_cancellation() {
357 let manager = RpcCorrelationManager::new();
358
359 let request_id = "cancel_test".to_string();
361 let method = "cancel_method".to_string();
362 let response_rx = manager.register_request(request_id.clone(), method);
363
364 let cancelled = manager.cancel_request(&request_id);
366 assert!(cancelled);
367
368 let result = response_rx.await.unwrap();
370 assert!(result.is_err());
371 let error = result.unwrap_err();
372 assert!(error.message.contains("cancelled"));
373 }
374}