pipenet 0.2.4

Non blocking tcp stream wrapper using channels
Documentation
use crate::metrics::Measurement;
use crate::packs::PackUnpack;

use super::*;

use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use mio::{Events, Interest, Poll, Token};
use std::fmt::Display;
use std::io::Cursor;
use std::mem;

// A wrapper to distinguish when to yield the thread or throw the error.
enum ShortCircuit {
    Yield,
    Err(std::io::Error),
    PacksError(Box<dyn std::error::Error>),
}

impl Display for ShortCircuit {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ShortCircuit::Yield => write!(f, "Yielded"),
            ShortCircuit::Err(error) => error.fmt(f),
            ShortCircuit::PacksError(error) => error.fmt(f),
        }
    }
}

impl From<std::io::Error> for ShortCircuit {
    fn from(value: std::io::Error) -> Self {
        ShortCircuit::Err(value)
    }
}

// The message header has a version and length.
//
// The version is unused for now, but later allow the user to specify a min and
// max version of compatibility of the type and otherwise discard any message
// that is not in that range.
struct MessageHeader {
    version: u16,
    size: u64,
}

impl MessageHeader {
    fn from_slice(version: u16, bytes: &[u8]) -> Self {
        Self {
            version,
            size: bytes.len() as u64,
        }
    }
}

impl From<MessageHeader> for [u8; 10] {
    fn from(value: MessageHeader) -> Self {
        let mut buf = std::io::Cursor::new(Vec::new());
        buf.write_u16::<BigEndian>(value.version).unwrap();
        buf.write_u64::<BigEndian>(value.size).unwrap();
        buf.get_ref().as_slice().try_into().unwrap()
    }
}

impl From<[u8; 10]> for MessageHeader {
    fn from(value: [u8; 10]) -> Self {
        let mut c = Cursor::new(&value);
        let version = c.read_u16::<BigEndian>().unwrap();
        let size = c.read_u64::<BigEndian>().unwrap();
        Self { version, size }
    }
}

pub(crate) struct StreamLooper {
    versions: Versions,
    max_size: Option<NonZero<usize>>,
    packs: Packs,
    stream: mio::net::TcpStream,
    tx_reader: Sender<Vec<u8>>,
    rx_writer: Receiver<Vec<u8>>,
    tx_term: Sender<std::io::Error>,
    reading: bool,
    read_version: u16,
    read_buf: Vec<u8>,
    read_pos: usize,
    read_target: usize,
    writing: bool,
    write_buf: Vec<u8>,
    write_pos: usize,
    write_target: usize,
    metrics_tx: Sender<Measurement>,
}

impl Drop for StreamLooper {
    fn drop(&mut self) {
        let _ = self.stream.shutdown(std::net::Shutdown::Both);
    }
}

impl StreamLooper {
    #[allow(clippy::too_many_arguments)]
    pub(crate) fn new(
        versions: Versions,
        max_size: Option<NonZero<usize>>,
        encapsulations: Packs,
        stream: TcpStream,
        tx_reader: Sender<Vec<u8>>,
        rx_writer: Receiver<Vec<u8>>,
        tx_term: Sender<std::io::Error>,
        metrics_tx: Sender<Measurement>,
    ) -> Self {
        Self {
            versions,
            max_size,
            packs: encapsulations,
            stream: mio::net::TcpStream::from_std(stream),
            tx_reader,
            rx_writer,
            tx_term,
            reading: false,
            read_version: 0,
            read_target: 0,
            read_buf: Vec::new(),
            read_pos: 0,
            writing: false,
            write_buf: Vec::new(),
            write_target: 0,
            write_pos: 0,
            metrics_tx,
        }
    }

    pub(crate) fn stream_loop(mut self) {
        let e = self.loop_until_error();
        let _ = self.tx_term.send(e);
    }

    fn loop_until_error(&mut self) -> std::io::Error {
        let mut events = Events::with_capacity(1024);
        let mut poll = match Poll::new() {
            Ok(p) => p,
            Err(e) => return e,
        };
        if let Err(e) = poll.registry().register(
            &mut self.stream,
            Token(0),
            Interest::READABLE | Interest::WRITABLE,
        ) {
            return e;
        };
        loop {
            match self.try_process_buffers() {
                Ok(_) => {}
                Err(ShortCircuit::Err(e)) => return e,
                Err(ShortCircuit::PacksError(_)) => return std::io::Error::other("packe error"),
                Err(ShortCircuit::Yield) => {
                    // We just care to get one event, no matter if for read
                    // or write as the process does both in a row anway.
                    while events.is_empty() {
                        if let Err(e) = poll.poll(&mut events, None) {
                            return std::io::Error::new(ErrorKind::ConnectionAborted, e);
                        };
                    }
                }
            }
        }
    }

    fn try_process_buffers(&mut self) -> Result<(), ShortCircuit> {
        let read_res = self.read();
        let write_res = self.write();

        // This only yields if both read and write yield.
        if let Err(ShortCircuit::Yield) = read_res
            && let Err(ShortCircuit::Yield) = write_res
        {
            return Err(ShortCircuit::Yield);
        }

        // Either or errors get thrown up
        if let Err(ShortCircuit::Err(e)) = read_res {
            return Err(ShortCircuit::Err(e));
        }
        if let Err(ShortCircuit::Err(e)) = write_res {
            return Err(ShortCircuit::Err(e));
        }

        // All other cases are consider as 'continue'
        Ok(())
    }

    fn read(&mut self) -> Result<(), ShortCircuit> {
        if !self.reading {
            self.read_start()?;
        } else {
            self.read_continue()?;
        }

        Ok(())
    }

