Skip to main content

alloy_rpc_client/
poller.rs

1use crate::WeakClient;
2use alloy_json_rpc::{RpcRecv, RpcSend};
3use alloy_transport::utils::Spawnable;
4use futures::{ready, stream::FusedStream, Future, FutureExt, Stream, StreamExt};
5use serde::Serialize;
6use serde_json::value::RawValue;
7use std::{
8    borrow::Cow,
9    collections::HashSet,
10    marker::PhantomData,
11    ops::{Deref, DerefMut},
12    pin::Pin,
13    task::{Context, Poll},
14    time::Duration,
15};
16use tokio::sync::broadcast;
17use tokio_stream::wrappers::BroadcastStream;
18use tracing::Span;
19
20#[cfg(all(target_family = "wasm", target_os = "unknown"))]
21use wasmtimer::tokio::{sleep, Sleep};
22
23#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
24use tokio::time::{sleep, Sleep};
25
26/// A poller task builder.
27///
28/// This builder is used to create a poller task that repeatedly polls a method on a client and
29/// sends the responses to a channel. By default, it uses the client's configured poll interval, a
30/// channel size of 16, and no limit on the number of successful polls. This is all configurable.
31///
32/// The builder is consumed using the [`spawn`](Self::spawn) method, which returns a channel to
33/// receive the responses. The task will continue to poll until either the client or the channel is
34/// dropped.
35///
36/// The channel can be converted into a stream using the [`into_stream`](PollChannel::into_stream)
37/// method.
38///
39/// Alternatively, [`into_stream`](Self::into_stream) on the builder can be used to directly return
40/// a stream of responses on the current thread, instead of spawning a task.
41///
42/// # Examples
43///
44/// Poll `eth_blockNumber` every 5 seconds:
45///
46/// ```no_run
47/// # async fn example(client: alloy_rpc_client::RpcClient) -> Result<(), Box<dyn std::error::Error>> {
48/// use alloy_primitives::U64;
49/// use alloy_rpc_client::PollerBuilder;
50/// use futures_util::StreamExt;
51///
52/// let poller: PollerBuilder<(), U64> = client
53///     .prepare_static_poller("eth_blockNumber", ())
54///     .with_poll_interval(std::time::Duration::from_secs(5));
55/// let mut stream = poller.into_stream();
56/// while let Some(block_number) = stream.next().await {
57///    println!("polled block number: {block_number}");
58/// }
59/// # Ok(())
60/// # }
61/// ```
62#[derive(Debug)]
63#[must_use = "this builder does nothing unless you call `spawn` or `into_stream`"]
64pub struct PollerBuilder<Params, Resp> {
65    /// The client to poll with.
66    client: WeakClient,
67
68    /// Request Method
69    method: Cow<'static, str>,
70    params: Params,
71
72    // config options
73    channel_size: usize,
74    poll_interval: Duration,
75    limit: usize,
76    terminal_error_codes: HashSet<i64>,
77
78    _pd: PhantomData<fn() -> Resp>,
79}
80
81impl<Params, Resp> PollerBuilder<Params, Resp>
82where
83    Params: RpcSend + 'static,
84    Resp: RpcRecv,
85{
86    /// Create a new poller task.
87    pub fn new(client: WeakClient, method: impl Into<Cow<'static, str>>, params: Params) -> Self {
88        let poll_interval =
89            client.upgrade().map_or_else(|| Duration::from_secs(7), |c| c.poll_interval());
90        Self {
91            client,
92            method: method.into(),
93            params,
94            channel_size: 16,
95            poll_interval,
96            limit: usize::MAX,
97            terminal_error_codes: HashSet::default(),
98            _pd: PhantomData,
99        }
100    }
101
102    /// Returns the channel size for the poller task.
103    pub const fn channel_size(&self) -> usize {
104        self.channel_size
105    }
106
107    /// Sets the channel size for the poller task.
108    pub const fn set_channel_size(&mut self, channel_size: usize) {
109        self.channel_size = channel_size;
110    }
111
112    /// Sets the channel size for the poller task.
113    pub const fn with_channel_size(mut self, channel_size: usize) -> Self {
114        self.set_channel_size(channel_size);
115        self
116    }
117
118    /// Returns the limit on the number of successful polls.
119    pub const fn limit(&self) -> usize {
120        self.limit
121    }
122
123    /// Sets a limit on the number of successful polls.
124    pub fn set_limit(&mut self, limit: Option<usize>) {
125        self.limit = limit.unwrap_or(usize::MAX);
126    }
127
128    /// Sets a limit on the number of successful polls.
129    pub fn with_limit(mut self, limit: Option<usize>) -> Self {
130        self.set_limit(limit);
131        self
132    }
133
134    /// Returns the error codes this poller terminates on.
135    pub fn terminal_error_codes(&self) -> impl IntoIterator<Item = &i64> {
136        self.terminal_error_codes.iter()
137    }
138
139    /// Sets the error codes this poller will terminate on.
140    pub fn set_terminal_error_codes<I>(&mut self, error_codes: I)
141    where
142        I: IntoIterator<Item = i64>,
143    {
144        self.terminal_error_codes = HashSet::from_iter(error_codes);
145    }
146
147    /// Sets the error codes this poller will terminate on.
148    pub fn with_terminal_error_codes<I>(mut self, error_codes: I) -> Self
149    where
150        I: IntoIterator<Item = i64>,
151    {
152        self.set_terminal_error_codes(error_codes);
153        self
154    }
155
156    /// Returns the duration between polls.
157    pub const fn poll_interval(&self) -> Duration {
158        self.poll_interval
159    }
160
161    /// Sets the duration between polls.
162    pub const fn set_poll_interval(&mut self, poll_interval: Duration) {
163        self.poll_interval = poll_interval;
164    }
165
166    /// Sets the duration between polls.
167    pub const fn with_poll_interval(mut self, poll_interval: Duration) -> Self {
168        self.set_poll_interval(poll_interval);
169        self
170    }
171
172    /// Starts the poller in a new task, returning a channel to receive the responses on.
173    pub fn spawn(self) -> PollChannel<Resp>
174    where
175        Resp: Clone,
176    {
177        let (tx, rx) = broadcast::channel(self.channel_size);
178        self.into_future(tx).spawn_task();
179        rx.into()
180    }
181
182    async fn into_future(self, tx: broadcast::Sender<Resp>)
183    where
184        Resp: Clone,
185    {
186        let mut stream = self.into_stream();
187        while let Some(resp) = stream.next().await {
188            if tx.send(resp).is_err() {
189                debug!("channel closed");
190                break;
191            }
192        }
193    }
194
195    /// Starts the poller and returns the stream of responses.
196    ///
197    /// Note that this does not spawn the poller on a separate task, thus all responses will be
198    /// polled on the current thread once this stream is polled.
199    pub fn into_stream(self) -> PollerStream<Resp> {
200        PollerStream::new(self)
201    }
202
203    /// Returns the [`WeakClient`] associated with the poller.
204    pub fn client(&self) -> WeakClient {
205        self.client.clone()
206    }
207}
208
209/// State for the polling stream.
210enum PollState<Resp> {
211    /// Poller is paused
212    Paused,
213    /// Waiting to start the next poll.
214    Waiting,
215    /// Currently polling for a response.
216    Polling(
217        alloy_transport::Pbf<
218            'static,
219            Resp,
220            alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
221        >,
222    ),
223    /// Sleeping between polls.
224    Sleeping(Pin<Box<Sleep>>),
225
226    /// Polling has finished due to an error.
227    Finished,
228}
229
230/// A stream of responses from polling an RPC method.
231///
232/// This stream polls the given RPC method at the specified interval and yields the responses.
233///
234/// # Examples
235///
236/// ```no_run
237/// # async fn example(client: alloy_rpc_client::RpcClient) -> Result<(), Box<dyn std::error::Error>> {
238/// use alloy_primitives::U64;
239/// use futures_util::StreamExt;
240///
241/// // Create a poller that fetches block numbers
242/// let poller = client
243///     .prepare_static_poller("eth_blockNumber", ())
244///     .with_poll_interval(std::time::Duration::from_secs(1));
245///
246/// // Convert the block number to a more useful format
247/// let mut stream = poller.into_stream().map(|block_num: U64| block_num.to::<u64>());
248///
249/// while let Some(block_number) = stream.next().await {
250///     println!("Current block: {}", block_number);
251/// }
252/// # Ok(())
253/// # }
254/// ```
255pub struct PollerStream<Resp, Output = Resp, Map = fn(Resp) -> Output> {
256    client: WeakClient,
257    method: Cow<'static, str>,
258    params: Box<RawValue>,
259    poll_interval: Duration,
260    limit: usize,
261    terminal_error_codes: HashSet<i64>,
262    poll_count: usize,
263    state: PollState<Resp>,
264    span: Span,
265    map: Map,
266    _pd: PhantomData<fn() -> Output>,
267}
268
269impl<Resp, Output, Map> std::fmt::Debug for PollerStream<Resp, Output, Map> {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        f.debug_struct("PollerStream")
272            .field("method", &self.method)
273            .field("poll_interval", &self.poll_interval)
274            .field("limit", &self.limit)
275            .field("poll_count", &self.poll_count)
276            .finish_non_exhaustive()
277    }
278}
279
280impl<Resp> PollerStream<Resp> {
281    fn new<Params: Serialize>(builder: PollerBuilder<Params, Resp>) -> Self {
282        let span = debug_span!("poller", method = %builder.method);
283
284        // Serialize params once
285        let params = serde_json::value::to_raw_value(&builder.params).unwrap_or_else(|err| {
286            error!(%err, "failed to serialize params during initialization");
287            // Fall back to empty params; subsequent polls may fail at the server according to
288            // the configured poll interval.
289            Box::<RawValue>::default()
290        });
291
292        Self {
293            client: builder.client,
294            method: builder.method,
295            params,
296            poll_interval: builder.poll_interval,
297            limit: builder.limit,
298            terminal_error_codes: builder.terminal_error_codes,
299            poll_count: 0,
300            state: PollState::Waiting,
301            span,
302            map: std::convert::identity,
303            _pd: PhantomData,
304        }
305    }
306
307    /// Get a reference to the [`WeakClient`] used by this poller.
308    pub fn client(&self) -> WeakClient {
309        self.client.clone()
310    }
311
312    /// Pauses the poller until it's unpaused.
313    ///
314    /// While paused the poller will not initiate new rpc requests
315    pub fn pause(&mut self) {
316        self.state = PollState::Paused;
317    }
318
319    /// Unpauses the poller.
320    ///
321    /// The poller will initiate new rpc requests once polled.
322    pub fn unpause(&mut self) {
323        if matches!(self.state, PollState::Paused) {
324            self.state = PollState::Waiting;
325        }
326    }
327}
328
329impl<Resp, Output, Map> PollerStream<Resp, Output, Map>
330where
331    Map: Fn(Resp) -> Output,
332{
333    /// Maps the responses using the provided function.
334    pub fn map<NewOutput, NewMap>(self, map: NewMap) -> PollerStream<Resp, NewOutput, NewMap>
335    where
336        NewMap: Fn(Resp) -> NewOutput,
337    {
338        PollerStream {
339            client: self.client,
340            method: self.method,
341            params: self.params,
342            poll_interval: self.poll_interval,
343            limit: self.limit,
344            terminal_error_codes: self.terminal_error_codes,
345            poll_count: self.poll_count,
346            state: self.state,
347            span: self.span,
348            map,
349            _pd: PhantomData,
350        }
351    }
352}
353
354impl<Resp, Output, Map> Stream for PollerStream<Resp, Output, Map>
355where
356    Resp: RpcRecv + 'static,
357    Map: Fn(Resp) -> Output + Unpin,
358{
359    type Item = Output;
360
361    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362        let this = self.get_mut();
363        let _guard = this.span.enter();
364
365        loop {
366            match &mut this.state {
367                PollState::Paused => return Poll::Pending,
368                PollState::Waiting => {
369                    // Check if we've reached the limit
370                    if this.poll_count >= this.limit {
371                        debug!("poll limit reached");
372                        this.state = PollState::Finished;
373                        continue;
374                    }
375
376                    // Check if client is still alive
377                    let Some(client) = this.client.upgrade() else {
378                        debug!("client dropped");
379                        this.state = PollState::Finished;
380                        continue;
381                    };
382
383                    // Start polling
384                    trace!("polling");
385                    let method = this.method.clone();
386                    let params = this.params.clone();
387                    let fut = Box::pin(async move { client.request(method, params).await });
388                    this.state = PollState::Polling(fut);
389                }
390                PollState::Polling(fut) => {
391                    match ready!(fut.poll_unpin(cx)) {
392                        Ok(resp) => {
393                            this.poll_count += 1;
394                            // Start sleeping before next poll
395                            trace!(duration=?this.poll_interval, "sleeping");
396                            let sleep = Box::pin(sleep(this.poll_interval));
397                            this.state = PollState::Sleeping(sleep);
398                            return Poll::Ready(Some((this.map)(resp)));
399                        }
400                        Err(err) => {
401                            error!(%err, "failed to poll");
402
403                            if let Some(resp) = err.as_error_resp() {
404                                // Check for terminal error codes if they are set
405                                if this.terminal_error_codes.contains(&resp.code) {
406                                    warn!("server returned terminal error code, stopping poller");
407                                    this.state = PollState::Finished;
408                                    continue;
409                                }
410
411                                // If no terminal error codes are set, check the message to see if
412                                // we should stop the poller. Error codes are not consistent
413                                // across reth/geth/nethermind, so we cannot check the error code.
414                                if resp.message.contains("filter not found")
415                                    && this.terminal_error_codes.is_empty()
416                                {
417                                    warn!("server has dropped the filter, stopping poller");
418                                    this.state = PollState::Finished;
419                                    continue;
420                                }
421                            }
422
423                            // Start sleeping before retry
424                            trace!(duration=?this.poll_interval, "sleeping after error");
425
426                            let sleep = Box::pin(sleep(this.poll_interval));
427                            this.state = PollState::Sleeping(sleep);
428                        }
429                    }
430                }
431                PollState::Sleeping(sleep) => {
432                    ready!(sleep.as_mut().poll(cx));
433                    this.state = PollState::Waiting;
434                }
435                PollState::Finished => {
436                    return Poll::Ready(None);
437                }
438            }
439        }
440    }
441}
442
443impl<Resp, Output, Map> FusedStream for PollerStream<Resp, Output, Map>
444where
445    Resp: RpcRecv + 'static,
446    Map: Fn(Resp) -> Output + Unpin,
447{
448    fn is_terminated(&self) -> bool {
449        matches!(self.state, PollState::Finished)
450    }
451}
452
453/// A channel yielding responses from a poller task.
454///
455/// This stream is backed by a coroutine, and will continue to produce responses
456/// until the poller task is dropped. The poller task is dropped when all
457/// [`RpcClient`] instances are dropped, or when all listening `PollChannel` are
458/// dropped.
459///
460/// The poller task also ignores errors from the server and deserialization
461/// errors, and will continue to poll until the client is dropped.
462///
463/// [`RpcClient`]: crate::RpcClient
464#[derive(Debug)]
465pub struct PollChannel<Resp> {
466    rx: broadcast::Receiver<Resp>,
467}
468
469impl<Resp> From<broadcast::Receiver<Resp>> for PollChannel<Resp> {
470    fn from(rx: broadcast::Receiver<Resp>) -> Self {
471        Self { rx }
472    }
473}
474
475impl<Resp> Deref for PollChannel<Resp> {
476    type Target = broadcast::Receiver<Resp>;
477
478    fn deref(&self) -> &Self::Target {
479        &self.rx
480    }
481}
482
483impl<Resp> DerefMut for PollChannel<Resp> {
484    fn deref_mut(&mut self) -> &mut Self::Target {
485        &mut self.rx
486    }
487}
488
489impl<Resp> PollChannel<Resp>
490where
491    Resp: RpcRecv + Clone,
492{
493    /// Resubscribe to the poller task.
494    pub fn resubscribe(&self) -> Self {
495        Self { rx: self.rx.resubscribe() }
496    }
497
498    /// Converts the poll channel into a stream.
499    pub fn into_stream(self) -> impl Stream<Item = Resp> + Unpin {
500        self.into_stream_raw().filter_map(|r| futures::future::ready(r.ok()))
501    }
502
503    /// Converts the poll channel into a stream that also yields
504    /// [lag errors](tokio_stream::wrappers::errors::BroadcastStreamRecvError).
505    pub fn into_stream_raw(self) -> BroadcastStream<Resp> {
506        self.rx.into()
507    }
508}
509
510#[cfg(test)]
511#[allow(clippy::missing_const_for_fn)]
512fn _assert_unpin() {
513    fn _assert<T: Unpin>() {}
514    _assert::<PollChannel<()>>();
515}