1use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
2use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
3use alloy_primitives::B256;
4use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
5use futures::{future::try_join_all, FutureExt, TryFutureExt};
6use std::{
7 future::Future,
8 sync::{
9 atomic::{AtomicUsize, Ordering},
10 Arc,
11 },
12 task::{Context, Poll},
13};
14use tokio::sync::{mpsc, oneshot};
15
16#[derive(Debug, Clone)]
21pub struct PubSubFrontend {
22 tx: mpsc::UnboundedSender<PubSubInstruction>,
23 channel_size: Arc<AtomicUsize>,
26}
27
28impl PubSubFrontend {
29 pub fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
31 Self { tx, channel_size: Arc::new(AtomicUsize::new(16)) }
32 }
33
34 pub fn get_subscription(
36 &self,
37 id: B256,
38 ) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
39 let backend_tx = self.tx.clone();
40 async move {
41 let (tx, rx) = oneshot::channel();
42 backend_tx
43 .send(PubSubInstruction::GetSub(id, tx))
44 .map_err(|_| TransportErrorKind::backend_gone())?;
45 rx.await.map_err(|_| TransportErrorKind::backend_gone())
46 }
47 }
48
49 pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
51 self.tx
52 .send(PubSubInstruction::Unsubscribe(id))
53 .map_err(|_| TransportErrorKind::backend_gone())
54 }
55
56 pub fn send(
58 &self,
59 req: SerializedRequest,
60 ) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
61 let tx = self.tx.clone();
62 let channel_size = self.channel_size.load(Ordering::Relaxed);
63
64 async move {
65 let (in_flight, rx) = InFlight::new(req, channel_size);
66 tx.send(PubSubInstruction::Request(in_flight))
67 .map_err(|_| TransportErrorKind::backend_gone())?;
68 rx.await.map_err(|_| TransportErrorKind::backend_gone())?
69 }
70 }
71
72 pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
76 match req {
77 RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
78 RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
79 .map_ok(ResponsePacket::Batch)
80 .boxed(),
81 }
82 }
83
84 pub fn channel_size(&self) -> usize {
89 self.channel_size.load(Ordering::Relaxed)
90 }
91
92 pub fn set_channel_size(&self, channel_size: usize) {
97 debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
98 self.channel_size.store(channel_size, Ordering::Relaxed);
99 }
100}
101
102impl tower::Service<RequestPacket> for PubSubFrontend {
103 type Response = ResponsePacket;
104 type Error = TransportError;
105 type Future = TransportFut<'static>;
106
107 #[inline]
108 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 let result =
110 if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
111 Poll::Ready(result)
112 }
113
114 #[inline]
115 fn call(&mut self, req: RequestPacket) -> Self::Future {
116 self.send_packet(req)
117 }
118}