use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll, ready},
};
use bytes::Bytes;
use futures::{Future, FutureExt, SinkExt, Stream, StreamExt, stream::FuturesUnordered};
use msg_common::span::{EnterSpan, SpanExt as _, WithSpan};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{
mpsc,
oneshot::{self, error::RecvError},
},
task::JoinSet,
time::Interval,
};
use tokio_stream::{StreamMap, StreamNotifyClose};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, trace, warn};
use crate::{ConnectionHookErased, RepOptions, Request, hooks, rep::SocketState};
use msg_transport::{Address, PeerAddress, Transport};
use msg_wire::{
compression::{Compressor, try_decompress_payload},
reqrep,
};
use super::RepError;
const PING: &[u8; 4] = b"PING";
const PONG: &[u8; 4] = b"PONG";
pub(crate) struct PeerState<T: AsyncRead + AsyncWrite, A: Address> {
pending_requests: FuturesUnordered<WithSpan<PendingRequest>>,
conn: Framed<T, reqrep::Codec>,
linger_timer: Option<Interval>,
write_buffer_size: usize,
addr: A,
pending_egress: Option<WithSpan<reqrep::Message>>,
max_pending_responses: usize,
state: Arc<SocketState>,
compressor: Option<Arc<dyn Compressor>>,
span: tracing::Span,
}
#[allow(clippy::type_complexity)]
pub(crate) struct RepDriver<T: Transport<A>, A: Address> {
pub(crate) transport: T,
pub(crate) state: Arc<SocketState>,
#[allow(unused)]
pub(crate) options: Arc<RepOptions>,
pub(crate) peer_states: StreamMap<A, StreamNotifyClose<PeerState<T::Io, A>>>,
pub(crate) to_socket: mpsc::Sender<Request<A>>,
pub(crate) hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
pub(crate) compressor: Option<Arc<dyn Compressor>>,
pub(crate) conn_tasks: FuturesUnordered<WithSpan<T::Accept>>,
pub(crate) hook_tasks: JoinSet<WithSpan<hooks::ErasedHookResult<(T::Io, A)>>>,
pub(crate) control_rx: mpsc::Receiver<T::Control>,
pub(crate) span: tracing::Span,
}
impl<T, A> Future for RepDriver<T, A>
where
T: Transport<A>,
A: Address,
{
type Output = Result<(), RepError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
if let Poll::Ready(Some((peer, maybe_result))) = this.peer_states.poll_next_unpin(cx) {
let Some(result) = maybe_result.enter() else {
debug!(?peer, "peer disconnected");
this.state.stats.specific.decrement_active_clients();
continue;
};
match result.inner {
Ok(mut request) => {
debug!("received request");
let size = request.msg().len();
match try_decompress_payload(request.compression_type, request.msg) {
Ok(decompressed) => request.msg = decompressed,
Err(e) => {
debug!(?e, "failed to decompress message");
continue;
}
}
this.state.stats.specific.increment_rx(size);
if let Err(e) = this.to_socket.try_send(request) {
error!(?e, ?peer, "failed to send to socket, dropping request");
};
}
Err(e) => {
if e.is_connection_reset() {
trace!(?peer, "connection reset")
} else {
error!(?e, ?peer, "failed to receive message from peer");
}
}
}
continue;
}
if let Poll::Ready(Some(Ok(hook_result))) = this.hook_tasks.poll_join_next(cx).enter() {
match hook_result.inner {
Ok((stream, addr)) => {
info!(?addr, "connection hook passed");
let conn = Framed::new(stream, reqrep::Codec::new());
let span = tracing::info_span!(parent: this.span.clone(), "peer", ?addr);
let linger_timer = this.options.write_buffer_linger.map(|duration| {
let mut timer = tokio::time::interval(duration);
timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
timer
});
this.peer_states.insert(
addr.clone(),
StreamNotifyClose::new(PeerState {
span,
pending_requests: FuturesUnordered::new(),
conn,
linger_timer,
write_buffer_size: this.options.write_buffer_size,
addr,
pending_egress: None,
max_pending_responses: this.options.max_pending_responses,
state: Arc::clone(&this.state),
compressor: this.compressor.clone(),
}),
);
}
Err(e) => {
debug!(?e, "connection hook failed");
this.state.stats.specific.decrement_active_clients();
}
}
continue;
}
if let Poll::Ready(Some(conn)) = this.conn_tasks.poll_next_unpin(cx).enter() {
match conn.inner {
Ok(io) => {
if let Err(e) = this.on_accepted_connection(io) {
error!(?e, "failed to handle accepted connection");
this.state.stats.specific.decrement_active_clients();
}
}
Err(e) => {
debug!(?e, "failed to accept incoming connection");
this.state.stats.specific.decrement_active_clients();
}
}
continue;
}
if let Poll::Ready(Some(cmd)) = this.control_rx.poll_recv(cx) {
this.transport.on_control(cmd);
}
if let Poll::Ready(accept) = Pin::new(&mut this.transport).poll_accept(cx) {
let span = this.span.clone().entered();
let active_clients = this.state.stats.specific.active_clients();
if this.options.max_clients.is_some_and(|max| active_clients >= max) {
warn!(
active_clients,
"max connections reached, rejecting new incoming connection",
);
continue;
}
this.state.stats.specific.increment_active_clients();
this.conn_tasks.push(accept.with_span(span));
continue;
}
return Poll::Pending;
}
}
}
impl<T, A> RepDriver<T, A>
where
T: Transport<A>,
A: Address,
{
fn on_accepted_connection(&mut self, io: T::Io) -> Result<(), io::Error> {
let addr = io.peer_addr()?;
info!(?addr, "new connection");
let linger_timer = self.options.write_buffer_linger.map(|duration| {
let mut timer = tokio::time::interval(duration);
timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
timer
});
let Some(ref hook) = self.hook else {
self.peer_states.insert(
addr.clone(),
StreamNotifyClose::new(PeerState {
span: tracing::info_span!("peer", ?addr),
pending_requests: FuturesUnordered::new(),
conn: Framed::new(io, reqrep::Codec::new()),
linger_timer,
write_buffer_size: self.options.write_buffer_size,
addr,
pending_egress: None,
max_pending_responses: self.options.max_pending_responses,
state: Arc::clone(&self.state),
compressor: self.compressor.clone(),
}),
);
return Ok(());
};
let hook = Arc::clone(hook);
let span = tracing::info_span!("connection_hook", ?addr);
let fut = async move {
let stream = hook.on_connection(io).await?;
Ok((stream, addr))
};
self.hook_tasks.spawn(fut.with_span(span));
Ok(())
}
}
impl<T: AsyncRead + AsyncWrite + Unpin, A: Address> PeerState<T, A> {
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let pending_msg = self.pending_egress.take();
let buffer_size = self.conn.write_buffer().len();
if pending_msg.is_none() && buffer_size == 0 {
debug!("flushed everything, closing connection");
return Poll::Ready(());
}
debug!(has_pending = ?pending_msg.is_some(), write_buffer_size = ?buffer_size, "found data to send");
if let Some(msg) = pending_msg &&
let Err(e) = self.conn.start_send_unpin(msg.inner)
{
error!(?e, "failed to send final message to socket, closing");
return Poll::Ready(());
}
if let Err(e) = ready!(self.conn.poll_flush_unpin(cx)) {
error!(?e, "failed to flush on shutdown, giving up");
}
Poll::Ready(())
}
}
impl<T: AsyncRead + AsyncWrite + Unpin, A: Address + Unpin> Stream for PeerState<T, A> {
type Item = WithSpan<Result<Request<A>, RepError>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(msg) = this.pending_egress.take().enter() {
let msg_len = msg.size();
debug!(msg_id = msg.id(), "sending response");
match this.conn.start_send_unpin(msg.inner) {
Ok(_) => {
this.state.stats.specific.increment_tx(msg_len);
continue;
}
Err(e) => {
this.state.stats.specific.increment_failed_requests();
error!(err = ?e, peer = ?this.addr, "failed to send message to socket, closing...");
return Poll::Ready(None);
}
}
}
if this.conn.write_buffer().len() >= this.write_buffer_size {
if let Poll::Ready(Err(e)) = this.conn.poll_flush_unpin(cx) {
error!(err = ?e, peer = ?this.addr, "failed to flush connection, closing...");
return Poll::Ready(None);
}
if let Some(ref mut linger_timer) = this.linger_timer {
linger_timer.reset();
}
}
if let Some(ref mut linger_timer) = this.linger_timer &&
!this.conn.write_buffer().is_empty() &&
linger_timer.poll_tick(cx).is_ready() &&
let Poll::Ready(Err(e)) = this.conn.poll_flush_unpin(cx)
{
error!(err = ?e, peer = ?this.addr, "failed to flush connection, closing...");
return Poll::Ready(None);
}
if this.pending_egress.is_none() &&
let Poll::Ready(Some(result)) = this.pending_requests.poll_next_unpin(cx).enter()
{
match result.inner {
Err(_) => tracing::error!("response channel closed unexpectedly"),
Ok(Response { msg_id, mut response }) => {
let mut compression_type = 0;
let len_before = response.len();
if let Some(ref compressor) = this.compressor {
match compressor.compress(&response) {
Ok(compressed) => {
response = compressed;
compression_type = compressor.compression_type() as u8;
}
Err(e) => {
error!(?e, "failed to compress message");
continue;
}
}
debug!(
msg_id,
len_before,
len_after = response.len(),
"compressed message"
)
}
debug!(msg_id, "received response to send");
let msg = reqrep::Message::new(msg_id, compression_type, response);
this.pending_egress = Some(msg.with_span(result.span));
continue;
}
}
}
let under_hwm = this.pending_requests.len() < this.max_pending_responses;
if under_hwm {
let _g = this.span.clone().entered();
match this.conn.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => {
let span = tracing::info_span!("request").entered();
trace!(?result, "received message");
let msg = match result {
Ok(m) => m,
Err(e) => {
return Poll::Ready(Some(Err(e.into()).with_span(span.clone())));
}
};
if msg.payload().as_ref() == PING {
debug!("received ping healthcheck, responding pong");
let msg = reqrep::Message::new(0, 0, PONG.as_ref().into());
if let Err(e) = this.conn.start_send_unpin(msg) {
error!(?e, "failed to send pong response");
}
continue;
}
let (tx, rx) = oneshot::channel();
this.pending_requests.push(
PendingRequest { msg_id: msg.id(), response: rx }
.with_span(span.clone()),
);
let request = Request {
source: this.addr.clone(),
response: tx,
compression_type: msg.header().compression_type(),
msg: msg.into_payload(),
};
return Poll::Ready(Some(Ok(request).with_span(span)));
}
Poll::Ready(None) => {
debug!("framed closed, sending and flushing leftover data if any");
if this.poll_shutdown(cx).is_ready() {
return Poll::Ready(None);
}
}
Poll::Pending => {}
}
} else {
trace!(
hwm = this.max_pending_responses,
pending = this.pending_requests.len(),
"at high-water mark, not polling from underlying connection until responses drain"
);
}
return Poll::Pending;
}
}
}
struct PendingRequest {
msg_id: u32,
response: oneshot::Receiver<Bytes>,
}
struct Response {
msg_id: u32,
response: Bytes,
}
impl Future for PendingRequest {
type Output = Result<Response, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.response.poll_unpin(cx) {
Poll::Ready(Ok(response)) => {
Poll::Ready(Ok(Response { msg_id: self.msg_id, response }))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}