use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use tokio::sync::watch;
use crate::error::TransportError;
use crate::request::{JsonRpcRequest, JsonRpcResponse};
use crate::transport::RpcTransport;
pub struct DedupTransport {
inner: Arc<dyn RpcTransport>,
pending: Mutex<HashMap<u64, watch::Receiver<Option<DedupResult>>>>,
}
type DedupResult = Result<JsonRpcResponse, String>;
impl DedupTransport {
pub fn new(inner: Arc<dyn RpcTransport>) -> Self {
Self {
inner,
pending: Mutex::new(HashMap::new()),
}
}
pub async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
let key = dedup_key(&req.method, &req.params);
let existing_rx = {
let pending = self.pending.lock().unwrap();
pending.get(&key).cloned()
};
if let Some(mut rx) = existing_rx {
return self.wait_for_result(&mut rx).await;
}
let (tx, rx) = watch::channel(None);
let coalesce_rx = {
let mut pending = self.pending.lock().unwrap();
if let Some(existing) = pending.get(&key) {
Some(existing.clone())
} else {
pending.insert(key, rx);
None
}
};
if let Some(mut rx) = coalesce_rx {
return self.wait_for_result(&mut rx).await;
}
let result = self.inner.send(req).await;
let dedup_result: DedupResult = match &result {
Ok(resp) => Ok(resp.clone()),
Err(e) => Err(e.to_string()),
};
let _ = tx.send(Some(dedup_result));
{
let mut pending = self.pending.lock().unwrap();
pending.remove(&key);
}
tracing::debug!("dedup: completed request (key={key:#018x})");
result
}
pub fn in_flight_count(&self) -> usize {
let pending = self.pending.lock().unwrap();
pending.len()
}
async fn wait_for_result(
&self,
rx: &mut watch::Receiver<Option<DedupResult>>,
) -> Result<JsonRpcResponse, TransportError> {
loop {
{
let val = rx.borrow();
if let Some(ref result) = *val {
tracing::debug!("dedup: coalesced request");
return match result {
Ok(resp) => Ok(resp.clone()),
Err(msg) => Err(TransportError::Other(msg.clone())),
};
}
}
if rx.changed().await.is_err() {
return Err(TransportError::Other(
"dedup: sender dropped without result".into(),
));
}
}
}
}
fn dedup_key(method: &str, params: &[serde_json::Value]) -> u64 {
let mut hasher = DefaultHasher::new();
method.hash(&mut hasher);
let params_str = serde_json::to_string(params).unwrap_or_default();
params_str.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::request::{JsonRpcRequest, JsonRpcResponse, RpcId};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, Ordering};
struct SlowCountingTransport {
call_count: AtomicU64,
delay: std::time::Duration,
}
impl SlowCountingTransport {
fn new(delay: std::time::Duration) -> Self {
Self {
call_count: AtomicU64::new(0),
delay,
}
}
fn calls(&self) -> u64 {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl RpcTransport for SlowCountingTransport {
async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(self.delay).await;
Ok(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: RpcId::Number(1),
result: Some(serde_json::Value::String("0x1".into())),
error: None,
})
}
fn url(&self) -> &str {
"mock://slow"
}
}
fn make_req(method: &str) -> JsonRpcRequest {
JsonRpcRequest::new(1, method, vec![])
}
#[tokio::test]
async fn two_concurrent_identical_requests_trigger_one_send() {
let transport = Arc::new(SlowCountingTransport::new(
std::time::Duration::from_millis(100),
));
let dedup = Arc::new(DedupTransport::new(transport.clone()));
let d1 = dedup.clone();
let d2 = dedup.clone();
let (r1, r2) = tokio::join!(
tokio::spawn(async move { d1.send(make_req("eth_chainId")).await }),
tokio::spawn(async move { d2.send(make_req("eth_chainId")).await }),
);
assert!(r1.unwrap().is_ok());
assert!(r2.unwrap().is_ok());
assert_eq!(transport.calls(), 1);
}
#[tokio::test]
async fn different_requests_go_through_independently() {
let transport = Arc::new(SlowCountingTransport::new(
std::time::Duration::from_millis(50),
));
let dedup = Arc::new(DedupTransport::new(transport.clone()));
let d1 = dedup.clone();
let d2 = dedup.clone();
let (r1, r2) = tokio::join!(
tokio::spawn(async move { d1.send(make_req("eth_chainId")).await }),
tokio::spawn(async move { d2.send(make_req("net_version")).await }),
);
assert!(r1.unwrap().is_ok());
assert!(r2.unwrap().is_ok());
assert_eq!(transport.calls(), 2);
}
#[tokio::test]
async fn cleanup_after_completion() {
let transport = Arc::new(SlowCountingTransport::new(
std::time::Duration::from_millis(10),
));
let dedup = DedupTransport::new(transport.clone());
dedup.send(make_req("eth_chainId")).await.unwrap();
assert_eq!(dedup.in_flight_count(), 0);
}
#[tokio::test]
async fn sequential_same_requests_both_go_through() {
let transport = Arc::new(SlowCountingTransport::new(
std::time::Duration::from_millis(1),
));
let dedup = DedupTransport::new(transport.clone());
dedup.send(make_req("eth_chainId")).await.unwrap();
dedup.send(make_req("eth_chainId")).await.unwrap();
assert_eq!(transport.calls(), 2);
}
#[test]
fn dedup_key_deterministic() {
let k1 = dedup_key("eth_chainId", &[]);
let k2 = dedup_key("eth_chainId", &[]);
assert_eq!(k1, k2);
}
#[test]
fn dedup_key_differs_by_method() {
let k1 = dedup_key("eth_chainId", &[]);
let k2 = dedup_key("net_version", &[]);
assert_ne!(k1, k2);
}
}