ethers_providers/stream/
tx_stream.rs

1use crate::{
2    FilterWatcher, JsonRpcClient, Middleware, Provider, ProviderError, PubsubClient,
3    SubscriptionStream,
4};
5use ethers_core::types::{Transaction, TxHash};
6use futures_core::{stream::Stream, Future};
7use futures_util::{
8    stream::{FuturesUnordered, StreamExt},
9    FutureExt,
10};
11use std::{
12    collections::VecDeque,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17/// Errors `TransactionStream` can throw
18#[derive(Debug, thiserror::Error)]
19pub enum GetTransactionError {
20    #[error("Failed to get transaction `{0}`: {1}")]
21    ProviderError(TxHash, ProviderError),
22    /// `get_transaction` resulted in a `None`
23    #[error("Transaction `{0}` not found")]
24    NotFound(TxHash),
25}
26
27impl From<GetTransactionError> for ProviderError {
28    fn from(err: GetTransactionError) -> Self {
29        match err {
30            GetTransactionError::ProviderError(_, err) => err,
31            err @ GetTransactionError::NotFound(_) => ProviderError::CustomError(err.to_string()),
32        }
33    }
34}
35
36#[cfg(not(target_arch = "wasm32"))]
37pub(crate) type TransactionFut<'a> = Pin<Box<dyn Future<Output = TransactionResult> + Send + 'a>>;
38
39#[cfg(target_arch = "wasm32")]
40pub(crate) type TransactionFut<'a> = Pin<Box<dyn Future<Output = TransactionResult> + 'a>>;
41
42pub(crate) type TransactionResult = Result<Transaction, GetTransactionError>;
43
44/// Drains a stream of transaction hashes and yields entire `Transaction`.
45#[must_use = "streams do nothing unless polled"]
46pub struct TransactionStream<'a, P, St> {
47    /// Currently running futures pending completion.
48    pub(crate) pending: FuturesUnordered<TransactionFut<'a>>,
49    /// Temporary buffered transaction that get started as soon as another future finishes.
50    pub(crate) buffered: VecDeque<TxHash>,
51    /// The provider that gets the transaction
52    pub(crate) provider: &'a Provider<P>,
53    /// A stream of transaction hashes.
54    pub(crate) stream: St,
55    /// Marks if the stream is done
56    stream_done: bool,
57    /// max allowed futures to execute at once.
58    pub(crate) max_concurrent: usize,
59}
60
61impl<'a, P: JsonRpcClient, St> TransactionStream<'a, P, St> {
62    /// Create a new `TransactionStream` instance
63    pub fn new(provider: &'a Provider<P>, stream: St, max_concurrent: usize) -> Self {
64        Self {
65            pending: Default::default(),
66            buffered: Default::default(),
67            provider,
68            stream,
69            stream_done: false,
70            max_concurrent,
71        }
72    }
73
74    /// Push a future into the set
75    pub(crate) fn push_tx(&mut self, tx: TxHash) {
76        let fut = self.provider.get_transaction(tx).then(move |res| match res {
77            Ok(Some(tx)) => futures_util::future::ok(tx),
78            Ok(None) => futures_util::future::err(GetTransactionError::NotFound(tx)),
79            Err(err) => futures_util::future::err(GetTransactionError::ProviderError(tx, err)),
80        });
81        self.pending.push(Box::pin(fut));
82    }
83}
84
85impl<'a, P, St> Stream for TransactionStream<'a, P, St>
86where
87    P: JsonRpcClient,
88    St: Stream<Item = TxHash> + Unpin + 'a,
89{
90    type Item = TransactionResult;
91
92    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
93        let this = self.get_mut();
94
95        // drain buffered transactions first
96        while this.pending.len() < this.max_concurrent {
97            if let Some(tx) = this.buffered.pop_front() {
98                this.push_tx(tx);
99            } else {
100                break
101            }
102        }
103
104        if !this.stream_done {
105            loop {
106                match Stream::poll_next(Pin::new(&mut this.stream), cx) {
107                    Poll::Ready(Some(tx)) => {
108                        if this.pending.len() < this.max_concurrent {
109                            this.push_tx(tx);
110                        } else {
111                            this.buffered.push_back(tx);
112                        }
113                    }
114                    Poll::Ready(None) => {
115                        this.stream_done = true;
116                        break
117                    }
118                    _ => break,
119                }
120            }
121        }
122
123        // poll running futures
124        if let tx @ Poll::Ready(Some(_)) = this.pending.poll_next_unpin(cx) {
125            return tx
126        }
127
128        if this.stream_done && this.pending.is_empty() {
129            // all done
130            return Poll::Ready(None)
131        }
132
133        Poll::Pending
134    }
135}
136
137impl<'a, P> FilterWatcher<'a, P, TxHash>
138where
139    P: JsonRpcClient,
140{
141    /// Returns a stream that yields the `Transaction`s for the transaction hashes this stream
142    /// yields.
143    ///
144    /// This internally calls `Provider::get_transaction` with every new transaction.
145    /// No more than n futures will be buffered at any point in time, and less than n may also be
146    /// buffered depending on the state of each future.
147    pub fn transactions_unordered(self, n: usize) -> TransactionStream<'a, P, Self> {
148        TransactionStream::new(self.provider, self, n)
149    }
150}
151
152impl<'a, P> SubscriptionStream<'a, P, TxHash>
153where
154    P: PubsubClient,
155{
156    /// Returns a stream that yields the `Transaction`s for the transaction hashes this stream
157    /// yields.
158    ///
159    /// This internally calls `Provider::get_transaction` with every new transaction.
160    /// No more than n futures will be buffered at any point in time, and less than n may also be
161    /// buffered depending on the state of each future.
162    pub fn transactions_unordered(self, n: usize) -> TransactionStream<'a, P, Self> {
163        TransactionStream::new(self.provider, self, n)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::{stream::tx_stream, Http};
171    use ethers_core::{types::TransactionRequest, utils::Anvil};
172    use std::collections::HashSet;
173
174    #[tokio::test]
175    #[cfg(feature = "ws")]
176    async fn can_stream_pending_transactions() {
177        use ethers_core::types::{Transaction, TransactionReceipt};
178        use futures_util::{FutureExt, StreamExt};
179        use std::time::Duration;
180
181        let num_txs = 5;
182        let geth = Anvil::new().block_time(2u64).spawn();
183        let provider = Provider::<Http>::try_from(geth.endpoint())
184            .unwrap()
185            .interval(Duration::from_millis(1000));
186        let ws = crate::Ws::connect(geth.ws_endpoint()).await.unwrap();
187        let ws_provider = Provider::new(ws);
188
189        let accounts = provider.get_accounts().await.unwrap();
190        let tx = TransactionRequest::new().from(accounts[0]).to(accounts[0]).value(1e18 as u64);
191
192        let mut sending = futures_util::future::join_all(
193            std::iter::repeat(tx.clone())
194                .take(num_txs)
195                .enumerate()
196                .map(|(nonce, tx)| tx.nonce(nonce))
197                .map(|tx| async {
198                    provider.send_transaction(tx, None).await.unwrap().await.unwrap().unwrap()
199                }),
200        )
201        .fuse();
202
203        let mut watch_tx_stream = provider
204            .watch_pending_transactions()
205            .await
206            .unwrap()
207            .transactions_unordered(num_txs)
208            .fuse();
209
210        let mut sub_tx_stream =
211            ws_provider.subscribe_pending_txs().await.unwrap().transactions_unordered(2).fuse();
212
213        let mut sent: Option<Vec<TransactionReceipt>> = None;
214        let mut watch_received: Vec<Transaction> = Vec::with_capacity(num_txs);
215        let mut sub_received: Vec<Transaction> = Vec::with_capacity(num_txs);
216
217        loop {
218            futures_util::select! {
219                txs = sending => {
220                    sent = Some(txs)
221                },
222                tx = watch_tx_stream.next() => watch_received.push(tx.unwrap().unwrap()),
223                tx = sub_tx_stream.next() => sub_received.push(tx.unwrap().unwrap()),
224            };
225            if watch_received.len() == num_txs && sub_received.len() == num_txs {
226                if let Some(ref sent) = sent {
227                    assert_eq!(sent.len(), watch_received.len());
228                    let sent_txs =
229                        sent.iter().map(|tx| tx.transaction_hash).collect::<HashSet<_>>();
230                    assert_eq!(sent_txs, watch_received.iter().map(|tx| tx.hash).collect());
231                    assert_eq!(sent_txs, sub_received.iter().map(|tx| tx.hash).collect());
232                    break
233                }
234            }
235        }
236    }
237
238    #[tokio::test]
239    async fn can_stream_transactions() {
240        let anvil = Anvil::new().block_time(2u64).spawn();
241        let provider =
242            Provider::<Http>::try_from(anvil.endpoint()).unwrap().with_sender(anvil.addresses()[0]);
243
244        let accounts = provider.get_accounts().await.unwrap();
245
246        let tx = TransactionRequest::new().from(accounts[0]).to(accounts[0]).value(1e18 as u64);
247        let txs = vec![tx.clone().nonce(0u64), tx.clone().nonce(1u64), tx.clone().nonce(2u64)];
248
249        let txs =
250            futures_util::future::join_all(txs.into_iter().map(|tx| async {
251                provider.send_transaction(tx, None).await.unwrap().await.unwrap()
252            }))
253            .await;
254
255        let stream = tx_stream::TransactionStream::new(
256            &provider,
257            futures_util::stream::iter(txs.iter().cloned().map(|tx| tx.unwrap().transaction_hash)),
258            10,
259        );
260        let res =
261            stream.collect::<Vec<_>>().await.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
262
263        assert_eq!(res.len(), txs.len());
264        assert_eq!(
265            res.into_iter().map(|tx| tx.hash).collect::<HashSet<_>>(),
266            txs.into_iter().map(|tx| tx.unwrap().transaction_hash).collect()
267        );
268    }
269}