1use crate::{
2 FilterWatcher, JsonRpcClient, Middleware, Provider, ProviderError, PubsubClient,
3 SubscriptionStream,
4};
5use ethers_core::types::{Transaction, TxHash};
6use futures_core::{stream::Stream, Future};
7use futures_util::{
8 stream::{FuturesUnordered, StreamExt},
9 FutureExt,
10};
11use std::{
12 collections::VecDeque,
13 pin::Pin,
14 task::{Context, Poll},
15};
16
17#[derive(Debug, thiserror::Error)]
19pub enum GetTransactionError {
20 #[error("Failed to get transaction `{0}`: {1}")]
21 ProviderError(TxHash, ProviderError),
22 #[error("Transaction `{0}` not found")]
24 NotFound(TxHash),
25}
26
27impl From<GetTransactionError> for ProviderError {
28 fn from(err: GetTransactionError) -> Self {
29 match err {
30 GetTransactionError::ProviderError(_, err) => err,
31 err @ GetTransactionError::NotFound(_) => ProviderError::CustomError(err.to_string()),
32 }
33 }
34}
35
36#[cfg(not(target_arch = "wasm32"))]
37pub(crate) type TransactionFut<'a> = Pin<Box<dyn Future<Output = TransactionResult> + Send + 'a>>;
38
39#[cfg(target_arch = "wasm32")]
40pub(crate) type TransactionFut<'a> = Pin<Box<dyn Future<Output = TransactionResult> + 'a>>;
41
42pub(crate) type TransactionResult = Result<Transaction, GetTransactionError>;
43
44#[must_use = "streams do nothing unless polled"]
46pub struct TransactionStream<'a, P, St> {
47 pub(crate) pending: FuturesUnordered<TransactionFut<'a>>,
49 pub(crate) buffered: VecDeque<TxHash>,
51 pub(crate) provider: &'a Provider<P>,
53 pub(crate) stream: St,
55 stream_done: bool,
57 pub(crate) max_concurrent: usize,
59}
60
61impl<'a, P: JsonRpcClient, St> TransactionStream<'a, P, St> {
62 pub fn new(provider: &'a Provider<P>, stream: St, max_concurrent: usize) -> Self {
64 Self {
65 pending: Default::default(),
66 buffered: Default::default(),
67 provider,
68 stream,
69 stream_done: false,
70 max_concurrent,
71 }
72 }
73
74 pub(crate) fn push_tx(&mut self, tx: TxHash) {
76 let fut = self.provider.get_transaction(tx).then(move |res| match res {
77 Ok(Some(tx)) => futures_util::future::ok(tx),
78 Ok(None) => futures_util::future::err(GetTransactionError::NotFound(tx)),
79 Err(err) => futures_util::future::err(GetTransactionError::ProviderError(tx, err)),
80 });
81 self.pending.push(Box::pin(fut));
82 }
83}
84
85impl<'a, P, St> Stream for TransactionStream<'a, P, St>
86where
87 P: JsonRpcClient,
88 St: Stream<Item = TxHash> + Unpin + 'a,
89{
90 type Item = TransactionResult;
91
92 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
93 let this = self.get_mut();
94
95 while this.pending.len() < this.max_concurrent {
97 if let Some(tx) = this.buffered.pop_front() {
98 this.push_tx(tx);
99 } else {
100 break
101 }
102 }
103
104 if !this.stream_done {
105 loop {
106 match Stream::poll_next(Pin::new(&mut this.stream), cx) {
107 Poll::Ready(Some(tx)) => {
108 if this.pending.len() < this.max_concurrent {
109 this.push_tx(tx);
110 } else {
111 this.buffered.push_back(tx);
112 }
113 }
114 Poll::Ready(None) => {
115 this.stream_done = true;
116 break
117 }
118 _ => break,
119 }
120 }
121 }
122
123 if let tx @ Poll::Ready(Some(_)) = this.pending.poll_next_unpin(cx) {
125 return tx
126 }
127
128 if this.stream_done && this.pending.is_empty() {
129 return Poll::Ready(None)
131 }
132
133 Poll::Pending
134 }
135}
136
137impl<'a, P> FilterWatcher<'a, P, TxHash>
138where
139 P: JsonRpcClient,
140{
141 pub fn transactions_unordered(self, n: usize) -> TransactionStream<'a, P, Self> {
148 TransactionStream::new(self.provider, self, n)
149 }
150}
151
152impl<'a, P> SubscriptionStream<'a, P, TxHash>
153where
154 P: PubsubClient,
155{
156 pub fn transactions_unordered(self, n: usize) -> TransactionStream<'a, P, Self> {
163 TransactionStream::new(self.provider, self, n)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::{stream::tx_stream, Http};
171 use ethers_core::{types::TransactionRequest, utils::Anvil};
172 use std::collections::HashSet;
173
174 #[tokio::test]
175 #[cfg(feature = "ws")]
176 async fn can_stream_pending_transactions() {
177 use ethers_core::types::{Transaction, TransactionReceipt};
178 use futures_util::{FutureExt, StreamExt};
179 use std::time::Duration;
180
181 let num_txs = 5;
182 let geth = Anvil::new().block_time(2u64).spawn();
183 let provider = Provider::<Http>::try_from(geth.endpoint())
184 .unwrap()
185 .interval(Duration::from_millis(1000));
186 let ws = crate::Ws::connect(geth.ws_endpoint()).await.unwrap();
187 let ws_provider = Provider::new(ws);
188
189 let accounts = provider.get_accounts().await.unwrap();
190 let tx = TransactionRequest::new().from(accounts[0]).to(accounts[0]).value(1e18 as u64);
191
192 let mut sending = futures_util::future::join_all(
193 std::iter::repeat(tx.clone())
194 .take(num_txs)
195 .enumerate()
196 .map(|(nonce, tx)| tx.nonce(nonce))
197 .map(|tx| async {
198 provider.send_transaction(tx, None).await.unwrap().await.unwrap().unwrap()
199 }),
200 )
201 .fuse();
202
203 let mut watch_tx_stream = provider
204 .watch_pending_transactions()
205 .await
206 .unwrap()
207 .transactions_unordered(num_txs)
208 .fuse();
209
210 let mut sub_tx_stream =
211 ws_provider.subscribe_pending_txs().await.unwrap().transactions_unordered(2).fuse();
212
213 let mut sent: Option<Vec<TransactionReceipt>> = None;
214 let mut watch_received: Vec<Transaction> = Vec::with_capacity(num_txs);
215 let mut sub_received: Vec<Transaction> = Vec::with_capacity(num_txs);
216
217 loop {
218 futures_util::select! {
219 txs = sending => {
220 sent = Some(txs)
221 },
222 tx = watch_tx_stream.next() => watch_received.push(tx.unwrap().unwrap()),
223 tx = sub_tx_stream.next() => sub_received.push(tx.unwrap().unwrap()),
224 };
225 if watch_received.len() == num_txs && sub_received.len() == num_txs {
226 if let Some(ref sent) = sent {
227 assert_eq!(sent.len(), watch_received.len());
228 let sent_txs =
229 sent.iter().map(|tx| tx.transaction_hash).collect::<HashSet<_>>();
230 assert_eq!(sent_txs, watch_received.iter().map(|tx| tx.hash).collect());
231 assert_eq!(sent_txs, sub_received.iter().map(|tx| tx.hash).collect());
232 break
233 }
234 }
235 }
236 }
237
238 #[tokio::test]
239 async fn can_stream_transactions() {
240 let anvil = Anvil::new().block_time(2u64).spawn();
241 let provider =
242 Provider::<Http>::try_from(anvil.endpoint()).unwrap().with_sender(anvil.addresses()[0]);
243
244 let accounts = provider.get_accounts().await.unwrap();
245
246 let tx = TransactionRequest::new().from(accounts[0]).to(accounts[0]).value(1e18 as u64);
247 let txs = vec![tx.clone().nonce(0u64), tx.clone().nonce(1u64), tx.clone().nonce(2u64)];
248
249 let txs =
250 futures_util::future::join_all(txs.into_iter().map(|tx| async {
251 provider.send_transaction(tx, None).await.unwrap().await.unwrap()
252 }))
253 .await;
254
255 let stream = tx_stream::TransactionStream::new(
256 &provider,
257 futures_util::stream::iter(txs.iter().cloned().map(|tx| tx.unwrap().transaction_hash)),
258 10,
259 );
260 let res =
261 stream.collect::<Vec<_>>().await.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
262
263 assert_eq!(res.len(), txs.len());
264 assert_eq!(
265 res.into_iter().map(|tx| tx.hash).collect::<HashSet<_>>(),
266 txs.into_iter().map(|tx| tx.unwrap().transaction_hash).collect()
267 );
268 }
269}