Skip to main content

chainrpc_core/
dedup.rs

1//! Request deduplication — coalesces identical in-flight RPC requests.
2//!
3//! When multiple callers issue the same request concurrently, only one actual
4//! transport call is made.  All waiters receive a clone of the result.
5//!
6//! # Key design
7//!
8//! - A request is identified by `hash(method, params)` (same scheme as cache).
9//! - In-flight tracking uses `tokio::sync::watch` channels: the first caller
10//!   creates a channel and sends the result; subsequent callers subscribe.
11//! - After the result is broadcast, the entry is removed from the pending map.
12
13use std::collections::hash_map::DefaultHasher;
14use std::collections::HashMap;
15use std::hash::{Hash, Hasher};
16use std::sync::{Arc, Mutex};
17
18use tokio::sync::watch;
19
20use crate::error::TransportError;
21use crate::request::{JsonRpcRequest, JsonRpcResponse};
22use crate::transport::RpcTransport;
23
24// ---------------------------------------------------------------------------
25// DedupTransport
26// ---------------------------------------------------------------------------
27
28/// Deduplicates identical in-flight RPC requests.
29///
30/// If two tasks call `send()` with the same `(method, params)` at the same
31/// time, only one transport call is made.  Both tasks receive a clone of the
32/// response (or the same error message).
33pub struct DedupTransport {
34    inner: Arc<dyn RpcTransport>,
35    /// Map from request-key to a watch receiver.
36    ///
37    /// The channel starts with `None` and is set to `Some(result)` once the
38    /// in-flight request completes.
39    pending: Mutex<HashMap<u64, watch::Receiver<Option<DedupResult>>>>,
40}
41
42/// The result type stored inside the watch channel.
43///
44/// We cannot clone `TransportError` (it doesn't derive Clone), so we
45/// represent errors as a string and re-wrap them on the receiving side.
46type DedupResult = Result<JsonRpcResponse, String>;
47
48impl DedupTransport {
49    /// Wrap an inner transport with request deduplication.
50    pub fn new(inner: Arc<dyn RpcTransport>) -> Self {
51        Self {
52            inner,
53            pending: Mutex::new(HashMap::new()),
54        }
55    }
56
57    /// Send a request, deduplicating identical in-flight requests.
58    pub async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
59        let key = dedup_key(&req.method, &req.params);
60
61        // Fast path: check if there is already an in-flight request.
62        // We extract the receiver (if any) while holding the lock, then
63        // drop the lock before awaiting.
64        let existing_rx = {
65            let pending = self.pending.lock().unwrap();
66            pending.get(&key).cloned()
67        };
68
69        if let Some(mut rx) = existing_rx {
70            return self.wait_for_result(&mut rx).await;
71        }
72
73        // Slow path: we are the first caller for this key.
74        let (tx, rx) = watch::channel(None);
75
76        // Double-check under the write lock: another task may have inserted
77        // between our read and this write.
78        let coalesce_rx = {
79            let mut pending = self.pending.lock().unwrap();
80            if let Some(existing) = pending.get(&key) {
81                Some(existing.clone())
82            } else {
83                pending.insert(key, rx);
84                None
85            }
86        };
87
88        if let Some(mut rx) = coalesce_rx {
89            return self.wait_for_result(&mut rx).await;
90        }
91
92        // Perform the actual request.
93        let result = self.inner.send(req).await;
94
95        // Broadcast the result to all waiters.
96        let dedup_result: DedupResult = match &result {
97            Ok(resp) => Ok(resp.clone()),
98            Err(e) => Err(e.to_string()),
99        };
100        // Ignore send errors (no receivers left).
101        let _ = tx.send(Some(dedup_result));
102
103        // Clean up the pending map.
104        {
105            let mut pending = self.pending.lock().unwrap();
106            pending.remove(&key);
107        }
108
109        tracing::debug!("dedup: completed request (key={key:#018x})");
110        result
111    }
112
113    /// Number of currently in-flight deduplicated requests.
114    pub fn in_flight_count(&self) -> usize {
115        let pending = self.pending.lock().unwrap();
116        pending.len()
117    }
118
119    // -- internal -----------------------------------------------------------
120
121    async fn wait_for_result(
122        &self,
123        rx: &mut watch::Receiver<Option<DedupResult>>,
124    ) -> Result<JsonRpcResponse, TransportError> {
125        // Wait until the value changes from `None` to `Some(...)`.
126        loop {
127            // Check the current value first.
128            {
129                let val = rx.borrow();
130                if let Some(ref result) = *val {
131                    tracing::debug!("dedup: coalesced request");
132                    return match result {
133                        Ok(resp) => Ok(resp.clone()),
134                        Err(msg) => Err(TransportError::Other(msg.clone())),
135                    };
136                }
137            }
138
139            // Wait for the next change.
140            if rx.changed().await.is_err() {
141                // Sender dropped without sending — should not happen.
142                return Err(TransportError::Other(
143                    "dedup: sender dropped without result".into(),
144                ));
145            }
146        }
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Hashing helper
152// ---------------------------------------------------------------------------
153
154fn dedup_key(method: &str, params: &[serde_json::Value]) -> u64 {
155    let mut hasher = DefaultHasher::new();
156    method.hash(&mut hasher);
157    let params_str = serde_json::to_string(params).unwrap_or_default();
158    params_str.hash(&mut hasher);
159    hasher.finish()
160}
161
162// ---------------------------------------------------------------------------
163// Tests
164// ---------------------------------------------------------------------------
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::request::{JsonRpcRequest, JsonRpcResponse, RpcId};
170    use async_trait::async_trait;
171    use std::sync::atomic::{AtomicU64, Ordering};
172
173    /// A mock transport that counts calls and has an optional delay.
174    struct SlowCountingTransport {
175        call_count: AtomicU64,
176        delay: std::time::Duration,
177    }
178
179    impl SlowCountingTransport {
180        fn new(delay: std::time::Duration) -> Self {
181            Self {
182                call_count: AtomicU64::new(0),
183                delay,
184            }
185        }
186
187        fn calls(&self) -> u64 {
188            self.call_count.load(Ordering::SeqCst)
189        }
190    }
191
192    #[async_trait]
193    impl RpcTransport for SlowCountingTransport {
194        async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
195            self.call_count.fetch_add(1, Ordering::SeqCst);
196            tokio::time::sleep(self.delay).await;
197            Ok(JsonRpcResponse {
198                jsonrpc: "2.0".into(),
199                id: RpcId::Number(1),
200                result: Some(serde_json::Value::String("0x1".into())),
201                error: None,
202            })
203        }
204
205        fn url(&self) -> &str {
206            "mock://slow"
207        }
208    }
209
210    fn make_req(method: &str) -> JsonRpcRequest {
211        JsonRpcRequest::new(1, method, vec![])
212    }
213
214    #[tokio::test]
215    async fn two_concurrent_identical_requests_trigger_one_send() {
216        let transport = Arc::new(SlowCountingTransport::new(
217            std::time::Duration::from_millis(100),
218        ));
219        let dedup = Arc::new(DedupTransport::new(transport.clone()));
220
221        let d1 = dedup.clone();
222        let d2 = dedup.clone();
223
224        let (r1, r2) = tokio::join!(
225            tokio::spawn(async move { d1.send(make_req("eth_chainId")).await }),
226            tokio::spawn(async move { d2.send(make_req("eth_chainId")).await }),
227        );
228
229        assert!(r1.unwrap().is_ok());
230        assert!(r2.unwrap().is_ok());
231        // Only one actual call to the inner transport.
232        assert_eq!(transport.calls(), 1);
233    }
234
235    #[tokio::test]
236    async fn different_requests_go_through_independently() {
237        let transport = Arc::new(SlowCountingTransport::new(
238            std::time::Duration::from_millis(50),
239        ));
240        let dedup = Arc::new(DedupTransport::new(transport.clone()));
241
242        let d1 = dedup.clone();
243        let d2 = dedup.clone();
244
245        let (r1, r2) = tokio::join!(
246            tokio::spawn(async move { d1.send(make_req("eth_chainId")).await }),
247            tokio::spawn(async move { d2.send(make_req("net_version")).await }),
248        );
249
250        assert!(r1.unwrap().is_ok());
251        assert!(r2.unwrap().is_ok());
252        // Two different methods = two transport calls.
253        assert_eq!(transport.calls(), 2);
254    }
255
256    #[tokio::test]
257    async fn cleanup_after_completion() {
258        let transport = Arc::new(SlowCountingTransport::new(
259            std::time::Duration::from_millis(10),
260        ));
261        let dedup = DedupTransport::new(transport.clone());
262
263        dedup.send(make_req("eth_chainId")).await.unwrap();
264        // After completion the pending map should be empty.
265        assert_eq!(dedup.in_flight_count(), 0);
266    }
267
268    #[tokio::test]
269    async fn sequential_same_requests_both_go_through() {
270        let transport = Arc::new(SlowCountingTransport::new(
271            std::time::Duration::from_millis(1),
272        ));
273        let dedup = DedupTransport::new(transport.clone());
274
275        // Sequential (not concurrent) same requests should each hit transport.
276        dedup.send(make_req("eth_chainId")).await.unwrap();
277        dedup.send(make_req("eth_chainId")).await.unwrap();
278        assert_eq!(transport.calls(), 2);
279    }
280
281    #[test]
282    fn dedup_key_deterministic() {
283        let k1 = dedup_key("eth_chainId", &[]);
284        let k2 = dedup_key("eth_chainId", &[]);
285        assert_eq!(k1, k2);
286    }
287
288    #[test]
289    fn dedup_key_differs_by_method() {
290        let k1 = dedup_key("eth_chainId", &[]);
291        let k2 = dedup_key("net_version", &[]);
292        assert_ne!(k1, k2);
293    }
294}