use std::alloc::{Layout, alloc, dealloc, handle_alloc_error};
use std::cell::Cell;
use std::marker::PhantomData;
use std::ptr::NonNull;
#[repr(C)]
struct ByteBufferHeader {
refcount: Cell<u32>,
capacity: u32,
}
const HEADER_SIZE: usize = std::mem::size_of::<ByteBufferHeader>();
const HEADER_ALIGN: usize = std::mem::align_of::<ByteBufferHeader>();
fn layout_for(payload: usize) -> Layout {
let total = HEADER_SIZE
.checked_add(payload)
.expect("ByteBuffer: layout overflow");
Layout::from_size_align(total, HEADER_ALIGN).expect("ByteBuffer: layout invariant")
}
unsafe fn alloc_header(cap: usize) -> NonNull<ByteBufferHeader> {
assert!(
cap <= u32::MAX as usize,
"ByteBuffer: capacity too large ({cap}, max {})",
u32::MAX
);
let layout = layout_for(cap);
let raw = unsafe { alloc(layout) };
if raw.is_null() {
handle_alloc_error(layout);
}
let header_ptr = raw as *mut ByteBufferHeader;
unsafe {
header_ptr.write(ByteBufferHeader {
refcount: Cell::new(1),
capacity: cap as u32,
});
NonNull::new_unchecked(header_ptr)
}
}
unsafe fn dealloc_buffer(ptr: NonNull<ByteBufferHeader>) {
unsafe {
let cap = ptr.as_ref().capacity as usize;
let layout = layout_for(cap);
std::ptr::drop_in_place(ptr.as_ptr());
dealloc(ptr.as_ptr() as *mut u8, layout);
}
}
unsafe fn refcount_inc(ptr: NonNull<ByteBufferHeader>) {
unsafe {
let header = ptr.as_ref();
let next = header
.refcount
.get()
.checked_add(1)
.expect("ByteBuffer: refcount overflow");
header.refcount.set(next);
}
}
unsafe fn refcount_dec(ptr: NonNull<ByteBufferHeader>) {
unsafe {
let header = ptr.as_ref();
let prev = header.refcount.get();
debug_assert!(prev > 0, "ByteBuffer: drop with zero refcount");
let next = prev - 1;
if next != 0 {
header.refcount.set(next);
return;
}
dealloc_buffer(ptr);
}
}
pub struct ByteBufferMut {
ptr: NonNull<ByteBufferHeader>,
_marker: PhantomData<*mut ()>,
}
impl ByteBufferMut {
pub fn with_capacity(cap: usize) -> Self {
Self {
ptr: unsafe { alloc_header(cap) },
_marker: PhantomData,
}
}
#[inline]
pub fn capacity(&self) -> u32 {
unsafe { self.ptr.as_ref().capacity }
}
#[inline]
pub fn data_ptr(&self) -> *const u8 {
unsafe { (self.ptr.as_ptr() as *const u8).add(HEADER_SIZE) }
}
#[inline]
pub fn data_mut_ptr(&mut self) -> *mut u8 {
unsafe { (self.ptr.as_ptr() as *mut u8).add(HEADER_SIZE) }
}
pub fn share(&self) -> ByteBuffer {
unsafe { refcount_inc(self.ptr) };
ByteBuffer {
ptr: self.ptr,
_marker: PhantomData,
}
}
pub fn freeze(self) -> ByteBuffer {
let ptr = self.ptr;
std::mem::forget(self);
ByteBuffer {
ptr,
_marker: PhantomData,
}
}
#[inline]
pub fn refcount(&self) -> u32 {
unsafe { self.ptr.as_ref().refcount.get() }
}
#[inline]
pub fn ensure_unique_for_mutate(&mut self, keep: usize) {
if self.refcount() == 1 {
return;
}
self.cow_swap(keep);
}
#[cold]
fn cow_swap(&mut self, keep: usize) {
let cap = self.capacity() as usize;
debug_assert!(keep <= cap, "ByteBufferMut::cow_swap: keep > cap");
let new_ptr = unsafe { alloc_header(cap) };
if keep > 0 {
unsafe {
std::ptr::copy_nonoverlapping(
(self.ptr.as_ptr() as *const u8).add(HEADER_SIZE),
(new_ptr.as_ptr() as *mut u8).add(HEADER_SIZE),
keep,
);
}
}
unsafe { refcount_dec(self.ptr) };
self.ptr = new_ptr;
}
}
impl Drop for ByteBufferMut {
fn drop(&mut self) {
unsafe { refcount_dec(self.ptr) };
}
}
pub struct ByteBuffer {
ptr: NonNull<ByteBufferHeader>,
_marker: PhantomData<*mut ()>,
}
impl ByteBuffer {
pub fn from_slice(slice: &[u8]) -> Self {
let mut buf = ByteBufferMut::with_capacity(slice.len());
if !slice.is_empty() {
unsafe {
std::ptr::copy_nonoverlapping(slice.as_ptr(), buf.data_mut_ptr(), slice.len());
}
}
buf.freeze()
}
#[inline]
pub fn data_ptr(&self) -> *const u8 {
unsafe { (self.ptr.as_ptr() as *const u8).add(HEADER_SIZE) }
}
#[inline]
pub fn capacity(&self) -> u32 {
unsafe { self.ptr.as_ref().capacity }
}
#[inline]
pub fn refcount(&self) -> u32 {
unsafe { self.ptr.as_ref().refcount.get() }
}
}
impl Clone for ByteBuffer {
#[inline]
fn clone(&self) -> Self {
unsafe { refcount_inc(self.ptr) };
Self {
ptr: self.ptr,
_marker: PhantomData,
}
}
}
impl Drop for ByteBuffer {
fn drop(&mut self) {
unsafe { refcount_dec(self.ptr) };
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_small() {
let buf = ByteBuffer::from_slice(b"hello");
assert_eq!(buf.capacity(), 5);
let slice = unsafe { std::slice::from_raw_parts(buf.data_ptr(), buf.capacity() as usize) };
assert_eq!(slice, b"hello");
}
#[test]
fn clone_shares_buffer() {
let buf = ByteBuffer::from_slice(b"abc");
let buf2 = buf.clone();
assert_eq!(buf.refcount(), 2);
assert_eq!(buf.data_ptr(), buf2.data_ptr());
}
#[test]
fn cow_isolates_outstanding_view() {
let mut buf = ByteBufferMut::with_capacity(8);
#[allow(clippy::manual_c_str_literals)]
unsafe {
std::ptr::copy_nonoverlapping(b"hello\0\0\0".as_ptr(), buf.data_mut_ptr(), 8)
};
let view = buf.share();
assert_eq!(buf.refcount(), 2);
buf.ensure_unique_for_mutate(5);
assert_eq!(buf.refcount(), 1);
assert_eq!(view.refcount(), 1);
assert_ne!(buf.data_ptr(), view.data_ptr());
let view_slice = unsafe { std::slice::from_raw_parts(view.data_ptr(), 5) };
assert_eq!(view_slice, b"hello");
unsafe {
std::ptr::copy_nonoverlapping(b"WORLD".as_ptr(), buf.data_mut_ptr(), 5);
}
let buf_slice = unsafe { std::slice::from_raw_parts(buf.data_ptr(), 5) };
assert_eq!(buf_slice, b"WORLD");
let view_slice = unsafe { std::slice::from_raw_parts(view.data_ptr(), 5) };
assert_eq!(view_slice, b"hello", "outstanding view must remain stable");
}
}