use std::fs::{File, OpenOptions};
use std::io;
use std::os::unix::fs::PermissionsExt;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::path::{Path, PathBuf};
use crate::Region;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileCleanup {
Manual,
Auto,
}
pub struct MmapRegion {
ptr: *mut u8,
len: usize,
#[allow(dead_code)]
file: File,
path: PathBuf,
owns_file: bool,
}
impl MmapRegion {
pub fn create(path: &Path, size: usize, cleanup: FileCleanup) -> io::Result<Self> {
if size == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"size must be > 0",
));
}
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(path)
.map_err(|e| {
let msg = std::format!("Failed to create SHM file at {}: {}", path.display(), e);
io::Error::new(e.kind(), msg)
})?;
file.set_permissions(std::fs::Permissions::from_mode(0o666))?;
file.set_len(size as u64)?;
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
file.as_raw_fd(),
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
let path_buf = path.to_path_buf();
if cleanup == FileCleanup::Auto {
std::fs::remove_file(&path_buf)?;
}
Ok(Self {
ptr: ptr as *mut u8,
len: size,
file,
path: path_buf,
owns_file: cleanup == FileCleanup::Manual,
})
}
pub fn attach(path: &Path) -> io::Result<Self> {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(path)
.map_err(|e| {
let msg = std::format!("Failed to open SHM file at {}: {}", path.display(), e);
io::Error::new(e.kind(), msg)
})?;
let metadata = file.metadata()?;
let size = metadata.len() as usize;
if size == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"segment file is empty",
));
}
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
file.as_raw_fd(),
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
Ok(Self {
ptr: ptr as *mut u8,
len: size,
file,
path: path.to_path_buf(),
owns_file: false, })
}
pub fn attach_fd(fd: OwnedFd, size: usize) -> io::Result<Self> {
if size == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"size must be > 0",
));
}
let raw_fd = fd.as_raw_fd();
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
raw_fd,
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
let file = unsafe { File::from_raw_fd(fd.into_raw_fd()) };
Ok(Self {
ptr: ptr as *mut u8,
len: size,
file,
path: PathBuf::new(),
owns_file: false,
})
}
pub fn as_raw_fd(&self) -> RawFd {
self.file.as_raw_fd()
}
#[inline]
pub fn region(&self) -> Region {
unsafe { Region::from_raw(self.ptr, self.len) }
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn path(&self) -> &Path {
&self.path
}
pub fn take_ownership(&mut self) {
self.owns_file = true;
}
pub fn release_ownership(&mut self) {
self.owns_file = false;
}
pub fn resize(&mut self, new_size: usize) -> io::Result<()> {
if new_size < self.len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"shrinking is not supported",
));
}
if new_size == self.len {
return Ok(()); }
self.file.set_len(new_size as u64)?;
let new_ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
new_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
self.file.as_raw_fd(),
0,
)
};
if new_ptr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.len) };
self.ptr = new_ptr as *mut u8;
self.len = new_size;
Ok(())
}
pub fn check_and_remap(&mut self) -> io::Result<bool> {
let file_size = self.file.metadata()?.len() as usize;
if file_size > self.len {
self.resize(file_size)?;
Ok(true)
} else {
Ok(false)
}
}
}
impl Drop for MmapRegion {
fn drop(&mut self) {
unsafe {
libc::munmap(self.ptr as *mut libc::c_void, self.len);
}
if self.owns_file {
let _ = std::fs::remove_file(&self.path);
}
}
}
unsafe impl Send for MmapRegion {}
unsafe impl Sync for MmapRegion {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_and_attach() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.shm");
let region1 = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
assert_eq!(region1.len(), 4096);
assert!(path.exists());
let data = region1.region();
unsafe {
std::ptr::write(data.as_ptr(), 0x42);
std::ptr::write(data.as_ptr().add(1), 0x43);
}
let region2 = MmapRegion::attach(&path).unwrap();
assert_eq!(region2.len(), 4096);
let data2 = region2.region();
unsafe {
assert_eq!(std::ptr::read(data2.as_ptr()), 0x42);
assert_eq!(std::ptr::read(data2.as_ptr().add(1)), 0x43);
}
}
#[test]
fn test_cleanup_on_drop() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("cleanup.shm");
{
let _region = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
assert!(path.exists());
}
assert!(!path.exists());
}
#[test]
fn test_attached_does_not_cleanup() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("attached.shm");
let owner = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
{
let _attached = MmapRegion::attach(&path).unwrap();
assert!(path.exists());
}
assert!(path.exists());
drop(owner);
assert!(!path.exists());
}
#[test]
fn test_shared_writes() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("shared.shm");
let region1 = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
let region2 = MmapRegion::attach(&path).unwrap();
let data2 = region2.region();
unsafe {
std::ptr::write(data2.as_ptr().add(100), 0xAB);
}
let data1 = region1.region();
unsafe {
assert_eq!(std::ptr::read(data1.as_ptr().add(100)), 0xAB);
}
}
#[test]
fn test_permissions() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("perms.shm");
let _region = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
let metadata = std::fs::metadata(&path).unwrap();
let mode = metadata.permissions().mode() & 0o777;
assert_eq!(mode, 0o666);
}
#[test]
fn test_zero_size_rejected() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("zero.shm");
let result = MmapRegion::create(&path, 0, FileCleanup::Manual);
assert!(result.is_err());
}
#[test]
fn test_resize_grows_region() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("resize.shm");
let mut region = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
assert_eq!(region.len(), 4096);
unsafe {
std::ptr::write(region.region().as_ptr(), 0xAB);
}
region.resize(8192).unwrap();
assert_eq!(region.len(), 8192);
unsafe {
assert_eq!(std::ptr::read(region.region().as_ptr()), 0xAB);
}
unsafe {
std::ptr::write(region.region().as_ptr().add(5000), 0xCD);
assert_eq!(std::ptr::read(region.region().as_ptr().add(5000)), 0xCD);
}
}
#[test]
fn test_resize_shrink_rejected() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("shrink.shm");
let mut region = MmapRegion::create(&path, 8192, FileCleanup::Manual).unwrap();
let result = region.resize(4096);
assert!(result.is_err());
}
#[test]
fn test_check_and_remap() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("remap.shm");
let mut owner = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
let mut guest = MmapRegion::attach(&path).unwrap();
assert_eq!(guest.len(), 4096);
owner.resize(8192).unwrap();
let remapped = guest.check_and_remap().unwrap();
assert!(remapped);
assert_eq!(guest.len(), 8192);
let remapped2 = guest.check_and_remap().unwrap();
assert!(!remapped2);
}
#[test]
fn test_resize_preserves_shared_data() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("shared_resize.shm");
let mut owner = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
let mut guest = MmapRegion::attach(&path).unwrap();
unsafe {
std::ptr::write(owner.region().as_ptr().add(100), 0x42);
}
unsafe {
assert_eq!(std::ptr::read(guest.region().as_ptr().add(100)), 0x42);
}
owner.resize(8192).unwrap();
guest.check_and_remap().unwrap();
unsafe {
assert_eq!(std::ptr::read(guest.region().as_ptr().add(100)), 0x42);
}
unsafe {
std::ptr::write(owner.region().as_ptr().add(5000), 0x99);
}
unsafe {
assert_eq!(std::ptr::read(guest.region().as_ptr().add(5000)), 0x99);
}
}
}