use alloy_primitives::B256;
use futures::{ready, Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde_json::value::RawValue;
use std::{pin::Pin, task};
use tokio::sync::broadcast;
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
#[derive(Debug)]
pub struct RawSubscription {
pub rx: broadcast::Receiver<Box<RawValue>>,
pub local_id: B256,
}
impl RawSubscription {
pub const fn local_id(&self) -> &B256 {
&self.local_id
}
pub fn blocking_recv(&mut self) -> Result<Box<RawValue>, broadcast::error::RecvError> {
self.rx.blocking_recv()
}
pub fn is_empty(&self) -> bool {
self.rx.is_empty()
}
pub fn len(&self) -> usize {
self.rx.len()
}
pub async fn recv(&mut self) -> Result<Box<RawValue>, broadcast::error::RecvError> {
self.rx.recv().await
}
pub fn resubscribe(&self) -> Self {
Self { rx: self.rx.resubscribe(), local_id: self.local_id }
}
pub fn same_channel(&self, other: &Self) -> bool {
self.rx.same_channel(&other.rx)
}
pub fn try_recv(&mut self) -> Result<Box<RawValue>, broadcast::error::TryRecvError> {
self.rx.try_recv()
}
pub fn into_stream(self) -> BroadcastStream<Box<RawValue>> {
self.rx.into()
}
pub fn into_typed<T>(self) -> Subscription<T> {
self.into()
}
}
#[derive(Debug)]
pub enum SubscriptionItem<T> {
Item(T),
Other(Box<RawValue>),
}
impl<T: DeserializeOwned> From<Box<RawValue>> for SubscriptionItem<T> {
fn from(value: Box<RawValue>) -> Self {
serde_json::from_str(value.get()).map_or_else(
|_| {
trace!(value = value.get(), "Received unexpected value in subscription.");
Self::Other(value)
},
|item| Self::Item(item),
)
}
}
#[derive(Debug)]
#[must_use]
pub struct Subscription<T> {
pub(crate) inner: RawSubscription,
_pd: std::marker::PhantomData<T>,
}
impl<T> From<RawSubscription> for Subscription<T> {
fn from(inner: RawSubscription) -> Self {
Self { inner, _pd: std::marker::PhantomData }
}
}
impl<T> Subscription<T> {
pub const fn local_id(&self) -> &B256 {
self.inner.local_id()
}
pub fn into_raw(self) -> RawSubscription {
self.inner
}
pub const fn inner(&self) -> &RawSubscription {
&self.inner
}
pub const fn inner_mut(&mut self) -> &mut RawSubscription {
&mut self.inner
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn resubscribe_inner(&self) -> RawSubscription {
self.inner.resubscribe()
}
pub fn resubscribe(&self) -> Self {
self.inner.resubscribe().into()
}
pub fn same_channel<U>(&self, other: &Subscription<U>) -> bool {
self.inner.same_channel(&other.inner)
}
}
impl<T: DeserializeOwned> Subscription<T> {
pub fn blocking_recv_any(
&mut self,
) -> Result<SubscriptionItem<T>, broadcast::error::RecvError> {
self.inner.blocking_recv().map(Into::into)
}
pub async fn recv_any(&mut self) -> Result<SubscriptionItem<T>, broadcast::error::RecvError> {
self.inner.recv().await.map(Into::into)
}
pub fn try_recv_any(&mut self) -> Result<SubscriptionItem<T>, broadcast::error::TryRecvError> {
self.inner.try_recv().map(Into::into)
}
pub fn into_stream(self) -> SubscriptionStream<T> {
SubscriptionStream {
id: self.inner.local_id,
inner: self.inner.into_stream(),
_pd: std::marker::PhantomData,
}
}
pub fn into_result_stream(self) -> SubResultStream<T> {
SubResultStream {
id: self.inner.local_id,
inner: self.inner.into_stream(),
_pd: std::marker::PhantomData,
}
}
pub fn into_any_stream(self) -> SubAnyStream<T> {
SubAnyStream {
id: self.inner.local_id,
inner: self.inner.into_stream(),
_pd: std::marker::PhantomData,
}
}
pub fn blocking_recv(&mut self) -> Result<T, broadcast::error::RecvError> {
loop {
match self.blocking_recv_any()? {
SubscriptionItem::Item(item) => return Ok(item),
SubscriptionItem::Other(_) => continue,
}
}
}
pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
loop {
match self.recv_any().await? {
SubscriptionItem::Item(item) => return Ok(item),
SubscriptionItem::Other(_) => continue,
}
}
}
pub fn try_recv(&mut self) -> Result<T, broadcast::error::TryRecvError> {
loop {
match self.try_recv_any()? {
SubscriptionItem::Item(item) => return Ok(item),
SubscriptionItem::Other(_) => continue,
}
}
}
pub fn blocking_recv_result(
&mut self,
) -> Result<Result<T, serde_json::Error>, broadcast::error::RecvError> {
self.inner.blocking_recv().map(|value| serde_json::from_str(value.get()))
}
pub async fn recv_result(
&mut self,
) -> Result<Result<T, serde_json::Error>, broadcast::error::RecvError> {
self.inner.recv().await.map(|value| serde_json::from_str(value.get()))
}
pub fn try_recv_result(
&mut self,
) -> Result<Result<T, serde_json::Error>, broadcast::error::TryRecvError> {
self.inner.try_recv().map(|value| serde_json::from_str(value.get()))
}
}
#[derive(Debug)]
pub struct SubAnyStream<T> {
id: B256,
inner: BroadcastStream<Box<RawValue>>,
_pd: std::marker::PhantomData<fn() -> T>,
}
impl<T> SubAnyStream<T> {
pub const fn id(&self) -> &B256 {
&self.id
}
}
impl<T: DeserializeOwned> Stream for SubAnyStream<T> {
type Item = SubscriptionItem<T>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
loop {
match ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(value)) => return task::Poll::Ready(Some(value.into())),
Some(Err(err @ BroadcastStreamRecvError::Lagged(_))) => {
debug!(%err, %self.id, "stream lagged");
continue;
}
None => return task::Poll::Ready(None),
}
}
}
}
#[derive(Debug)]
pub struct SubscriptionStream<T> {
id: B256,
inner: BroadcastStream<Box<RawValue>>,
_pd: std::marker::PhantomData<fn() -> T>,
}
impl<T> SubscriptionStream<T> {
pub const fn id(&self) -> &B256 {
&self.id
}
}
impl<T: DeserializeOwned> Stream for SubscriptionStream<T> {
type Item = T;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
loop {
match ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(value)) => match serde_json::from_str(value.get()) {
Ok(item) => return task::Poll::Ready(Some(item)),
Err(err) => {
debug!(value = ?value.get(), %err, %self.id, "failed deserializing subscription item");
continue;
}
},
Some(Err(err @ BroadcastStreamRecvError::Lagged(_))) => {
debug!(%err, %self.id, "stream lagged");
continue;
}
None => return task::Poll::Ready(None),
}
}
}
}
#[derive(Debug)]
pub struct SubResultStream<T> {
id: B256,
inner: BroadcastStream<Box<RawValue>>,
_pd: std::marker::PhantomData<fn() -> T>,
}
impl<T> SubResultStream<T> {
pub const fn id(&self) -> &B256 {
&self.id
}
}
impl<T: DeserializeOwned> Stream for SubResultStream<T> {
type Item = serde_json::Result<T>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
loop {
match ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(value)) => {
return task::Poll::Ready(Some(serde_json::from_str(value.get())))
}
Some(Err(err @ BroadcastStreamRecvError::Lagged(_))) => {
debug!(%err, %self.id, "stream lagged");
continue;
}
None => return task::Poll::Ready(None),
}
}
}
}