use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
pub struct MemfdBuffer {
ptr: *mut u8,
len: usize,
mmap_len: usize,
fd: RawFd,
}
unsafe impl Send for MemfdBuffer {}
unsafe impl Sync for MemfdBuffer {}
impl MemfdBuffer {
#[cfg(target_os = "linux")]
pub fn new(name: &str, size: usize) -> io::Result<Self> {
use std::ffi::CString;
if size == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Size must be greater than 0",
));
}
let total_size = size + 64;
let c_name = CString::new(name).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid name: {}", e))
})?;
let fd = unsafe { libc::memfd_create(c_name.as_ptr(), 0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { libc::ftruncate(fd, total_size as libc::off_t) } < 0 {
let err = io::Error::last_os_error();
unsafe { libc::close(fd) };
return Err(err);
}
let mmap_ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
total_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
fd,
0,
)
};
if mmap_ptr == libc::MAP_FAILED {
let err = io::Error::last_os_error();
unsafe { libc::close(fd) };
return Err(err);
}
let base_addr = mmap_ptr as usize;
let alignment_offset = (64 - (base_addr % 64)) % 64;
let aligned_ptr = unsafe { (mmap_ptr as *mut u8).add(alignment_offset) };
Ok(Self {
ptr: aligned_ptr,
len: size,
mmap_len: total_size,
fd,
})
}
#[cfg(target_os = "linux")]
pub fn reopen(creator_pid: u32, fd: RawFd, size: usize) -> io::Result<Self> {
use std::fs::OpenOptions;
let path = format!("/proc/{}/fd/{}", creator_pid, fd);
let file = OpenOptions::new().read(true).write(true).open(&path)?;
let metadata = file.metadata()?;
let total_size = metadata.len() as usize;
if total_size < size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Memfd too small: expected at least {} bytes, got {}",
size, total_size
),
));
}
let our_fd = file.as_raw_fd();
let mmap_ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
total_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
our_fd,
0,
)
};
if mmap_ptr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
let owned_fd = unsafe {
let dup_fd = libc::dup(our_fd);
if dup_fd < 0 {
libc::munmap(mmap_ptr, total_size);
return Err(io::Error::last_os_error());
}
dup_fd
};
let base_addr = mmap_ptr as usize;
let alignment_offset = (64 - (base_addr % 64)) % 64;
let aligned_ptr = unsafe { (mmap_ptr as *mut u8).add(alignment_offset) };
Ok(Self {
ptr: aligned_ptr,
len: size,
mmap_len: total_size,
fd: owned_fd,
})
}
#[inline]
pub fn fd(&self) -> RawFd {
self.fd
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl AsRef<[u8]> for MemfdBuffer {
#[inline]
fn as_ref(&self) -> &[u8] {
self.as_slice()
}
}
impl AsMut<[u8]> for MemfdBuffer {
#[inline]
fn as_mut(&mut self) -> &mut [u8] {
self.as_mut_slice()
}
}
impl Drop for MemfdBuffer {
fn drop(&mut self) {
let base_addr = self.ptr as usize;
let alignment_offset = base_addr % 64;
let mmap_ptr = unsafe { self.ptr.sub(alignment_offset) };
unsafe {
libc::munmap(mmap_ptr as *mut libc::c_void, self.mmap_len);
}
unsafe {
libc::close(self.fd);
}
}
}
use crate::Vec64;
use crate::structs::shared_buffer::SharedBuffer;
use crate::structs::shared_buffer::internal::vtable::Vtable;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
#[repr(C)]
pub(crate) struct OwnedMemfd {
pub(crate) ref_cnt: AtomicUsize,
pub(crate) owner: MemfdBuffer,
}
unsafe fn memfd_clone(h: &AtomicPtr<()>, p: *const u8, l: usize) -> SharedBuffer {
let raw = h.load(Ordering::Acquire);
assert!(!raw.is_null());
let ref_cnt = unsafe { &*(raw as *const AtomicUsize) };
ref_cnt.fetch_add(1, Ordering::Relaxed);
SharedBuffer {
ptr: p,
len: l,
data: AtomicPtr::new(raw),
vtable: &MEMFD_VT,
}
}
unsafe fn memfd_drop(h: &mut AtomicPtr<()>, _p: *const u8, _l: usize) {
let raw = h.load(Ordering::Acquire);
if raw.is_null() {
return;
}
let ref_cnt = unsafe { &*(raw as *const AtomicUsize) };
if ref_cnt.fetch_sub(1, Ordering::AcqRel) == 1 {
drop(unsafe { Box::from_raw(raw as *mut OwnedMemfd) });
}
}
unsafe fn memfd_is_unique(h: &AtomicPtr<()>) -> bool {
let raw = h.load(Ordering::Acquire);
if raw.is_null() {
return false;
}
let ref_cnt = unsafe { &*(raw as *const AtomicUsize) };
ref_cnt.load(Ordering::Acquire) == 1
}
unsafe fn memfd_get_fd(h: &AtomicPtr<()>) -> Option<i32> {
let raw = h.load(Ordering::Acquire);
if raw.is_null() {
return None;
}
let owned = unsafe { &*(raw as *const OwnedMemfd) };
Some(owned.owner.fd)
}
pub(crate) static MEMFD_VT: Vtable = Vtable {
clone: memfd_clone,
drop: memfd_drop,
is_unique: memfd_is_unique,
to_vec: |_, p, l| unsafe { std::slice::from_raw_parts(p, l) }.to_vec(),
to_vec64: |_, p, l| {
let mut v = Vec64::with_capacity(l);
unsafe {
v.extend_from_slice(std::slice::from_raw_parts(p, l));
}
v
},
memfd_fd: memfd_get_fd,
};
impl SharedBuffer {
pub fn from_memfd_owner(memfd: MemfdBuffer) -> Self {
let raw: *mut OwnedMemfd = Box::into_raw(Box::new(OwnedMemfd {
ref_cnt: AtomicUsize::new(1),
owner: memfd,
}));
let buf = unsafe { (*raw).owner.as_ref() };
Self {
ptr: buf.as_ptr(),
len: buf.len(),
data: AtomicPtr::new(raw.cast()),
vtable: &MEMFD_VT,
}
}
#[inline]
pub fn memfd_fd(&self) -> Option<i32> {
unsafe { (self.vtable.memfd_fd)(&self.data) }
}
}
#[cfg(all(test, target_os = "linux"))]
mod tests {
use super::*;
#[test]
fn test_create_memfd() {
let memfd = MemfdBuffer::new("test_buffer", 1024).expect("Failed to create memfd");
assert_eq!(memfd.len(), 1024);
assert!(!memfd.is_empty());
assert!(memfd.fd() >= 0);
}
#[test]
fn test_alignment() {
let memfd = MemfdBuffer::new("align_test", 1024).expect("Failed to create memfd");
let ptr_addr = memfd.as_slice().as_ptr() as usize;
assert_eq!(ptr_addr % 64, 0, "Buffer should be 64-byte aligned");
}
#[test]
fn test_read_write() {
let mut memfd = MemfdBuffer::new("rw_test", 256).expect("Failed to create memfd");
let data = b"Hello, memfd!";
memfd.as_mut_slice()[..data.len()].copy_from_slice(data);
assert_eq!(&memfd.as_slice()[..data.len()], data);
}
#[test]
fn test_reopen_same_process() {
let memfd = MemfdBuffer::new("reopen_test", 512).expect("Failed to create memfd");
let fd = memfd.fd();
let pid = std::process::id();
let data = b"Test data for reopen";
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), memfd.ptr, data.len());
}
let reopened = MemfdBuffer::reopen(pid, fd, 512).expect("Failed to reopen memfd");
assert_eq!(&reopened.as_slice()[..data.len()], data);
}
#[test]
fn test_zero_size_rejected() {
let result = MemfdBuffer::new("zero_test", 0);
assert!(result.is_err());
}
#[test]
fn test_with_shared_buffer() {
let memfd = MemfdBuffer::new("shared_test", 256).expect("Failed to create memfd");
let shared = SharedBuffer::from_owner(memfd);
assert_eq!(shared.len(), 256);
}
#[test]
fn test_memfd_fd_extraction() {
let memfd = MemfdBuffer::new("fd_test", 256).expect("Failed to create memfd");
let expected_fd = memfd.fd();
let shared = SharedBuffer::from_memfd_owner(memfd);
assert_eq!(shared.len(), 256);
let extracted_fd = shared.memfd_fd();
assert_eq!(extracted_fd, Some(expected_fd));
}
#[test]
fn test_memfd_fd_extraction_from_buffer() {
use crate::structs::buffer::Buffer;
let buffer: Buffer<i64> =
Buffer::from_memfd("buffer_fd_test", 100).expect("Failed to create memfd buffer");
let fd = buffer.memfd_fd();
assert!(fd.is_some());
assert!(fd.unwrap() >= 0);
}
#[test]
fn test_cloned_buffer_preserves_fd() {
let memfd = MemfdBuffer::new("clone_test", 256).expect("Failed to create memfd");
let expected_fd = memfd.fd();
let shared = SharedBuffer::from_memfd_owner(memfd);
let cloned = shared.clone();
assert_eq!(shared.memfd_fd(), Some(expected_fd));
assert_eq!(cloned.memfd_fd(), Some(expected_fd));
}
#[test]
fn test_non_memfd_buffer_returns_none() {
let shared = SharedBuffer::from_vec(vec![1, 2, 3, 4]);
assert_eq!(shared.memfd_fd(), None);
}
}