use crate::{ix::PubSubInstruction, managers::InFlight, RawSubscription};
use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest};
use alloy_primitives::B256;
use alloy_transport::{TransportError, TransportErrorKind, TransportFut, TransportResult};
use futures::{future::try_join_all, FutureExt, TryFutureExt};
use std::{
future::Future,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{Context, Poll},
};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, debug_span, Instrument};
#[derive(Debug, Clone)]
pub struct PubSubFrontend {
tx: mpsc::UnboundedSender<PubSubInstruction>,
channel_size: Arc<AtomicUsize>,
}
impl PubSubFrontend {
pub fn new(tx: mpsc::UnboundedSender<PubSubInstruction>) -> Self {
Self { tx, channel_size: Arc::new(AtomicUsize::new(16)) }
}
pub fn get_subscription(
&self,
id: B256,
) -> impl Future<Output = TransportResult<RawSubscription>> + Send + 'static {
let backend_tx = self.tx.clone();
async move {
let (tx, rx) = oneshot::channel();
backend_tx
.send(PubSubInstruction::GetSub(id, tx))
.map_err(|_| TransportErrorKind::backend_gone())?;
rx.await
.map_err(|_| TransportErrorKind::backend_gone())?
.map_or_else(|| Err(TransportErrorKind::custom_str("subscription not found")), Ok)
}
}
pub fn unsubscribe(&self, id: B256) -> TransportResult<()> {
self.tx
.send(PubSubInstruction::Unsubscribe(id))
.map_err(|_| TransportErrorKind::backend_gone())
}
pub fn send(
&self,
req: SerializedRequest,
) -> impl Future<Output = TransportResult<Response>> + Send + 'static {
let tx = self.tx.clone();
let channel_size = self.channel_size.load(Ordering::Relaxed);
let method_name = req.method_clone();
async move {
debug!("sending request to backend");
let (in_flight, rx) = InFlight::new(req, channel_size);
tx.send(PubSubInstruction::Request(in_flight))
.map_err(|_| TransportErrorKind::backend_gone())?;
let resp = rx.await.map_err(|_| TransportErrorKind::backend_gone())?;
if tracing::enabled!(tracing::Level::TRACE) {
trace!(?resp, "retrieved response");
} else {
debug!(resp=?resp.as_ref().map(|_| ()), "retrieved response");
};
resp
}
.instrument(debug_span!("request", %method_name))
}
pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> {
match req {
RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(),
RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req)))
.map_ok(ResponsePacket::Batch)
.boxed(),
}
}
pub fn channel_size(&self) -> usize {
self.channel_size.load(Ordering::Relaxed)
}
pub fn set_channel_size(&self, channel_size: usize) {
debug_assert_ne!(channel_size, 0, "channel size must be non-zero");
self.channel_size.store(channel_size, Ordering::Relaxed);
}
}
impl tower::Service<RequestPacket> for PubSubFrontend {
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let result =
if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) };
Poll::Ready(result)
}
#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.send_packet(req)
}
}