use core::alloc::Layout;
use core::mem::transmute;
use core::ptr::{self, NonNull};
use core::sync::atomic::{AtomicUsize, Ordering};
use core::{fmt, slice};
#[cfg(all(windows, not(miri)))]
use core::mem::MaybeUninit;
use flex_alloc::alloc::{AllocError, Allocator, AllocatorDefault, AllocatorZeroizes};
use flex_alloc::StorageError;
use zeroize::Zeroize;
#[cfg(all(unix, not(miri)))]
use libc::{free, mlock, mprotect, munlock, posix_memalign};
#[cfg(all(windows, not(miri)))]
use windows_sys::Win32::System::{Memory, SystemInformation};
pub const UNINIT_ALLOC_BYTE: u8 = 0xdb;
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct MemoryError;
impl fmt::Display for MemoryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Memory error")
}
}
impl std::error::Error for MemoryError {}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum ProtectionMode {
NoAccess,
ReadOnly,
#[default]
ReadWrite,
}
impl ProtectionMode {
#[cfg(all(unix, not(miri)))]
pub(crate) const fn as_native(self) -> i32 {
match self {
Self::NoAccess => libc::PROT_NONE,
Self::ReadOnly => libc::PROT_READ,
Self::ReadWrite => libc::PROT_READ | libc::PROT_WRITE,
}
}
#[cfg(all(windows, not(miri)))]
pub(crate) const fn as_native(self) -> u32 {
match self {
Self::NoAccess => windows_sys::Win32::System::Memory::PAGE_NOACCESS,
Self::ReadOnly => windows_sys::Win32::System::Memory::PAGE_READONLY,
Self::ReadWrite => windows_sys::Win32::System::Memory::PAGE_READWRITE,
}
}
}
pub fn default_page_size() -> usize {
static CACHE: AtomicUsize = AtomicUsize::new(0);
let mut size = CACHE.load(Ordering::Relaxed);
if size == 0 {
#[cfg(miri)]
{
size = 4096;
}
#[cfg(all(target_os = "macos", not(miri)))]
{
size = unsafe { libc::vm_page_size };
}
#[cfg(all(unix, not(target_os = "macos"), not(miri)))]
{
size = unsafe { libc::sysconf(libc::_SC_PAGE_SIZE) } as usize;
}
#[cfg(all(windows, not(miri)))]
{
let mut sysinfo = MaybeUninit::<SystemInformation::SYSTEM_INFO>::uninit();
unsafe { SystemInformation::GetSystemInfo(sysinfo.as_mut_ptr()) };
size = unsafe { sysinfo.assume_init_ref() }.dwPageSize as usize;
}
debug_assert_ne!(size, 0);
debug_assert_eq!(size % size_of::<*const ()>(), 0);
CACHE.store(size, Ordering::Relaxed);
}
size
}
pub fn alloc_pages(len: usize) -> Result<NonNull<[u8]>, AllocError> {
let page_size = default_page_size();
let alloc_len = page_rounded_length(len, page_size);
#[cfg(miri)]
{
let addr =
unsafe { std::alloc::alloc(Layout::from_size_align_unchecked(alloc_len, page_size)) };
let range = ptr::slice_from_raw_parts_mut(addr, alloc_len);
NonNull::new(range).ok_or_else(|| AllocError)
}
#[cfg(all(unix, not(miri)))]
{
let mut addr = ptr::null_mut();
let ret = unsafe { posix_memalign(&mut addr, page_size, alloc_len) };
if ret == 0 {
let range = ptr::slice_from_raw_parts_mut(addr.cast(), alloc_len);
Ok(NonNull::new(range).expect("null allocation result"))
} else {
Err(AllocError)
}
}
#[cfg(all(windows, not(miri)))]
{
let addr = unsafe {
Memory::VirtualAlloc(
ptr::null_mut(),
alloc_len,
Memory::MEM_COMMIT | Memory::MEM_RESERVE,
Memory::PAGE_READWRITE,
)
};
let range = ptr::slice_from_raw_parts_mut(addr.cast(), alloc_len);
NonNull::new(range).ok_or_else(|| AllocError)
}
}
pub fn dealloc_pages(addr: *mut u8, len: usize) {
#[cfg(miri)]
{
let page_size = default_page_size();
let alloc_len = page_rounded_length(len, page_size);
unsafe {
std::alloc::dealloc(
addr,
Layout::from_size_align_unchecked(alloc_len, page_size),
)
};
return;
}
#[cfg(all(unix, not(miri)))]
{
let _ = len;
unsafe { free(addr.cast()) };
}
#[cfg(all(windows, not(miri)))]
{
let _ = len;
unsafe { Memory::VirtualFree(addr.cast(), 0, Memory::MEM_RELEASE) };
}
}
pub fn lock_pages(addr: *mut u8, len: usize) -> Result<(), MemoryError> {
#[cfg(miri)]
{
_ = (addr, len);
Ok(())
}
#[cfg(all(unix, not(miri)))]
{
#[cfg(target_os = "linux")]
madvise(addr, len, libc::MADV_DONTDUMP)?;
#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
madvise(addr, len, libc::MADV_NOCORE)?;
let res = unsafe { mlock(addr.cast(), len) };
if res == 0 {
Ok(())
} else {
Err(MemoryError)
}
}
#[cfg(all(windows, not(miri)))]
{
let res = unsafe { Memory::VirtualLock(addr.cast(), len) };
if res != 0 {
Ok(())
} else {
Err(MemoryError)
}
}
}
pub fn unlock_pages(addr: *mut u8, len: usize) -> Result<(), MemoryError> {
#[cfg(miri)]
{
_ = (addr, len);
Ok(())
}
#[cfg(all(unix, not(miri)))]
{
#[cfg(target_os = "linux")]
madvise(addr, len, libc::MADV_DODUMP)?;
#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
madvise(addr, len, libc::MADV_CORE)?;
let res = unsafe { munlock(addr.cast(), len) };
if res == 0 {
Ok(())
} else {
Err(MemoryError)
}
}
#[cfg(all(windows, not(miri)))]
{
let res = unsafe { Memory::VirtualUnlock(addr.cast(), len) };
if res != 0 {
Ok(())
} else {
Err(MemoryError)
}
}
}
pub fn set_page_protection(
addr: *mut u8,
len: usize,
mode: ProtectionMode,
) -> Result<(), MemoryError> {
#[cfg(miri)]
{
_ = (addr, len, mode);
Ok(())
}
#[cfg(all(unix, not(miri)))]
{
let res = unsafe { mprotect(addr.cast(), len, mode.as_native()) };
if res == 0 {
Ok(())
} else {
Err(MemoryError)
}
}
#[cfg(all(windows, not(miri)))]
{
let mut prev_mode = MaybeUninit::<u32>::uninit();
let res = unsafe {
Memory::VirtualProtect(addr.cast(), len, mode.as_native(), prev_mode.as_mut_ptr())
};
if res != 0 {
Ok(())
} else {
Err(MemoryError)
}
}
}
#[cfg(unix)]
#[allow(unused)]
#[inline]
fn madvise(addr: *mut u8, len: usize, advice: i32) -> Result<(), MemoryError> {
{
let res = unsafe { libc::madvise(addr.cast(), len, advice) };
if res == 0 {
Ok(())
} else {
Err(MemoryError)
}
}
}
#[inline(always)]
pub fn page_rounded_length(len: usize, page_size: usize) -> usize {
len + ((page_size - (len & (page_size - 1))) % page_size)
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct SecureAlloc;
impl SecureAlloc {
pub(crate) fn set_page_protection(
&self,
ptr: *mut u8,
len: usize,
mode: ProtectionMode,
) -> Result<(), StorageError> {
if len != 0 {
let alloc_len = page_rounded_length(len, default_page_size());
set_page_protection(ptr, alloc_len, mode).map_err(|_| StorageError::ProtectionError)
} else {
Ok(())
}
}
}
unsafe impl Allocator for SecureAlloc {
#[inline]
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
debug_assert!(
layout.align() <= default_page_size(),
"alignment cannot exceed page size"
);
let layout_len = layout.size();
if layout_len == 0 {
#[allow(clippy::useless_transmute)]
let head = unsafe { NonNull::new_unchecked(transmute(layout.align())) };
Ok(NonNull::slice_from_raw_parts(head, 0))
} else {
let alloc = alloc_pages(layout_len).map_err(|_| AllocError)?;
let alloc_len = alloc.len();
unsafe { ptr::write_bytes(alloc.as_ptr().cast::<u8>(), UNINIT_ALLOC_BYTE, alloc_len) };
lock_pages(alloc.as_ptr().cast(), alloc_len).map_err(|_| AllocError)?;
Ok(alloc)
}
}
#[inline]
unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
let len = layout.size();
if len > 0 {
let alloc_len = page_rounded_length(len, default_page_size());
let mem = unsafe { slice::from_raw_parts_mut(ptr.as_ptr(), alloc_len) };
mem.zeroize();
unlock_pages(ptr.as_ptr().cast(), alloc_len).ok();
dealloc_pages(ptr.as_ptr(), alloc_len);
}
}
}
impl AllocatorDefault for SecureAlloc {
const DEFAULT: Self = Self;
}
impl AllocatorZeroizes for SecureAlloc {}
#[cfg(test)]
mod tests {
use core::alloc::Layout;
use flex_alloc::alloc::Allocator;
use crate::{alloc::UNINIT_ALLOC_BYTE, vec::SecureVec};
use super::SecureAlloc;
#[test]
fn check_extra_capacity() {
let vec = SecureVec::<usize>::with_capacity(1);
assert!(vec.capacity() > 1);
}
#[test]
fn check_uninit() {
let layout = Layout::new::<usize>();
let buf = SecureAlloc.allocate(layout).expect("allocation error");
#[allow(clippy::len_zero)]
{
assert!(buf.len() != 0 && unsafe { buf.as_ref() }[..4] == [UNINIT_ALLOC_BYTE; 4]);
}
unsafe { SecureAlloc.deallocate(buf.cast(), layout) };
}
}