use crate::{
handle::ConnectionHandle,
ix::PubSubInstruction,
managers::{InFlight, RequestManager, SubscriptionManager},
PubSubConnect, PubSubFrontend, RawSubscription,
};
use alloy_json_rpc::{Id, PubSubItem, Request, Response, ResponsePayload, SubId};
use alloy_primitives::B256;
use alloy_transport::{
utils::{to_json_raw_value, Spawnable},
TransportErrorKind, TransportResult,
};
use serde_json::value::RawValue;
use tokio::sync::{mpsc, oneshot};
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
use wasmtimer::tokio::sleep;
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
use tokio::time::sleep;
#[derive(Debug)]
pub(crate) struct PubSubService<T> {
pub(crate) handle: ConnectionHandle,
pub(crate) connector: T,
pub(crate) reqs: mpsc::UnboundedReceiver<PubSubInstruction>,
pub(crate) subs: SubscriptionManager,
pub(crate) in_flights: RequestManager,
}
impl<T: PubSubConnect> PubSubService<T> {
pub(crate) async fn connect(connector: T) -> TransportResult<PubSubFrontend> {
let handle = connector.connect().await?;
let (tx, reqs) = mpsc::unbounded_channel();
let this = Self {
handle,
connector,
reqs,
subs: SubscriptionManager::default(),
in_flights: Default::default(),
};
this.spawn();
Ok(PubSubFrontend::new(tx))
}
async fn get_new_backend(&mut self) -> TransportResult<ConnectionHandle> {
let mut handle = self.connector.try_reconnect().await?;
std::mem::swap(&mut self.handle, &mut handle);
Ok(handle)
}
async fn reconnect(&mut self) -> TransportResult<()> {
debug!("Reconnecting pubsub service backend");
let mut old_handle = self.get_new_backend().await?;
debug!("Draining old backend to_handle");
while let Ok(item) = old_handle.from_socket.try_recv() {
self.handle_item(item)?;
}
old_handle.shutdown();
debug!(count = self.in_flights.len(), "Reissuing pending requests");
for (_, in_flight) in self.in_flights.iter() {
let msg = in_flight.request.serialized().to_owned();
self.dispatch_request(msg)?;
}
debug!(count = self.subs.len(), "Re-starting active subscriptions");
self.subs.drop_server_ids();
for (_, sub) in self.subs.iter() {
let req = sub.request().to_owned();
let (in_flight, _) = InFlight::new(req.clone(), sub.tx.receiver_count());
self.in_flights.insert(in_flight);
let msg = req.into_serialized();
self.dispatch_request(msg)?;
}
Ok(())
}
fn dispatch_request(&self, brv: Box<RawValue>) -> TransportResult<()> {
self.handle.to_socket.send(brv).map(drop).map_err(|_| TransportErrorKind::backend_gone())
}
fn service_request(&mut self, in_flight: InFlight) -> TransportResult<()> {
let brv = in_flight.request();
self.dispatch_request(brv.serialized().to_owned())?;
self.in_flights.insert(in_flight);
Ok(())
}
fn service_get_sub(&self, local_id: B256, tx: oneshot::Sender<Option<RawSubscription>>) {
let _ = tx.send(self.subs.get_subscription(local_id));
}
fn service_unsubscribe(&mut self, local_id: B256) -> TransportResult<()> {
if let Some(server_id) = self.subs.server_id_for(&local_id) {
let req = Request::new("eth_unsubscribe", Id::Number(1), [server_id]);
let brv = req.serialize().expect("no ser error").take_request();
self.dispatch_request(brv)?;
}
self.subs.remove_sub(local_id);
Ok(())
}
fn service_ix(&mut self, ix: PubSubInstruction) -> TransportResult<()> {
trace!(?ix, "servicing instruction");
match ix {
PubSubInstruction::Request(in_flight) => self.service_request(in_flight),
PubSubInstruction::GetSub(alias, tx) => {
self.service_get_sub(alias, tx);
Ok(())
}
PubSubInstruction::Unsubscribe(alias) => self.service_unsubscribe(alias),
}
}
fn handle_item(&mut self, item: PubSubItem) -> TransportResult<()> {
match item {
PubSubItem::Response(resp) => match self.in_flights.handle_response(resp) {
Some((server_id, in_flight)) => self.handle_sub_response(in_flight, server_id),
None => Ok(()),
},
PubSubItem::Notification(notification) => {
self.subs.notify(notification);
Ok(())
}
}
}
fn handle_sub_response(
&mut self,
in_flight: InFlight,
server_id: SubId,
) -> TransportResult<()> {
let request = in_flight.request;
let id = request.id().clone();
let sub = self.subs.upsert(request, server_id, in_flight.channel_size);
let ser_alias = to_json_raw_value(sub.local_id())?;
let _ =
in_flight.tx.send(Ok(Response { id, payload: ResponsePayload::Success(ser_alias) }));
Ok(())
}
async fn reconnect_with_retries(&mut self) -> TransportResult<()> {
let mut retry_count = 0;
let max_retries = self.handle.max_retries;
let interval = self.handle.retry_interval;
loop {
match self.reconnect().await {
Ok(()) => break Ok(()),
Err(e) => {
retry_count += 1;
if retry_count >= max_retries {
error!("Reconnect failed after {max_retries} attempts, shutting down: {e}");
break Err(e);
}
warn!(
"Reconnection attempt {retry_count}/{max_retries} failed: {e}. \
Retrying in {:?}s...",
interval.as_secs_f64(),
);
sleep(interval).await;
}
}
}
}
pub(crate) fn spawn(mut self) {
let fut = async move {
let result: TransportResult<()> = loop {
tokio::select! {
biased;
item_opt = self.handle.from_socket.recv() => {
if let Some(item) = item_opt {
if let Err(e) = self.handle_item(item) {
break Err(e)
}
} else if let Err(e) = self.reconnect_with_retries().await {
break Err(e)
}
}
_ = &mut self.handle.error => {
error!("Pubsub service backend error.");
if let Err(e) = self.reconnect_with_retries().await {
break Err(e)
}
}
req_opt = self.reqs.recv() => {
if let Some(req) = req_opt {
if let Err(err) = self.service_ix(req) {
if err
.as_transport_err()
.is_some_and(TransportErrorKind::is_backend_gone)
{
if let Err(e) = self.reconnect_with_retries().await {
break Err(e)
}
} else {
break Err(err)
}
}
} else {
info!("Pubsub service request channel closed. Shutting down.");
break Ok(())
}
}
}
};
if let Err(err) = result {
error!(%err, "pubsub service reconnection error");
}
};
fut.spawn_task();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ConnectionInterface;
use alloy_json_rpc::Request;
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use tokio::time::timeout;
#[derive(Clone, Debug, Default)]
struct MockConnect(Arc<Mutex<Option<ConnectionHandle>>>);
impl PubSubConnect for MockConnect {
fn is_local(&self) -> bool {
true
}
async fn connect(&self) -> TransportResult<ConnectionHandle> {
Err(TransportErrorKind::custom_str("connect is not used in this test"))
}
async fn try_reconnect(&self) -> TransportResult<ConnectionHandle> {
self.0
.lock()
.expect("poisoned mutex")
.take()
.ok_or_else(|| TransportErrorKind::custom_str("missing mock connection handle"))
}
}
#[tokio::test]
async fn reconnects_after_request_dispatch_hits_backend_gone() {
let (dead_handle, dead_interface) = ConnectionHandle::new();
let ConnectionInterface { from_frontend, to_frontend, error, shutdown } = dead_interface;
drop(from_frontend);
let _keep_dead_backend_alive = (to_frontend, error, shutdown);
let (reconnected_handle, mut reconnected_interface) = ConnectionHandle::new();
let connector = MockConnect(Arc::new(Mutex::new(Some(reconnected_handle))));
let (tx, reqs) = mpsc::unbounded_channel();
let service = PubSubService {
handle: dead_handle,
connector,
reqs,
subs: SubscriptionManager::default(),
in_flights: RequestManager::default(),
};
service.spawn();
let first = Request::new("eth_blockNumber", Id::Number(1), ()).serialize().unwrap();
let (in_flight, rx) = InFlight::new(first, 16);
tx.send(PubSubInstruction::Request(in_flight)).unwrap();
timeout(Duration::from_secs(1), rx)
.await
.expect("failed request should resolve promptly")
.expect_err("raced request should be dropped when the backend is gone");
let second = Request::new("eth_chainId", Id::Number(2), ()).serialize().unwrap();
let expected = second.serialized().get().to_owned();
let (in_flight, _rx) = InFlight::new(second, 16);
tx.send(PubSubInstruction::Request(in_flight)).unwrap();
let dispatched =
timeout(Duration::from_secs(1), reconnected_interface.recv_from_frontend())
.await
.expect("request should be dispatched after reconnect")
.expect("new backend should receive the request");
assert_eq!(dispatched.get(), expected);
}
}