use std::mem::{size_of, MaybeUninit};
use bytes::BytesMut;
use zerocopy::{AsBytes, FromBytes};
pub unsafe trait IoBuf: Unpin + 'static {
fn stable_ptr(&self) -> *const u8;
fn bytes_init(&self) -> usize;
fn bytes_total(&self) -> usize;
}
pub unsafe trait IoBufMut: IoBuf {
fn stable_mut_ptr(&mut self) -> *mut u8;
unsafe fn set_init(&mut self, pos: usize);
}
unsafe impl<T: IoBufMut> IoBufMut for Box<T> {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.as_mut().stable_mut_ptr()
}
unsafe fn set_init(&mut self, pos: usize) {
self.as_mut().set_init(pos)
}
}
unsafe impl IoBuf for BytesMut {
fn stable_ptr(&self) -> *const u8 {
self.as_ptr()
}
fn bytes_init(&self) -> usize {
self.len()
}
fn bytes_total(&self) -> usize {
self.capacity()
}
}
unsafe impl IoBufMut for BytesMut {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.as_mut_ptr()
}
unsafe fn set_init(&mut self, pos: usize) {
unsafe { self.set_len(pos) }
}
}
unsafe impl<T: IoBuf> IoBuf for Box<T> {
fn stable_ptr(&self) -> *const u8 {
self.as_ref().stable_ptr()
}
fn bytes_init(&self) -> usize {
self.as_ref().bytes_init()
}
fn bytes_total(&self) -> usize {
self.as_ref().bytes_total()
}
}
pub struct ZeroCopyBuf<T> {
init: usize,
inner: MaybeUninit<T>,
}
pub struct ZeroCopyBoxIoBuf<T> {
inner: Box<T>,
init: usize,
}
impl<T> ZeroCopyBoxIoBuf<T> {
pub fn new(inner: Box<T>) -> Self {
Self {
init: size_of::<T>(),
inner,
}
}
pub fn new_uninit(inner: Box<T>) -> Self {
Self { init: 0, inner }
}
fn is_init(&self) -> bool {
self.init == size_of::<T>()
}
pub fn into_inner(self) -> Box<T> {
assert!(self.is_init());
self.inner
}
}
unsafe impl<T: AsBytes + Unpin + 'static> IoBuf for ZeroCopyBoxIoBuf<T> {
fn stable_ptr(&self) -> *const u8 {
T::as_bytes(&self.inner).as_ptr()
}
fn bytes_init(&self) -> usize {
self.init
}
fn bytes_total(&self) -> usize {
size_of::<T>()
}
}
unsafe impl<T: AsBytes + FromBytes + Unpin + 'static> IoBufMut for ZeroCopyBoxIoBuf<T> {
fn stable_mut_ptr(&mut self) -> *mut u8 {
T::as_bytes_mut(&mut self.inner).as_mut_ptr()
}
unsafe fn set_init(&mut self, pos: usize) {
self.init = pos;
}
}
impl<T> ZeroCopyBuf<T> {
pub fn new_init(inner: T) -> Self {
Self {
inner: MaybeUninit::new(inner),
init: size_of::<T>(),
}
}
pub fn new_uninit() -> Self {
Self {
init: 0,
inner: MaybeUninit::uninit(),
}
}
pub fn map_slice<F>(self, f: F) -> MapSlice<T, F>
where
for<'a> F: Fn(&'a Self) -> &'a [u8] + Unpin + 'static,
{
MapSlice { inner: self, f }
}
#[inline]
pub fn is_init(&self) -> bool {
self.init == size_of::<T>()
}
pub fn get_ref(&self) -> &T {
assert!(self.is_init());
unsafe { self.inner.assume_init_ref() }
}
pub fn get_mut(&mut self) -> &mut T {
assert!(self.is_init());
unsafe { self.inner.assume_init_mut() }
}
pub fn into_inner(self) -> T {
assert!(self.is_init());
unsafe { self.inner.assume_init() }
}
pub fn deinit(&mut self) {
self.init = 0;
}
}
pub struct MapSlice<T, F> {
inner: ZeroCopyBuf<T>,
f: F,
}
impl<T, F> MapSlice<T, F> {
pub(crate) fn into_inner(self) -> ZeroCopyBuf<T> {
self.inner
}
}
unsafe impl<T, F> IoBuf for MapSlice<T, F>
where
for<'a> F: Fn(&'a ZeroCopyBuf<T>) -> &'a [u8] + Unpin + 'static,
T: Unpin + 'static + AsBytes,
{
fn stable_ptr(&self) -> *const u8 {
(self.f)(&self.inner).as_ptr()
}
fn bytes_init(&self) -> usize {
(self.f)(&self.inner).len()
}
fn bytes_total(&self) -> usize {
(self.f)(&self.inner).len()
}
}
unsafe impl<T: AsBytes + Unpin + 'static> IoBuf for ZeroCopyBuf<T> {
fn stable_ptr(&self) -> *const u8 {
self.inner.as_ptr() as *const _
}
fn bytes_init(&self) -> usize {
self.init
}
fn bytes_total(&self) -> usize {
size_of::<T>()
}
}
unsafe impl<T: AsBytes + Unpin + 'static> IoBufMut for ZeroCopyBuf<T> {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.inner.as_mut_ptr() as *mut _
}
unsafe fn set_init(&mut self, pos: usize) {
assert!(pos <= size_of::<T>());
self.init = pos
}
}
unsafe impl IoBufMut for Vec<u8> {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.as_mut_ptr()
}
unsafe fn set_init(&mut self, init_len: usize) {
if self.len() < init_len {
self.set_len(init_len);
}
}
}
unsafe impl IoBuf for Vec<u8> {
fn stable_ptr(&self) -> *const u8 {
self.as_ptr()
}
fn bytes_init(&self) -> usize {
self.len()
}
fn bytes_total(&self) -> usize {
self.capacity()
}
}