use std::fs::{File, OpenOptions};
use std::io;
use std::os::windows::ffi::OsStrExt;
use std::os::windows::io::{AsRawHandle, FromRawHandle};
use std::path::{Path, PathBuf};
use std::vec::Vec;
use windows_sys::Win32::Foundation::{
CloseHandle, GENERIC_READ, GENERIC_WRITE, HANDLE, INVALID_HANDLE_VALUE,
};
use windows_sys::Win32::Storage::FileSystem::{
CREATE_ALWAYS, CreateFileW, FILE_ATTRIBUTE_NORMAL, FILE_FLAG_DELETE_ON_CLOSE, FILE_SHARE_READ,
FILE_SHARE_WRITE,
};
use windows_sys::Win32::System::Memory::{
CreateFileMappingW, FILE_MAP_ALL_ACCESS, MEMORY_MAPPED_VIEW_ADDRESS, MapViewOfFile,
PAGE_READWRITE, UnmapViewOfFile,
};
use crate::Region;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileCleanup {
Manual,
Auto,
}
pub struct MmapRegion {
ptr: *mut u8,
len: usize,
file: File,
mapping_handle: HANDLE,
path: PathBuf,
owns_file: bool,
}
impl MmapRegion {
pub fn create(path: &Path, size: usize, cleanup: FileCleanup) -> io::Result<Self> {
if size == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"size must be > 0",
));
}
let file = if cleanup == FileCleanup::Auto {
let path_wide: Vec<u16> = path
.as_os_str()
.encode_wide()
.chain(std::iter::once(0))
.collect();
let handle = unsafe {
CreateFileW(
path_wide.as_ptr(),
GENERIC_READ | GENERIC_WRITE,
FILE_SHARE_READ | FILE_SHARE_WRITE,
std::ptr::null(),
CREATE_ALWAYS,
FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE,
std::ptr::null_mut(),
)
};
if handle == INVALID_HANDLE_VALUE {
let err = io::Error::last_os_error();
let msg = std::format!("Failed to create SHM file at {}: {}", path.display(), err);
return Err(io::Error::new(err.kind(), msg));
}
unsafe { File::from_raw_handle(handle as _) }
} else {
OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(path)
.map_err(|e| {
let msg =
std::format!("Failed to create SHM file at {}: {}", path.display(), e);
io::Error::new(e.kind(), msg)
})?
};
file.set_len(size as u64)?;
let file_handle = file.as_raw_handle() as HANDLE;
let mapping_handle = unsafe {
CreateFileMappingW(
file_handle,
std::ptr::null(), PAGE_READWRITE, (size >> 32) as u32, size as u32, std::ptr::null(), )
};
if mapping_handle.is_null() {
return Err(io::Error::last_os_error());
}
let ptr = unsafe {
MapViewOfFile(
mapping_handle,
FILE_MAP_ALL_ACCESS,
0, 0, size, )
};
if ptr.Value.is_null() {
unsafe { CloseHandle(mapping_handle) };
return Err(io::Error::last_os_error());
}
Ok(Self {
ptr: ptr.Value as *mut u8,
len: size,
file,
mapping_handle,
path: path.to_path_buf(),
owns_file: cleanup == FileCleanup::Manual,
})
}
pub fn attach(path: &Path) -> io::Result<Self> {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(path)
.map_err(|e| {
let msg = std::format!("Failed to open SHM file at {}: {}", path.display(), e);
io::Error::new(e.kind(), msg)
})?;
let metadata = file.metadata()?;
let size = metadata.len() as usize;
if size == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"segment file is empty",
));
}
let file_handle = file.as_raw_handle() as HANDLE;
let mapping_handle = unsafe {
CreateFileMappingW(
file_handle,
std::ptr::null(),
PAGE_READWRITE,
(size >> 32) as u32,
size as u32,
std::ptr::null(),
)
};
if mapping_handle.is_null() {
return Err(io::Error::last_os_error());
}
let ptr = unsafe { MapViewOfFile(mapping_handle, FILE_MAP_ALL_ACCESS, 0, 0, size) };
if ptr.Value.is_null() {
unsafe { CloseHandle(mapping_handle) };
return Err(io::Error::last_os_error());
}
Ok(Self {
ptr: ptr.Value as *mut u8,
len: size,
file,
mapping_handle,
path: path.to_path_buf(),
owns_file: false, })
}
#[inline]
pub fn region(&self) -> Region {
unsafe { Region::from_raw(self.ptr, self.len) }
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn path(&self) -> &Path {
&self.path
}
pub fn take_ownership(&mut self) {
self.owns_file = true;
}
pub fn release_ownership(&mut self) {
self.owns_file = false;
}
pub fn resize(&mut self, new_size: usize) -> io::Result<()> {
if new_size < self.len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"shrinking is not supported",
));
}
if new_size == self.len {
return Ok(()); }
self.file.set_len(new_size as u64)?;
let file_handle = self.file.as_raw_handle() as HANDLE;
let new_mapping = unsafe {
CreateFileMappingW(
file_handle,
std::ptr::null(),
PAGE_READWRITE,
(new_size >> 32) as u32,
new_size as u32,
std::ptr::null(),
)
};
if new_mapping.is_null() {
return Err(io::Error::last_os_error());
}
let new_view = unsafe { MapViewOfFile(new_mapping, FILE_MAP_ALL_ACCESS, 0, 0, new_size) };
if new_view.Value.is_null() {
unsafe { CloseHandle(new_mapping) };
return Err(io::Error::last_os_error());
}
unsafe {
UnmapViewOfFile(MEMORY_MAPPED_VIEW_ADDRESS {
Value: self.ptr as *mut _,
});
CloseHandle(self.mapping_handle);
}
self.ptr = new_view.Value as *mut u8;
self.len = new_size;
self.mapping_handle = new_mapping;
Ok(())
}
pub fn check_and_remap(&mut self) -> io::Result<bool> {
let file_size = self.file.metadata()?.len() as usize;
if file_size > self.len {
self.resize(file_size)?;
Ok(true)
} else {
Ok(false)
}
}
}
impl Drop for MmapRegion {
fn drop(&mut self) {
unsafe {
UnmapViewOfFile(MEMORY_MAPPED_VIEW_ADDRESS {
Value: self.ptr as *mut _,
});
CloseHandle(self.mapping_handle);
}
if self.owns_file {
let _ = std::fs::remove_file(&self.path);
}
}
}
unsafe impl Send for MmapRegion {}
unsafe impl Sync for MmapRegion {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_and_attach() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.shm");
let region1 = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
assert_eq!(region1.len(), 4096);
assert!(path.exists());
let data = region1.region();
unsafe {
std::ptr::write(data.as_ptr(), 0x42);
std::ptr::write(data.as_ptr().add(1), 0x43);
}
let region2 = MmapRegion::attach(&path).unwrap();
assert_eq!(region2.len(), 4096);
let data2 = region2.region();
unsafe {
assert_eq!(std::ptr::read(data2.as_ptr()), 0x42);
assert_eq!(std::ptr::read(data2.as_ptr().add(1)), 0x43);
}
}
#[test]
fn test_cleanup_on_drop() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("cleanup.shm");
{
let _region = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
assert!(path.exists());
}
assert!(!path.exists());
}
#[test]
fn test_attached_does_not_cleanup() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("attached.shm");
let owner = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
{
let _attached = MmapRegion::attach(&path).unwrap();
assert!(path.exists());
}
assert!(path.exists());
drop(owner);
assert!(!path.exists());
}
#[test]
fn test_shared_writes() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("shared.shm");
let region1 = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
let region2 = MmapRegion::attach(&path).unwrap();
let data2 = region2.region();
unsafe {
std::ptr::write(data2.as_ptr().add(100), 0xAB);
}
let data1 = region1.region();
unsafe {
assert_eq!(std::ptr::read(data1.as_ptr().add(100)), 0xAB);
}
}
#[test]
fn test_zero_size_rejected() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("zero.shm");
let result = MmapRegion::create(&path, 0, FileCleanup::Manual);
assert!(result.is_err());
}
#[test]
fn test_resize_grows_region() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("resize.shm");
let mut region = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
assert_eq!(region.len(), 4096);
unsafe {
std::ptr::write(region.region().as_ptr(), 0xAB);
}
region.resize(8192).unwrap();
assert_eq!(region.len(), 8192);
unsafe {
assert_eq!(std::ptr::read(region.region().as_ptr()), 0xAB);
}
unsafe {
std::ptr::write(region.region().as_ptr().add(5000), 0xCD);
assert_eq!(std::ptr::read(region.region().as_ptr().add(5000)), 0xCD);
}
}
#[test]
fn test_resize_shrink_rejected() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("shrink.shm");
let mut region = MmapRegion::create(&path, 8192, FileCleanup::Manual).unwrap();
let result = region.resize(4096);
assert!(result.is_err());
}
#[test]
fn test_check_and_remap() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("remap.shm");
let mut owner = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
let mut guest = MmapRegion::attach(&path).unwrap();
assert_eq!(guest.len(), 4096);
owner.resize(8192).unwrap();
let remapped = guest.check_and_remap().unwrap();
assert!(remapped);
assert_eq!(guest.len(), 8192);
let remapped2 = guest.check_and_remap().unwrap();
assert!(!remapped2);
}
#[test]
fn test_resize_preserves_shared_data() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("shared_resize.shm");
let mut owner = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
let mut guest = MmapRegion::attach(&path).unwrap();
unsafe {
std::ptr::write(owner.region().as_ptr().add(100), 0x42);
}
unsafe {
assert_eq!(std::ptr::read(guest.region().as_ptr().add(100)), 0x42);
}
owner.resize(8192).unwrap();
guest.check_and_remap().unwrap();
unsafe {
assert_eq!(std::ptr::read(guest.region().as_ptr().add(100)), 0x42);
}
unsafe {
std::ptr::write(owner.region().as_ptr().add(5000), 0x99);
}
unsafe {
assert_eq!(std::ptr::read(guest.region().as_ptr().add(5000)), 0x99);
}
}
}