cp2k-rs 0.2.0

Rust bindings for CP2K with Python interface
Documentation
//! POSIX shared memory helpers for zero-copy IPC of large arrays.
//!
//! The worker process creates a shared memory region, writes data into it,
//! and sends only the region name + metadata over the IPC socket. The client
//! opens the same region, reads the data, and unlinks it.

use std::ffi::CString;

/// A POSIX shared memory region backed by `shm_open` + `mmap`.
///
/// On drop, the region is `munmap`'d but **not** unlinked. The consumer
/// must call [`ShmRegion::unlink`] explicitly when done.
pub struct ShmRegion {
    name: String,
    ptr: *mut u8,
    size: usize,
    fd: i32,
}

// SAFETY: The mmap'd region is process-local memory; the pointer is only
// dereferenced on the thread that owns the ShmRegion.
unsafe impl Send for ShmRegion {}

impl ShmRegion {
    /// Create a new shared memory region of `size` bytes.
    ///
    /// The name is generated as `/cp2k_<pid>_<counter>` to avoid collisions.
    pub fn create(size: usize) -> Result<Self, String> {
        static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
        let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        let name = format!("/cp2k_{}_{}", std::process::id(), id);
        Self::create_named(&name, size)
    }

    /// Create a shared memory region with a specific name.
    fn create_named(name: &str, size: usize) -> Result<Self, String> {
        let c_name = CString::new(name).map_err(|e| e.to_string())?;

        unsafe {
            let fd = libc::shm_open(
                c_name.as_ptr(),
                libc::O_CREAT | libc::O_RDWR | libc::O_EXCL,
                0o600,
            );
            if fd < 0 {
                return Err(format!(
                    "shm_open({name}) failed: {}",
                    std::io::Error::last_os_error()
                ));
            }

            if libc::ftruncate(fd, size as libc::off_t) != 0 {
                let err = std::io::Error::last_os_error();
                libc::close(fd);
                libc::shm_unlink(c_name.as_ptr());
                return Err(format!("ftruncate failed: {err}"));
            }

            let ptr = libc::mmap(
                std::ptr::null_mut(),
                size,
                libc::PROT_READ | libc::PROT_WRITE,
                libc::MAP_SHARED,
                fd,
                0,
            );
            if ptr == libc::MAP_FAILED {
                let err = std::io::Error::last_os_error();
                libc::close(fd);
                libc::shm_unlink(c_name.as_ptr());
                return Err(format!("mmap failed: {err}"));
            }

            Ok(ShmRegion {
                name: name.to_string(),
                ptr: ptr as *mut u8,
                size,
                fd,
            })
        }
    }

    /// Open an existing shared memory region (read-only).
    pub fn open_readonly(name: &str, size: usize) -> Result<Self, String> {
        let c_name = CString::new(name).map_err(|e| e.to_string())?;

        unsafe {
            let fd = libc::shm_open(c_name.as_ptr(), libc::O_RDONLY, 0);
            if fd < 0 {
                return Err(format!(
                    "shm_open({name}) failed: {}",
                    std::io::Error::last_os_error()
                ));
            }

            let ptr = libc::mmap(
                std::ptr::null_mut(),
                size,
                libc::PROT_READ,
                libc::MAP_SHARED,
                fd,
                0,
            );
            if ptr == libc::MAP_FAILED {
                let err = std::io::Error::last_os_error();
                libc::close(fd);
                return Err(format!("mmap failed: {err}"));
            }

            Ok(ShmRegion {
                name: name.to_string(),
                ptr: ptr as *mut u8,
                size,
                fd,
            })
        }
    }

    /// The name of the shared memory region (e.g. `/cp2k_12345_0`).
    pub fn name(&self) -> &str {
        &self.name
    }

    /// Size in bytes.
    pub fn size(&self) -> usize {
        self.size
    }

    /// Get a mutable pointer to the shared memory (for writing data).
    pub fn as_mut_ptr(&mut self) -> *mut u8 {
        self.ptr
    }

    /// Get a slice view of the shared memory.
    pub fn as_slice(&self) -> &[u8] {
        unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
    }

    /// Interpret the shared memory as a slice of `f64`.
    pub fn as_f64_slice(&self) -> &[f64] {
        let n = self.size / std::mem::size_of::<f64>();
        unsafe { std::slice::from_raw_parts(self.ptr as *const f64, n) }
    }

    /// Unlink (delete) the shared memory segment from the filesystem namespace.
    ///
    /// The backing memory remains accessible until all mappings are dropped.
    pub fn unlink(&self) {
        if let Ok(c_name) = CString::new(self.name.as_str()) {
            unsafe {
                libc::shm_unlink(c_name.as_ptr());
            }
        }
    }
}

impl Drop for ShmRegion {
    fn drop(&mut self) {
        unsafe {
            libc::munmap(self.ptr as *mut libc::c_void, self.size);
            libc::close(self.fd);
        }
    }
}