use super::error::CapacityExceeded;
use bytes::buf::UninitSlice;
use bytes::BufMut;
use core::mem;
use core::mem::MaybeUninit;
pub struct SafeBytesSlice<'a> {
slice: &'a mut [MaybeUninit<u8>],
bytes_written: usize,
cap_exceeded: bool,
}
impl<'a> From<&'a mut [u8]> for SafeBytesSlice<'a> {
fn from(slice: &'a mut [u8]) -> Self {
let maybe_uninit_slice =
unsafe { &mut *(&mut *slice as *mut [u8] as *mut [mem::MaybeUninit<u8>]) };
Self::from(maybe_uninit_slice)
}
}
impl<'a> From<&'a mut [MaybeUninit<u8>]> for SafeBytesSlice<'a> {
fn from(slice: &'a mut [MaybeUninit<u8>]) -> Self {
Self {
slice,
bytes_written: 0,
cap_exceeded: false,
}
}
}
impl<'a> SafeBytesSlice<'a> {
#[allow(dead_code)]
fn bytes_written(&self) -> usize {
debug_assert_eq!(self.cap_exceeded, false);
self.bytes_written
}
#[inline]
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.bytes_written() == 0
}
pub fn try_into_bytes(self) -> Result<&'a [u8], CapacityExceeded> {
if self.is_exceed() {
Err(CapacityExceeded {})
} else {
Ok(unsafe {
&*(&self.slice[..self.bytes_written] as *const [core::mem::MaybeUninit<u8>]
as *const [u8])
})
}
}
pub fn is_exceed(&self) -> bool {
self.cap_exceeded
}
}
unsafe impl<'a> BufMut for SafeBytesSlice<'a> {
fn remaining_mut(&self) -> usize {
debug_assert!(self.bytes_written <= self.slice.len());
self.slice.len() - self.bytes_written
}
unsafe fn advance_mut(&mut self, cnt: usize) {
let new_bytes_written = self.bytes_written + cnt;
if new_bytes_written > self.slice.len() {
self.bytes_written = self.slice.len(); self.cap_exceeded = true;
} else {
self.bytes_written = new_bytes_written;
}
}
fn bytes_mut(&mut self) -> &mut UninitSlice {
let bytes = &mut self.slice[self.bytes_written..];
let len = bytes.len();
let ptr = bytes.as_mut_ptr() as *mut _;
unsafe { UninitSlice::from_raw_parts_mut(ptr, len) }
}
fn put_slice(&mut self, src: &[u8]) {
use core::ptr;
let src_len = src.len();
if self.remaining_mut() < src_len {
self.bytes_written = self.slice.len(); self.cap_exceeded = true;
return;
}
unsafe {
let dst = self.bytes_mut();
ptr::copy_nonoverlapping(src[..].as_ptr(), dst.as_mut_ptr() as *mut u8, src_len);
self.advance_mut(src_len);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fill_with_random(buf: &mut dyn BufMut, amount: usize) {
for _ in 0..amount {
buf.put_u8(0xFF);
}
}
#[test]
fn usefullness() {
let mut data = [0u8; 32];
let mut slice = &mut data[..];
let slice_len = slice.len();
fill_with_random(&mut slice, 27);
let n = slice_len - slice.len();
assert_eq!(n, 27);
let _wrote_data = &data[..n];
let mut raw = [0u8; 32];
let mut slice = SafeBytesSlice::from(&mut raw[..]);
fill_with_random(&mut slice, 27);
let _wrote_data = match slice.try_into_bytes() {
Ok(bytes) => bytes,
Err(_err) => unimplemented!(),
};
}
#[test]
fn naive_test() {
let mut static_data = [0u8; 32];
let mut safe_slice = SafeBytesSlice::from(&mut static_data[..]);
fill_with_random(&mut safe_slice, 32);
assert_eq!(safe_slice.is_exceed(), false);
for v in safe_slice
.try_into_bytes()
.expect("not expected capacity")
.iter()
{
assert_eq!(*v, 0xFF);
}
let mut safe_slice = SafeBytesSlice::from(&mut static_data[..]);
fill_with_random(&mut safe_slice, 33);
assert_eq!(safe_slice.is_exceed(), true);
}
}