dambi 0.1.1

Single-threaded (!Send + !Sync) primitives
Documentation
use std::alloc::{Layout, alloc, dealloc, handle_alloc_error};
use std::cell::Cell;
use std::marker::PhantomData;
use std::ptr::NonNull;

#[repr(C)]
struct ByteBufferHeader {
    refcount: Cell<u32>,
    capacity: u32,
}

const HEADER_SIZE: usize = std::mem::size_of::<ByteBufferHeader>();
const HEADER_ALIGN: usize = std::mem::align_of::<ByteBufferHeader>();

fn layout_for(payload: usize) -> Layout {
    let total = HEADER_SIZE
        .checked_add(payload)
        .expect("ByteBuffer: layout overflow");
    Layout::from_size_align(total, HEADER_ALIGN).expect("ByteBuffer: layout invariant")
}

unsafe fn alloc_header(cap: usize) -> NonNull<ByteBufferHeader> {
    assert!(
        cap <= u32::MAX as usize,
        "ByteBuffer: capacity too large ({cap}, max {})",
        u32::MAX
    );
    let layout = layout_for(cap);
    // SAFETY: positive size, header alignment.
    let raw = unsafe { alloc(layout) };
    if raw.is_null() {
        handle_alloc_error(layout);
    }
    let header_ptr = raw as *mut ByteBufferHeader;
    // SAFETY: header_ptr points to the freshly-allocated, suitably-aligned
    // header slot; the trailing payload stays uninitialised.
    unsafe {
        header_ptr.write(ByteBufferHeader {
            refcount: Cell::new(1),
            capacity: cap as u32,
        });
        NonNull::new_unchecked(header_ptr)
    }
}

unsafe fn dealloc_buffer(ptr: NonNull<ByteBufferHeader>) {
    // SAFETY: caller must hold the last reference (refcount == 0), so this
    // is the sole drop + free of the header and its trailing payload.
    unsafe {
        let cap = ptr.as_ref().capacity as usize;
        let layout = layout_for(cap);
        std::ptr::drop_in_place(ptr.as_ptr());
        dealloc(ptr.as_ptr() as *mut u8, layout);
    }
}

unsafe fn refcount_inc(ptr: NonNull<ByteBufferHeader>) {
    // SAFETY: header live for caller's lifetime.
    unsafe {
        let header = ptr.as_ref();
        let next = header
            .refcount
            .get()
            .checked_add(1)
            .expect("ByteBuffer: refcount overflow");
        header.refcount.set(next);
    }
}

unsafe fn refcount_dec(ptr: NonNull<ByteBufferHeader>) {
    // SAFETY: header live; refcount > 0 invariant.
    unsafe {
        let header = ptr.as_ref();
        let prev = header.refcount.get();
        debug_assert!(prev > 0, "ByteBuffer: drop with zero refcount");
        let next = prev - 1;
        if next != 0 {
            header.refcount.set(next);
            return;
        }
        dealloc_buffer(ptr);
    }
}

pub struct ByteBufferMut {
    ptr: NonNull<ByteBufferHeader>,
    _marker: PhantomData<*mut ()>,
}

impl ByteBufferMut {
    pub fn with_capacity(cap: usize) -> Self {
        // SAFETY: alloc_header writes a valid header.
        Self {
            ptr: unsafe { alloc_header(cap) },
            _marker: PhantomData,
        }
    }

    #[inline]
    pub fn capacity(&self) -> u32 {
        // SAFETY: header live for self's lifetime.
        unsafe { self.ptr.as_ref().capacity }
    }

    #[inline]
    pub fn data_ptr(&self) -> *const u8 {
        // SAFETY: tail of the allocation.
        unsafe { (self.ptr.as_ptr() as *const u8).add(HEADER_SIZE) }
    }

    #[inline]
    pub fn data_mut_ptr(&mut self) -> *mut u8 {
        // SAFETY: tail of the allocation.
        unsafe { (self.ptr.as_ptr() as *mut u8).add(HEADER_SIZE) }
    }

    pub fn share(&self) -> ByteBuffer {
        // SAFETY: header live; inc the refcount for the new handle.
        unsafe { refcount_inc(self.ptr) };
        ByteBuffer {
            ptr: self.ptr,
            _marker: PhantomData,
        }
    }

    pub fn freeze(self) -> ByteBuffer {
        let ptr = self.ptr;
        std::mem::forget(self);
        ByteBuffer {
            ptr,
            _marker: PhantomData,
        }
    }

    #[inline]
    pub fn refcount(&self) -> u32 {
        // SAFETY: header live for self's lifetime.
        unsafe { self.ptr.as_ref().refcount.get() }
    }

    #[inline]
    pub fn ensure_unique_for_mutate(&mut self, keep: usize) {
        if self.refcount() == 1 {
            return;
        }
        self.cow_swap(keep);
    }

