use std::io::{self, Read};
use crate::MAX_PACKET_LEN;
pub const DEFAULT_TRANSPORT_READ_LIMIT: usize = 1024 * 1024;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TransportErrorCode {
OversizedInput,
InvalidFrame,
UnexpectedEof,
Timeout,
Io,
}
impl TransportErrorCode {
#[must_use]
pub const fn code(self) -> &'static str {
match self {
Self::OversizedInput => "transport.oversized_input",
Self::InvalidFrame => "transport.invalid_frame",
Self::UnexpectedEof => "transport.unexpected_eof",
Self::Timeout => "transport.timeout",
Self::Io => "transport.io",
}
}
}
pub trait PacketSource {
type Error;
fn recv_packets(&mut self) -> Result<Vec<Vec<u8>>, Self::Error>;
}
pub trait PacketSink {
type Error;
fn send_packet(&mut self, packet: &[u8]) -> Result<(), Self::Error>;
}
pub fn read_all_with_limit(mut reader: impl Read, max_bytes: usize) -> io::Result<Vec<u8>> {
let mut input = Vec::new();
let mut limited = (&mut reader).take(max_bytes.saturating_add(1) as u64);
limited.read_to_end(&mut input)?;
if input.len() > max_bytes {
return Err(oversized_input_error());
}
Ok(input)
}
#[must_use]
pub fn oversized_input_error() -> io::Error {
io::Error::new(
io::ErrorKind::InvalidData,
TransportErrorCode::OversizedInput.code(),
)
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct LineTransport<'a> {
input: &'a [u8],
}
impl<'a> LineTransport<'a> {
#[must_use]
pub fn new(input: &'a [u8]) -> Self {
Self { input }
}
#[must_use]
pub fn packets(&self) -> Vec<&'a [u8]> {
self.input
.split(|byte| *byte == b'\n')
.map(trim_trailing_carriage_return)
.filter(|line| !line.is_empty())
.collect()
}
pub fn packets_with_limit(&self, max_packet_len: usize) -> io::Result<Vec<&'a [u8]>> {
let mut packets = Vec::new();
for line in self
.input
.split(|byte| *byte == b'\n')
.map(trim_trailing_carriage_return)
.filter(|line| !line.is_empty())
{
if line.len() > max_packet_len {
return Err(oversized_input_error());
}
packets.push(line);
}
Ok(packets)
}
}
impl PacketSource for LineTransport<'_> {
type Error = io::Error;
fn recv_packets(&mut self) -> Result<Vec<Vec<u8>>, Self::Error> {
Ok(self
.packets_with_limit(MAX_PACKET_LEN)?
.into_iter()
.map(<[u8]>::to_vec)
.collect())
}
}
impl PacketSink for Vec<Vec<u8>> {
type Error = io::Error;
fn send_packet(&mut self, packet: &[u8]) -> Result<(), Self::Error> {
self.push(packet.to_vec());
Ok(())
}
}
fn trim_trailing_carriage_return(line: &[u8]) -> &[u8] {
line.strip_suffix(b"\r").unwrap_or(line)
}