1use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
2use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
3use alloy_primitives::B256;
4use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
5use futures::{future::try_join_all, FutureExt, TryFutureExt};
6use std::{
7 future::Future,
8 sync::{
9 atomic::{AtomicUsize, Ordering},
10 Arc,
11 },
12 task::{Context, Poll},
13};
14use tokio::sync::{mpsc, oneshot};
15
16#[derive(Debug, Clone)]
21pub struct PubSubFrontend {
22 tx: mpsc::UnboundedSender<PubSubInstruction>,
23 channel_size: Arc<AtomicUsize>,
26}
27
28impl PubSubFrontend {
29 pub fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
31 Self { tx, channel_size: Arc::new(AtomicUsize::new(16)) }
32 }
33
34 pub fn get_subscription(
36 &self,
37 id: B256,
38 ) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
39 let backend_tx = self.tx.clone();
40 async move {
41 let (tx, rx) = oneshot::channel();
42 backend_tx
43 .send(PubSubInstruction::GetSub(id, tx))
44 .map_err(|_| TransportErrorKind::backend_gone())?;
45 rx.await
46 .map_err(|_| TransportErrorKind::backend_gone())?
47 .map_or_else(|| Err(TransportErrorKind::custom_str("subscription not found")), Ok)
48 }
49 }
50
51 pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
53 self.tx
54 .send(PubSubInstruction::Unsubscribe(id))
55 .map_err(|_| TransportErrorKind::backend_gone())
56 }
57
58 pub fn send(
60 &self,
61 req: SerializedRequest,
62 ) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
63 let tx = self.tx.clone();
64 let channel_size = self.channel_size.load(Ordering::Relaxed);
65
66 async move {
67 let (in_flight, rx) = InFlight::new(req, channel_size);
68 tx.send(PubSubInstruction::Request(in_flight))
69 .map_err(|_| TransportErrorKind::backend_gone())?;
70 rx.await.map_err(|_| TransportErrorKind::backend_gone())?
71 }
72 }
73
74 pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
78 match req {
79 RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
80 RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
81 .map_ok(ResponsePacket::Batch)
82 .boxed(),
83 }
84 }
85
86 pub fn channel_size(&self) -> usize {
91 self.channel_size.load(Ordering::Relaxed)
92 }
93
94 pub fn set_channel_size(&self, channel_size: usize) {
99 debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
100 self.channel_size.store(channel_size, Ordering::Relaxed);
101 }
102}
103
104impl tower::Service<RequestPacket> for PubSubFrontend {
105 type Response = ResponsePacket;
106 type Error = TransportError;
107 type Future = TransportFut<'static>;
108
109 #[inline]
110 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111 let result =
112 if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
113 Poll::Ready(result)
114 }
115
116 #[inline]
117 fn call(&mut self, req: RequestPacket) -> Self::Future {
118 self.send_packet(req)
119 }
120}