alloy_pubsub/
frontend.rs

1use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
2use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
3use alloy_primitives::B256;
4use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
5use futures::{future::try_join_all, FutureExt, TryFutureExt};
6use std::{
7    future::Future,
8    sync::{
9        atomic::{AtomicUsize, Ordering},
10        Arc,
11    },
12    task::{Context, Poll},
13};
14use tokio::sync::{mpsc, oneshot};
15
16/// A `PubSubFrontend` is [`Transport`] composed of a channel to a running
17/// PubSub service.
18///
19/// [`Transport`]: alloy_transport::Transport
20#[derive(Debug, Clone)]
21pub struct PubSubFrontend {
22    tx: mpsc::UnboundedSender<PubSubInstruction>,
23    /// The number of items to buffer in new subscription channels. Defaults to
24    /// 16. See [`tokio::sync::broadcast::channel`] for a description.
25    channel_size: Arc<AtomicUsize>,
26}
27
28impl PubSubFrontend {
29    /// Create a new frontend.
30    pub fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
31        Self { tx, channel_size: Arc::new(AtomicUsize::new(16)) }
32    }
33
34    /// Get the subscription ID for a local ID.
35    pub fn get_subscription(
36        &self,
37        id: B256,
38    ) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
39        let backend_tx = self.tx.clone();
40        async move {
41            let (tx, rx) = oneshot::channel();
42            backend_tx
43                .send(PubSubInstruction::GetSub(id, tx))
44                .map_err(|_| TransportErrorKind::backend_gone())?;
45            rx.await
46                .map_err(|_| TransportErrorKind::backend_gone())?
47                .map_or_else(|| Err(TransportErrorKind::custom_str("subscription not found")), Ok)
48        }
49    }
50
51    /// Unsubscribe from a subscription.
52    pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
53        self.tx
54            .send(PubSubInstruction::Unsubscribe(id))
55            .map_err(|_| TransportErrorKind::backend_gone())
56    }
57
58    /// Send a request.
59    pub fn send(
60        &self,
61        req: SerializedRequest,
62    ) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
63        let tx = self.tx.clone();
64        let channel_size = self.channel_size.load(Ordering::Relaxed);
65
66        async move {
67            let (in_flight, rx) = InFlight::new(req, channel_size);
68            tx.send(PubSubInstruction::Request(in_flight))
69                .map_err(|_| TransportErrorKind::backend_gone())?;
70            rx.await.map_err(|_| TransportErrorKind::backend_gone())?
71        }
72    }
73
74    /// Send a packet of requests, by breaking it up into individual requests.
75    ///
76    /// Once all responses are received, we return a single response packet.
77    pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
78        match req {
79            RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
80            RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
81                .map_ok(ResponsePacket::Batch)
82                .boxed(),
83        }
84    }
85
86    /// Get the currently configured channel size. This is the number of items
87    /// to buffer in new subscription channels. Defaults to 16. See
88    /// [`tokio::sync::broadcast`] for a description of relevant
89    /// behavior.
90    pub fn channel_size(&self) -> usize {
91        self.channel_size.load(Ordering::Relaxed)
92    }
93
94    /// Set the channel size. This is the number of items to buffer in new
95    /// subscription channels. Defaults to 16. See
96    /// [`tokio::sync::broadcast`] for a description of relevant
97    /// behavior.
98    pub fn set_channel_size(&self, channel_size: usize) {
99        debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
100        self.channel_size.store(channel_size, Ordering::Relaxed);
101    }
102}
103
104impl tower::Service<RequestPacket> for PubSubFrontend {
105    type Response = ResponsePacket;
106    type Error = TransportError;
107    type Future = TransportFut<'static>;
108
109    #[inline]
110    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111        let result =
112            if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
113        Poll::Ready(result)
114    }
115
116    #[inline]
117    fn call(&mut self, req: RequestPacket) -> Self::Future {
118        self.send_packet(req)
119    }
120}