use crate::frame::FrameError;
use crate::region::{AllocOrd, Push, RegionBuffer, RegionError, Seed, REGION_SIZE};
use alloc::vec::Vec;
use core::mem::MaybeUninit;
#[derive(Debug, thiserror::Error)]
pub enum WriteError {
#[error("Requested buffer out of bounds")]
OutOfBounds,
#[cfg(feature = "std")]
#[error(transparent)]
Io(#[from] std::io::Error),
}
pub type WriteResult<T> = Result<T, WriteError>;
pub trait Writer {
fn allocate(&mut self, n: usize) -> WriteResult<&mut [MaybeUninit<u8>]>;
unsafe fn commit(&mut self, n: usize);
}
impl<T> Writer for &mut T
where
T: Writer,
{
fn allocate(&mut self, n: usize) -> WriteResult<&mut [MaybeUninit<u8>]> {
(**self).allocate(n)
}
unsafe fn commit(&mut self, n: usize) {
unsafe { (**self).commit(n); }
}
}
impl Writer for &mut [u8] {
fn allocate(&mut self, n: usize) -> WriteResult<&mut [MaybeUninit<u8>]> {
let slice = self.get_mut(n..).ok_or(WriteError::OutOfBounds)?;
Ok(unsafe {
core::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len())
})
}
#[inline]
unsafe fn commit(&mut self, n: usize) {
*self = unsafe {
core::mem::take(self).get_unchecked_mut(n..)
}
}
}
impl Writer for &mut [MaybeUninit<u8>] {
#[inline]
fn allocate(&mut self, n: usize) -> WriteResult<&mut [MaybeUninit<u8>]> {
self.get_mut(n..).ok_or(WriteError::OutOfBounds)
}
#[inline]
unsafe fn commit(&mut self, n: usize) {
*self = unsafe {
core::mem::take(self).get_unchecked_mut(n..)
}
}
}
impl Writer for Vec<u8> {
fn allocate(&mut self, n: usize) -> WriteResult<&mut [MaybeUninit<u8>]> {
self.reserve(n);
let slice = self.spare_capacity_mut();
Ok(&mut slice[..n])
}
#[inline]
unsafe fn commit(&mut self, n: usize) {
unsafe {
self.set_len(self.len() + n);
}
}
}
#[inline]
pub(crate) fn write_to_uninit(src: &[u8], dst: &mut [MaybeUninit<u8>]) {
if src.len() > dst.len() {
panic!("Attempt to write with overflow")
}
unsafe {
core::ptr::copy_nonoverlapping(src.as_ptr().cast(), dst.as_mut_ptr(), src.len());
}
}
mod private {
pub trait Sealed {}
impl<W: super::Writer> Sealed for super::StreamEncoder<W> {}
}
pub trait EncodeBearer: private::Sealed {
fn write(&mut self, bytes: &[u8]) -> Result<(), RegionError>;
}
pub struct StreamEncoder<W>
where
W: Writer,
{
region_buffer: RegionBuffer,
authority: Push<W>,
}
impl<W> StreamEncoder<W>
where
W: Writer,
{
pub fn new<E>(dst: W, seed: Seed, ord: AllocOrd<E>) -> Self
where
E: Encode,
{
Self {
region_buffer: RegionBuffer::new(ord.cap()),
authority: Push::new(dst, seed),
}
}
pub fn relocate(&mut self, src: W) -> Result<(), RegionError> {
self.relocate_with_seed(src, self.authority.seed())
}
pub fn relocate_with_seed(&mut self, src: W, seed: Seed) -> Result<(), RegionError> {
self.region_buffer.pass(&mut self.authority)?;
self.authority = Push::new(src, seed);
Ok(())
}
pub fn flush(&mut self) -> Result<(), RegionError> {
self.region_buffer.pass(&mut self.authority)?;
Ok(())
}
}
impl<W> EncodeBearer for StreamEncoder<W>
where
W: Writer,
{
fn write(&mut self, bytes: &[u8]) -> Result<(), RegionError> {
if bytes.len() > REGION_SIZE {
return Err(RegionError::OutOfBounds);
}
if self.region_buffer.remaining_cap() < bytes.len() {
self.region_buffer.pass(&mut self.authority)?;
self.region_buffer.swap();
}
unsafe {
self.region_buffer.write_nonoverlapping(bytes);
}
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
#[error("Frame encoding error: \"{0}\"")]
FrameError(#[from] FrameError),
#[error("Encode bearer error: \"{0}\"")]
BearerError(#[from] RegionError),
#[error("{0}")]
Other(&'static str),
}
pub trait Encode {
fn encode(bearer: &mut impl EncodeBearer, src: &Self) -> Result<(), EncodeError>;
fn size_of(&self) -> usize;
}