use std::{
collections::VecDeque,
future::Future,
io::IoSlice,
pin::{pin, Pin},
task::{Context, Poll},
};
use bytes::Buf;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::{
encoding::Codec,
interrupted,
message_reactor::{MessageReactor, ReactorStatus},
would_block, Decoder, DeserializeError, Encoder,
};
pub struct Connection<
TStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static,
TCodec: Codec,
TReactor: MessageReactor<Inbound = <TCodec as Decoder>::Message, Outbound = <TCodec as Encoder>::Message>,
> {
stream: TStream,
outbound_messages: spillway::Receiver<TReactor::LogicalOutbound>,
send_buffer: VecDeque<<TCodec as Encoder>::Serialized>,
receive_buffer_unread_index: usize,
receive_buffer: Vec<u8>,
max_buffer_length: usize,
max_queued_send_messages: usize,
buffer_allocation_increment: usize,
codec: TCodec,
reactor: TReactor,
}
impl<
TStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static,
TCodec: Codec,
TReactor: MessageReactor<
Inbound = <TCodec as Decoder>::Message,
Outbound = <TCodec as Encoder>::Message,
>,
> std::fmt::Display for Connection<TStream, TCodec, TReactor>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let read_end = self.receive_buffer_unread_index;
let read_capacity = self.receive_buffer.len();
let write_queue = self.send_buffer.len();
let write_length: usize = self.send_buffer.iter().map(|b| b.remaining()).sum();
write!(f, "Connection: {{read{{end: {read_end}, capacity: {read_capacity}}}, write{{queue: {write_queue}, length: {write_length}}} }}")
}
}
impl<
TStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static,
TCodec: Codec,
TReactor: MessageReactor<
Inbound = <TCodec as Decoder>::Message,
Outbound = <TCodec as Encoder>::Message,
>,
> Unpin for Connection<TStream, TCodec, TReactor>
{
}
impl<
TStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
TCodec: Codec,
TReactor: MessageReactor<
Inbound = <TCodec as Decoder>::Message,
Outbound = <TCodec as Encoder>::Message,
>,
> Future for Connection<TStream, TCodec, TReactor>
{
type Output = ();
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
log::trace!("polling receive");
if self.as_mut().poll_receive(context).is_ready() {
return Poll::Ready(());
}
let structurally_pinned_reactor =
unsafe { self.as_mut().map_unchecked_mut(|me| &mut me.reactor) };
if structurally_pinned_reactor.poll(context).is_break() {
log::debug!("reactor requested disconnect");
return Poll::Ready(());
}
log::trace!("polling write");
match self.poll_writev_buffers(context) {
Ok(false) => {
log::trace!("write stream is empty or registered for wake when writable");
}
Ok(true) => {
log::debug!("write stream closed");
return Poll::Ready(());
}
Err(e) => {
log::warn!("error while writing to tcp stream: {e:?}");
return Poll::Ready(());
}
}
Poll::Pending
}
}
impl<
TStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static,
TCodec: Codec,
TReactor: MessageReactor<
Inbound = <TCodec as Decoder>::Message,
Outbound = <TCodec as Encoder>::Message,
>,
> Drop for Connection<TStream, TCodec, TReactor>
{
fn drop(&mut self) {
log::debug!("connection dropped")
}
}
#[derive(Debug)]
enum ReadBufferState {
Pending,
MoreToRead,
Disconnected,
Error(std::io::Error),
}
impl<
TStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
TCodec: Codec,
TReactor: MessageReactor<
Inbound = <TCodec as Decoder>::Message,
Outbound = <TCodec as Encoder>::Message,
>,
> Connection<TStream, TCodec, TReactor>
{
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
pub fn new(
stream: TStream,
codec: TCodec,
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_send_messages: usize,
outbound_messages: spillway::Receiver<TReactor::LogicalOutbound>,
reactor: TReactor,
) -> Self {
Self {
stream,
outbound_messages,
send_buffer: Default::default(),
receive_buffer: Vec::new(),
max_buffer_length,
max_queued_send_messages,
receive_buffer_unread_index: 0,
buffer_allocation_increment,
codec,
reactor,
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn poll_read_inbound(&mut self, context: &mut Context<'_>) -> ReadBufferState {
if self.receive_buffer.len() < self.max_buffer_length
&& self.receive_buffer.len() - self.receive_buffer_unread_index
< self.buffer_allocation_increment
{
self.receive_buffer.resize(
self.receive_buffer.len() + self.buffer_allocation_increment,
0,
);
}
if 0 < self.receive_buffer.len() - self.receive_buffer_unread_index {
self.poll_read_from_stream(context)
} else {
log::debug!("receive is full {self}");
ReadBufferState::MoreToRead
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn read_inbound_messages_and_react(&mut self) -> ReadBufferState {
let mut buffer_cursor = 0;
let state = loop {
if buffer_cursor == self.receive_buffer_unread_index {
break ReadBufferState::Pending;
} else if self.receive_buffer_unread_index < buffer_cursor {
break ReadBufferState::Error(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"buffer cursor is beyond the end of the receive buffer. Deserializer must not consume more than the buffer length",
));
}
let buffer = &self.receive_buffer[buffer_cursor..self.receive_buffer_unread_index];
log::trace!("decode {buffer:?}");
match self.codec.decode(buffer) {
Ok((length, message)) => {
buffer_cursor += length;
if self.reactor.on_inbound_message(message) == ReactorStatus::Disconnect {
log::debug!("reactor requested disconnect");
return ReadBufferState::Disconnected;
}
}
Err(e) => match e {
DeserializeError::IncompleteBuffer { next_message_size } => {
if self.max_buffer_length < next_message_size {
log::error!("tried to receive message that is too long. Resetting connection - max: {}, requested: {}", self.max_buffer_length, next_message_size);
return ReadBufferState::Disconnected;
}
log::debug!("waiting for the next message of length {next_message_size}");
break ReadBufferState::Pending;
}
DeserializeError::InvalidBuffer => {
log::error!("message was invalid - broken stream");
return ReadBufferState::Disconnected;
}
DeserializeError::SkipMessage { distance } => {
if self.receive_buffer_unread_index - buffer_cursor < distance {
log::trace!("cannot skip yet, need to read more. Skipping: {distance}, remaining:{}", self.receive_buffer_unread_index - buffer_cursor);
break ReadBufferState::Pending;
}
log::debug!("skipping message of length {distance}");
buffer_cursor += distance;
}
},
};
};
if buffer_cursor != 0 && buffer_cursor == self.receive_buffer_unread_index {
log::trace!("read buffer complete - resetting: {self}");
self.receive_buffer_unread_index = 0;
} else if buffer_cursor != 0 {
log::trace!("read buffer partially consumed - shifting: {self}");
self.receive_buffer
.copy_within(buffer_cursor..self.receive_buffer_unread_index, 0);
self.receive_buffer_unread_index -= buffer_cursor;
}
state
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn poll_read_from_stream(&mut self, context: &mut Context<'_>) -> ReadBufferState {
let mut buffer = ReadBuf::new(&mut self.receive_buffer[self.receive_buffer_unread_index..]);
match pin!(&mut self.stream).poll_read(context, &mut buffer) {
Poll::Ready(Ok(_)) => {
let distance = buffer.filled().len();
if distance == 0 {
log::debug!("read 0 bytes, stream is closed");
ReadBufferState::Disconnected
} else {
self.receive_buffer_unread_index += distance;
log::trace!(
"read from stream: {distance}b, total: {}b",
self.receive_buffer_unread_index
);
ReadBufferState::MoreToRead
}
}
Poll::Ready(Err(ref err)) if would_block(err) => {
log::debug!("read everything. No longer readable");
ReadBufferState::Pending
}
Poll::Ready(Err(ref err)) if interrupted(err) => {
log::trace!("interrupted, so try again later");
ReadBufferState::MoreToRead
}
Poll::Ready(Err(err)) => {
log::warn!("error while reading from tcp stream: {err:?}");
ReadBufferState::Error(err)
}
Poll::Pending => {
log::debug!("pending on read stream");
ReadBufferState::Pending
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn poll_serialize_outbound_messages(&mut self, context: &mut Context<'_>) -> Poll<()> {
let max_outbound = self.max_queued_send_messages - self.send_buffer.len();
if max_outbound == 0 {
log::debug!("send is full: {self}");
return Poll::Pending;
}
let start_len = self.send_buffer.len();
for _ in 0..max_outbound {
let message = match self.outbound_messages.poll_next(context) {
Poll::Pending => {
log::debug!("no more messages to serialize, and we are pending for more");
break;
}
Poll::Ready(None) => {
log::info!("outbound message channel was closed");
return Poll::Ready(());
}
Poll::Ready(Some(next)) => next,
};
let message = self.reactor.on_outbound_message(message);
let buffer = self.codec.encode(message);
log::trace!(
"serialized message and enqueueing outbound buffer: {}b",
buffer.remaining()
);
self.send_buffer.push_back(buffer);
}
let new_len = self.send_buffer.len();
if start_len != new_len {
log::debug!("serialized {} messages", new_len - start_len);
}
Poll::Pending
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn poll_writev_buffers(
&mut self,
context: &mut Context<'_>,
) -> std::result::Result<bool, std::io::Error> {
loop {
if self.poll_serialize_outbound_messages(context).is_ready() {
log::debug!("outbound channel closed");
return Ok(true);
}
break if self.send_buffer.is_empty() {
log::debug!("send buffer is empty");
Ok(false)
} else {
const UIO_MAXIOV: usize = 256;
let mut stack_buffers = [IoSlice::new(&[]); UIO_MAXIOV];
let mut filled = 0;
for buf in self.send_buffer.iter() {
if UIO_MAXIOV <= filled {
break;
}
let n = buf.chunks_vectored(&mut stack_buffers[filled..]);
if n == 0 {
break;
}
filled += n;
}
let buffers = &stack_buffers[..filled];
#[cfg(feature = "tracing")]
let span = tracing::span!(tracing::Level::INFO, "writing", buffers = buffers.len());
#[cfg(feature = "tracing")]
let span_guard = span.enter();
let poll = pin!(&mut self.stream).poll_write_vectored(context, buffers);
#[cfg(feature = "tracing")]
drop(span_guard);
match poll {
Poll::Pending => {
log::debug!("writev not ready - waiting for wake");
Ok(false)
}
Poll::Ready(Ok(0)) => {
log::info!("write stream was closed");
Ok(true)
}
Poll::Ready(Ok(written)) => {
log::debug!("writev sent {written}");
self.advance_send_buffers(written);
continue;
}
Poll::Ready(Err(ref err)) if would_block(err) => {
log::debug!("would block - no longer writable");
continue;
}
Poll::Ready(Err(ref err)) if interrupted(err) => {
log::debug!("write interrupted - try again later");
continue;
}
Poll::Ready(Err(err)) => {
log::warn!(
"error while writing to tcp stream: {err:?}, buffers: {}, {}b: {:?}",
buffers.len(),
buffers.iter().map(|b| b.len()).sum::<usize>(),
buffers.iter().map(|b| b.len()).collect::<Vec<_>>()
);
Err(err)
}
}
};
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
fn advance_send_buffers(&mut self, total_written: usize) {
let mut written = total_written;
while 0 < written {
if let Some(mut front) = self.send_buffer.pop_front() {
let remaining = front.remaining();
if remaining <= written {
written -= remaining;
log::trace!("returning consumed buffer after sending final {remaining}b");
self.codec.return_buffer(front);
} else {
log::debug!("after writing {total_written}b, advancing partially written buffer of {remaining}b by {written}b");
front.advance(written);
self.send_buffer.push_front(front);
break;
}
} else {
log::error!("rotated all buffers but {written} bytes unaccounted for");
break;
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
fn poll_receive(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<()> {
loop {
match self.poll_read_inbound(context) {
ReadBufferState::Pending => {
log::debug!("consumed all that I can from the read stream for now {self}");
return Poll::Pending;
}
ReadBufferState::MoreToRead => {
log::debug!("more to read");
self.read_inbound_messages_and_react();
continue;
}
ReadBufferState::Disconnected => {
log::info!("read connection closed");
return Poll::Ready(());
}
ReadBufferState::Error(e) => {
log::warn!("error while reading from tcp stream: {e:?}");
return Poll::Ready(());
}
}
}
}
}