use crate::metrics::Measurement;
use crate::packs::PackUnpack;
use super::*;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use mio::{Events, Interest, Poll, Token};
use std::fmt::Display;
use std::io::Cursor;
use std::mem;
enum ShortCircuit {
Yield,
Err(std::io::Error),
PacksError(Box<dyn std::error::Error>),
}
impl Display for ShortCircuit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShortCircuit::Yield => write!(f, "Yielded"),
ShortCircuit::Err(error) => error.fmt(f),
ShortCircuit::PacksError(error) => error.fmt(f),
}
}
}
impl From<std::io::Error> for ShortCircuit {
fn from(value: std::io::Error) -> Self {
ShortCircuit::Err(value)
}
}
struct MessageHeader {
version: u16,
size: u64,
}
impl MessageHeader {
fn from_slice(version: u16, bytes: &[u8]) -> Self {
Self {
version,
size: bytes.len() as u64,
}
}
}
impl From<MessageHeader> for [u8; 10] {
fn from(value: MessageHeader) -> Self {
let mut buf = std::io::Cursor::new(Vec::new());
buf.write_u16::<BigEndian>(value.version).unwrap();
buf.write_u64::<BigEndian>(value.size).unwrap();
buf.get_ref().as_slice().try_into().unwrap()
}
}
impl From<[u8; 10]> for MessageHeader {
fn from(value: [u8; 10]) -> Self {
let mut c = Cursor::new(&value);
let version = c.read_u16::<BigEndian>().unwrap();
let size = c.read_u64::<BigEndian>().unwrap();
Self { version, size }
}
}
pub(crate) struct StreamLooper {
versions: Versions,
max_size: Option<NonZero<usize>>,
packs: Packs,
stream: mio::net::TcpStream,
tx_reader: Sender<Vec<u8>>,
rx_writer: Receiver<Vec<u8>>,
tx_term: Sender<std::io::Error>,
reading: bool,
read_version: u16,
read_buf: Vec<u8>,
read_pos: usize,
read_target: usize,
writing: bool,
write_buf: Vec<u8>,
write_pos: usize,
write_target: usize,
metrics_tx: Sender<Measurement>,
}
impl Drop for StreamLooper {
fn drop(&mut self) {
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
}
impl StreamLooper {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
versions: Versions,
max_size: Option<NonZero<usize>>,
encapsulations: Packs,
stream: TcpStream,
tx_reader: Sender<Vec<u8>>,
rx_writer: Receiver<Vec<u8>>,
tx_term: Sender<std::io::Error>,
metrics_tx: Sender<Measurement>,
) -> Self {
Self {
versions,
max_size,
packs: encapsulations,
stream: mio::net::TcpStream::from_std(stream),
tx_reader,
rx_writer,
tx_term,
reading: false,
read_version: 0,
read_target: 0,
read_buf: Vec::new(),
read_pos: 0,
writing: false,
write_buf: Vec::new(),
write_target: 0,
write_pos: 0,
metrics_tx,
}
}
pub(crate) fn stream_loop(mut self) {
let e = self.loop_until_error();
let _ = self.tx_term.send(e);
}
fn loop_until_error(&mut self) -> std::io::Error {
let mut events = Events::with_capacity(1024);
let mut poll = match Poll::new() {
Ok(p) => p,
Err(e) => return e,
};
if let Err(e) = poll.registry().register(
&mut self.stream,
Token(0),
Interest::READABLE | Interest::WRITABLE,
) {
return e;
};
loop {
match self.try_process_buffers() {
Ok(_) => {}
Err(ShortCircuit::Err(e)) => return e,
Err(ShortCircuit::PacksError(_)) => return std::io::Error::other("packe error"),
Err(ShortCircuit::Yield) => {
while events.is_empty() {
if let Err(e) = poll.poll(&mut events, None) {
return std::io::Error::new(ErrorKind::ConnectionAborted, e);
};
}
}
}
}
}
fn try_process_buffers(&mut self) -> Result<(), ShortCircuit> {
let read_res = self.read();
let write_res = self.write();
if let Err(ShortCircuit::Yield) = read_res
&& let Err(ShortCircuit::Yield) = write_res
{
return Err(ShortCircuit::Yield);
}
if let Err(ShortCircuit::Err(e)) = read_res {
return Err(ShortCircuit::Err(e));
}
if let Err(ShortCircuit::Err(e)) = write_res {
return Err(ShortCircuit::Err(e));
}
Ok(())
}
fn read(&mut self) -> Result<(), ShortCircuit> {
if !self.reading {
self.read_start()?;
} else {
self.read_continue()?;
}
Ok(())
}
fn read_start(&mut self) -> Result<(), ShortCircuit> {
let Some(header) = self.check_for_header()? else {
return Ok(());
};
let size = header.size as usize;
if let Some(max_size) = self.max_size
&& size > max_size.get()
{
return Err(std::io::Error::new(
ErrorKind::ConnectionAborted,
"max packet size exceeded",
)
.into());
}
let mut buf = Vec::new();
if buf.try_reserve(size).is_err() {
return Err(
std::io::Error::new(ErrorKind::ConnectionAborted, "failed to allocate").into(),
);
}
buf.resize(size, 0);
self.reading = true;
self.read_version = header.version;
self.read_buf = buf;
self.read_target = size;
self.read_pos = 0;
self.read_continue()?;
Ok(())
}
fn read_continue(&mut self) -> Result<(), ShortCircuit> {
let buf = &mut self.read_buf[self.read_pos..self.read_target];
let op = self.stream.read(buf);
let count = match op {
Ok(n) => n,
Err(e) => match e.kind() {
ErrorKind::WouldBlock => return Err(ShortCircuit::Yield),
_ => return Err(e.into()),
},
};
let _ = self.metrics_tx.send(Measurement::Received(count));
self.read_pos += count;
if self.read_pos == self.read_target {
self.read_end()?;
}
Ok(())
}
fn read_end(&mut self) -> Result<(), ShortCircuit> {
let version = self.read_version;
self.reading = false;
self.read_target = 0;
self.read_pos = 0;
self.read_version = 0;
let mut buf = Vec::new();
mem::swap(&mut buf, &mut self.read_buf);
if version > self.versions.max || version < self.versions.min {
return Ok(());
}
if !self.packs.is_empty() {
buf = self.packs.unpack(&buf).map_err(ShortCircuit::PacksError)?;
}
let _ = self.tx_reader.send(buf);
Ok(())
}
fn check_for_header(&mut self) -> Result<Option<MessageHeader>, ShortCircuit> {
let mut buf = [0; 10];
match self.stream.peek(&mut buf) {
Ok(read) => {
if read == 10 {
let _ = self.stream.read_exact(&mut buf);
return Ok(Some(buf.into()));
}
Ok(None)
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => Err(ShortCircuit::Yield),
_ => Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
},
}
}
fn write(&mut self) -> Result<(), ShortCircuit> {
if !self.writing {
self.write_start()?;
} else {
self.write_continue()?;
}
Ok(())
}
fn write_start(&mut self) -> Result<(), ShortCircuit> {
let fetch = self.rx_writer.try_recv();
let mut msg = match fetch {
Ok(msg) => msg,
Err(e) => match e {
TryRecvError::Empty => return Err(ShortCircuit::Yield),
TryRecvError::Disconnected => {
return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into());
}
},
};
if !self.packs.is_empty() {
msg = self
.packs
.pack(msg.as_slice())
.map_err(ShortCircuit::PacksError)?;
}
self.write_buf = msg;
self.writing = true;
self.write_pos = 0;
self.write_target = self.write_buf.len();
self.write_header()?;
self.write_continue()?;
Ok(())
}
fn write_header(&mut self) -> Result<(), ShortCircuit> {
let header: [u8; 10] = MessageHeader::from_slice(self.versions.cur, &self.write_buf).into();
self.write_all_blocking(&header)?;
Ok(())
}
fn write_all_blocking(&mut self, mut buf: &[u8]) -> Result<(), ShortCircuit> {
while !buf.is_empty() {
match self.stream.write(buf) {
Ok(0) => {
return Err(std::io::Error::new(ErrorKind::BrokenPipe, "").into());
}
Ok(n) => buf = &buf[n..],
Err(e) => match e.kind() {
ErrorKind::WouldBlock => continue,
_ => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
},
}
}
Ok(())
}
fn write_continue(&mut self) -> Result<(), ShortCircuit> {
let buf = &self.write_buf[self.write_pos..self.write_target];
let op = self.stream.write(buf);
let count = match op {
Ok(n) => n,
Err(e) => match e.kind() {
ErrorKind::WouldBlock => return Err(ShortCircuit::Yield),
_ => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
},
};
let _ = self.metrics_tx.send(Measurement::Sent(count));
self.write_pos += count;
if self.write_pos == self.write_target {
self.write_end();
}
Ok(())
}
fn write_end(&mut self) {
self.writing = false;
self.write_pos = 0;
self.write_target = 0;
self.write_buf = Vec::new();
}
}