1use std::sync::Arc;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use tokio::sync::{mpsc, oneshot};
24use tokio::time;
25
26use crate::error::TransportError;
27use crate::request::{JsonRpcRequest, JsonRpcResponse};
28use crate::transport::{HealthStatus, RpcTransport};
29
30type ResponseSender = oneshot::Sender<Result<JsonRpcResponse, TransportError>>;
31
32struct BatchItem {
33 req: JsonRpcRequest,
34 tx: ResponseSender,
35}
36
37pub struct BatchingTransport {
42 inner: Arc<dyn RpcTransport>,
43 tx: mpsc::UnboundedSender<BatchItem>,
44 #[allow(dead_code)]
45 window: Duration,
46}
47
48impl BatchingTransport {
49 pub fn new(inner: Arc<dyn RpcTransport>, window: Duration) -> Arc<Self> {
51 let (tx, rx) = mpsc::unbounded_channel::<BatchItem>();
52 let batcher = Arc::new(Self {
53 inner: inner.clone(),
54 tx,
55 window,
56 });
57
58 let flush_inner = inner;
60 let flush_window = window;
61 tokio::spawn(async move {
62 flush_loop(rx, flush_inner, flush_window).await;
63 });
64
65 batcher
66 }
67}
68
69async fn flush_loop(
70 mut rx: mpsc::UnboundedReceiver<BatchItem>,
71 transport: Arc<dyn RpcTransport>,
72 window: Duration,
73) {
74 loop {
75 let first = match rx.recv().await {
77 Some(item) => item,
78 None => break, };
80
81 let mut batch = vec![first];
82
83 let deadline = time::sleep(window);
85 tokio::pin!(deadline);
86
87 loop {
88 tokio::select! {
89 _ = &mut deadline => break,
90 item = rx.recv() => {
91 match item {
92 Some(i) => batch.push(i),
93 None => break,
94 }
95 }
96 }
97 }
98
99 if batch.len() == 1 {
100 let item = batch.remove(0);
102 let result = transport.send(item.req).await;
103 let _ = item.tx.send(result);
104 } else {
105 let reqs: Vec<JsonRpcRequest> = batch.iter().map(|b| b.req.clone()).collect();
107 match transport.send_batch(reqs).await {
108 Ok(responses) => {
109 for (item, resp) in batch.into_iter().zip(responses.into_iter()) {
111 let _ = item.tx.send(Ok(resp));
112 }
113 }
114 Err(e) => {
115 let msg = e.to_string();
117 for item in batch {
118 let _ = item.tx.send(Err(TransportError::Http(msg.clone())));
119 }
120 }
121 }
122 }
123 }
124}
125
126#[async_trait]
127impl RpcTransport for BatchingTransport {
128 async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
129 let (tx, rx) = oneshot::channel();
130 self.tx
131 .send(BatchItem { req, tx })
132 .map_err(|_| TransportError::Other("batcher channel closed".into()))?;
133 rx.await
134 .map_err(|_| TransportError::Other("batcher task dropped".into()))?
135 }
136
137 fn health(&self) -> HealthStatus {
138 self.inner.health()
139 }
140
141 fn url(&self) -> &str {
142 self.inner.url()
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::request::RpcId;
150 use std::sync::atomic::{AtomicU64, Ordering};
151
152 struct CountingTransport {
153 send_count: AtomicU64,
154 batch_count: AtomicU64,
155 }
156
157 impl CountingTransport {
158 fn new() -> Self {
159 Self {
160 send_count: AtomicU64::new(0),
161 batch_count: AtomicU64::new(0),
162 }
163 }
164 }
165
166 #[async_trait]
167 impl RpcTransport for CountingTransport {
168 async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
169 self.send_count.fetch_add(1, Ordering::SeqCst);
170 Ok(JsonRpcResponse {
171 jsonrpc: "2.0".into(),
172 id: RpcId::Number(1),
173 result: Some(serde_json::json!("0x1")),
174 error: None,
175 })
176 }
177
178 async fn send_batch(
179 &self,
180 reqs: Vec<JsonRpcRequest>,
181 ) -> Result<Vec<JsonRpcResponse>, TransportError> {
182 self.batch_count.fetch_add(1, Ordering::SeqCst);
183 Ok(reqs
184 .iter()
185 .map(|r| JsonRpcResponse {
186 jsonrpc: "2.0".into(),
187 id: r.id.clone(),
188 result: Some(serde_json::json!("0x1")),
189 error: None,
190 })
191 .collect())
192 }
193
194 fn url(&self) -> &str {
195 "mock://counting"
196 }
197 }
198
199 #[tokio::test]
200 async fn single_request_bypasses_batch() {
201 let inner = Arc::new(CountingTransport::new());
202 let batcher = BatchingTransport::new(inner.clone(), Duration::from_millis(50));
203
204 let req = JsonRpcRequest::new(1, "eth_blockNumber", vec![]);
205 let resp = batcher.send(req).await.unwrap();
206 assert!(resp.result.is_some());
207
208 tokio::time::sleep(Duration::from_millis(100)).await;
210 assert_eq!(inner.send_count.load(Ordering::SeqCst), 1);
211 assert_eq!(inner.batch_count.load(Ordering::SeqCst), 0);
212 }
213}