round_pipers 0.1.0

A way to pipe ndarrays using circular buffers
Documentation
use anyhow::{bail, Result};
use nix::{
    fcntl::{fcntl, SealFlag, F_ADD_SEALS},
    sys::{
        memfd::{memfd_create, MemFdCreateFlag},
        mman::{mmap, munmap, MapFlags, ProtFlags},
    },
    unistd::ftruncate,
};
use std::any::type_name;
use std::ffi::{c_void, CString};
use std::mem::size_of;
use std::num::NonZeroUsize;
use std::os::fd::{AsRawFd, OwnedFd};
use std::path::Path;
use std::ptr::NonNull;
use std::slice::{from_raw_parts, from_raw_parts_mut};

pub(crate) struct CircularBuffer {
    _fd: OwnedFd,
    rwptr: NonNull<c_void>,
    rwptr_copy: NonNull<c_void>,
    nbytes: usize,
}
impl Drop for CircularBuffer {
    fn drop(&mut self) {
        unsafe {
            let _ = munmap(self.rwptr_copy, self.nbytes);
            let _ = munmap(self.rwptr, self.nbytes * 2);
        }
    }
}
impl CircularBuffer {
    pub(crate) fn new(path: impl AsRef<Path>, size: usize) -> Result<CircularBuffer> {
        let fd = memfd_create(
            &CString::new((&path.as_ref()).as_os_str().as_encoded_bytes())?,
            MemFdCreateFlag::MFD_ALLOW_SEALING,
        )?;
        let page_size = 4096;
        let nbytes = (size / page_size + 1) * page_size;
        ftruncate(&fd, nbytes.try_into()?)?;
        let rwptr = unsafe {
            mmap(
                None,
                (nbytes * 2).try_into()?, //Overallocate how much we're requesting so that the next mmap can be contiguous
                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
                MapFlags::MAP_SHARED,
                &fd,
                0,
            )?
        };
        let desired_pointer = (rwptr.as_ptr() as usize) + nbytes;
        let rwptr_copy = unsafe {
            mmap(
                Some(
                    NonZeroUsize::new(desired_pointer).expect("desired pointer has to be non-zero"),
                ),
                nbytes.try_into()?,
                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
                MapFlags::MAP_SHARED | MapFlags::MAP_FIXED,
                &fd,
                0,
            )?
        };
        assert_eq!(rwptr_copy.as_ptr() as usize, desired_pointer);

        //Make it so the file can't grow, shrink, be written to (except by the two mmaps above)
        fcntl(fd.as_raw_fd(), F_ADD_SEALS(SealFlag::F_SEAL_SHRINK))?;
        fcntl(fd.as_raw_fd(), F_ADD_SEALS(SealFlag::F_SEAL_GROW))?;
        fcntl(fd.as_raw_fd(), F_ADD_SEALS(SealFlag::F_SEAL_FUTURE_WRITE))?;

        //Make it so the seals can't be changed.
        fcntl(fd.as_raw_fd(), F_ADD_SEALS(SealFlag::F_SEAL_SEAL))?;

        Ok(CircularBuffer {
            _fd: fd,
            rwptr,
            rwptr_copy,
            nbytes,
        })
    }
    pub(crate) fn view<T>(&self) -> Result<&[T]> {
        let size = size_of::<T>();
        if self.nbytes % size != 0 {
            bail!(
                "Can not divide our buffer size {} by the size of {} which is {}",
                self.nbytes,
                type_name::<T>(),
                size
            );
        }

        Ok(unsafe { from_raw_parts(self.rwptr.as_ptr() as *mut T, self.nbytes * 2 / size) })
    }
    pub(crate) fn view_mut<T>(&self) -> Result<&mut [T]> {
        let size = size_of::<T>();
        if self.nbytes % size != 0 {
            bail!(
                "Can not divide our buffer size {} by the size of {} which is {}",
                self.nbytes,
                type_name::<T>(),
                size
            );
        }

        Ok(unsafe { from_raw_parts_mut(self.rwptr.as_ptr() as *mut T, self.nbytes * 2 / size) })
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_circular_buffer() -> Result<()> {
        let buf = CircularBuffer::new(Path::new("junk"), 1234)?;
        let view: &mut [u8] = buf.view_mut()?;
        assert_eq!(view.len(), 8192);
        for ii in 0..4096 {
            view[ii] = (ii + 1) as u8;
        }
        for ii in 0..8192 {
            assert_eq!(view[ii], (ii + 1) as u8);
        }

        let ro_view: &[u8] = buf.view()?;
        for ii in 0..8192 {
            assert_eq!(ro_view[ii], (ii + 1) as u8);
        }
        Ok(())
    }
}