leptos_ws_pro/rpc/
correlation.rs

1//! Real RPC Request/Response Correlation System
2//!
3//! Provides production-ready correlation of RPC requests with WebSocket responses
4
5use crate::rpc::{RpcError, RpcResponse};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use tokio::sync::oneshot;
10use serde::{Deserialize, Serialize};
11
12/// Pending RPC request awaiting response
13struct PendingRequest {
14    /// Channel to send response back to caller
15    response_tx: oneshot::Sender<Result<RpcResponse<serde_json::Value>, RpcError>>,
16    /// When this request times out
17    timeout_at: Instant,
18    /// Method name for debugging
19    method: String,
20}
21
22/// RPC Correlation Manager handles request/response correlation
23#[derive(Clone)]
24pub struct RpcCorrelationManager {
25    /// Map of request ID -> pending request
26    pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
27    /// Default timeout for requests
28    default_timeout: Duration,
29}
30
31impl RpcCorrelationManager {
32    /// Create new correlation manager with default 30-second timeout
33    pub fn new() -> Self {
34        Self::with_timeout(Duration::from_secs(30))
35    }
36
37    /// Create correlation manager with custom timeout
38    pub fn with_timeout(timeout: Duration) -> Self {
39        Self {
40            pending_requests: Arc::new(Mutex::new(HashMap::new())),
41            default_timeout: timeout,
42        }
43    }
44
45    /// Register a new pending request
46    /// Returns a receiver that will get the response when it arrives
47    pub fn register_request(
48        &self,
49        request_id: String,
50        method: String,
51    ) -> oneshot::Receiver<Result<RpcResponse<serde_json::Value>, RpcError>> {
52        let (response_tx, response_rx) = oneshot::channel();
53
54        let pending_request = PendingRequest {
55            response_tx,
56            timeout_at: Instant::now() + self.default_timeout,
57            method,
58        };
59
60        {
61            let mut pending = self.pending_requests.lock().unwrap();
62            pending.insert(request_id, pending_request);
63        }
64
65        response_rx
66    }
67
68    /// Handle incoming RPC response, correlating it with pending request
69    pub fn handle_response(&self, response: RpcResponse<serde_json::Value>) -> Result<(), RpcError> {
70        let mut pending = self.pending_requests.lock().unwrap();
71
72        if let Some(pending_request) = pending.remove(&response.id) {
73            // Check if request has timed out
74            if Instant::now() > pending_request.timeout_at {
75                return Err(RpcError {
76                    code: -32603,
77                    message: format!("Request {} timed out", response.id),
78                    data: None,
79                });
80            }
81
82            // Send response back to caller
83            match pending_request.response_tx.send(Ok(response)) {
84                Ok(_) => Ok(()),
85                Err(_) => Err(RpcError {
86                    code: -32603,
87                    message: "Caller dropped request before response arrived".to_string(),
88                    data: None,
89                }),
90            }
91        } else {
92            Err(RpcError {
93                code: -32603,
94                message: format!("No pending request found for ID: {}", response.id),
95                data: None,
96            })
97        }
98    }
99
100    /// Handle incoming RPC error response
101    pub fn handle_error_response(&self, request_id: String, error: RpcError) -> Result<(), RpcError> {
102        let mut pending = self.pending_requests.lock().unwrap();
103
104        if let Some(pending_request) = pending.remove(&request_id) {
105            // Send error back to caller
106            match pending_request.response_tx.send(Err(error)) {
107                Ok(_) => Ok(()),
108                Err(_) => Err(RpcError {
109                    code: -32603,
110                    message: "Caller dropped request before error response arrived".to_string(),
111                    data: None,
112                }),
113            }
114        } else {
115            Err(RpcError {
116                code: -32603,
117                message: format!("No pending request found for error response ID: {}", request_id),
118                data: None,
119            })
120        }
121    }
122
123    /// Clean up expired/timed out requests
124    /// Returns number of requests cleaned up
125    pub fn cleanup_expired(&self) -> usize {
126        let mut pending = self.pending_requests.lock().unwrap();
127        let now = Instant::now();
128
129        let expired_ids: Vec<String> = pending
130            .iter()
131            .filter(|(_, request)| now > request.timeout_at)
132            .map(|(id, _)| id.clone())
133            .collect();
134
135        let cleanup_count = expired_ids.len();
136
137        for id in expired_ids {
138            if let Some(expired_request) = pending.remove(&id) {
139                let timeout_error = RpcError {
140                    code: -32603,
141                    message: format!("Request {} timed out after {:?}", id, self.default_timeout),
142                    data: Some(serde_json::json!({
143                        "method": expired_request.method,
144                        "timeout_duration_secs": self.default_timeout.as_secs()
145                    })),
146                };
147
148                // Try to notify caller of timeout (may fail if caller dropped)
149                let _ = expired_request.response_tx.send(Err(timeout_error));
150            }
151        }
152
153        cleanup_count
154    }
155
156    /// Get number of currently pending requests
157    pub fn pending_count(&self) -> usize {
158        self.pending_requests.lock().unwrap().len()
159    }
160
161    /// Get list of pending request IDs (for debugging)
162    pub fn pending_request_ids(&self) -> Vec<String> {
163        self.pending_requests.lock().unwrap().keys().cloned().collect()
164    }
165
166    /// Cancel a specific pending request
167    pub fn cancel_request(&self, request_id: &str) -> bool {
168        let mut pending = self.pending_requests.lock().unwrap();
169
170        if let Some(cancelled_request) = pending.remove(request_id) {
171            let cancel_error = RpcError {
172                code: -32603,
173                message: format!("Request {} was cancelled", request_id),
174                data: None,
175            };
176
177            // Notify caller of cancellation
178            let _ = cancelled_request.response_tx.send(Err(cancel_error));
179            true
180        } else {
181            false
182        }
183    }
184
185    /// Cancel all pending requests
186    pub fn cancel_all(&self) -> usize {
187        let mut pending = self.pending_requests.lock().unwrap();
188        let count = pending.len();
189
190        for (request_id, cancelled_request) in pending.drain() {
191            let cancel_error = RpcError {
192                code: -32603,
193                message: format!("Request {} was cancelled due to shutdown", request_id),
194                data: None,
195            };
196
197            let _ = cancelled_request.response_tx.send(Err(cancel_error));
198        }
199
200        count
201    }
202}
203
204impl Default for RpcCorrelationManager {
205    fn default() -> Self {
206        Self::new()
207    }
208}
209
210/// Background task that periodically cleans up expired requests
211pub struct CorrelationCleanupTask {
212    manager: RpcCorrelationManager,
213    cleanup_interval: Duration,
214}
215
216impl CorrelationCleanupTask {
217    /// Create new cleanup task
218    pub fn new(manager: RpcCorrelationManager) -> Self {
219        Self {
220            manager,
221            cleanup_interval: Duration::from_secs(10), // Clean up every 10 seconds
222        }
223    }
224
225    /// Create cleanup task with custom interval
226    pub fn with_interval(manager: RpcCorrelationManager, interval: Duration) -> Self {
227        Self {
228            manager,
229            cleanup_interval: interval,
230        }
231    }
232
233    /// Run the cleanup task (should be spawned as background task)
234    pub async fn run(&self) {
235        let mut interval = tokio::time::interval(self.cleanup_interval);
236
237        loop {
238            interval.tick().await;
239            let cleaned_up = self.manager.cleanup_expired();
240
241            if cleaned_up > 0 {
242                tracing::debug!("Cleaned up {} expired RPC requests", cleaned_up);
243            }
244        }
245    }
246}
247
248/// Statistics about correlation manager performance
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct CorrelationStats {
251    pub pending_requests: usize,
252    pub total_requests_processed: u64,
253    pub total_timeouts: u64,
254    pub total_cancellations: u64,
255    pub average_response_time_ms: f64,
256}
257
258impl CorrelationStats {
259    pub fn new() -> Self {
260        Self {
261            pending_requests: 0,
262            total_requests_processed: 0,
263            total_timeouts: 0,
264            total_cancellations: 0,
265            average_response_time_ms: 0.0,
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use tokio::time::{sleep, Duration};
274
275    #[tokio::test]
276    async fn test_correlation_manager_basic() {
277        let manager = RpcCorrelationManager::new();
278
279        // Register a request
280        let request_id = "test_123".to_string();
281        let method = "test_method".to_string();
282        let response_rx = manager.register_request(request_id.clone(), method.clone());
283
284        assert_eq!(manager.pending_count(), 1);
285
286        // Simulate response
287        let response = RpcResponse {
288            id: request_id.clone(),
289            result: Some(serde_json::json!({"success": true})),
290            error: None,
291        };
292
293        // Handle response
294        assert!(manager.handle_response(response).is_ok());
295
296        // Should have received response
297        let result = response_rx.await.unwrap();
298        assert!(result.is_ok());
299        let rpc_response = result.unwrap();
300        assert_eq!(rpc_response.id, request_id);
301
302        // Should no longer be pending
303        assert_eq!(manager.pending_count(), 0);
304    }
305
306    #[tokio::test]
307    async fn test_correlation_manager_timeout() {
308        let manager = RpcCorrelationManager::with_timeout(Duration::from_millis(100));
309
310        // Register a request
311        let request_id = "timeout_test".to_string();
312        let method = "timeout_method".to_string();
313        let response_rx = manager.register_request(request_id.clone(), method);
314
315        // Wait for timeout
316        sleep(Duration::from_millis(200)).await;
317
318        // Clean up expired requests
319        let cleaned_up = manager.cleanup_expired();
320        assert_eq!(cleaned_up, 1);
321
322        // Should have received timeout error
323        let result = response_rx.await.unwrap();
324        assert!(result.is_err());
325        let error = result.unwrap_err();
326        assert!(error.message.contains("timed out"));
327    }
328
329    #[tokio::test]
330    async fn test_correlation_manager_error_response() {
331        let manager = RpcCorrelationManager::new();
332
333        // Register a request
334        let request_id = "error_test".to_string();
335        let method = "error_method".to_string();
336        let response_rx = manager.register_request(request_id.clone(), method);
337
338        // Simulate error response
339        let error = RpcError {
340            code: 404,
341            message: "Method not found".to_string(),
342            data: None,
343        };
344
345        assert!(manager.handle_error_response(request_id, error.clone()).is_ok());
346
347        // Should have received error
348        let result = response_rx.await.unwrap();
349        assert!(result.is_err());
350        let received_error = result.unwrap_err();
351        assert_eq!(received_error.code, 404);
352        assert_eq!(received_error.message, "Method not found");
353    }
354
355    #[tokio::test]
356    async fn test_correlation_manager_cancellation() {
357        let manager = RpcCorrelationManager::new();
358
359        // Register a request
360        let request_id = "cancel_test".to_string();
361        let method = "cancel_method".to_string();
362        let response_rx = manager.register_request(request_id.clone(), method);
363
364        // Cancel the request
365        let cancelled = manager.cancel_request(&request_id);
366        assert!(cancelled);
367
368        // Should have received cancellation error
369        let result = response_rx.await.unwrap();
370        assert!(result.is_err());
371        let error = result.unwrap_err();
372        assert!(error.message.contains("cancelled"));
373    }
374}