alloy_pubsub/
frontend.rs

1use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
2use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
3use alloy_primitives::B256;
4use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
5use futures::{future::try_join_all, FutureExt, TryFutureExt};
6use std::{
7    future::Future,
8    sync::atomic::{AtomicUsize, Ordering},
9    task::{Context, Poll},
10};
11use tokio::sync::{mpsc, oneshot};
12
13/// A `PubSubFrontend` is [`Transport`] composed of a channel to a running
14/// PubSub service.
15///
16/// [`Transport`]: alloy_transport::Transport
17#[derive(Debug)]
18pub struct PubSubFrontend {
19    tx: mpsc::UnboundedSender<PubSubInstruction>,
20    /// The number of items to buffer in new subscription channels. Defaults to
21    /// 16. See [`tokio::sync::broadcast::channel`] for a description.
22    channel_size: AtomicUsize,
23}
24
25impl Clone for PubSubFrontend {
26    fn clone(&self) -> Self {
27        let channel_size = self.channel_size.load(Ordering::Relaxed);
28        Self { tx: self.tx.clone(), channel_size: AtomicUsize::new(channel_size) }
29    }
30}
31
32impl PubSubFrontend {
33    /// Create a new frontend.
34    pub(crate) const fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
35        Self { tx, channel_size: AtomicUsize::new(16) }
36    }
37
38    /// Get the subscription ID for a local ID.
39    pub fn get_subscription(
40        &self,
41        id: B256,
42    ) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
43        let backend_tx = self.tx.clone();
44        async move {
45            let (tx, rx) = oneshot::channel();
46            backend_tx
47                .send(PubSubInstruction::GetSub(id, tx))
48                .map_err(|_| TransportErrorKind::backend_gone())?;
49            rx.await.map_err(|_| TransportErrorKind::backend_gone())
50        }
51    }
52
53    /// Unsubscribe from a subscription.
54    pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
55        self.tx
56            .send(PubSubInstruction::Unsubscribe(id))
57            .map_err(|_| TransportErrorKind::backend_gone())
58    }
59
60    /// Send a request.
61    pub fn send(
62        &self,
63        req: SerializedRequest,
64    ) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
65        let tx = self.tx.clone();
66        let channel_size = self.channel_size.load(Ordering::Relaxed);
67
68        async move {
69            let (in_flight, rx) = InFlight::new(req, channel_size);
70            tx.send(PubSubInstruction::Request(in_flight))
71                .map_err(|_| TransportErrorKind::backend_gone())?;
72            rx.await.map_err(|_| TransportErrorKind::backend_gone())?
73        }
74    }
75
76    /// Send a packet of requests, by breaking it up into individual requests.
77    ///
78    /// Once all responses are received, we return a single response packet.
79    pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
80        match req {
81            RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
82            RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
83                .map_ok(ResponsePacket::Batch)
84                .boxed(),
85        }
86    }
87
88    /// Get the currently configured channel size. This is the number of items
89    /// to buffer in new subscription channels. Defaults to 16. See
90    /// [`tokio::sync::broadcast`] for a description of relevant
91    /// behavior.
92    pub fn channel_size(&self) -> usize {
93        self.channel_size.load(Ordering::Relaxed)
94    }
95
96    /// Set the channel size. This is the number of items to buffer in new
97    /// subscription channels. Defaults to 16. See
98    /// [`tokio::sync::broadcast`] for a description of relevant
99    /// behavior.
100    pub fn set_channel_size(&self, channel_size: usize) {
101        debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
102        self.channel_size.store(channel_size, Ordering::Relaxed);
103    }
104}
105
106impl tower::Service<RequestPacket> for PubSubFrontend {
107    type Response = ResponsePacket;
108    type Error = TransportError;
109    type Future = TransportFut<'static>;
110
111    #[inline]
112    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113        let result =
114            if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
115        Poll::Ready(result)
116    }
117
118    #[inline]
119    fn call(&mut self, req: RequestPacket) -> Self::Future {
120        self.send_packet(req)
121    }
122}