Skip to main content

chainrpc_core/
batch.rs

1//! Transport-agnostic auto-batching engine.
2//!
3//! Coalesces multiple `JsonRpcRequest`s arriving within a time window and
4//! flushes them as a single batch call. Each caller gets their response back
5//! via a `oneshot` channel.
6//!
7//! # Usage
8//! ```rust,no_run
9//! use chainrpc_core::batch::BatchingTransport;
10//! use chainrpc_core::transport::RpcTransport;
11//! use std::sync::Arc;
12//! use std::time::Duration;
13//!
14//! fn example(inner: Arc<dyn RpcTransport>) {
15//!     let batcher = BatchingTransport::new(inner, Duration::from_millis(5));
16//! }
17//! ```
18
19use std::sync::Arc;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use tokio::sync::{mpsc, oneshot};
24use tokio::time;
25
26use crate::error::TransportError;
27use crate::request::{JsonRpcRequest, JsonRpcResponse};
28use crate::transport::{HealthStatus, RpcTransport};
29
30type ResponseSender = oneshot::Sender<Result<JsonRpcResponse, TransportError>>;
31
32struct BatchItem {
33    req: JsonRpcRequest,
34    tx: ResponseSender,
35}
36
37/// Auto-batching transport wrapper.
38///
39/// Sends a flush task in the background that groups pending requests into
40/// a single batch call every `window` milliseconds.
41pub struct BatchingTransport {
42    inner: Arc<dyn RpcTransport>,
43    tx: mpsc::UnboundedSender<BatchItem>,
44    #[allow(dead_code)]
45    window: Duration,
46}
47
48impl BatchingTransport {
49    /// Create a new batching transport wrapping `inner`.
50    pub fn new(inner: Arc<dyn RpcTransport>, window: Duration) -> Arc<Self> {
51        let (tx, rx) = mpsc::unbounded_channel::<BatchItem>();
52        let batcher = Arc::new(Self {
53            inner: inner.clone(),
54            tx,
55            window,
56        });
57
58        // Spawn background flush task
59        let flush_inner = inner;
60        let flush_window = window;
61        tokio::spawn(async move {
62            flush_loop(rx, flush_inner, flush_window).await;
63        });
64
65        batcher
66    }
67}
68
69async fn flush_loop(
70    mut rx: mpsc::UnboundedReceiver<BatchItem>,
71    transport: Arc<dyn RpcTransport>,
72    window: Duration,
73) {
74    loop {
75        // Wait for the first item
76        let first = match rx.recv().await {
77            Some(item) => item,
78            None => break, // channel closed
79        };
80
81        let mut batch = vec![first];
82
83        // Collect all items that arrive within the window
84        let deadline = time::sleep(window);
85        tokio::pin!(deadline);
86
87        loop {
88            tokio::select! {
89                _ = &mut deadline => break,
90                item = rx.recv() => {
91                    match item {
92                        Some(i) => batch.push(i),
93                        None => break,
94                    }
95                }
96            }
97        }
98
99        if batch.len() == 1 {
100            // Single item — skip batch overhead
101            let item = batch.remove(0);
102            let result = transport.send(item.req).await;
103            let _ = item.tx.send(result);
104        } else {
105            // True batch
106            let reqs: Vec<JsonRpcRequest> = batch.iter().map(|b| b.req.clone()).collect();
107            match transport.send_batch(reqs).await {
108                Ok(responses) => {
109                    // Match responses to senders by position
110                    for (item, resp) in batch.into_iter().zip(responses.into_iter()) {
111                        let _ = item.tx.send(Ok(resp));
112                    }
113                }
114                Err(e) => {
115                    // Broadcast error to all callers
116                    let msg = e.to_string();
117                    for item in batch {
118                        let _ = item.tx.send(Err(TransportError::Http(msg.clone())));
119                    }
120                }
121            }
122        }
123    }
124}
125
126#[async_trait]
127impl RpcTransport for BatchingTransport {
128    async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
129        let (tx, rx) = oneshot::channel();
130        self.tx
131            .send(BatchItem { req, tx })
132            .map_err(|_| TransportError::Other("batcher channel closed".into()))?;
133        rx.await
134            .map_err(|_| TransportError::Other("batcher task dropped".into()))?
135    }
136
137    fn health(&self) -> HealthStatus {
138        self.inner.health()
139    }
140
141    fn url(&self) -> &str {
142        self.inner.url()
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::request::RpcId;
150    use std::sync::atomic::{AtomicU64, Ordering};
151
152    struct CountingTransport {
153        send_count: AtomicU64,
154        batch_count: AtomicU64,
155    }
156
157    impl CountingTransport {
158        fn new() -> Self {
159            Self {
160                send_count: AtomicU64::new(0),
161                batch_count: AtomicU64::new(0),
162            }
163        }
164    }
165
166    #[async_trait]
167    impl RpcTransport for CountingTransport {
168        async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
169            self.send_count.fetch_add(1, Ordering::SeqCst);
170            Ok(JsonRpcResponse {
171                jsonrpc: "2.0".into(),
172                id: RpcId::Number(1),
173                result: Some(serde_json::json!("0x1")),
174                error: None,
175            })
176        }
177
178        async fn send_batch(
179            &self,
180            reqs: Vec<JsonRpcRequest>,
181        ) -> Result<Vec<JsonRpcResponse>, TransportError> {
182            self.batch_count.fetch_add(1, Ordering::SeqCst);
183            Ok(reqs
184                .iter()
185                .map(|r| JsonRpcResponse {
186                    jsonrpc: "2.0".into(),
187                    id: r.id.clone(),
188                    result: Some(serde_json::json!("0x1")),
189                    error: None,
190                })
191                .collect())
192        }
193
194        fn url(&self) -> &str {
195            "mock://counting"
196        }
197    }
198
199    #[tokio::test]
200    async fn single_request_bypasses_batch() {
201        let inner = Arc::new(CountingTransport::new());
202        let batcher = BatchingTransport::new(inner.clone(), Duration::from_millis(50));
203
204        let req = JsonRpcRequest::new(1, "eth_blockNumber", vec![]);
205        let resp = batcher.send(req).await.unwrap();
206        assert!(resp.result.is_some());
207
208        // Wait for flush
209        tokio::time::sleep(Duration::from_millis(100)).await;
210        assert_eq!(inner.send_count.load(Ordering::SeqCst), 1);
211        assert_eq!(inner.batch_count.load(Ordering::SeqCst), 0);
212    }
213}