use std::io::{self, Read, Write};
use std::num::NonZeroU8;
use std::time::Instant;
use crate::{DecodeError, Message};
#[derive(Debug, Clone, Copy)]
pub struct StreamConfig {
pub rx_buf_min_size: usize,
pub rx_buf_max_size: MaxMessageSizeMultiple,
pub tx_buf_min_size: usize,
pub tx_buf_max_size: MaxMessageSizeMultiple,
pub tx_timeout: std::time::Duration,
pub connect_timeout: std::time::Duration,
}
#[derive(Debug, Clone, Copy)]
pub struct MaxMessageSizeMultiple(pub NonZeroU8);
impl MaxMessageSizeMultiple {
pub fn compute<M: Message>(&self) -> usize {
(self.0.get() as usize) * M::MAX_SIZE
}
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
rx_buf_min_size: 4 * 1024,
rx_buf_max_size: MaxMessageSizeMultiple(NonZeroU8::new(1).unwrap()),
tx_buf_min_size: 4 * 1024,
tx_buf_max_size: MaxMessageSizeMultiple(NonZeroU8::new(2).unwrap()),
tx_timeout: std::time::Duration::from_secs(30),
connect_timeout: std::time::Duration::from_secs(5),
}
}
}
#[derive(Debug)]
pub struct MessageStream<T: Read + Write> {
config: StreamConfig,
stream: T,
rx_msg_buf: Vec<u8>,
tx_msg_buf: Vec<u8>,
tx_queue_points: queue_points::Queue,
ready: bool,
last_write: Instant,
}
#[derive(Debug)]
pub enum ReadError {
MalformedMessage,
EndOfStream,
Error(io::Error),
}
impl<T: Read + Write> MessageStream<T> {
pub fn new(stream: T, config: StreamConfig) -> Self {
Self {
stream,
rx_msg_buf: Vec::new(),
tx_msg_buf: Vec::new(),
tx_queue_points: Default::default(),
ready: false,
last_write: Instant::now(),
config,
}
}
#[must_use]
pub fn read<M: Message, F: Fn(M, usize)>(
&mut self,
rx_buf: &mut [u8],
on_msg: F,
) -> Result<bool, ReadError> {
let preexisting = !self.rx_msg_buf.is_empty();
let max_buf_size = self.config.rx_buf_max_size.compute::<M>();
let limit = (max_buf_size - self.rx_msg_buf.len()).min(rx_buf.len());
let (total_read, read_result) = {
let buffer = &mut rx_buf[..limit];
let mut total_read: usize = 0;
let result = loop {
match self.stream.read(&mut buffer[total_read..]) {
Ok(0) if buffer.len() == 0 => break Ok(true),
Ok(0) => break Err(ReadError::EndOfStream),
Ok(read @ 1..) => {
total_read += read;
if total_read == buffer.len() {
break Ok(true);
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => break Ok(false),
Err(err) => break Err(ReadError::Error(err)),
}
};
(total_read, result)
};
let _decode_has_more = if !preexisting {
let (consumed, result) = decode_from_buffer(&mut rx_buf[..total_read], on_msg)?;
if consumed < total_read {
self.rx_msg_buf
.extend_from_slice(&rx_buf[consumed..total_read]);
}
result
} else {
self.rx_msg_buf.extend_from_slice(&rx_buf[..total_read]);
let (consumed, result) = decode_from_buffer(&mut &mut self.rx_msg_buf[..], on_msg)?;
self.rx_msg_buf.drain(..consumed);
result
};
read_result
}
#[must_use]
pub fn write(&mut self, now: Instant) -> io::Result<bool> {
if !self.has_queued_data() {
return Ok(false);
}
loop {
match self.try_write(now) {
Ok(written) => {
let has_more = self.has_queued_data();
log::trace!("wrote out {written} bytes, has more: {}", has_more);
if !has_more {
break Ok(false);
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
log::trace!("write would block");
break Ok(self.has_queued_data());
}
Err(err) => break Err(err),
}
}
}
#[must_use]
pub fn queue_message<M: Message>(&mut self, message: &M) -> bool {
let size_hint = message.size_hint().unwrap_or_default();
if size_hint + self.tx_msg_buf.len() >= self.config.tx_buf_max_size.compute::<M>() {
false
} else {
let encoded = message.encode(&mut self.tx_msg_buf);
self.tx_queue_points.append(encoded);
true
}
}
pub fn is_write_stale(&self, now: Instant) -> bool {
self.tx_queue_points.first().is_some_and(|t| {
let timeout = self.config.tx_timeout;
(now - t > self.config.tx_timeout) && (now - self.last_write > timeout)
})
}
pub fn shrink_buffers(&mut self) {
fn shrink(v: &mut Vec<u8>, min: usize) {
if v.capacity() > min {
let shrink_to = 3 * (v.capacity() / 4);
v.shrink_to(min.max(shrink_to));
}
}
shrink(&mut self.rx_msg_buf, self.config.rx_buf_min_size);
shrink(&mut self.tx_msg_buf, self.config.tx_buf_min_size);
self.tx_queue_points.shrink();
}
fn try_write(&mut self, now: Instant) -> io::Result<usize> {
let written = self.stream.write(&self.tx_msg_buf)?;
self.tx_msg_buf.drain(..written);
self.stream.flush()?;
self.last_write = now;
self.tx_queue_points.mark_write(written);
Ok(written)
}
#[inline(always)]
pub fn has_queued_data(&self) -> bool {
!self.tx_msg_buf.is_empty()
}
}
fn decode_from_buffer<M: Message, F: Fn(M, usize)>(
buffer: &mut [u8],
on_msg: F,
) -> Result<(usize, bool), ReadError> {
let mut cursor: usize = 0;
loop {
match M::decode(&buffer[cursor..]) {
Ok((message, consumed)) => {
cursor += consumed;
on_msg(message, consumed);
}
Err(DecodeError::NotEnoughData) => {
break Ok((cursor, false)); }
Err(DecodeError::MalformedMessage) => {
break Err(ReadError::MalformedMessage);
}
}
}
}
impl MessageStream<mio::net::TcpStream> {
pub fn is_ready(&mut self) -> bool {
if !self.ready {
self.ready = self.stream.peer_addr().is_ok();
}
self.ready
}
pub fn shutdown(self) -> io::Result<()> {
self.stream.shutdown(std::net::Shutdown::Both)
}
pub fn take_error(&self) -> Option<io::Error> {
self.stream.take_error().ok().flatten()
}
pub fn as_source(&mut self) -> &mut impl mio::event::Source {
&mut self.stream
}
}
mod queue_points {
use std::collections::VecDeque;
use std::time::Instant;
#[derive(Debug)]
struct Point {
time: Instant,
left: usize,
}
#[derive(Debug, Default)]
pub struct Queue(VecDeque<Point>);
impl Queue {
pub fn mark_write(&mut self, n_written: usize) {
let mut n_bytes_left = n_written;
let mut n_pop = 0;
for q in &mut self.0 {
let q_written = n_bytes_left.min(q.left);
n_bytes_left -= q_written;
q.left -= q_written;
if q.left == 0 {
n_pop += 1;
}
if n_bytes_left == 0 {
break;
}
}
assert_eq!(n_bytes_left, 0);
self.0.drain(..n_pop);
}
pub fn append(&mut self, size: usize) {
self.0.push_back(Point {
time: Instant::now(),
left: size,
})
}
pub fn first(&self) -> Option<Instant> {
self.0.front().map(|p| p.time)
}
pub fn shrink(&mut self) {
if self.0.capacity() > 8 {
self.0.shrink_to(8.max(3 * (self.0.capacity() / 4)));
}
}
}
#[cfg(test)]
#[test]
fn queue_behavior() {
let mut queue = Queue::default();
queue.append(10);
queue.append(20);
queue.append(30);
assert_eq!(queue.0[0].left, 10);
assert_eq!(queue.0[1].left, 20);
assert_eq!(queue.0[2].left, 30);
queue.mark_write(5);
assert_eq!(queue.0[0].left, 5);
assert_eq!(queue.0[1].left, 20);
assert_eq!(queue.0[2].left, 30);
queue.mark_write(5);
assert_eq!(queue.0[0].left, 20);
assert_eq!(queue.0[1].left, 30);
queue.mark_write(25);
assert_eq!(queue.0[0].left, 25);
assert_eq!(queue.0.len(), 1);
queue.mark_write(25);
assert!(queue.first().is_none());
}
}
#[cfg(test)]
mod test {
use std::cell::RefCell;
use std::io::Cursor;
use super::*;
#[derive(Debug, Eq, PartialEq)]
struct Ping(u64);
impl Message for Ping {
const MAX_SIZE: usize = 8;
fn encode(&self, dest: &mut impl std::io::Write) -> usize {
dest.write(&self.0.to_le_bytes()).unwrap()
}
fn decode(buffer: &[u8]) -> Result<(Self, usize), DecodeError> {
if buffer.len() >= 8 {
Ok((Ping(u64::from_le_bytes(buffer[..8].try_into().unwrap())), 8))
} else {
Err(DecodeError::NotEnoughData)
}
}
}
#[test]
fn reassemble_message_whole_reads() {
let mut buf = [0; 1024];
let mut cursor = Cursor::new(Vec::new());
Ping(0).encode(&mut cursor);
Ping(1).encode(&mut cursor);
cursor.set_position(0);
let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
let received: RefCell<Vec<Ping>> = Default::default();
conn.read(&mut buf, |message, size| {
assert_eq!(size, 8);
received.borrow_mut().push(message);
})
.unwrap();
assert_eq!(received.borrow()[0], Ping(0));
conn.read(&mut buf, |message, size| {
assert_eq!(size, 8);
received.borrow_mut().push(message);
})
.unwrap();
assert_eq!(received.borrow()[1], Ping(1));
let err = conn.read(&mut buf, |message, size| {
assert_eq!(size, 8);
received.borrow_mut().push(message);
});
assert!(matches!(err, Err(ReadError::EndOfStream)));
assert_eq!(conn.stream.position(), 16);
assert!(conn.rx_msg_buf.is_empty());
}
#[test]
fn reassemble_message_partial_reads() {
let mut buf = [0; 8];
let mut cursor = Cursor::new(Vec::new());
let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
let mut serialized = Vec::new();
Ping(u64::MAX - 1).encode(&mut serialized);
Ping(u64::MAX).encode(&mut serialized);
let received: RefCell<Vec<Ping>> = Default::default();
conn.stream.get_mut().extend_from_slice(&serialized[..4]);
let _ = conn.read(&mut buf, |message, size| {
assert_eq!(size, 8);
received.borrow_mut().push(message);
});
assert!(received.borrow().is_empty());
assert_eq!(conn.rx_msg_buf.len(), 4);
conn.stream.get_mut().extend_from_slice(&serialized[4..]);
let _ = conn.read(&mut buf, |message, size| {
assert_eq!(size, 8);
received.borrow_mut().push(message);
});
assert_eq!(received.borrow()[0], Ping(u64::MAX - 1));
let _ = conn.read(&mut buf, |message, size| {
assert_eq!(size, 8);
received.borrow_mut().push(message);
});
assert_eq!(received.borrow()[1], Ping(u64::MAX));
}
#[test]
fn send_message() {
let mut wire = Cursor::new(Vec::<u8>::new());
let mut connection = MessageStream::new(
&mut wire,
StreamConfig {
tx_buf_max_size: MaxMessageSizeMultiple(3.try_into().unwrap()),
..Default::default()
},
);
assert!(connection.queue_message(&Ping(0)));
assert!(connection.queue_message(&Ping(1)));
assert!(connection.queue_message(&Ping(2)));
let cloned_buffer = connection.tx_msg_buf.clone();
connection.write(Instant::now()).unwrap();
assert_eq!(wire.position(), 24);
assert_eq!(wire.into_inner(), cloned_buffer);
}
#[test]
fn send_message_buf_full() {
let mut wire = Cursor::new(Vec::<u8>::new());
let config = StreamConfig {
tx_buf_min_size: 1,
tx_buf_max_size: MaxMessageSizeMultiple(1.try_into().unwrap()),
..Default::default()
};
let mut connection = MessageStream::new(&mut wire, config);
assert!(connection.queue_message(&Ping(0)));
assert!(!connection.queue_message(&Ping(1)));
let buffer_len = connection.tx_msg_buf.len();
let cloned_buffer = connection.tx_msg_buf.clone();
connection.write(Instant::now()).unwrap();
assert_eq!(wire.position(), buffer_len as u64);
assert_eq!(wire.into_inner(), cloned_buffer);
}
}