use crate::WeakClient;
use alloy_json_rpc::{RpcRecv, RpcSend};
use alloy_transport::utils::Spawnable;
use futures::{ready, stream::FusedStream, Future, FutureExt, Stream, StreamExt};
use serde::Serialize;
use serde_json::value::RawValue;
use std::{
borrow::Cow,
collections::HashSet,
marker::PhantomData,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use tracing::Span;
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
use wasmtimer::tokio::{sleep, Sleep};
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
use tokio::time::{sleep, Sleep};
#[derive(Debug)]
#[must_use = "this builder does nothing unless you call `spawn` or `into_stream`"]
pub struct PollerBuilder<Params, Resp> {
client: WeakClient,
method: Cow<'static, str>,
params: Params,
channel_size: usize,
poll_interval: Duration,
limit: usize,
terminal_error_codes: HashSet<i64>,
_pd: PhantomData<fn() -> Resp>,
}
impl<Params, Resp> PollerBuilder<Params, Resp>
where
Params: RpcSend + 'static,
Resp: RpcRecv,
{
pub fn new(client: WeakClient, method: impl Into<Cow<'static, str>>, params: Params) -> Self {
let poll_interval =
client.upgrade().map_or_else(|| Duration::from_secs(7), |c| c.poll_interval());
Self {
client,
method: method.into(),
params,
channel_size: 16,
poll_interval,
limit: usize::MAX,
terminal_error_codes: HashSet::default(),
_pd: PhantomData,
}
}
pub const fn channel_size(&self) -> usize {
self.channel_size
}
pub const fn set_channel_size(&mut self, channel_size: usize) {
self.channel_size = channel_size;
}
pub const fn with_channel_size(mut self, channel_size: usize) -> Self {
self.set_channel_size(channel_size);
self
}
pub const fn limit(&self) -> usize {
self.limit
}
pub fn set_limit(&mut self, limit: Option<usize>) {
self.limit = limit.unwrap_or(usize::MAX);
}
pub fn with_limit(mut self, limit: Option<usize>) -> Self {
self.set_limit(limit);
self
}
pub fn terminal_error_codes(&self) -> impl IntoIterator<Item = &i64> {
self.terminal_error_codes.iter()
}
pub fn set_terminal_error_codes<I>(&mut self, error_codes: I)
where
I: IntoIterator<Item = i64>,
{
self.terminal_error_codes = HashSet::from_iter(error_codes);
}
pub fn with_terminal_error_codes<I>(mut self, error_codes: I) -> Self
where
I: IntoIterator<Item = i64>,
{
self.set_terminal_error_codes(error_codes);
self
}
pub const fn poll_interval(&self) -> Duration {
self.poll_interval
}
pub const fn set_poll_interval(&mut self, poll_interval: Duration) {
self.poll_interval = poll_interval;
}
pub const fn with_poll_interval(mut self, poll_interval: Duration) -> Self {
self.set_poll_interval(poll_interval);
self
}
pub fn spawn(self) -> PollChannel<Resp>
where
Resp: Clone,
{
let (tx, rx) = broadcast::channel(self.channel_size);
self.into_future(tx).spawn_task();
rx.into()
}
async fn into_future(self, tx: broadcast::Sender<Resp>)
where
Resp: Clone,
{
let mut stream = self.into_stream();
while let Some(resp) = stream.next().await {
if tx.send(resp).is_err() {
debug!("channel closed");
break;
}
}
}
pub fn into_stream(self) -> PollerStream<Resp> {
PollerStream::new(self)
}
pub fn client(&self) -> WeakClient {
self.client.clone()
}
}
enum PollState<Resp> {
Paused,
Waiting,
Polling(
alloy_transport::Pbf<
'static,
Resp,
alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
>,
),
Sleeping(Pin<Box<Sleep>>),
Finished,
}
pub struct PollerStream<Resp, Output = Resp, Map = fn(Resp) -> Output> {
client: WeakClient,
method: Cow<'static, str>,
params: Box<RawValue>,
poll_interval: Duration,
limit: usize,
terminal_error_codes: HashSet<i64>,
poll_count: usize,
state: PollState<Resp>,
span: Span,
map: Map,
_pd: PhantomData<fn() -> Output>,
}
impl<Resp, Output, Map> std::fmt::Debug for PollerStream<Resp, Output, Map> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PollerStream")
.field("method", &self.method)
.field("poll_interval", &self.poll_interval)
.field("limit", &self.limit)
.field("poll_count", &self.poll_count)
.finish_non_exhaustive()
}
}
impl<Resp> PollerStream<Resp> {
fn new<Params: Serialize>(builder: PollerBuilder<Params, Resp>) -> Self {
let span = debug_span!("poller", method = %builder.method);
let params = serde_json::value::to_raw_value(&builder.params).unwrap_or_else(|err| {
error!(%err, "failed to serialize params during initialization");
Box::<RawValue>::default()
});
Self {
client: builder.client,
method: builder.method,
params,
poll_interval: builder.poll_interval,
limit: builder.limit,
terminal_error_codes: builder.terminal_error_codes,
poll_count: 0,
state: PollState::Waiting,
span,
map: std::convert::identity,
_pd: PhantomData,
}
}
pub fn client(&self) -> WeakClient {
self.client.clone()
}
pub fn pause(&mut self) {
self.state = PollState::Paused;
}
pub fn unpause(&mut self) {
if matches!(self.state, PollState::Paused) {
self.state = PollState::Waiting;
}
}
}
impl<Resp, Output, Map> PollerStream<Resp, Output, Map>
where
Map: Fn(Resp) -> Output,
{
pub fn map<NewOutput, NewMap>(self, map: NewMap) -> PollerStream<Resp, NewOutput, NewMap>
where
NewMap: Fn(Resp) -> NewOutput,
{
PollerStream {
client: self.client,
method: self.method,
params: self.params,
poll_interval: self.poll_interval,
limit: self.limit,
terminal_error_codes: self.terminal_error_codes,
poll_count: self.poll_count,
state: self.state,
span: self.span,
map,
_pd: PhantomData,
}
}
}
impl<Resp, Output, Map> Stream for PollerStream<Resp, Output, Map>
where
Resp: RpcRecv + 'static,
Map: Fn(Resp) -> Output + Unpin,
{
type Item = Output;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let _guard = this.span.enter();
loop {
match &mut this.state {
PollState::Paused => return Poll::Pending,
PollState::Waiting => {
if this.poll_count >= this.limit {
debug!("poll limit reached");
this.state = PollState::Finished;
continue;
}
let Some(client) = this.client.upgrade() else {
debug!("client dropped");
this.state = PollState::Finished;
continue;
};
trace!("polling");
let method = this.method.clone();
let params = this.params.clone();
let fut = Box::pin(async move { client.request(method, params).await });
this.state = PollState::Polling(fut);
}
PollState::Polling(fut) => {
match ready!(fut.poll_unpin(cx)) {
Ok(resp) => {
this.poll_count += 1;
trace!(duration=?this.poll_interval, "sleeping");
let sleep = Box::pin(sleep(this.poll_interval));
this.state = PollState::Sleeping(sleep);
return Poll::Ready(Some((this.map)(resp)));
}
Err(err) => {
error!(%err, "failed to poll");
if let Some(resp) = err.as_error_resp() {
if this.terminal_error_codes.contains(&resp.code) {
warn!("server returned terminal error code, stopping poller");
this.state = PollState::Finished;
continue;
}
if resp.message.contains("filter not found")
&& this.terminal_error_codes.is_empty()
{
warn!("server has dropped the filter, stopping poller");
this.state = PollState::Finished;
continue;
}
}
trace!(duration=?this.poll_interval, "sleeping after error");
let sleep = Box::pin(sleep(this.poll_interval));
this.state = PollState::Sleeping(sleep);
}
}
}
PollState::Sleeping(sleep) => {
ready!(sleep.as_mut().poll(cx));
this.state = PollState::Waiting;
}
PollState::Finished => {
return Poll::Ready(None);
}
}
}
}
}
impl<Resp, Output, Map> FusedStream for PollerStream<Resp, Output, Map>
where
Resp: RpcRecv + 'static,
Map: Fn(Resp) -> Output + Unpin,
{
fn is_terminated(&self) -> bool {
matches!(self.state, PollState::Finished)
}
}
#[derive(Debug)]
pub struct PollChannel<Resp> {
rx: broadcast::Receiver<Resp>,
}
impl<Resp> From<broadcast::Receiver<Resp>> for PollChannel<Resp> {
fn from(rx: broadcast::Receiver<Resp>) -> Self {
Self { rx }
}
}
impl<Resp> Deref for PollChannel<Resp> {
type Target = broadcast::Receiver<Resp>;
fn deref(&self) -> &Self::Target {
&self.rx
}
}
impl<Resp> DerefMut for PollChannel<Resp> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.rx
}
}
impl<Resp> PollChannel<Resp>
where
Resp: RpcRecv + Clone,
{
pub fn resubscribe(&self) -> Self {
Self { rx: self.rx.resubscribe() }
}
pub fn into_stream(self) -> impl Stream<Item = Resp> + Unpin {
self.into_stream_raw().filter_map(|r| futures::future::ready(r.ok()))
}
pub fn into_stream_raw(self) -> BroadcastStream<Resp> {
self.rx.into()
}
}
#[cfg(test)]
#[allow(clippy::missing_const_for_fn)]
fn _assert_unpin() {
fn _assert<T: Unpin>() {}
_assert::<PollChannel<()>>();
}