#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(feature = "alloc")]
use alloc::rc::Rc;
#[cfg(feature = "alloc")]
use alloc::string::String;
#[cfg(feature = "alloc")]
use alloc::sync::Arc;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::slice;
use crate::error::InvalidInput;
#[allow(private_bounds, reason = "Sealed trait.")]
pub trait Buffer: Sized {
#[doc(hidden)]
type Uninit: BufMut;
#[doc(hidden)]
fn uninit(len: usize) -> Result<Self::Uninit, InvalidInput>;
#[allow(unsafe_code, reason = "XXX")]
#[doc(hidden)]
unsafe fn assume_init(uninit: Self::Uninit) -> Self;
}
impl<const N: usize> Buffer for [u8; N] {
type Uninit = [MaybeUninit<u8>; N];
#[inline]
fn uninit(len: usize) -> Result<Self::Uninit, InvalidInput> {
if len != N {
return Err(InvalidInput);
}
Ok([MaybeUninit::uninit(); N])
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn assume_init(uninit: Self::Uninit) -> Self {
unsafe { *((&raw const uninit).cast()) }
}
}
#[cfg(feature = "alloc")]
impl Buffer for Vec<u8> {
type Uninit = Vec<u8>;
#[inline]
fn uninit(len: usize) -> Result<Self::Uninit, InvalidInput> {
Ok(Vec::with_capacity(len))
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn assume_init(uninit: Self::Uninit) -> Self {
uninit
}
}
#[cfg(feature = "alloc")]
impl Buffer for Box<[u8]> {
type Uninit = Box<[MaybeUninit<u8>]>;
#[inline]
fn uninit(len: usize) -> Result<Self::Uninit, InvalidInput> {
Ok(Self::new_uninit_slice(len))
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn assume_init(uninit: Self::Uninit) -> Self {
unsafe { uninit.assume_init() }
}
}
#[cfg(feature = "alloc")]
impl Buffer for Arc<[u8]> {
type Uninit = Arc<[MaybeUninit<u8>]>;
#[inline]
fn uninit(len: usize) -> Result<Self::Uninit, InvalidInput> {
Ok(Self::new_uninit_slice(len))
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn assume_init(uninit: Self::Uninit) -> Self {
unsafe { uninit.assume_init() }
}
}
#[cfg(feature = "alloc")]
impl Buffer for Rc<[u8]> {
type Uninit = Rc<[MaybeUninit<u8>]>;
#[inline]
fn uninit(len: usize) -> Result<Self::Uninit, InvalidInput> {
Ok(Self::new_uninit_slice(len))
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn assume_init(uninit: Self::Uninit) -> Self {
unsafe { uninit.assume_init() }
}
}
#[allow(private_bounds, reason = "Sealed trait.")]
pub trait StringBuffer {
#[doc(hidden)]
type Bytes: Buffer;
#[allow(unsafe_code, reason = "XXX")]
#[doc(hidden)]
unsafe fn from_utf8_unchecked(buf: Self::Bytes) -> Self;
}
impl<T> StringBuffer for T
where
T: Buffer,
{
type Bytes = Self;
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn from_utf8_unchecked(buf: Self::Bytes) -> Self {
buf
}
}
#[cfg(feature = "alloc")]
impl StringBuffer for String {
type Bytes = Vec<u8>;
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn from_utf8_unchecked(buf: Self::Bytes) -> Self {
unsafe { Self::from_utf8_unchecked(buf) }
}
}
#[cfg(feature = "alloc")]
impl StringBuffer for Box<str> {
type Bytes = Box<[u8]>;
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn from_utf8_unchecked(buf: Self::Bytes) -> Self {
unsafe { Self::from_raw(Box::into_raw(buf) as *mut str) }
}
}
#[cfg(feature = "alloc")]
impl StringBuffer for Arc<str> {
type Bytes = Arc<[u8]>;
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn from_utf8_unchecked(buf: Self::Bytes) -> Self {
unsafe { Self::from_raw(Arc::into_raw(buf) as *const str) }
}
}
#[cfg(feature = "alloc")]
impl StringBuffer for Rc<str> {
type Bytes = Rc<[u8]>;
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn from_utf8_unchecked(buf: Self::Bytes) -> Self {
unsafe { Self::from_raw(Rc::into_raw(buf) as *const str) }
}
}
pub trait BufMut {
#[doc(hidden)]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>];
#[allow(unsafe_code, reason = "XXX")]
#[doc(hidden)]
unsafe fn advance(&mut self, additional: usize) -> &[u8];
}
impl<T> BufMut for &mut T
where
T: BufMut + ?Sized,
{
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
BufMut::as_uninit_mut(*self)
}
#[inline]
#[allow(unsafe_code, reason = "XXX")]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
unsafe { BufMut::advance(*self, additional) }
}
}
impl BufMut for [MaybeUninit<u8>] {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr().cast(), additional) }
}
}
impl<const N: usize> BufMut for [MaybeUninit<u8>; N] {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= N);
unsafe { slice::from_raw_parts(self.as_ptr().cast(), additional) }
}
}
impl BufMut for [u8] {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
#[allow(unsafe_code, reason = "XXX")]
unsafe {
&mut *((&raw mut *self) as *mut [MaybeUninit<u8>])
}
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr(), additional) }
}
}
impl<const N: usize> BufMut for [u8; N] {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
#[allow(unsafe_code, reason = "XXX")]
unsafe {
&mut *((&raw mut *self).cast::<[MaybeUninit<u8>; N]>())
}
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= N);
unsafe { slice::from_raw_parts(self.as_ptr(), additional) }
}
}
#[cfg(feature = "alloc")]
impl BufMut for Vec<u8> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self.spare_capacity_mut()
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
let len = self.len();
unsafe {
self.set_len(len + additional);
}
&self[len..]
}
}
#[cfg(feature = "alloc")]
impl BufMut for Box<[MaybeUninit<u8>]> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr().cast::<u8>(), additional) }
}
}
#[cfg(feature = "alloc")]
impl BufMut for Box<[u8]> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
#[allow(unsafe_code, reason = "XXX")]
unsafe {
&mut *((&raw mut **self) as *mut [MaybeUninit<u8>])
}
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr(), additional) }
}
}
#[cfg(feature = "alloc")]
impl BufMut for Arc<[u8]> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
#[allow(unsafe_code, reason = "XXX")]
unsafe {
&mut *(&raw mut *Arc::make_mut(self) as *mut [MaybeUninit<u8>])
}
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr(), additional) }
}
}
#[cfg(feature = "alloc")]
impl BufMut for Arc<[MaybeUninit<u8>]> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
Arc::make_mut(self)
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr().cast::<u8>(), additional) }
}
}
#[cfg(feature = "alloc")]
impl BufMut for Rc<[u8]> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
#[allow(unsafe_code, reason = "XXX")]
unsafe {
&mut *(&raw mut *Rc::make_mut(self) as *mut [MaybeUninit<u8>])
}
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr(), additional) }
}
}
#[cfg(feature = "alloc")]
impl BufMut for Rc<[MaybeUninit<u8>]> {
#[inline]
fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
Rc::make_mut(self)
}
#[allow(unsafe_code, reason = "XXX")]
#[inline]
unsafe fn advance(&mut self, additional: usize) -> &[u8] {
debug_assert!(additional <= self.len());
unsafe { slice::from_raw_parts(self.as_ptr().cast::<u8>(), additional) }
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_buf_mut() {
fn test_buf_mut<B: BufMut>(mut buf: B, additional: usize) {
const B: u8 = 0x09;
let uninit = buf.as_uninit_mut();
assert!(uninit.len() >= additional);
uninit.iter_mut().take(additional).for_each(|b| {
b.write(B);
});
#[allow(unsafe_code, reason = "XXX")]
let initialized = unsafe { buf.advance(additional) };
assert!(initialized.len() == additional);
assert!(initialized.iter().all(|&b| b == B));
}
macro_rules! test_buf_mut {
($buf:expr, $additional:expr) => {
test_buf_mut($buf, $additional);
test_buf_mut(&mut $buf, $additional);
};
}
test_buf_mut!(&mut [0u8; 0][..], 0);
test_buf_mut!(&mut [MaybeUninit::<u8>::uninit(); 0][..], 0);
test_buf_mut!(&mut [0u8; 0], 0);
test_buf_mut!(&mut [MaybeUninit::<u8>::uninit(); 0], 0);
test_buf_mut!(Vec::with_capacity(0), 0);
test_buf_mut!(Box::<[u8]>::new_uninit_slice(0), 0);
test_buf_mut!(Arc::<[u8]>::new_uninit_slice(0), 0);
test_buf_mut!(Rc::<[u8]>::new_uninit_slice(0), 0);
test_buf_mut!(&mut [0u8; 16][..], 16);
test_buf_mut!(&mut [0u8; 16][..], 15);
test_buf_mut!(&mut [0u8; 16][..], 0);
test_buf_mut!(&mut [MaybeUninit::<u8>::uninit(); 16][..], 16);
test_buf_mut!(&mut [MaybeUninit::<u8>::uninit(); 16][..], 15);
test_buf_mut!(&mut [MaybeUninit::<u8>::uninit(); 16][..], 0);
test_buf_mut!([0u8; 16], 16);
test_buf_mut!([0u8; 16], 15);
test_buf_mut!([0u8; 16], 0);
test_buf_mut!([0u8; 16], 16);
test_buf_mut!([0u8; 16], 15);
test_buf_mut!([0u8; 16], 0);
test_buf_mut!([MaybeUninit::<u8>::uninit(); 16], 16);
test_buf_mut!([MaybeUninit::<u8>::uninit(); 16], 15);
test_buf_mut!([MaybeUninit::<u8>::uninit(); 16], 0);
test_buf_mut!(Vec::with_capacity(16), 16);
test_buf_mut!(Vec::with_capacity(16), 15);
test_buf_mut!(Vec::with_capacity(16), 0);
test_buf_mut!(Box::<[u8]>::from(&[0u8; 16][..]), 16);
test_buf_mut!(Box::<[u8]>::from(&[0u8; 16][..]), 15);
test_buf_mut!(Box::<[u8]>::from(&[0u8; 16][..]), 0);
test_buf_mut!(Box::<[u8]>::new_uninit_slice(16), 16);
test_buf_mut!(Box::<[u8]>::new_uninit_slice(16), 15);
test_buf_mut!(Box::<[u8]>::new_uninit_slice(16), 0);
test_buf_mut!(Arc::<[u8]>::from(&[0u8; 16][..]), 16);
test_buf_mut!(Arc::<[u8]>::from(&[0u8; 16][..]), 15);
test_buf_mut!(Arc::<[u8]>::from(&[0u8; 16][..]), 0);
test_buf_mut!(Arc::<[u8]>::new_uninit_slice(16), 16);
test_buf_mut!(Arc::<[u8]>::new_uninit_slice(16), 15);
test_buf_mut!(Arc::<[u8]>::new_uninit_slice(16), 0);
test_buf_mut!(Rc::<[u8]>::from(&[0u8; 16][..]), 16);
test_buf_mut!(Rc::<[u8]>::new_uninit_slice(16), 15);
test_buf_mut!(Rc::<[u8]>::new_uninit_slice(16), 0);
}
#[test]
fn test_buffer() {
fn test_buffer<B: Buffer + AsRef<[u8]>>(len: usize) {
const B: u8 = 0x09;
let mut uninit = B::uninit(len).expect("Buffer::uninit");
{
let uninit_slice = uninit.as_uninit_mut();
assert!(
uninit_slice.len() >= len,
"uninit buffer should have at least {len} bytes, got {}",
uninit_slice.len()
);
uninit_slice.iter_mut().take(len).for_each(|b| {
b.write(B);
});
}
#[allow(unsafe_code, reason = "XXX")]
let initialized = unsafe { uninit.advance(len) };
assert_eq!(initialized.len(), len);
assert!(initialized.iter().all(|&b| b == B));
#[allow(unsafe_code, reason = "XXX")]
let buf = unsafe { B::assume_init(uninit) };
let bytes: &[u8] = buf.as_ref();
assert_eq!(bytes.len(), len);
assert!(bytes.iter().all(|&b| b == B));
}
test_buffer::<[u8; 0]>(0);
test_buffer::<[u8; 1]>(1);
test_buffer::<[u8; 8]>(8);
test_buffer::<[u8; 16]>(16);
test_buffer::<Vec<u8>>(0);
test_buffer::<Vec<u8>>(1);
test_buffer::<Vec<u8>>(16);
test_buffer::<Box<[u8]>>(0);
test_buffer::<Box<[u8]>>(16);
test_buffer::<Arc<[u8]>>(0);
test_buffer::<Arc<[u8]>>(16);
test_buffer::<Rc<[u8]>>(0);
test_buffer::<Rc<[u8]>>(16);
}
#[test]
fn test_buffer_uninit_rejects_wrong_len() {
assert!(<[u8; 16]>::uninit(17).is_err());
assert!(<[u8; 16]>::uninit(16).is_ok());
assert!(<[u8; 16]>::uninit(15).is_err());
assert!(<[u8; 16]>::uninit(1).is_err());
assert!(<[u8; 16]>::uninit(0).is_err());
assert!(<[u8; 0]>::uninit(0).is_ok());
assert!(<[u8; 0]>::uninit(1).is_err());
}
#[test]
fn test_string_buffer() {
fn test_string_buffer_bytes<S: StringBuffer + AsRef<[u8]>>(len: usize) {
const C: u8 = b'a';
let mut uninit = S::Bytes::uninit(len).expect("Bytes::uninit");
{
let uninit_slice = uninit.as_uninit_mut();
assert!(uninit_slice.len() >= len);
uninit_slice.iter_mut().take(len).for_each(|b| {
b.write(C);
});
}
#[allow(unsafe_code, reason = "XXX")]
let _ = unsafe { uninit.advance(len) };
#[allow(unsafe_code, reason = "XXX")]
let bytes = unsafe { S::Bytes::assume_init(uninit) };
#[allow(unsafe_code, reason = "ASCII bytes are valid UTF-8")]
let result = unsafe { S::from_utf8_unchecked(bytes) };
let buf: &[u8] = result.as_ref();
assert_eq!(buf.len(), len);
assert!(buf.iter().all(|&b| b == C));
}
fn test_string_buffer_str<S: StringBuffer + core::ops::Deref<Target = str>>(len: usize) {
const C: u8 = b'a';
let mut uninit = S::Bytes::uninit(len).expect("Bytes::uninit");
{
let uninit_slice = uninit.as_uninit_mut();
assert!(uninit_slice.len() >= len);
uninit_slice.iter_mut().take(len).for_each(|b| {
b.write(C);
});
}
#[allow(unsafe_code, reason = "XXX")]
let _ = unsafe { uninit.advance(len) };
#[allow(unsafe_code, reason = "XXX")]
let bytes = unsafe { S::Bytes::assume_init(uninit) };
#[allow(unsafe_code, reason = "ASCII bytes are valid UTF-8")]
let string = unsafe { S::from_utf8_unchecked(bytes) };
let s: &str = &string;
assert_eq!(s.len(), len);
assert!(s.bytes().all(|b| b == C));
assert!(s.chars().all(|c| c as u8 == C));
}
test_string_buffer_bytes::<[u8; 0]>(0);
test_string_buffer_bytes::<[u8; 16]>(16);
test_string_buffer_bytes::<Vec<u8>>(0);
test_string_buffer_bytes::<Vec<u8>>(16);
test_string_buffer_bytes::<Box<[u8]>>(0);
test_string_buffer_bytes::<Box<[u8]>>(16);
test_string_buffer_bytes::<Arc<[u8]>>(0);
test_string_buffer_bytes::<Arc<[u8]>>(16);
test_string_buffer_bytes::<Rc<[u8]>>(0);
test_string_buffer_bytes::<Rc<[u8]>>(16);
test_string_buffer_str::<String>(0);
test_string_buffer_str::<String>(16);
test_string_buffer_str::<Box<str>>(0);
test_string_buffer_str::<Box<str>>(16);
test_string_buffer_str::<Arc<str>>(0);
test_string_buffer_str::<Arc<str>>(16);
test_string_buffer_str::<Rc<str>>(0);
test_string_buffer_str::<Rc<str>>(16);
}
}