1use std::collections::hash_map::DefaultHasher;
14use std::collections::HashMap;
15use std::hash::{Hash, Hasher};
16use std::sync::{Arc, Mutex};
17
18use tokio::sync::watch;
19
20use crate::error::TransportError;
21use crate::request::{JsonRpcRequest, JsonRpcResponse};
22use crate::transport::RpcTransport;
23
24pub struct DedupTransport {
34 inner: Arc<dyn RpcTransport>,
35 pending: Mutex<HashMap<u64, watch::Receiver<Option<DedupResult>>>>,
40}
41
42type DedupResult = Result<JsonRpcResponse, String>;
47
48impl DedupTransport {
49 pub fn new(inner: Arc<dyn RpcTransport>) -> Self {
51 Self {
52 inner,
53 pending: Mutex::new(HashMap::new()),
54 }
55 }
56
57 pub async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
59 let key = dedup_key(&req.method, &req.params);
60
61 let existing_rx = {
65 let pending = self.pending.lock().unwrap();
66 pending.get(&key).cloned()
67 };
68
69 if let Some(mut rx) = existing_rx {
70 return self.wait_for_result(&mut rx).await;
71 }
72
73 let (tx, rx) = watch::channel(None);
75
76 let coalesce_rx = {
79 let mut pending = self.pending.lock().unwrap();
80 if let Some(existing) = pending.get(&key) {
81 Some(existing.clone())
82 } else {
83 pending.insert(key, rx);
84 None
85 }
86 };
87
88 if let Some(mut rx) = coalesce_rx {
89 return self.wait_for_result(&mut rx).await;
90 }
91
92 let result = self.inner.send(req).await;
94
95 let dedup_result: DedupResult = match &result {
97 Ok(resp) => Ok(resp.clone()),
98 Err(e) => Err(e.to_string()),
99 };
100 let _ = tx.send(Some(dedup_result));
102
103 {
105 let mut pending = self.pending.lock().unwrap();
106 pending.remove(&key);
107 }
108
109 tracing::debug!("dedup: completed request (key={key:#018x})");
110 result
111 }
112
113 pub fn in_flight_count(&self) -> usize {
115 let pending = self.pending.lock().unwrap();
116 pending.len()
117 }
118
119 async fn wait_for_result(
122 &self,
123 rx: &mut watch::Receiver<Option<DedupResult>>,
124 ) -> Result<JsonRpcResponse, TransportError> {
125 loop {
127 {
129 let val = rx.borrow();
130 if let Some(ref result) = *val {
131 tracing::debug!("dedup: coalesced request");
132 return match result {
133 Ok(resp) => Ok(resp.clone()),
134 Err(msg) => Err(TransportError::Other(msg.clone())),
135 };
136 }
137 }
138
139 if rx.changed().await.is_err() {
141 return Err(TransportError::Other(
143 "dedup: sender dropped without result".into(),
144 ));
145 }
146 }
147 }
148}
149
150fn dedup_key(method: &str, params: &[serde_json::Value]) -> u64 {
155 let mut hasher = DefaultHasher::new();
156 method.hash(&mut hasher);
157 let params_str = serde_json::to_string(params).unwrap_or_default();
158 params_str.hash(&mut hasher);
159 hasher.finish()
160}
161
162#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::request::{JsonRpcRequest, JsonRpcResponse, RpcId};
170 use async_trait::async_trait;
171 use std::sync::atomic::{AtomicU64, Ordering};
172
173 struct SlowCountingTransport {
175 call_count: AtomicU64,
176 delay: std::time::Duration,
177 }
178
179 impl SlowCountingTransport {
180 fn new(delay: std::time::Duration) -> Self {
181 Self {
182 call_count: AtomicU64::new(0),
183 delay,
184 }
185 }
186
187 fn calls(&self) -> u64 {
188 self.call_count.load(Ordering::SeqCst)
189 }
190 }
191
192 #[async_trait]
193 impl RpcTransport for SlowCountingTransport {
194 async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
195 self.call_count.fetch_add(1, Ordering::SeqCst);
196 tokio::time::sleep(self.delay).await;
197 Ok(JsonRpcResponse {
198 jsonrpc: "2.0".into(),
199 id: RpcId::Number(1),
200 result: Some(serde_json::Value::String("0x1".into())),
201 error: None,
202 })
203 }
204
205 fn url(&self) -> &str {
206 "mock://slow"
207 }
208 }
209
210 fn make_req(method: &str) -> JsonRpcRequest {
211 JsonRpcRequest::new(1, method, vec![])
212 }
213
214 #[tokio::test]
215 async fn two_concurrent_identical_requests_trigger_one_send() {
216 let transport = Arc::new(SlowCountingTransport::new(
217 std::time::Duration::from_millis(100),
218 ));
219 let dedup = Arc::new(DedupTransport::new(transport.clone()));
220
221 let d1 = dedup.clone();
222 let d2 = dedup.clone();
223
224 let (r1, r2) = tokio::join!(
225 tokio::spawn(async move { d1.send(make_req("eth_chainId")).await }),
226 tokio::spawn(async move { d2.send(make_req("eth_chainId")).await }),
227 );
228
229 assert!(r1.unwrap().is_ok());
230 assert!(r2.unwrap().is_ok());
231 assert_eq!(transport.calls(), 1);
233 }
234
235 #[tokio::test]
236 async fn different_requests_go_through_independently() {
237 let transport = Arc::new(SlowCountingTransport::new(
238 std::time::Duration::from_millis(50),
239 ));
240 let dedup = Arc::new(DedupTransport::new(transport.clone()));
241
242 let d1 = dedup.clone();
243 let d2 = dedup.clone();
244
245 let (r1, r2) = tokio::join!(
246 tokio::spawn(async move { d1.send(make_req("eth_chainId")).await }),
247 tokio::spawn(async move { d2.send(make_req("net_version")).await }),
248 );
249
250 assert!(r1.unwrap().is_ok());
251 assert!(r2.unwrap().is_ok());
252 assert_eq!(transport.calls(), 2);
254 }
255
256 #[tokio::test]
257 async fn cleanup_after_completion() {
258 let transport = Arc::new(SlowCountingTransport::new(
259 std::time::Duration::from_millis(10),
260 ));
261 let dedup = DedupTransport::new(transport.clone());
262
263 dedup.send(make_req("eth_chainId")).await.unwrap();
264 assert_eq!(dedup.in_flight_count(), 0);
266 }
267
268 #[tokio::test]
269 async fn sequential_same_requests_both_go_through() {
270 let transport = Arc::new(SlowCountingTransport::new(
271 std::time::Duration::from_millis(1),
272 ));
273 let dedup = DedupTransport::new(transport.clone());
274
275 dedup.send(make_req("eth_chainId")).await.unwrap();
277 dedup.send(make_req("eth_chainId")).await.unwrap();
278 assert_eq!(transport.calls(), 2);
279 }
280
281 #[test]
282 fn dedup_key_deterministic() {
283 let k1 = dedup_key("eth_chainId", &[]);
284 let k2 = dedup_key("eth_chainId", &[]);
285 assert_eq!(k1, k2);
286 }
287
288 #[test]
289 fn dedup_key_differs_by_method() {
290 let k1 = dedup_key("eth_chainId", &[]);
291 let k2 = dedup_key("net_version", &[]);
292 assert_ne!(k1, k2);
293 }
294}