1use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
2use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
3use alloy_primitives::B256;
4use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
5use futures::{future::try_join_all, FutureExt, TryFutureExt};
6use std::{
7 future::Future,
8 sync::atomic::{AtomicUsize, Ordering},
9 task::{Context, Poll},
10};
11use tokio::sync::{mpsc, oneshot};
12
13#[derive(Debug)]
18pub struct PubSubFrontend {
19 tx: mpsc::UnboundedSender<PubSubInstruction>,
20 channel_size: AtomicUsize,
23}
24
25impl Clone for PubSubFrontend {
26 fn clone(&self) -> Self {
27 let channel_size = self.channel_size.load(Ordering::Relaxed);
28 Self { tx: self.tx.clone(), channel_size: AtomicUsize::new(channel_size) }
29 }
30}
31
32impl PubSubFrontend {
33 pub(crate) const fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
35 Self { tx, channel_size: AtomicUsize::new(16) }
36 }
37
38 pub fn get_subscription(
40 &self,
41 id: B256,
42 ) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
43 let backend_tx = self.tx.clone();
44 async move {
45 let (tx, rx) = oneshot::channel();
46 backend_tx
47 .send(PubSubInstruction::GetSub(id, tx))
48 .map_err(|_| TransportErrorKind::backend_gone())?;
49 rx.await.map_err(|_| TransportErrorKind::backend_gone())
50 }
51 }
52
53 pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
55 self.tx
56 .send(PubSubInstruction::Unsubscribe(id))
57 .map_err(|_| TransportErrorKind::backend_gone())
58 }
59
60 pub fn send(
62 &self,
63 req: SerializedRequest,
64 ) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
65 let tx = self.tx.clone();
66 let channel_size = self.channel_size.load(Ordering::Relaxed);
67
68 async move {
69 let (in_flight, rx) = InFlight::new(req, channel_size);
70 tx.send(PubSubInstruction::Request(in_flight))
71 .map_err(|_| TransportErrorKind::backend_gone())?;
72 rx.await.map_err(|_| TransportErrorKind::backend_gone())?
73 }
74 }
75
76 pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
80 match req {
81 RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
82 RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
83 .map_ok(ResponsePacket::Batch)
84 .boxed(),
85 }
86 }
87
88 pub fn channel_size(&self) -> usize {
93 self.channel_size.load(Ordering::Relaxed)
94 }
95
96 pub fn set_channel_size(&self, channel_size: usize) {
101 debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
102 self.channel_size.store(channel_size, Ordering::Relaxed);
103 }
104}
105
106impl tower::Service<RequestPacket> for PubSubFrontend {
107 type Response = ResponsePacket;
108 type Error = TransportError;
109 type Future = TransportFut<'static>;
110
111 #[inline]
112 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113 let result =
114 if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
115 Poll::Ready(result)
116 }
117
118 #[inline]
119 fn call(&mut self, req: RequestPacket) -> Self::Future {
120 self.send_packet(req)
121 }
122}