pub trait IoWrite {
type Error: core::error::Error;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error>;
}
#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)]
pub enum WError {
BufferFull,
}
impl core::fmt::Display for WError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
WError::BufferFull => write!(f, "Buffer is full"),
}
}
}
impl core::error::Error for WError {}
pub struct SliceWriter<'a> {
buf: &'a mut [u8],
cursor: usize,
}
impl<'a> SliceWriter<'a> {
pub fn new(buf: &'a mut [u8]) -> Self {
Self { buf, cursor: 0 }
}
fn len(&self) -> usize {
self.buf.len() - self.cursor
}
}
impl IoWrite for SliceWriter<'_> {
type Error = WError;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
if self.len() >= buf.len() {
let to = &mut self.buf[self.cursor..self.cursor + buf.len()];
to.copy_from_slice(buf);
self.cursor += buf.len();
Ok(())
} else {
Err(WError::BufferFull)
}
}
}
#[cfg(all(not(test), not(feature = "std")))]
impl IoWrite for &mut [u8] {
type Error = WError;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
let this = core::mem::take(self);
let (written, rest) = this
.split_at_mut_checked(buf.len())
.ok_or(WError::BufferFull)?;
written.copy_from_slice(buf);
*self = rest;
Ok(())
}
}
#[cfg(all(not(test), feature = "alloc", not(feature = "std")))]
mod alloc_without_std {
use super::{IoWrite, vec_writer::VecRefWriter};
impl IoWrite for alloc::vec::Vec<u8> {
type Error = core::convert::Infallible;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
VecRefWriter::new(self).write(buf)
}
}
impl IoWrite for &mut alloc::vec::Vec<u8> {
type Error = core::convert::Infallible;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
VecRefWriter::new(self).write(buf)
}
}
}
#[cfg(feature = "alloc")]
mod vec_writer {
use super::IoWrite;
pub struct VecRefWriter<'a> {
vec: &'a mut alloc::vec::Vec<u8>,
}
impl<'a> VecRefWriter<'a> {
pub fn new(vec: &'a mut alloc::vec::Vec<u8>) -> Self {
Self { vec }
}
}
impl IoWrite for VecRefWriter<'_> {
type Error = core::convert::Infallible;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
self.vec.extend_from_slice(buf);
Ok(())
}
}
}
#[cfg(feature = "alloc")]
pub use vec_writer::VecRefWriter;
#[cfg(any(test, feature = "std"))]
impl<W> IoWrite for W
where
W: std::io::Write,
{
type Error = std::io::Error;
fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
self.write_all(buf)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Reference<'de, 'a> {
Borrowed(&'de [u8]),
Copied(&'a [u8]),
}
impl Reference<'_, '_> {
pub const fn as_bytes(&self) -> &[u8] {
match self {
Reference::Borrowed(b) => b,
Reference::Copied(b) => b,
}
}
}
impl PartialEq<[u8]> for Reference<'_, '_> {
fn eq(&self, other: &[u8]) -> bool {
self.as_bytes() == other
}
}
pub trait IoRead<'de> {
type Error: core::error::Error + 'static;
fn read_slice<'a>(&'a mut self, len: usize) -> Result<Reference<'de, 'a>, Self::Error>;
}
pub struct SliceReader<'de> {
cursor: &'de [u8],
}
impl<'de> SliceReader<'de> {
pub fn new(buf: &'de [u8]) -> Self {
Self { cursor: buf }
}
pub fn rest(&self) -> &'de [u8] {
self.cursor
}
}
#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)]
pub enum RError {
BufferEmpty,
}
impl core::fmt::Display for RError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
RError::BufferEmpty => write!(f, "Buffer is empty"),
}
}
}
impl core::error::Error for RError {}
impl<'de> IoRead<'de> for SliceReader<'de> {
type Error = RError;
#[inline]
fn read_slice<'a>(&'a mut self, len: usize) -> Result<Reference<'de, 'a>, Self::Error> {
let (read, rest) = self
.cursor
.split_at_checked(len)
.ok_or(RError::BufferEmpty)?;
self.cursor = rest;
Ok(Reference::Borrowed(read))
}
}
#[cfg(feature = "alloc")]
mod iter_reader {
use crate::io::RError;
use super::IoRead;
pub struct IterReader<I> {
it: I,
buf: alloc::vec::Vec<u8>,
}
impl<I> IterReader<I>
where
I: Iterator<Item = u8>,
{
pub fn new(it: I) -> Self {
Self {
it: it.into_iter(),
buf: alloc::vec::Vec::new(),
}
}
}
impl<'de, I> IoRead<'de> for IterReader<I>
where
I: Iterator<Item = u8>,
{
type Error = RError;
fn read_slice<'a>(
&'a mut self,
len: usize,
) -> Result<super::Reference<'de, 'a>, Self::Error> {
self.buf.clear();
if self.buf.capacity() < len {
self.buf.reserve(len - self.buf.capacity());
}
self.buf.extend(self.it.by_ref().take(len));
if self.buf.len() != len {
return Err(RError::BufferEmpty);
};
Ok(super::Reference::Copied(&self.buf[..len]))
}
}
}
#[cfg(feature = "alloc")]
pub use iter_reader::IterReader;
#[cfg(feature = "std")]
mod std_reader {
use super::IoRead;
pub struct StdReader<R> {
reader: R,
buf: std::vec::Vec<u8>,
}
impl<R> StdReader<R>
where
R: std::io::Read,
{
pub fn new(reader: R) -> Self {
Self {
reader,
buf: std::vec::Vec::new(),
}
}
}
impl<'de, R> IoRead<'de> for StdReader<R>
where
R: std::io::Read,
{
type Error = std::io::Error;
fn read_slice<'a>(
&'a mut self,
len: usize,
) -> Result<super::Reference<'de, 'a>, Self::Error> {
if self.buf.len() < len {
self.buf.resize(len, 0);
};
self.reader.read_exact(&mut self.buf[..len])?;
Ok(super::Reference::Copied(&self.buf[..len]))
}
}
}
#[cfg(feature = "std")]
pub use std_reader::StdReader;
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn buffer_full() {
let buf: &mut [u8] = &mut [0u8];
let mut writer = SliceWriter::new(buf);
writer.write(&[1, 2]).unwrap();
}
#[test]
fn slice_reader_reads_and_advances() {
let input: &[u8] = &[1, 2, 3, 4, 5];
let mut reader = SliceReader::new(input);
{
let a = reader.read_slice(2).expect("read 2 bytes");
assert_eq!(a.as_bytes(), &[1, 2]);
}
let b = reader.read_slice(3).expect("read 3 bytes");
assert_eq!(b.as_bytes(), &[3, 4, 5]);
assert_eq!(reader.rest(), &[]);
}
#[test]
fn slice_reader_returns_error_on_overshoot() {
let input: &[u8] = &[10, 20];
let mut reader = SliceReader::new(input);
let first = reader.read_slice(2).expect("read 2 bytes");
assert_eq!(first.as_bytes(), &[10, 20]);
assert!(matches!(reader.read_slice(1), Err(RError::BufferEmpty)));
}
#[cfg(feature = "alloc")]
#[test]
fn iter_reader_reads_exact_length() {
let it = [7u8, 8, 9, 10].into_iter();
let mut reader = IterReader::new(it);
{
let part1 = reader.read_slice(3).expect("read 3 bytes");
assert_eq!(part1.as_bytes(), &[7, 8, 9]);
}
let part2 = reader.read_slice(1).expect("read 1 byte");
assert_eq!(part2.as_bytes(), &[10]);
}
#[cfg(feature = "alloc")]
#[test]
fn iter_reader_returns_error_when_insufficient() {
let it = [1u8, 2].into_iter();
let mut reader = IterReader::new(it);
assert!(matches!(reader.read_slice(3), Err(RError::BufferEmpty)));
}
}