1use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
2use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
3use alloy_primitives::B256;
4use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
5use futures::{future::try_join_all, FutureExt, TryFutureExt};
6use std::{
7    future::Future,
8    sync::{
9        atomic::{AtomicUsize, Ordering},
10        Arc,
11    },
12    task::{Context, Poll},
13};
14use tokio::sync::{mpsc, oneshot};
15
16#[derive(Debug, Clone)]
21pub struct PubSubFrontend {
22    tx: mpsc::UnboundedSender<PubSubInstruction>,
23    channel_size: Arc<AtomicUsize>,
26}
27
28impl PubSubFrontend {
29    pub fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
31        Self { tx, channel_size: Arc::new(AtomicUsize::new(16)) }
32    }
33
34    pub fn get_subscription(
36        &self,
37        id: B256,
38    ) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
39        let backend_tx = self.tx.clone();
40        async move {
41            let (tx, rx) = oneshot::channel();
42            backend_tx
43                .send(PubSubInstruction::GetSub(id, tx))
44                .map_err(|_| TransportErrorKind::backend_gone())?;
45            rx.await.map_err(|_| TransportErrorKind::backend_gone())
46        }
47    }
48
49    pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
51        self.tx
52            .send(PubSubInstruction::Unsubscribe(id))
53            .map_err(|_| TransportErrorKind::backend_gone())
54    }
55
56    pub fn send(
58        &self,
59        req: SerializedRequest,
60    ) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
61        let tx = self.tx.clone();
62        let channel_size = self.channel_size.load(Ordering::Relaxed);
63
64        async move {
65            let (in_flight, rx) = InFlight::new(req, channel_size);
66            tx.send(PubSubInstruction::Request(in_flight))
67                .map_err(|_| TransportErrorKind::backend_gone())?;
68            rx.await.map_err(|_| TransportErrorKind::backend_gone())?
69        }
70    }
71
72    pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
76        match req {
77            RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
78            RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
79                .map_ok(ResponsePacket::Batch)
80                .boxed(),
81        }
82    }
83
84    pub fn channel_size(&self) -> usize {
89        self.channel_size.load(Ordering::Relaxed)
90    }
91
92    pub fn set_channel_size(&self, channel_size: usize) {
97        debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
98        self.channel_size.store(channel_size, Ordering::Relaxed);
99    }
100}
101
102impl tower::Service<RequestPacket> for PubSubFrontend {
103    type Response = ResponsePacket;
104    type Error = TransportError;
105    type Future = TransportFut<'static>;
106
107    #[inline]
108    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        let result =
110            if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
111        Poll::Ready(result)
112    }
113
114    #[inline]
115    fn call(&mut self, req: RequestPacket) -> Self::Future {
116        self.send_packet(req)
117    }
118}