use super::{
err::{Error as MMFError, MMFResult},
states::MMFLock,
};
use fixedstr::ztr64;
use microseh::try_seh;
use windows::{
core::Error as WErr,
Win32::{
Foundation::HANDLE,
System::Memory::{UnmapViewOfFile, MEMORY_MAPPED_VIEW_ADDRESS},
},
};
use std::cell::Cell;
#[cfg(feature = "impl_mmf")]
use std::{fmt, num::NonZeroUsize, ops::Deref};
#[cfg(feature = "impl_mmf")]
use windows::{
core::{w as widestring, PCWSTR},
Win32::{
Foundation::{CloseHandle, GetLastError, SetLastError, INVALID_HANDLE_VALUE, WIN32_ERROR},
Security::{
AdjustTokenPrivileges, LookupPrivilegeValueW, SE_PRIVILEGE_ENABLED, TOKEN_ADJUST_PRIVILEGES,
TOKEN_PRIVILEGES, TOKEN_QUERY,
},
System::{
Memory::{
CreateFileMappingW, GetLargePageMinimum, MapViewOfFile, OpenFileMappingW, FILE_MAP_ALL_ACCESS,
FILE_MAP_LARGE_PAGES, PAGE_EXECUTE_READWRITE, SEC_COMMIT, SEC_LARGE_PAGES,
},
SystemInformation::{GetSystemInfo, SYSTEM_INFO},
Threading::{GetCurrentProcess, OpenProcessToken},
},
},
};
#[cfg(feature = "impl_mmf")]
use windows_ext::ext::QWordExt;
pub const LOCAL_NAMESPACE: ztr64 = ztr64::const_make("Local\\");
pub const GLOBAL_NAMESPACE: ztr64 = ztr64::const_make("Global\\");
#[cfg(feature = "namespaces")]
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Namespace {
LOCAL = 0,
GLOBAL = 1,
CUSTOM = 2,
}
#[cfg(feature = "namespaces")]
impl TryFrom<u8> for Namespace {
type Error = ();
fn try_from(value: u8) -> Result<Namespace, Self::Error> {
match value {
..=2 => Ok(unsafe { std::mem::transmute::<u8, Self>(value) }),
_ => Err(()),
}
}
}
#[cfg(feature = "namespaces")]
impl fmt::Display for Namespace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::LOCAL => write!(f, "{LOCAL_NAMESPACE}"),
Self::GLOBAL => write!(f, "{GLOBAL_NAMESPACE}"),
_ => write!(f, "A custom namespace was used here."),
}
}
}
pub trait Mmf {
fn read(&self, count: usize) -> MMFResult<Vec<u8>>;
fn read_to_buf(&self, buffer: &mut Vec<u8>, count: usize) -> MMFResult<()>;
unsafe fn read_to_raw(&self, buffer: *mut u8, count: usize) -> MMFResult<()>;
fn size(&self) -> usize;
fn write(&self, buffer: impl Deref<Target = [u8]>) -> MMFResult<()>;
fn read_spin(&self, count: usize, tries: usize) -> MMFResult<Vec<u8>>;
fn read_to_buf_spin(&self, buffer: &mut Vec<u8>, count: usize, tries: usize) -> MMFResult<()>;
unsafe fn read_to_raw_spin(&self, buffer: *mut u8, count: usize, tries: usize) -> MMFResult<()>;
fn write_spin(&self, buffer: impl Deref<Target = [u8]>, tries: usize) -> MMFResult<()>;
}
#[derive(Debug)]
pub struct MemoryMappedFile<LOCK: MMFLock> {
handle: HANDLE,
name: Vec<u16>,
p_name: PCWSTR,
#[allow(dead_code)]
size_high_order: u32,
#[allow(dead_code)]
size_low_order: u32,
size: usize,
lock: LOCK,
map_view: Option<MemoryMappedView>,
write_ptr: *mut u8,
closed: Cell<bool>,
readonly: bool,
}
#[cold]
pub fn w32_enable_lock_mem() -> MMFResult<()> {
_ = try_seh(|| unsafe {
let mut tk = HANDLE::default();
let mut tp = TOKEN_PRIVILEGES::default();
OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES, &mut tk)?;
LookupPrivilegeValueW(None, widestring!("SeLockMemoryPrivilege"), &mut tp.Privileges[0].Luid)?;
tp.PrivilegeCount = 1;
tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
AdjustTokenPrivileges(tk, false, Some(&mut tp), 0, None, None)
});
if unsafe { GetLastError().is_err() } {
return Err(MMFError::LargePagePermissionError);
}
Ok(())
}
#[cfg(feature = "impl_mmf")]
impl<LOCK: MMFLock> MemoryMappedFile<LOCK> {
pub fn new(
size: NonZeroUsize,
name: impl Into<ztr64>,
namespace: Namespace,
large_pages: Option<bool>,
) -> MMFResult<Self> {
let mut sysinfo = SYSTEM_INFO::default();
unsafe { GetSystemInfo(&mut sysinfo) };
let pagesize = sysinfo.dwPageSize as usize;
let alt_pagesize = unsafe { GetLargePageMinimum() }.max(pagesize);
let has_lp = pagesize < alt_pagesize;
let (mut use_size, mut use_lp) = if !large_pages.is_some_and(|lp| lp == false)
&& (alt_pagesize < (size.get() + 5) || large_pages.is_some_and(|lp| lp))
&& has_lp
{
((size.get() + 4).checked_next_multiple_of(alt_pagesize).ok_or(MMFError::NotEnoughMemory)?, true)
} else {
((size.get() + 4).checked_next_multiple_of(pagesize).ok_or(MMFError::NotEnoughMemory)?, false)
};
if use_lp {
if large_pages.is_some() {
w32_enable_lock_mem()?;
} else if w32_enable_lock_mem().is_err() {
(use_size, use_lp) =
((size.get() + 4).checked_next_multiple_of(pagesize).ok_or(MMFError::NotEnoughMemory)?, false);
}
}
let mut init_name: Vec<u16> = (match namespace {
Namespace::GLOBAL => GLOBAL_NAMESPACE,
Namespace::LOCAL => LOCAL_NAMESPACE,
Namespace::CUSTOM => ztr64::new(),
} + name.into())
.encode_utf16()
.collect();
init_name.push(b'\0' as _);
let mmf_name = PCWSTR::from_raw(init_name.as_ptr());
let (dw_low, dw_high) = use_size.split();
let handle = try_seh(|| unsafe {
SetLastError(WIN32_ERROR(0));
CreateFileMappingW(
INVALID_HANDLE_VALUE,
None,
if use_lp {
PAGE_EXECUTE_READWRITE | SEC_LARGE_PAGES | SEC_COMMIT
} else {
PAGE_EXECUTE_READWRITE | SEC_COMMIT
},
dw_high,
dw_low,
mmf_name,
)
})??;
let map_view = try_seh(|| unsafe {
MapViewOfFile(
handle,
if use_lp { FILE_MAP_ALL_ACCESS | FILE_MAP_LARGE_PAGES } else { FILE_MAP_ALL_ACCESS },
0,
0,
use_size,
)
})?;
if unsafe { GetLastError() }.is_err() {
return Err(WErr::from_win32().into());
}
let zeroing = vec![0_u8; use_size];
unsafe { std::ptr::copy(zeroing.as_ptr(), map_view.Value.cast(), zeroing.len()) };
let lock = unsafe { LOCK::from_raw(map_view.Value.cast()).initialize() };
let write_ptr = unsafe { map_view.Value.cast::<u8>().add(4) };
Ok(Self {
handle,
name: init_name,
p_name: mmf_name,
size_high_order: dw_high,
size_low_order: dw_low,
size: use_size - 4,
map_view: Some(map_view.into()),
lock,
write_ptr,
closed: Cell::new(false),
readonly: false,
})
}
pub fn open(
size: NonZeroUsize,
name: impl Into<ztr64>,
namespace: Namespace,
readonly: bool,
large_pages: Option<bool>,
) -> MMFResult<Self> {
let mut sysinfo = SYSTEM_INFO::default();
unsafe { GetSystemInfo(&mut sysinfo) };
let pagesize = sysinfo.dwPageSize as usize;
let alt_pagesize = unsafe { GetLargePageMinimum() }.max(pagesize);
let has_lp = pagesize < alt_pagesize;
let (mut use_size, mut use_lp) = if !large_pages.is_some_and(|lp| lp == false)
&& (alt_pagesize < (size.get() + 5) || large_pages.is_some_and(|lp| lp))
&& has_lp
{
((size.get() + 4).checked_next_multiple_of(alt_pagesize).ok_or(MMFError::NotEnoughMemory)?, true)
} else {
((size.get() + 4).checked_next_multiple_of(pagesize).ok_or(MMFError::NotEnoughMemory)?, false)
};
if use_lp {
if large_pages.is_some() {
w32_enable_lock_mem()?;
} else if w32_enable_lock_mem().is_err() {
(use_size, use_lp) =
((size.get() + 4).checked_next_multiple_of(pagesize).ok_or(MMFError::NotEnoughMemory)?, false);
}
}
let mut init_name: Vec<u16> = (match namespace {
Namespace::GLOBAL => GLOBAL_NAMESPACE,
Namespace::LOCAL => LOCAL_NAMESPACE,
Namespace::CUSTOM => ztr64::new(),
} + name.into())
.encode_utf16()
.collect();
init_name.push(b'\0' as _);
let mmf_name = PCWSTR::from_raw(init_name.as_ptr());
let (dw_low, dw_high) = use_size.split();
let handle = try_seh(|| unsafe {
SetLastError(WIN32_ERROR(0));
OpenFileMappingW(
if use_lp { (FILE_MAP_ALL_ACCESS | FILE_MAP_LARGE_PAGES).0 } else { FILE_MAP_ALL_ACCESS.0 },
false,
mmf_name,
)
})??;
let map_view = try_seh(|| unsafe {
MapViewOfFile(
handle,
if use_lp { FILE_MAP_ALL_ACCESS | FILE_MAP_LARGE_PAGES } else { FILE_MAP_ALL_ACCESS },
0,
0,
use_size,
)
})?;
if unsafe { GetLastError() }.is_err() {
return Err(WErr::from_win32().into());
}
let lock = unsafe { LOCK::from_existing(map_view.Value.cast()) };
let write_ptr = unsafe { map_view.Value.cast::<u8>().add(4) };
Ok(Self {
handle,
name: init_name,
p_name: mmf_name,
size_high_order: dw_high,
size_low_order: dw_low,
size: use_size - 4,
lock,
map_view: Some(map_view.into()),
write_ptr,
closed: Cell::new(false),
readonly,
})
}
pub fn open_read(
size: NonZeroUsize,
name: &str,
namespace: Namespace,
large_pages: Option<bool>,
) -> MMFResult<Self> {
Self::open(size, name, namespace, true, large_pages)
}
pub fn open_write(
size: NonZeroUsize,
name: &str,
namespace: Namespace,
large_pages: Option<bool>,
) -> MMFResult<Self> {
Self::open(size, name, namespace, false, large_pages)
}
pub fn is_writable(&self) -> bool {
!self.readonly && !self.closed.get() && self.lock.initialized()
}
pub fn is_readable(&self) -> bool {
!self.closed.get() && self.lock.initialized()
}
pub fn namespace(&self) -> String {
String::from_utf16_lossy(&self.name).split_once('\\').unwrap_or_default().0.to_owned()
}
pub fn filename(&self) -> String {
let s = String::from_utf16_lossy(&self.name);
s.split_once('\\').map(|s| s.1.to_owned()).unwrap_or(s)
}
pub fn fullname(&self) -> String {
String::from_utf16_lossy(&self.name)
}
pub fn close(&self) -> MMFResult<()> {
self.closed.set(true);
match try_seh(|| unsafe { CloseHandle(self.handle) })?.map_err(MMFError::from) {
Err(MMFError::OS_OK(_)) | Ok(_) => Ok(()),
err => err.map_err(|e| {
eprintln!("Error closing MMF's handle: {:#?}", e);
e
}),
}
}
}
#[cfg(feature = "impl_mmf")]
impl<LOCK: MMFLock> Mmf for MemoryMappedFile<LOCK> {
#[inline]
fn read(&self, count: usize) -> Result<Vec<u8>, MMFError> {
let mut buf = Vec::with_capacity(self.size);
self.read_to_buf(&mut buf, count)?;
Ok(buf)
}
fn read_spin(&self, count: usize, tries: usize) -> MMFResult<Vec<u8>> {
let mut buf = Vec::with_capacity(self.size);
self.read_to_buf_spin(&mut buf, count, tries)?;
Ok(buf)
}
fn read_to_buf(&self, buffer: &mut Vec<u8>, count: usize) -> MMFResult<()> {
let buf_cap = buffer.capacity();
let to_read = if count == 0 { self.size } else { count };
if buf_cap < to_read {
buffer.reserve_exact(to_read - buf_cap);
}
unsafe {
self.read_to_raw(buffer.as_mut_ptr(), count)?;
buffer.set_len(to_read);
}
Ok(())
}
fn read_to_buf_spin(&self, buffer: &mut Vec<u8>, count: usize, tries: usize) -> MMFResult<()> {
let buf_cap = buffer.capacity();
let to_read = if count == 0 { self.size } else { count };
if buf_cap < to_read {
buffer.reserve_exact(to_read - buf_cap);
}
unsafe {
self.read_to_raw_spin(buffer.as_mut_ptr(), count, tries)?;
buffer.set_len(to_read);
}
Ok(())
}
unsafe fn read_to_raw(&self, buffer: *mut u8, count: usize) -> Result<(), MMFError> {
if self.closed.get() {
Err(MMFError::MMF_NotFound)
} else if count == 0 {
Err(MMFError::GeneralFailure)
} else if self.map_view.is_some() {
if !self.lock.initialized() {
return Err(MMFError::Uninitialized);
}
self.lock.lock_read()?;
unsafe {
self.write_ptr.copy_to(buffer, count.min(self.size));
}
self.lock.unlock_read().unwrap();
Ok(())
} else {
Err(MMFError::MMF_NotFound)
}
}
unsafe fn read_to_raw_spin(&self, buffer: *mut u8, count: usize, tries: usize) -> MMFResult<()> {
if self.closed.get() {
Err(MMFError::MMF_NotFound)
} else if count == 0 {
Err(MMFError::GeneralFailure)
} else if self.map_view.is_some() {
self.lock.spin_and_lock_read(tries)?;
unsafe {
self.write_ptr.copy_to(buffer, count.min(self.size));
}
self.lock.unlock_read().unwrap();
Ok(())
} else {
Err(MMFError::MMF_NotFound)
}
}
fn write(&self, buffer: impl Deref<Target = [u8]>) -> MMFResult<()> {
if self.readonly || self.closed.get() {
return Err(MMFError::MMF_NotFound);
}
let cap = buffer.len().min(self.size);
if cap < buffer.len() {
Err(MMFError::NotEnoughMemory)
} else if !self.lock.initialized() {
Err(MMFError::Uninitialized)
} else if self.map_view.is_some() {
self.lock.lock_write()?;
let src_ptr = buffer.as_ptr();
unsafe { src_ptr.copy_to(self.write_ptr, cap) };
self.lock.unlock_write()
} else {
Err(MMFError::MMF_NotFound)
}
}
fn write_spin(&self, buffer: impl Deref<Target = [u8]>, tries: usize) -> MMFResult<()> {
if self.readonly || self.closed.get() {
return Err(MMFError::MMF_NotFound);
}
let cap = buffer.len().min(self.size);
if cap < buffer.len() {
Err(MMFError::NotEnoughMemory)
} else if self.map_view.is_some() {
self.lock.spin_and_lock_write(tries)?;
let src_ptr = buffer.as_ptr();
unsafe { src_ptr.copy_to(self.write_ptr, cap) };
self.lock.unlock_write()
} else {
Err(MMFError::MMF_NotFound)
}
}
fn size(&self) -> usize {
self.size
}
}
#[derive(Debug, Clone)]
pub struct MemoryMappedView {
address: MEMORY_MAPPED_VIEW_ADDRESS,
}
impl From<MEMORY_MAPPED_VIEW_ADDRESS> for MemoryMappedView {
fn from(value: MEMORY_MAPPED_VIEW_ADDRESS) -> Self {
Self { address: value }
}
}
impl MemoryMappedView {
fn unmap(&self) -> MMFResult<()> {
match try_seh(|| unsafe { UnmapViewOfFile(self.address) })?.map_err(MMFError::from) {
Err(MMFError::OS_OK(_)) | Ok(_) => Ok(()),
err => err.map_err(|e| {
eprintln!("Error unmapping the view of the MMF: {:#?}", e);
e
}),
}
}
}
impl Drop for MemoryMappedView {
fn drop(&mut self) {
self.unmap().unwrap_or(())
}
}
#[cfg(feature = "impl_mmf")]
impl<LOCK: MMFLock> Drop for MemoryMappedFile<LOCK> {
fn drop(&mut self) {
if !self.p_name.is_null() {
self.p_name = PCWSTR::null();
self.close().unwrap_or(())
}
}
}
#[cfg(all(feature = "mmf_send", feature = "impl_mmf"))]
unsafe impl<LOCK: MMFLock + Send + Sync> Send for MemoryMappedFile<LOCK> {}
#[cfg(all(feature = "mmf_send", feature = "impl_mmf"))]
unsafe impl<LOCK: MMFLock + Send + Sync> Sync for MemoryMappedFile<LOCK> {}