    fn read_start(&mut self) -> Result<(), ShortCircuit> {
        // See if there is a header ready first, skip otherwise.
        let Some(header) = self.check_for_header()? else {
            return Ok(());
        };

        // Must commit to read the buffer now
        let size = header.size as usize;
        if let Some(max_size) = self.max_size
            && size > max_size.get()
        {
            return Err(std::io::Error::new(
                ErrorKind::ConnectionAborted,
                "max packet size exceeded",
            )
            .into());
        }
        let mut buf = Vec::new();
        if buf.try_reserve(size).is_err() {
            // Kill the pipe if a stream requests an unreasonable or unfeasible
            // allocation.
            return Err(
                std::io::Error::new(ErrorKind::ConnectionAborted, "failed to allocate").into(),
            );
        }
        buf.resize(size, 0);

        self.reading = true;
        self.read_version = header.version;
        self.read_buf = buf;
        self.read_target = size;
        self.read_pos = 0;

        self.read_continue()?;

        Ok(())
    }

    // This is a continuation of piping into the read buffer for MessageData
    fn read_continue(&mut self) -> Result<(), ShortCircuit> {
        let buf = &mut self.read_buf[self.read_pos..self.read_target];
        let op = self.stream.read(buf);
        let count = match op {
            Ok(n) => n,
            Err(e) => match e.kind() {
                ErrorKind::WouldBlock => return Err(ShortCircuit::Yield),
                _ => return Err(e.into()),
            },
        };
        let _ = self.metrics_tx.send(Measurement::Received(count));
        self.read_pos += count;

        if self.read_pos == self.read_target {
            self.read_end()?;
        }

        Ok(())
    }

    // Send a message taking the current read buffer marking read end.
    fn read_end(&mut self) -> Result<(), ShortCircuit> {
        let version = self.read_version;
        self.reading = false;
        self.read_target = 0;
        self.read_pos = 0;
        self.read_version = 0;
        let mut buf = Vec::new();
        mem::swap(&mut buf, &mut self.read_buf);
        // Discard unsupported versions.
        if version > self.versions.max || version < self.versions.min {
            return Ok(());
        }
        if !self.packs.is_empty() {
            buf = self.packs.unpack(&buf).map_err(ShortCircuit::PacksError)?;
        }
        let _ = self.tx_reader.send(buf);
        Ok(())
    }

    // This function will return None if there are not enough bytes to read.
    fn check_for_header(&mut self) -> Result<Option<MessageHeader>, ShortCircuit> {
        let mut buf = [0; 10];
        match self.stream.peek(&mut buf) {
            Ok(read) => {
                if read == 10 {
                    // Just peeked same amount, this should not fail.
                    let _ = self.stream.read_exact(&mut buf);
                    return Ok(Some(buf.into()));
                }
                // Not enough bytes yet, mark as empty to be called again later.
                Ok(None)
            }
            Err(e) => match e.kind() {
                // Avoid blocking and just return empty instead
                ErrorKind::WouldBlock => Err(ShortCircuit::Yield),
                _ => Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
            },
        }
    }

    fn write(&mut self) -> Result<(), ShortCircuit> {
        if !self.writing {
            self.write_start()?;
        } else {
            self.write_continue()?;
        }

        Ok(())
    }

    fn write_start(&mut self) -> Result<(), ShortCircuit> {
        let fetch = self.rx_writer.try_recv();
        let mut msg = match fetch {
            Ok(msg) => msg,
            Err(e) => match e {
                TryRecvError::Empty => return Err(ShortCircuit::Yield),
                TryRecvError::Disconnected => {
                    return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into());
                }
            },
        };
        if !self.packs.is_empty() {
            msg = self
                .packs
                .pack(msg.as_slice())
                .map_err(ShortCircuit::PacksError)?;
        }
        self.write_buf = msg;
        self.writing = true;
        self.write_pos = 0;
        self.write_target = self.write_buf.len();

        self.write_header()?;
        self.write_continue()?;

        Ok(())
    }

    fn write_header(&mut self) -> Result<(), ShortCircuit> {
        let header: [u8; 10] = MessageHeader::from_slice(self.versions.cur, &self.write_buf).into();
        self.write_all_blocking(&header)?;
        Ok(())
    }

    // This is the only real blocking operation in the whole module
    fn write_all_blocking(&mut self, mut buf: &[u8]) -> Result<(), ShortCircuit> {
        while !buf.is_empty() {
            match self.stream.write(buf) {
                Ok(0) => {
                    return Err(std::io::Error::new(ErrorKind::BrokenPipe, "").into());
                }
                Ok(n) => buf = &buf[n..],
                Err(e) => match e.kind() {
                    ErrorKind::WouldBlock => continue,
                    _ => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
                },
            }
        }
        Ok(())
    }

    fn write_continue(&mut self) -> Result<(), ShortCircuit> {
        let buf = &self.write_buf[self.write_pos..self.write_target];
        let op = self.stream.write(buf);
        let count = match op {
            Ok(n) => n,
            Err(e) => match e.kind() {
                ErrorKind::WouldBlock => return Err(ShortCircuit::Yield),
                _ => return Err(std::io::Error::new(ErrorKind::ConnectionAborted, e).into()),
            },
        };
        let _ = self.metrics_tx.send(Measurement::Sent(count));
        self.write_pos += count;

        if self.write_pos == self.write_target {
            self.write_end();
        }

        Ok(())
    }

    fn write_end(&mut self) {
        self.writing = false;
        self.write_pos = 0;
        self.write_target = 0;
        self.write_buf = Vec::new();
    }
}