use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{Future, FutureExt, StreamExt};
use msg_common::span::{EnterSpan as _, WithSpan};
use tokio_util::codec::Framed;
use tracing::Instrument;
use crate::{ConnOptions, ConnectionHookErased, ConnectionState, ExponentialBackoff, hooks};
use msg_transport::{Address, MeteredIo, Transport};
use msg_wire::reqrep;
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnectionError<T: std::error::Error> {
#[error("transport error: {0}")]
Transport(#[source] T),
#[error("connection hook error: {0}")]
Hook(#[source] Box<dyn std::error::Error + Send + Sync>),
}
impl<T: std::error::Error + From<std::io::Error>> ConnectionError<T> {
pub(crate) fn from_erased_hook(
err: hooks::Error<Box<dyn std::error::Error + Send + Sync>>,
) -> Self {
match err {
hooks::Error::Io(io_err) => Self::Transport(T::from(io_err)),
hooks::Error::Hook(e) => Self::Hook(e),
}
}
}
type ConnTask<Io, E> = Pin<Box<dyn Future<Output = Result<Io, ConnectionError<E>>> + Send>>;
pub(crate) type Conn<Io, S, A> = Framed<MeteredIo<Io, S, A>, reqrep::Codec>;
pub(crate) type ConnCtl<Io, S, A> = ConnectionState<Conn<Io, S, A>, ExponentialBackoff, A>;
pub(crate) struct ConnManager<T: Transport<A>, A: Address> {
options: ConnOptions,
conn_task: Option<WithSpan<ConnTask<T::Io, T::Error>>>,
conn_ctl: ConnCtl<T::Io, T::Stats, A>,
transport: T,
addr: A,
transport_stats: Arc<arc_swap::ArcSwap<T::Stats>>,
hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
span: tracing::Span,
}
impl<T, A> ConnManager<T, A>
where
T: Transport<A>,
A: Address,
{
pub(crate) fn new(
options: ConnOptions,
transport: T,
addr: A,
conn_ctl: ConnCtl<T::Io, T::Stats, A>,
transport_stats: Arc<arc_swap::ArcSwap<T::Stats>>,
hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
span: tracing::Span,
) -> Self {
Self { options, conn_task: None, conn_ctl, transport, addr, transport_stats, hook, span }
}
fn try_connect(&mut self) {
let connect = self.transport.connect(self.addr.clone());
let hook = self.hook.clone();
let task = async move {
let io = connect.await.map_err(ConnectionError::Transport)?;
let Some(hook) = hook else {
return Ok(io);
};
hook.on_connection(io).await.map_err(ConnectionError::from_erased_hook)
}
.in_current_span();
self.conn_task = Some(WithSpan::current(Box::pin(task)));
}
#[inline]
pub(crate) fn reset_connection(&mut self) {
self.conn_ctl = ConnectionState::Inactive {
addr: self.addr.clone(),
backoff: ExponentialBackoff::from(&self.options),
};
}
#[inline]
pub(crate) fn active_connection(&mut self) -> Option<&mut Conn<T::Io, T::Stats, A>> {
if let ConnectionState::Active { ref mut channel } = self.conn_ctl {
Some(channel)
} else {
None
}
}
#[allow(clippy::type_complexity)]
pub(crate) fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<&mut Conn<T::Io, T::Stats, A>>> {
loop {
if let Some(ref mut conn_task) = self.conn_task &&
let Poll::Ready(result) = conn_task.poll_unpin(cx).enter()
{
self.conn_task = None;
match result.inner {
Ok(io) => {
tracing::info!("connected");
let metered = MeteredIo::new(io, self.transport_stats.clone());
let framed = Framed::new(metered, reqrep::Codec::new());
self.conn_ctl = ConnectionState::Active { channel: framed };
}
Err(e) => {
tracing::error!(?e, "failed to connect");
}
}
}
if let ConnectionState::Inactive { backoff, .. } = &mut self.conn_ctl {
let Poll::Ready(item) = backoff.poll_next_unpin(cx) else {
return Poll::Pending;
};
let _span = tracing::info_span!(parent: &self.span, "connect").entered();
if let Some(duration) = item {
if self.conn_task.is_none() {
tracing::debug!(backoff = ?duration, "trying connection");
self.try_connect();
} else {
tracing::debug!(
backoff = ?duration,
"not retrying as there is already a connection task"
);
}
} else {
tracing::error!("exceeded maximum number of retries, terminating connection");
return Poll::Ready(None);
}
}
if let ConnectionState::Active { ref mut channel } = self.conn_ctl {
return Poll::Ready(Some(channel));
}
}
}
}