1use crate::WeakClient;
2use alloy_json_rpc::{RpcRecv, RpcSend};
3use alloy_transport::utils::Spawnable;
4use futures::{ready, stream::FusedStream, Future, FutureExt, Stream, StreamExt};
5use serde::Serialize;
6use serde_json::value::RawValue;
7use std::{
8 borrow::Cow,
9 collections::HashSet,
10 marker::PhantomData,
11 ops::{Deref, DerefMut},
12 pin::Pin,
13 task::{Context, Poll},
14 time::Duration,
15};
16use tokio::sync::broadcast;
17use tokio_stream::wrappers::BroadcastStream;
18use tracing::Span;
19
20#[cfg(all(target_family = "wasm", target_os = "unknown"))]
21use wasmtimer::tokio::{sleep, Sleep};
22
23#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
24use tokio::time::{sleep, Sleep};
25
26#[derive(Debug)]
63#[must_use = "this builder does nothing unless you call `spawn` or `into_stream`"]
64pub struct PollerBuilder<Params, Resp> {
65 client: WeakClient,
67
68 method: Cow<'static, str>,
70 params: Params,
71
72 channel_size: usize,
74 poll_interval: Duration,
75 limit: usize,
76 terminal_error_codes: HashSet<i64>,
77
78 _pd: PhantomData<fn() -> Resp>,
79}
80
81impl<Params, Resp> PollerBuilder<Params, Resp>
82where
83 Params: RpcSend + 'static,
84 Resp: RpcRecv,
85{
86 pub fn new(client: WeakClient, method: impl Into<Cow<'static, str>>, params: Params) -> Self {
88 let poll_interval =
89 client.upgrade().map_or_else(|| Duration::from_secs(7), |c| c.poll_interval());
90 Self {
91 client,
92 method: method.into(),
93 params,
94 channel_size: 16,
95 poll_interval,
96 limit: usize::MAX,
97 terminal_error_codes: HashSet::default(),
98 _pd: PhantomData,
99 }
100 }
101
102 pub const fn channel_size(&self) -> usize {
104 self.channel_size
105 }
106
107 pub const fn set_channel_size(&mut self, channel_size: usize) {
109 self.channel_size = channel_size;
110 }
111
112 pub const fn with_channel_size(mut self, channel_size: usize) -> Self {
114 self.set_channel_size(channel_size);
115 self
116 }
117
118 pub const fn limit(&self) -> usize {
120 self.limit
121 }
122
123 pub fn set_limit(&mut self, limit: Option<usize>) {
125 self.limit = limit.unwrap_or(usize::MAX);
126 }
127
128 pub fn with_limit(mut self, limit: Option<usize>) -> Self {
130 self.set_limit(limit);
131 self
132 }
133
134 pub fn terminal_error_codes(&self) -> impl IntoIterator<Item = &i64> {
136 self.terminal_error_codes.iter()
137 }
138
139 pub fn set_terminal_error_codes<I>(&mut self, error_codes: I)
141 where
142 I: IntoIterator<Item = i64>,
143 {
144 self.terminal_error_codes = HashSet::from_iter(error_codes);
145 }
146
147 pub fn with_terminal_error_codes<I>(mut self, error_codes: I) -> Self
149 where
150 I: IntoIterator<Item = i64>,
151 {
152 self.set_terminal_error_codes(error_codes);
153 self
154 }
155
156 pub const fn poll_interval(&self) -> Duration {
158 self.poll_interval
159 }
160
161 pub const fn set_poll_interval(&mut self, poll_interval: Duration) {
163 self.poll_interval = poll_interval;
164 }
165
166 pub const fn with_poll_interval(mut self, poll_interval: Duration) -> Self {
168 self.set_poll_interval(poll_interval);
169 self
170 }
171
172 pub fn spawn(self) -> PollChannel<Resp>
174 where
175 Resp: Clone,
176 {
177 let (tx, rx) = broadcast::channel(self.channel_size);
178 self.into_future(tx).spawn_task();
179 rx.into()
180 }
181
182 async fn into_future(self, tx: broadcast::Sender<Resp>)
183 where
184 Resp: Clone,
185 {
186 let mut stream = self.into_stream();
187 while let Some(resp) = stream.next().await {
188 if tx.send(resp).is_err() {
189 debug!("channel closed");
190 break;
191 }
192 }
193 }
194
195 pub fn into_stream(self) -> PollerStream<Resp> {
200 PollerStream::new(self)
201 }
202
203 pub fn client(&self) -> WeakClient {
205 self.client.clone()
206 }
207}
208
209enum PollState<Resp> {
211 Paused,
213 Waiting,
215 Polling(
217 alloy_transport::Pbf<
218 'static,
219 Resp,
220 alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
221 >,
222 ),
223 Sleeping(Pin<Box<Sleep>>),
225
226 Finished,
228}
229
230pub struct PollerStream<Resp, Output = Resp, Map = fn(Resp) -> Output> {
256 client: WeakClient,
257 method: Cow<'static, str>,
258 params: Box<RawValue>,
259 poll_interval: Duration,
260 limit: usize,
261 terminal_error_codes: HashSet<i64>,
262 poll_count: usize,
263 state: PollState<Resp>,
264 span: Span,
265 map: Map,
266 _pd: PhantomData<fn() -> Output>,
267}
268
269impl<Resp, Output, Map> std::fmt::Debug for PollerStream<Resp, Output, Map> {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 f.debug_struct("PollerStream")
272 .field("method", &self.method)
273 .field("poll_interval", &self.poll_interval)
274 .field("limit", &self.limit)
275 .field("poll_count", &self.poll_count)
276 .finish_non_exhaustive()
277 }
278}
279
280impl<Resp> PollerStream<Resp> {
281 fn new<Params: Serialize>(builder: PollerBuilder<Params, Resp>) -> Self {
282 let span = debug_span!("poller", method = %builder.method);
283
284 let params = serde_json::value::to_raw_value(&builder.params).unwrap_or_else(|err| {
286 error!(%err, "failed to serialize params during initialization");
287 Box::<RawValue>::default()
290 });
291
292 Self {
293 client: builder.client,
294 method: builder.method,
295 params,
296 poll_interval: builder.poll_interval,
297 limit: builder.limit,
298 terminal_error_codes: builder.terminal_error_codes,
299 poll_count: 0,
300 state: PollState::Waiting,
301 span,
302 map: std::convert::identity,
303 _pd: PhantomData,
304 }
305 }
306
307 pub fn client(&self) -> WeakClient {
309 self.client.clone()
310 }
311
312 pub fn pause(&mut self) {
316 self.state = PollState::Paused;
317 }
318
319 pub fn unpause(&mut self) {
323 if matches!(self.state, PollState::Paused) {
324 self.state = PollState::Waiting;
325 }
326 }
327}
328
329impl<Resp, Output, Map> PollerStream<Resp, Output, Map>
330where
331 Map: Fn(Resp) -> Output,
332{
333 pub fn map<NewOutput, NewMap>(self, map: NewMap) -> PollerStream<Resp, NewOutput, NewMap>
335 where
336 NewMap: Fn(Resp) -> NewOutput,
337 {
338 PollerStream {
339 client: self.client,
340 method: self.method,
341 params: self.params,
342 poll_interval: self.poll_interval,
343 limit: self.limit,
344 terminal_error_codes: self.terminal_error_codes,
345 poll_count: self.poll_count,
346 state: self.state,
347 span: self.span,
348 map,
349 _pd: PhantomData,
350 }
351 }
352}
353
354impl<Resp, Output, Map> Stream for PollerStream<Resp, Output, Map>
355where
356 Resp: RpcRecv + 'static,
357 Map: Fn(Resp) -> Output + Unpin,
358{
359 type Item = Output;
360
361 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 let this = self.get_mut();
363 let _guard = this.span.enter();
364
365 loop {
366 match &mut this.state {
367 PollState::Paused => return Poll::Pending,
368 PollState::Waiting => {
369 if this.poll_count >= this.limit {
371 debug!("poll limit reached");
372 this.state = PollState::Finished;
373 continue;
374 }
375
376 let Some(client) = this.client.upgrade() else {
378 debug!("client dropped");
379 this.state = PollState::Finished;
380 continue;
381 };
382
383 trace!("polling");
385 let method = this.method.clone();
386 let params = this.params.clone();
387 let fut = Box::pin(async move { client.request(method, params).await });
388 this.state = PollState::Polling(fut);
389 }
390 PollState::Polling(fut) => {
391 match ready!(fut.poll_unpin(cx)) {
392 Ok(resp) => {
393 this.poll_count += 1;
394 trace!(duration=?this.poll_interval, "sleeping");
396 let sleep = Box::pin(sleep(this.poll_interval));
397 this.state = PollState::Sleeping(sleep);
398 return Poll::Ready(Some((this.map)(resp)));
399 }
400 Err(err) => {
401 error!(%err, "failed to poll");
402
403 if let Some(resp) = err.as_error_resp() {
404 if this.terminal_error_codes.contains(&resp.code) {
406 warn!("server returned terminal error code, stopping poller");
407 this.state = PollState::Finished;
408 continue;
409 }
410
411 if resp.message.contains("filter not found")
415 && this.terminal_error_codes.is_empty()
416 {
417 warn!("server has dropped the filter, stopping poller");
418 this.state = PollState::Finished;
419 continue;
420 }
421 }
422
423 trace!(duration=?this.poll_interval, "sleeping after error");
425
426 let sleep = Box::pin(sleep(this.poll_interval));
427 this.state = PollState::Sleeping(sleep);
428 }
429 }
430 }
431 PollState::Sleeping(sleep) => {
432 ready!(sleep.as_mut().poll(cx));
433 this.state = PollState::Waiting;
434 }
435 PollState::Finished => {
436 return Poll::Ready(None);
437 }
438 }
439 }
440 }
441}
442
443impl<Resp, Output, Map> FusedStream for PollerStream<Resp, Output, Map>
444where
445 Resp: RpcRecv + 'static,
446 Map: Fn(Resp) -> Output + Unpin,
447{
448 fn is_terminated(&self) -> bool {
449 matches!(self.state, PollState::Finished)
450 }
451}
452
453#[derive(Debug)]
465pub struct PollChannel<Resp> {
466 rx: broadcast::Receiver<Resp>,
467}
468
469impl<Resp> From<broadcast::Receiver<Resp>> for PollChannel<Resp> {
470 fn from(rx: broadcast::Receiver<Resp>) -> Self {
471 Self { rx }
472 }
473}
474
475impl<Resp> Deref for PollChannel<Resp> {
476 type Target = broadcast::Receiver<Resp>;
477
478 fn deref(&self) -> &Self::Target {
479 &self.rx
480 }
481}
482
483impl<Resp> DerefMut for PollChannel<Resp> {
484 fn deref_mut(&mut self) -> &mut Self::Target {
485 &mut self.rx
486 }
487}
488
489impl<Resp> PollChannel<Resp>
490where
491 Resp: RpcRecv + Clone,
492{
493 pub fn resubscribe(&self) -> Self {
495 Self { rx: self.rx.resubscribe() }
496 }
497
498 pub fn into_stream(self) -> impl Stream<Item = Resp> + Unpin {
500 self.into_stream_raw().filter_map(|r| futures::future::ready(r.ok()))
501 }
502
503 pub fn into_stream_raw(self) -> BroadcastStream<Resp> {
506 self.rx.into()
507 }
508}
509
510#[cfg(test)]
511#[allow(clippy::missing_const_for_fn)]
512fn _assert_unpin() {
513 fn _assert<T: Unpin>() {}
514 _assert::<PollChannel<()>>();
515}