1use futures::channel::oneshot::{channel, Sender};
5use web3::{
6 error::{Error as Web3Error, TransportError},
7 helpers::{self},
8 types::{BlockId, BlockNumber, Bytes, CallRequest},
9 BatchTransport as Web3BatchTransport,
10};
11
12pub struct CallBatch<T: Web3BatchTransport> {
14 inner: T,
15 requests: Vec<(Request, CompletionHandler)>,
16}
17
18type Request = (CallRequest, Option<BlockId>);
19type CompletionHandler = Sender<Result<Bytes, Web3Error>>;
20
21impl<T: Web3BatchTransport> CallBatch<T> {
22 pub fn new(inner: T) -> Self {
24 Self {
25 inner,
26 requests: Default::default(),
27 }
28 }
29
30 pub fn push(
37 &mut self,
38 call: CallRequest,
39 block: Option<BlockId>,
40 ) -> impl std::future::Future<Output = Result<Bytes, Web3Error>> {
41 let (tx, rx) = channel();
42 self.requests.push(((call, block), tx));
43 async move {
44 rx.await.unwrap_or_else(|_| {
45 Err(Web3Error::Transport(TransportError::Message(
46 "Batch has been dropped without executing".to_owned(),
47 )))
48 })
49 }
50 }
51
52 pub async fn execute_all(self, batch_size: usize) {
55 let Self { inner, requests } = self;
56 let mut iterator = requests.into_iter().peekable();
57 while iterator.peek().is_some() {
58 let (requests, senders): (Vec<_>, Vec<_>) = iterator.by_ref().take(batch_size).unzip();
59
60 let batch_result = inner
62 .send_batch(requests.iter().map(|(request, block)| {
63 let req = helpers::serialize(request);
64 let block =
65 helpers::serialize(&block.unwrap_or_else(|| BlockNumber::Latest.into()));
66 let (id, request) = inner.prepare("eth_call", vec![req, block]);
67 (id, request)
68 }))
69 .await;
70
71 for (i, sender) in senders.into_iter().enumerate() {
73 let _ = match &batch_result {
74 Ok(results) => sender.send(
75 results
76 .get(i)
77 .unwrap_or(&Err(Web3Error::Decoder(
78 "Batch result did not contain enough responses".to_owned(),
79 )))
80 .clone()
81 .and_then(helpers::decode),
82 ),
83 Err(err) => sender.send(Err(Web3Error::Transport(TransportError::Message(
84 format!("Batch failed with: {}", err),
85 )))),
86 };
87 }
88 }
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use futures::future::join_all;
95 use serde_json::json;
96
97 use super::*;
98 use crate::test::prelude::FutureTestExt;
99 use crate::test::transport::TestTransport;
100
101 #[test]
102 fn batches_calls() {
103 let mut transport = TestTransport::new();
104 transport.add_response(json!([json!("0x01"), json!("0x02")]));
105
106 let mut batch = CallBatch::new(transport);
107
108 let results = vec![
109 batch.push(CallRequest::default(), None),
110 batch.push(CallRequest::default(), None),
111 ];
112
113 batch.execute_all(usize::MAX).immediate();
114
115 let results = join_all(results).immediate();
116 assert_eq!(results[0].clone().unwrap().0, vec![1u8]);
117 assert_eq!(results[1].clone().unwrap().0, vec![2u8]);
118 }
119
120 #[test]
121 fn resolves_calls_to_error_if_dropped() {
122 let future = {
123 let transport = TestTransport::new();
124 let mut batch = CallBatch::new(transport);
125 batch.push(CallRequest::default(), None)
126 };
127
128 assert!(matches!(
129 future.immediate().unwrap_err(),
130 Web3Error::Transport(_)
131 ));
132 }
133
134 #[test]
135 fn fails_all_calls_if_batch_fails() {
136 let transport = TestTransport::new();
137 let mut batch = CallBatch::new(transport);
138 let call = batch.push(CallRequest::default(), None);
139
140 batch.execute_all(usize::MAX).immediate();
141 match call.immediate().unwrap_err() {
142 Web3Error::Transport(TransportError::Message(reason)) => {
143 assert!(reason.starts_with("Batch failed with:"))
144 }
145 _ => panic!("Wrong Error type"),
146 };
147 }
148
149 #[test]
150 fn splits_batch_into_multiple_calls() {
151 let mut transport = TestTransport::new();
152 transport.add_response(json!([json!("0x01"), json!("0x02")]));
153 transport.add_response(json!([json!("0x03")]));
154
155 let mut batch = CallBatch::new(transport);
156
157 let results = vec![
158 batch.push(CallRequest::default(), None),
159 batch.push(CallRequest::default(), None),
160 batch.push(CallRequest::default(), None),
161 ];
162
163 batch.execute_all(2).immediate();
164
165 let results = join_all(results).immediate();
166 assert_eq!(results[0].clone().unwrap().0, vec![1u8]);
167 assert_eq!(results[1].clone().unwrap().0, vec![2u8]);
168 assert_eq!(results[2].clone().unwrap().0, vec![3u8]);
169 }
170}