use super::*;
#[cfg(feature = "alloc")]
use {alloc::vec::Vec, core::slice::from_raw_parts_mut};
pub struct Cursor<T> {
inner: T,
pos: usize,
}
impl<T> Cursor<T> {
pub const fn new(inner: T) -> Self {
Self { inner, pos: 0 }
}
pub const fn new_at(inner: T, pos: usize) -> Self {
Self { inner, pos }
}
pub const fn set_position(&mut self, pos: usize) {
self.pos = pos;
}
pub fn into_inner(self) -> T {
self.inner
}
pub const fn position(&self) -> usize {
self.pos
}
}
impl<T> Cursor<T>
where
T: AsRef<[u8]>,
{
#[inline]
fn cur_slice(&self) -> &[u8] {
let slice = self.inner.as_ref();
&slice[self.pos.min(slice.len())..]
}
#[inline]
fn cur_len(&self) -> usize {
self.inner.as_ref().len().saturating_sub(self.pos)
}
#[inline]
fn consume_slice_checked(&mut self, mid: usize) -> ReadResult<&[u8]> {
let slice = self.inner.as_ref();
let cur = &slice[self.pos.min(slice.len())..];
let Some(left) = cur.get(..mid) else {
return Err(read_size_limit(mid));
};
self.pos = unsafe { self.pos.unchecked_add(mid) };
Ok(left)
}
}
impl<'a, T> Reader<'a> for Cursor<T>
where
T: AsRef<[u8]>,
{
type Trusted<'b>
= TrustedSliceReader<'a, 'b>
where
Self: 'b;
#[inline]
fn fill_buf(&mut self, n_bytes: usize) -> ReadResult<&[u8]> {
let src = self.cur_slice();
Ok(&src[..n_bytes.min(src.len())])
}
#[inline]
fn fill_exact(&mut self, n_bytes: usize) -> ReadResult<&[u8]> {
let Some(src) = self.cur_slice().get(..n_bytes) else {
return Err(read_size_limit(n_bytes));
};
Ok(src)
}
#[inline]
unsafe fn consume_unchecked(&mut self, amt: usize) {
self.pos = unsafe { self.pos.unchecked_add(amt) };
}
fn consume(&mut self, amt: usize) -> ReadResult<()> {
if self.cur_len() < amt {
return Err(read_size_limit(amt));
}
unsafe { self.consume_unchecked(amt) };
Ok(())
}
#[inline]
unsafe fn as_trusted_for(&mut self, n_bytes: usize) -> ReadResult<Self::Trusted<'_>> {
Ok(TrustedSliceReader::new(
self.consume_slice_checked(n_bytes)?,
))
}
}
mod uninit_slice {
use super::*;
#[inline]
pub(super) fn cur_slice_mut(
inner: &mut [MaybeUninit<u8>],
pos: usize,
) -> &mut [MaybeUninit<u8>] {
let len = inner.len();
&mut inner[pos.min(len)..]
}
#[inline]
pub(super) fn get_slice_mut_checked(
inner: &mut [MaybeUninit<u8>],
pos: usize,
len: usize,
) -> WriteResult<&mut [MaybeUninit<u8>]> {
let Some(dst) = cur_slice_mut(inner, pos).get_mut(..len) else {
return Err(write_size_limit(len));
};
Ok(dst)
}
pub(super) fn write(
inner: &mut [MaybeUninit<u8>],
pos: &mut usize,
src: &[u8],
) -> WriteResult<()> {
let len = src.len();
let dst = get_slice_mut_checked(inner, *pos, len)?;
unsafe { ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast(), len) };
*pos = unsafe { pos.unchecked_add(len) };
Ok(())
}
#[inline]
pub(super) fn as_trusted_for<'a>(
inner: &'a mut [MaybeUninit<u8>],
pos: &mut usize,
n_bytes: usize,
) -> WriteResult<TrustedSliceWriter<'a>> {
let dst = get_slice_mut_checked(inner, *pos, n_bytes)?;
*pos = unsafe { pos.unchecked_add(n_bytes) };
Ok(TrustedSliceWriter::new(dst))
}
}
impl Writer for Cursor<&mut [MaybeUninit<u8>]> {
type Trusted<'b>
= TrustedSliceWriter<'b>
where
Self: 'b;
#[inline]
fn write(&mut self, src: &[u8]) -> WriteResult<()> {
uninit_slice::write(self.inner, &mut self.pos, src)
}
#[inline]
unsafe fn as_trusted_for(&mut self, n_bytes: usize) -> WriteResult<Self::Trusted<'_>> {
uninit_slice::as_trusted_for(self.inner, &mut self.pos, n_bytes)
}
}
impl<const N: usize> Cursor<&mut MaybeUninit<[u8; N]>> {
#[inline(always)]
pub(super) const fn transpose(inner: &mut MaybeUninit<[u8; N]>) -> &mut [MaybeUninit<u8>; N] {
unsafe { transmute::<&mut MaybeUninit<[u8; N]>, &mut [MaybeUninit<u8>; N]>(inner) }
}
}
impl<const N: usize> Writer for Cursor<&mut MaybeUninit<[u8; N]>> {
type Trusted<'b>
= TrustedSliceWriter<'b>
where
Self: 'b;
#[inline]
fn write(&mut self, src: &[u8]) -> WriteResult<()> {
uninit_slice::write(Self::transpose(self.inner), &mut self.pos, src)
}
#[inline]
unsafe fn as_trusted_for(&mut self, n_bytes: usize) -> WriteResult<Self::Trusted<'_>> {
uninit_slice::as_trusted_for(Self::transpose(self.inner), &mut self.pos, n_bytes)
}
}
#[cfg(feature = "alloc")]
mod vec {
use super::*;
#[inline]
pub(super) fn maybe_grow(inner: &mut Vec<u8>, pos: usize, needed: usize) -> WriteResult<()> {
let Some(required) = pos.checked_add(needed) else {
return Err(write_size_limit(needed));
};
if required > inner.capacity() {
grow(inner, required);
}
#[cold]
fn grow(inner: &mut Vec<u8>, required: usize) {
let additional = unsafe { required.unchecked_sub(inner.len()) };
inner.reserve(additional);
}
Ok(())
}
pub(super) unsafe fn add_len(inner: &mut Vec<u8>, pos: &mut usize, len: usize) {
let next_pos = unsafe { pos.unchecked_add(len) };
if next_pos > inner.len() {
unsafe {
inner.set_len(next_pos);
}
}
*pos = next_pos;
}
pub(super) fn write(inner: &mut Vec<u8>, pos: &mut usize, src: &[u8]) -> WriteResult<()> {
maybe_grow(inner, *pos, src.len())?;
unsafe { ptr::copy_nonoverlapping(src.as_ptr(), inner.as_mut_ptr().add(*pos), src.len()) };
unsafe { add_len(inner, pos, src.len()) };
Ok(())
}
#[inline]
pub(super) fn as_trusted_for<'a>(
inner: &'a mut Vec<u8>,
pos: &'a mut usize,
n_bytes: usize,
) -> WriteResult<TrustedVecWriter<'a>> {
maybe_grow(inner, *pos, n_bytes)?;
let buf = unsafe {
from_raw_parts_mut(
inner.as_mut_ptr().cast::<MaybeUninit<u8>>(),
inner.capacity(),
)
};
Ok(TrustedVecWriter::new(buf, pos))
}
#[inline]
pub(super) fn finish(inner: &mut Vec<u8>, pos: &mut usize) {
if *pos > inner.len() {
unsafe {
inner.set_len(*pos);
}
}
}
}
#[cfg(feature = "alloc")]
pub struct TrustedVecWriter<'a> {
inner: &'a mut [MaybeUninit<u8>],
pos: &'a mut usize,
}
#[cfg(feature = "alloc")]
impl<'a> TrustedVecWriter<'a> {
pub fn new(inner: &'a mut [MaybeUninit<u8>], pos: &'a mut usize) -> Self {
Self { inner, pos }
}
}
#[cfg(feature = "alloc")]
impl<'a> Writer for TrustedVecWriter<'a> {
type Trusted<'b>
= TrustedVecWriter<'b>
where
Self: 'b;
fn write(&mut self, src: &[u8]) -> WriteResult<()> {
unsafe {
ptr::copy_nonoverlapping(
src.as_ptr().cast(),
self.inner.as_mut_ptr().add(*self.pos),
src.len(),
)
};
*self.pos = unsafe { self.pos.unchecked_add(src.len()) };
Ok(())
}
#[inline]
unsafe fn as_trusted_for(&mut self, _n_bytes: usize) -> WriteResult<Self::Trusted<'_>> {
Ok(TrustedVecWriter::new(self.inner, self.pos))
}
}
#[cfg(feature = "alloc")]
impl Writer for Cursor<&mut Vec<u8>> {
type Trusted<'b>
= TrustedVecWriter<'b>
where
Self: 'b;
#[inline]
fn write(&mut self, src: &[u8]) -> WriteResult<()> {
vec::write(self.inner, &mut self.pos, src)
}
#[inline]
fn finish(&mut self) -> WriteResult<()> {
vec::finish(self.inner, &mut self.pos);
Ok(())
}
#[inline]
unsafe fn as_trusted_for(&mut self, n_bytes: usize) -> WriteResult<Self::Trusted<'_>> {
vec::as_trusted_for(self.inner, &mut self.pos, n_bytes)
}
}
#[cfg(feature = "alloc")]
impl Writer for Cursor<Vec<u8>> {
type Trusted<'b>
= TrustedVecWriter<'b>
where
Self: 'b;
#[inline]
fn write(&mut self, src: &[u8]) -> WriteResult<()> {
vec::write(&mut self.inner, &mut self.pos, src)
}
#[inline]
fn finish(&mut self) -> WriteResult<()> {
vec::finish(&mut self.inner, &mut self.pos);
Ok(())
}
#[inline]
unsafe fn as_trusted_for(&mut self, n_bytes: usize) -> WriteResult<Self::Trusted<'_>> {
vec::as_trusted_for(&mut self.inner, &mut self.pos, n_bytes)
}
}
#[cfg(all(test, feature = "alloc"))]
mod tests {
#![allow(clippy::arithmetic_side_effects)]
use {super::*, crate::proptest_config::proptest_cfg, alloc::vec, proptest::prelude::*};
proptest! {
#![proptest_config(proptest_cfg())]
#[test]
fn cursor_read_no_panic_no_ub_check(bytes in any::<Vec<u8>>(), pos in any::<usize>()) {
let mut cursor = Cursor::new_at(&bytes, pos);
let buf = cursor.fill_buf(bytes.len()).unwrap();
if pos > bytes.len() {
prop_assert_eq!(buf, &[]);
} else {
prop_assert_eq!(buf, &bytes[pos..]);
}
let res = cursor.fill_exact(bytes.len());
if pos > bytes.len() && !bytes.is_empty() {
prop_assert!(matches!(res, Err(ReadError::ReadSizeLimit(x)) if x == bytes.len()));
} else {
prop_assert_eq!(res.unwrap(), &bytes[pos.min(bytes.len())..]);
}
}
#[test]
fn cursor_zero_len_ops_ok(bytes in any::<Vec<u8>>(), pos in any::<usize>()) {
let mut cursor = Cursor::new_at(&bytes, pos);
let start = cursor.position();
let fe = cursor.fill_exact(0).unwrap();
prop_assert_eq!(fe.len(), 0);
prop_assert_eq!(cursor.position(), start);
prop_assert!(cursor.consume(0).is_ok());
prop_assert_eq!(cursor.position(), start);
let start2 = cursor.position();
let mut trusted = unsafe { <Cursor<_> as Reader>::as_trusted_for(&mut cursor, 0) }.unwrap();
prop_assert_eq!(trusted.fill_buf(1).unwrap(), &[]);
prop_assert_eq!(trusted.fill_exact(0).unwrap().len(), 0);
prop_assert_eq!(cursor.position(), start2);
}
#[test]
fn cursor_as_trusted_for_remaining_advances_to_len(bytes in any::<Vec<u8>>(), pos in any::<usize>()) {
let len = bytes.len();
let pos = if len == 0 { 0 } else { pos % (len + 1) };
let mut cursor = Cursor::new_at(&bytes, pos);
let remaining = len.saturating_sub(pos);
let _trusted = unsafe { <Cursor<_> as Reader>::as_trusted_for(&mut cursor, remaining) }.unwrap();
prop_assert_eq!(cursor.position(), len);
}
#[test]
fn cursor_extremal_pos_max_zero_len_ok(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new_at(&bytes, usize::MAX);
prop_assert_eq!(cursor.fill_buf(1).unwrap(), &[]);
prop_assert!(matches!(cursor.peek(), Err(ReadError::ReadSizeLimit(1))));
let start = cursor.position();
prop_assert!(cursor.fill_exact(0).is_ok());
prop_assert!(cursor.consume(0).is_ok());
let _trusted = unsafe { <Cursor<_> as Reader>::as_trusted_for(&mut cursor, 0) }.unwrap();
prop_assert_eq!(cursor.position(), start);
}
#[test]
fn uninit_slice_write_no_panic_no_ub_check(bytes in any::<Vec<u8>>(), pos in any::<usize>()) {
let mut output: Vec<u8> = Vec::with_capacity(bytes.len());
let mut cursor = Cursor::new_at(output.spare_capacity_mut(), pos);
let res = cursor.write(&bytes);
if pos > bytes.len() && !bytes.is_empty() {
prop_assert!(matches!(res, Err(WriteError::WriteSizeLimit(x)) if x == bytes.len()));
} else if pos == 0 {
prop_assert_eq!(output, bytes);
}
}
#[test]
fn vec_write_no_panic_no_ub_check(bytes in any::<Vec<u8>>(), pos in any::<u16>()) {
let pos = pos as usize;
let mut output: Vec<u8> = Vec::new();
let mut cursor = Cursor::new_at(&mut output, pos);
cursor.write(&bytes).unwrap();
prop_assert_eq!(&output[pos..], &bytes);
}
#[test]
fn cursor_write_vec_new(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(Vec::new());
cursor.write(&bytes).unwrap();
prop_assert_eq!(&cursor.inner, &bytes);
let mut vec = Vec::with_capacity(bytes.len());
let mut cursor = Cursor::new(vec.spare_capacity_mut());
cursor.write(&bytes).unwrap();
unsafe { vec.set_len(bytes.len()) };
prop_assert_eq!(&vec, &bytes);
}
#[test]
fn cursor_write_existing_vec(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(vec![0; bytes.len()]);
cursor.write(&bytes).unwrap();
prop_assert_eq!(&cursor.inner, &bytes);
}
#[test]
fn cursor_write_existing_grow_vec(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(vec![0; bytes.len() / 2]);
cursor.write(&bytes).unwrap();
prop_assert_eq!(&cursor.inner, &bytes);
}
#[test]
fn cursor_write_partial_vec(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(vec![1; bytes.len()]);
let half = bytes.len() - bytes.len() / 2;
cursor.write(&bytes[..half]).unwrap();
prop_assert_eq!(&cursor.inner[..half], &bytes[..half]);
prop_assert_eq!(&cursor.inner[half..], &vec![1; bytes.len() - half]);
cursor.write(&bytes[half..]).unwrap();
prop_assert_eq!(&cursor.inner, &bytes);
}
#[test]
fn cursor_write_trusted_vec(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(vec![1; bytes.len()]);
let half = bytes.len() - bytes.len() / 2;
cursor.write(&bytes[..half]).unwrap();
unsafe { <Cursor<_> as Writer>::as_trusted_for(&mut cursor, bytes.len() - half) }
.unwrap()
.write(&bytes[half..])
.unwrap();
cursor.finish().unwrap();
prop_assert_eq!(&cursor.inner, &bytes);
}
#[test]
fn cursor_write_trusted_grow_vec(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(vec![1; bytes.len() / 2]);
let half = bytes.len() - bytes.len() / 2;
cursor.write(&bytes[..half]).unwrap();
unsafe { <Cursor<_> as Writer>::as_trusted_for(&mut cursor, bytes.len() - half) }
.unwrap()
.write(&bytes[half..])
.unwrap();
cursor.finish().unwrap();
prop_assert_eq!(&cursor.inner, &bytes);
}
#[test]
fn cursor_write_trusted_oversized_vec(bytes in any::<Vec<u8>>()) {
let mut cursor = Cursor::new(vec![1; bytes.len() * 2]);
let half = bytes.len() - bytes.len() / 2;
cursor.write(&bytes[..half]).unwrap();
unsafe { <Cursor<_> as Writer>::as_trusted_for(&mut cursor, bytes.len() - half) }
.unwrap()
.write(&bytes[half..])
.unwrap();
cursor.finish().unwrap();
prop_assert_eq!(&cursor.inner[..bytes.len()], &bytes);
prop_assert_eq!(&cursor.inner[bytes.len()..], &vec![1; bytes.len()]);
}
}
}