    #[cold]
    fn cow_swap(&mut self, keep: usize) {
        let cap = self.capacity() as usize;
        debug_assert!(keep <= cap, "ByteBufferMut::cow_swap: keep > cap");
        // SAFETY: alloc_header writes a valid header (refcount=1).
        let new_ptr = unsafe { alloc_header(cap) };
        if keep > 0 {
            // SAFETY: src and dst are distinct allocations; `keep <= cap`
            // bounds both reads and writes to their payload regions.
            unsafe {
                std::ptr::copy_nonoverlapping(
                    (self.ptr.as_ptr() as *const u8).add(HEADER_SIZE),
                    (new_ptr.as_ptr() as *mut u8).add(HEADER_SIZE),
                    keep,
                );
            }
        }
        // SAFETY: self.ptr was a live handle.
        unsafe { refcount_dec(self.ptr) };
        self.ptr = new_ptr;
    }
}

impl Drop for ByteBufferMut {
    fn drop(&mut self) {
        // SAFETY: self.ptr live until here.
        unsafe { refcount_dec(self.ptr) };
    }
}

pub struct ByteBuffer {
    ptr: NonNull<ByteBufferHeader>,
    _marker: PhantomData<*mut ()>,
}

impl ByteBuffer {
    pub fn from_slice(slice: &[u8]) -> Self {
        let mut buf = ByteBufferMut::with_capacity(slice.len());
        if !slice.is_empty() {
            // SAFETY: refcount==1 by construction; capacity >= slice.len().
            unsafe {
                std::ptr::copy_nonoverlapping(slice.as_ptr(), buf.data_mut_ptr(), slice.len());
            }
        }
        buf.freeze()
    }

    #[inline]
    pub fn data_ptr(&self) -> *const u8 {
        // SAFETY: tail of the allocation.
        unsafe { (self.ptr.as_ptr() as *const u8).add(HEADER_SIZE) }
    }

    #[inline]
    pub fn capacity(&self) -> u32 {
        // SAFETY: header live for self's lifetime.
        unsafe { self.ptr.as_ref().capacity }
    }

    #[inline]
    pub fn refcount(&self) -> u32 {
        // SAFETY: header live for self's lifetime.
        unsafe { self.ptr.as_ref().refcount.get() }
    }
}

impl Clone for ByteBuffer {
    #[inline]
    fn clone(&self) -> Self {
        // SAFETY: header live; refcount is Cell<u32>.
        unsafe { refcount_inc(self.ptr) };
        Self {
            ptr: self.ptr,
            _marker: PhantomData,
        }
    }
}

impl Drop for ByteBuffer {
    fn drop(&mut self) {
        // SAFETY: self.ptr live until here.
        unsafe { refcount_dec(self.ptr) };
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn roundtrip_small() {
        let buf = ByteBuffer::from_slice(b"hello");
        assert_eq!(buf.capacity(), 5);
        let slice = unsafe { std::slice::from_raw_parts(buf.data_ptr(), buf.capacity() as usize) };
        assert_eq!(slice, b"hello");
    }

    #[test]
    fn clone_shares_buffer() {
        let buf = ByteBuffer::from_slice(b"abc");
        let buf2 = buf.clone();
        assert_eq!(buf.refcount(), 2);
        assert_eq!(buf.data_ptr(), buf2.data_ptr());
    }

    #[test]
    fn cow_isolates_outstanding_view() {
        let mut buf = ByteBufferMut::with_capacity(8);
        // SAFETY: refcount==1 fresh.
        #[allow(clippy::manual_c_str_literals)]
        unsafe {
            std::ptr::copy_nonoverlapping(b"hello\0\0\0".as_ptr(), buf.data_mut_ptr(), 8)
        };
        let view = buf.share();
        assert_eq!(buf.refcount(), 2);

        buf.ensure_unique_for_mutate(5);
        assert_eq!(buf.refcount(), 1);
        assert_eq!(view.refcount(), 1);
        assert_ne!(buf.data_ptr(), view.data_ptr());

        let view_slice = unsafe { std::slice::from_raw_parts(view.data_ptr(), 5) };
        assert_eq!(view_slice, b"hello");

        // SAFETY: refcount==1, sole writer.
        unsafe {
            std::ptr::copy_nonoverlapping(b"WORLD".as_ptr(), buf.data_mut_ptr(), 5);
        }
        let buf_slice = unsafe { std::slice::from_raw_parts(buf.data_ptr(), 5) };
        assert_eq!(buf_slice, b"WORLD");
        let view_slice = unsafe { std::slice::from_raw_parts(view.data_ptr(), 5) };
        assert_eq!(view_slice, b"hello", "outstanding view must remain stable");
    }
}