alloy_pubsub/
sub.rs

1use alloy_primitives::B256;
2use futures::{ready, Stream, StreamExt};
3use serde::de::DeserializeOwned;
4use serde_json::value::RawValue;
5use std::{pin::Pin, task};
6use tokio::sync::broadcast;
7use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
8
9/// A Subscription is a feed of notifications from the server, identified by a
10/// local ID.
11///
12/// This type is mostly a wrapper around [`broadcast::Receiver`], and exposes
13/// the same methods.
14#[derive(Debug)]
15pub struct RawSubscription {
16    /// The channel via which notifications are received.
17    pub rx: broadcast::Receiver<Box<RawValue>>,
18    /// The local ID of the subscription.
19    pub local_id: B256,
20}
21
22impl RawSubscription {
23    /// Get the local ID of the subscription.
24    pub const fn local_id(&self) -> &B256 {
25        &self.local_id
26    }
27
28    /// Wrapper for [`blocking_recv`]. Block the current thread until a message
29    /// is available.
30    ///
31    /// [`blocking_recv`]: broadcast::Receiver::blocking_recv
32    pub fn blocking_recv(&mut self) -> Result<Box<RawValue>, broadcast::error::RecvError> {
33        self.rx.blocking_recv()
34    }
35
36    /// Returns `true` if the broadcast channel is empty (i.e. there are
37    /// currently no notifications to receive).
38    pub fn is_empty(&self) -> bool {
39        self.rx.is_empty()
40    }
41
42    /// Returns the number of messages in the broadcast channel that this
43    /// receiver has yet to receive.
44    pub fn len(&self) -> usize {
45        self.rx.len()
46    }
47
48    /// Wrapper for [`recv`]. Await an item from the channel.
49    ///
50    /// [`recv`]: broadcast::Receiver::recv
51    pub async fn recv(&mut self) -> Result<Box<RawValue>, broadcast::error::RecvError> {
52        self.rx.recv().await
53    }
54
55    /// Wrapper for [`resubscribe`]. Create a new Subscription, starting from
56    /// the current tail element.
57    ///
58    /// [`resubscribe`]: broadcast::Receiver::resubscribe
59    pub fn resubscribe(&self) -> Self {
60        Self { rx: self.rx.resubscribe(), local_id: self.local_id }
61    }
62
63    /// Wrapper for [`same_channel`]. Returns `true` if the two subscriptions
64    /// share the same broadcast channel.
65    ///
66    /// [`same_channel`]: broadcast::Receiver::same_channel
67    pub fn same_channel(&self, other: &Self) -> bool {
68        self.rx.same_channel(&other.rx)
69    }
70
71    /// Wrapper for [`try_recv`]. Attempt to receive a message from the channel
72    /// without awaiting.
73    ///
74    /// [`try_recv`]: broadcast::Receiver::try_recv
75    pub fn try_recv(&mut self) -> Result<Box<RawValue>, broadcast::error::TryRecvError> {
76        self.rx.try_recv()
77    }
78
79    /// Convert the subscription into a stream.
80    pub fn into_stream(self) -> BroadcastStream<Box<RawValue>> {
81        self.rx.into()
82    }
83
84    /// Convert into a typed subscription.
85    pub fn into_typed<T>(self) -> Subscription<T> {
86        self.into()
87    }
88}
89
90/// An item in a typed [`Subscription`]. This is either the expected type, or
91/// some serialized value of another type.
92#[derive(Debug)]
93pub enum SubscriptionItem<T> {
94    /// The expected item.
95    Item(T),
96    /// Some other value.
97    Other(Box<RawValue>),
98}
99
100impl<T: DeserializeOwned> From<Box<RawValue>> for SubscriptionItem<T> {
101    fn from(value: Box<RawValue>) -> Self {
102        serde_json::from_str(value.get()).map_or_else(
103            |_| {
104                trace!(value = value.get(), "Received unexpected value in subscription.");
105                Self::Other(value)
106            },
107            |item| Self::Item(item),
108        )
109    }
110}
111
112/// A Subscription is a feed of notifications from the server of a specific
113/// type `T`, identified by a local ID.
114///
115/// For flexibility, we expose three similar APIs:
116/// - The [`Subscription::recv`] method and its variants will discard any notifications of
117///   unexpected types.
118/// - The [`Subscription::recv_any`] and its variants will yield unexpected types as
119///   [`SubscriptionItem::Other`].
120/// - The [`Subscription::recv_result`] and its variants will attempt to deserialize the
121///   notifications and yield the `serde_json::Result` of the deserialization.
122#[derive(Debug)]
123#[must_use]
124pub struct Subscription<T> {
125    pub(crate) inner: RawSubscription,
126    _pd: std::marker::PhantomData<T>,
127}
128
129impl<T> From<RawSubscription> for Subscription<T> {
130    fn from(inner: RawSubscription) -> Self {
131        Self { inner, _pd: std::marker::PhantomData }
132    }
133}
134
135impl<T> Subscription<T> {
136    /// Get the local ID of the subscription.
137    pub const fn local_id(&self) -> &B256 {
138        self.inner.local_id()
139    }
140
141    /// Convert the subscription into its inner [`RawSubscription`].
142    pub fn into_raw(self) -> RawSubscription {
143        self.inner
144    }
145
146    /// Get a reference to the inner subscription.
147    pub const fn inner(&self) -> &RawSubscription {
148        &self.inner
149    }
150
151    /// Get a mutable reference to the inner subscription.
152    pub const fn inner_mut(&mut self) -> &mut RawSubscription {
153        &mut self.inner
154    }
155
156    /// Returns `true` if the broadcast channel is empty (i.e. there are
157    /// currently no notifications to receive).
158    pub fn is_empty(&self) -> bool {
159        self.inner.is_empty()
160    }
161
162    /// Returns the number of messages in the broadcast channel that this
163    /// receiver has yet to receive.
164    ///
165    /// NB: This count may include messages of unexpected types that will be
166    /// discarded upon receipt.
167    pub fn len(&self) -> usize {
168        self.inner.len()
169    }
170
171    /// Wrapper for [`resubscribe`]. Create a new [`RawSubscription`], starting
172    /// from the current tail element.
173    ///
174    /// [`resubscribe`]: broadcast::Receiver::resubscribe
175    pub fn resubscribe_inner(&self) -> RawSubscription {
176        self.inner.resubscribe()
177    }
178
179    /// Wrapper for [`resubscribe`]. Create a new `Subscription`, starting from
180    /// the current tail element.
181    ///
182    /// [`resubscribe`]: broadcast::Receiver::resubscribe
183    pub fn resubscribe(&self) -> Self {
184        self.inner.resubscribe().into()
185    }
186
187    /// Wrapper for [`same_channel`]. Returns `true` if the two subscriptions
188    /// share the same broadcast channel.
189    ///
190    /// [`same_channel`]: broadcast::Receiver::same_channel
191    pub fn same_channel<U>(&self, other: &Subscription<U>) -> bool {
192        self.inner.same_channel(&other.inner)
193    }
194}
195
196impl<T: DeserializeOwned> Subscription<T> {
197    /// Wrapper for [`blocking_recv`], may produce unexpected values. Block the
198    /// current thread until a message is available.
199    ///
200    /// [`blocking_recv`]: broadcast::Receiver::blocking_recv
201    pub fn blocking_recv_any(
202        &mut self,
203    ) -> Result<SubscriptionItem<T>, broadcast::error::RecvError> {
204        self.inner.blocking_recv().map(Into::into)
205    }
206
207    /// Wrapper for [`recv`], may produce unexpected values. Await an item from
208    /// the channel.
209    ///
210    /// [`recv`]: broadcast::Receiver::recv
211    pub async fn recv_any(&mut self) -> Result<SubscriptionItem<T>, broadcast::error::RecvError> {
212        self.inner.recv().await.map(Into::into)
213    }
214
215    /// Wrapper for [`try_recv`]. Attempt to receive a message from the channel
216    /// without awaiting.
217    ///
218    /// [`try_recv`]: broadcast::Receiver::try_recv
219    pub fn try_recv_any(&mut self) -> Result<SubscriptionItem<T>, broadcast::error::TryRecvError> {
220        self.inner.try_recv().map(Into::into)
221    }
222
223    /// Convert the subscription into a stream.
224    ///
225    /// Errors are logged and ignored.
226    pub fn into_stream(self) -> SubscriptionStream<T> {
227        SubscriptionStream {
228            id: self.inner.local_id,
229            inner: self.inner.into_stream(),
230            _pd: std::marker::PhantomData,
231        }
232    }
233
234    /// Convert the subscription into a stream that returns deserialization results.
235    pub fn into_result_stream(self) -> SubResultStream<T> {
236        SubResultStream {
237            id: self.inner.local_id,
238            inner: self.inner.into_stream(),
239            _pd: std::marker::PhantomData,
240        }
241    }
242
243    /// Convert the subscription into a stream that may yield unexpected types.
244    pub fn into_any_stream(self) -> SubAnyStream<T> {
245        SubAnyStream {
246            id: self.inner.local_id,
247            inner: self.inner.into_stream(),
248            _pd: std::marker::PhantomData,
249        }
250    }
251
252    /// Wrapper for [`blocking_recv`]. Block the current thread until a message
253    /// of the expected type is available.
254    ///
255    /// [`blocking_recv`]: broadcast::Receiver::blocking_recv
256    pub fn blocking_recv(&mut self) -> Result<T, broadcast::error::RecvError> {
257        loop {
258            match self.blocking_recv_any()? {
259                SubscriptionItem::Item(item) => return Ok(item),
260                SubscriptionItem::Other(_) => continue,
261            }
262        }
263    }
264
265    /// Wrapper for [`recv`]. Await an item of the expected type from the
266    /// channel.
267    ///
268    /// [`recv`]: broadcast::Receiver::recv
269    pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
270        loop {
271            match self.recv_any().await? {
272                SubscriptionItem::Item(item) => return Ok(item),
273                SubscriptionItem::Other(_) => continue,
274            }
275        }
276    }
277
278    /// Wrapper for [`try_recv`]. Attempt to receive a message of the expected
279    /// type from the channel without awaiting.
280    ///
281    /// [`try_recv`]: broadcast::Receiver::try_recv
282    pub fn try_recv(&mut self) -> Result<T, broadcast::error::TryRecvError> {
283        loop {
284            match self.try_recv_any()? {
285                SubscriptionItem::Item(item) => return Ok(item),
286                SubscriptionItem::Other(_) => continue,
287            }
288        }
289    }
290
291    /// Wrapper for [`blocking_recv`]. Block the current thread until a message
292    /// is available, deserializing the message and returning the result.
293    ///
294    /// [`blocking_recv`]: broadcast::Receiver::blocking_recv
295    pub fn blocking_recv_result(
296        &mut self,
297    ) -> Result<Result<T, serde_json::Error>, broadcast::error::RecvError> {
298        self.inner.blocking_recv().map(|value| serde_json::from_str(value.get()))
299    }
300
301    /// Wrapper for [`recv`]. Await an item from the channel, deserializing the
302    /// message and returning the result.
303    ///
304    /// [`recv`]: broadcast::Receiver::recv
305    pub async fn recv_result(
306        &mut self,
307    ) -> Result<Result<T, serde_json::Error>, broadcast::error::RecvError> {
308        self.inner.recv().await.map(|value| serde_json::from_str(value.get()))
309    }
310
311    /// Wrapper for [`try_recv`]. Attempt to receive a message from the channel
312    /// without awaiting, deserializing the message and returning the result.
313    ///
314    /// [`try_recv`]: broadcast::Receiver::try_recv
315    pub fn try_recv_result(
316        &mut self,
317    ) -> Result<Result<T, serde_json::Error>, broadcast::error::TryRecvError> {
318        self.inner.try_recv().map(|value| serde_json::from_str(value.get()))
319    }
320}
321
322/// A stream of notifications from the server, identified by a local ID. This
323/// stream may yield unexpected types.
324#[derive(Debug)]
325pub struct SubAnyStream<T> {
326    id: B256,
327    inner: BroadcastStream<Box<RawValue>>,
328    _pd: std::marker::PhantomData<fn() -> T>,
329}
330
331impl<T> SubAnyStream<T> {
332    /// Get the local ID of the subscription.
333    pub const fn id(&self) -> &B256 {
334        &self.id
335    }
336}
337
338impl<T: DeserializeOwned> Stream for SubAnyStream<T> {
339    type Item = SubscriptionItem<T>;
340
341    fn poll_next(
342        mut self: Pin<&mut Self>,
343        cx: &mut task::Context<'_>,
344    ) -> task::Poll<Option<Self::Item>> {
345        loop {
346            match ready!(self.inner.poll_next_unpin(cx)) {
347                Some(Ok(value)) => return task::Poll::Ready(Some(value.into())),
348                Some(Err(err @ BroadcastStreamRecvError::Lagged(_))) => {
349                    // This is OK.
350                    debug!(%err, %self.id, "stream lagged");
351                    continue;
352                }
353                None => return task::Poll::Ready(None),
354            }
355        }
356    }
357}
358
359/// A stream of notifications from the server, identified by a local ID. This
360/// stream will yield only the expected type, discarding any notifications of
361/// unexpected types.
362#[derive(Debug)]
363pub struct SubscriptionStream<T> {
364    id: B256,
365    inner: BroadcastStream<Box<RawValue>>,
366    _pd: std::marker::PhantomData<fn() -> T>,
367}
368
369impl<T> SubscriptionStream<T> {
370    /// Get the local ID of the subscription.
371    pub const fn id(&self) -> &B256 {
372        &self.id
373    }
374}
375
376impl<T: DeserializeOwned> Stream for SubscriptionStream<T> {
377    type Item = T;
378
379    fn poll_next(
380        mut self: Pin<&mut Self>,
381        cx: &mut task::Context<'_>,
382    ) -> task::Poll<Option<Self::Item>> {
383        loop {
384            match ready!(self.inner.poll_next_unpin(cx)) {
385                Some(Ok(value)) => match serde_json::from_str(value.get()) {
386                    Ok(item) => return task::Poll::Ready(Some(item)),
387                    Err(err) => {
388                        debug!(value = ?value.get(), %err, %self.id, "failed deserializing subscription item");
389                        error!(%err, %self.id, "failed deserializing subscription item");
390                        continue;
391                    }
392                },
393                Some(Err(err @ BroadcastStreamRecvError::Lagged(_))) => {
394                    // This is OK.
395                    debug!(%err, %self.id, "stream lagged");
396                    continue;
397                }
398                None => return task::Poll::Ready(None),
399            }
400        }
401    }
402}
403
404/// A stream of notifications from the server, identified by a local ID.
405///
406/// This stream will attempt to deserialize the notifications and yield the [`serde_json::Result`]
407/// of the deserialization.
408#[derive(Debug)]
409pub struct SubResultStream<T> {
410    id: B256,
411    inner: BroadcastStream<Box<RawValue>>,
412    _pd: std::marker::PhantomData<fn() -> T>,
413}
414
415impl<T> SubResultStream<T> {
416    /// Get the local ID of the subscription.
417    pub const fn id(&self) -> &B256 {
418        &self.id
419    }
420}
421
422impl<T: DeserializeOwned> Stream for SubResultStream<T> {
423    type Item = serde_json::Result<T>;
424
425    fn poll_next(
426        mut self: Pin<&mut Self>,
427        cx: &mut task::Context<'_>,
428    ) -> task::Poll<Option<Self::Item>> {
429        loop {
430            match ready!(self.inner.poll_next_unpin(cx)) {
431                Some(Ok(value)) => {
432                    return task::Poll::Ready(Some(serde_json::from_str(value.get())))
433                }
434                Some(Err(err @ BroadcastStreamRecvError::Lagged(_))) => {
435                    // This is OK.
436                    debug!(%err, %self.id, "stream lagged");
437                    continue;
438                }
439                None => return task::Poll::Ready(None),
440            }
441        }
442    }
443}