use super::{MemoryRegion, Protection};
use crate::error::{Result, ShroudError};
use crate::policy::Policy;
pub fn page_size() -> usize {
unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize }
}
fn round_up_to_page(size: usize, page_size: usize) -> usize {
(size + page_size - 1) & !(page_size - 1)
}
#[allow(dead_code)]
pub fn allocate(size: usize, policy: Policy) -> Result<MemoryRegion> {
allocate_aligned(size, 1, policy)
}
pub fn allocate_aligned(size: usize, alignment: usize, policy: Policy) -> Result<MemoryRegion> {
let page_sz = page_size();
debug_assert!(
alignment.is_power_of_two(),
"alignment must be a power of 2"
);
debug_assert!(
alignment <= page_sz,
"alignment ({}) cannot exceed page size ({})",
alignment,
page_sz
);
if size == 0 {
return Ok(MemoryRegion {
ptr: core::ptr::NonNull::dangling().as_ptr(),
len: 0,
alloc_ptr: core::ptr::NonNull::dangling().as_ptr(),
alloc_len: 0,
alloc_align: alignment,
is_locked: false,
has_guard_pages: false,
is_protected: core::sync::atomic::AtomicBool::new(false),
});
}
let data_pages = round_up_to_page(size, page_sz);
let use_guard_pages = cfg!(feature = "guard-pages") && policy.protection_enabled();
let guard_size = if use_guard_pages { page_sz } else { 0 };
let total_size = guard_size
.checked_add(data_pages)
.and_then(|s| s.checked_add(guard_size))
.ok_or_else(|| ShroudError::AllocationFailed("size calculation overflow".to_string()))?;
let alloc_ptr = unsafe {
libc::mmap(
core::ptr::null_mut(),
total_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
-1,
0,
)
};
if alloc_ptr == libc::MAP_FAILED {
return Err(ShroudError::AllocationFailed(format!(
"mmap failed: {}",
std::io::Error::last_os_error()
)));
}
let alloc_ptr = alloc_ptr as *mut u8;
let data_ptr = unsafe { alloc_ptr.add(guard_size) };
debug_assert!(
data_ptr as usize % alignment == 0,
"data pointer {:p} is not aligned to {} bytes",
data_ptr,
alignment
);
let has_guard_pages = if use_guard_pages {
let leading_guard_ok = set_protection(alloc_ptr, page_sz, Protection::None).is_ok();
let trailing_guard_ptr = unsafe { alloc_ptr.add(guard_size + data_pages) };
let trailing_guard_ok =
set_protection(trailing_guard_ptr, page_sz, Protection::None).is_ok();
if policy.is_strict() && (!leading_guard_ok || !trailing_guard_ok) {
unsafe { libc::munmap(alloc_ptr as *mut libc::c_void, total_size) };
return Err(ShroudError::ProtectFailed(
"failed to set up guard pages".to_string(),
));
}
leading_guard_ok && trailing_guard_ok
} else {
false
};
let use_mlock = cfg!(feature = "mlock") && policy.protection_enabled();
let is_locked = if use_mlock {
let result = unsafe { libc::mlock(data_ptr as *const libc::c_void, data_pages) };
if result == 0 {
true
} else if policy.is_strict() {
unsafe { libc::munmap(alloc_ptr as *mut libc::c_void, total_size) };
return Err(ShroudError::LockFailed(format!(
"mlock failed: {}",
std::io::Error::last_os_error()
)));
} else {
false
}
} else {
false
};
#[cfg(target_os = "linux")]
{
const MADV_DONTDUMP: libc::c_int = 16;
let result =
unsafe { libc::madvise(data_ptr as *mut libc::c_void, data_pages, MADV_DONTDUMP) };
if result != 0 && policy.is_strict() {
if is_locked {
unsafe { libc::munlock(data_ptr as *const libc::c_void, data_pages) };
}
unsafe { libc::munmap(alloc_ptr as *mut libc::c_void, total_size) };
return Err(ShroudError::ProtectFailed(format!(
"MADV_DONTDUMP failed: {}",
std::io::Error::last_os_error()
)));
}
}
Ok(MemoryRegion {
ptr: data_ptr,
len: size,
alloc_ptr,
alloc_len: total_size,
alloc_align: alignment,
is_locked,
has_guard_pages,
is_protected: core::sync::atomic::AtomicBool::new(false),
})
}
fn set_protection(ptr: *mut u8, len: usize, protection: Protection) -> Result<()> {
let prot = match protection {
Protection::None => libc::PROT_NONE,
Protection::Read => libc::PROT_READ,
Protection::ReadWrite => libc::PROT_READ | libc::PROT_WRITE,
};
let result = unsafe { libc::mprotect(ptr as *mut libc::c_void, len, prot) };
if result == 0 {
Ok(())
} else {
Err(ShroudError::ProtectFailed(format!(
"mprotect failed: {}",
std::io::Error::last_os_error()
)))
}
}
pub fn protect(ptr: *mut u8, len: usize, protection: Protection) -> Result<()> {
if len == 0 {
return Ok(());
}
let page_sz = page_size();
let len_rounded = round_up_to_page(len, page_sz);
set_protection(ptr, len_rounded, protection)
}
pub fn unlock(ptr: *mut u8, len: usize) -> Result<()> {
if len == 0 {
return Ok(());
}
let result = unsafe { libc::munlock(ptr as *const libc::c_void, len) };
if result == 0 {
Ok(())
} else {
Err(ShroudError::UnlockFailed(format!(
"munlock failed: {}",
std::io::Error::last_os_error()
)))
}
}
pub fn deallocate(ptr: *mut u8, len: usize, _alignment: usize) -> Result<()> {
if len == 0 {
return Ok(());
}
let result = unsafe { libc::munmap(ptr as *mut libc::c_void, len) };
if result == 0 {
Ok(())
} else {
Err(ShroudError::DeallocationFailed(format!(
"munmap failed: {}",
std::io::Error::last_os_error()
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_page_size() {
let ps = page_size();
assert!(ps >= 4096);
assert!(ps.is_power_of_two());
}
#[test]
fn test_round_up_to_page() {
let ps = 4096;
assert_eq!(round_up_to_page(0, ps), 0);
assert_eq!(round_up_to_page(1, ps), 4096);
assert_eq!(round_up_to_page(4096, ps), 4096);
assert_eq!(round_up_to_page(4097, ps), 8192);
}
#[test]
fn test_allocate_zero_size() {
let region = allocate(0, Policy::BestEffort).unwrap();
assert!(region.is_empty());
}
#[test]
fn test_allocate_and_write() {
let mut region = allocate(100, Policy::BestEffort).unwrap();
assert_eq!(region.len(), 100);
unsafe {
let slice = region.as_mut_slice();
slice[0] = 42;
slice[99] = 255;
}
unsafe {
let slice = region.as_slice();
assert_eq!(slice[0], 42);
assert_eq!(slice[99], 255);
}
}
#[cfg(target_os = "linux")]
#[test]
fn test_mlock_verification() {
use std::fs;
fn get_vmlck() -> Option<usize> {
let status = fs::read_to_string("/proc/self/status").ok()?;
for line in status.lines() {
if line.starts_with("VmLck:") {
let parts: Vec<&str> = line.split_whitespace().collect();
return parts.get(1)?.parse().ok();
}
}
None
}
let initial_vmlck = get_vmlck().unwrap_or(0);
let result = allocate(8192, Policy::BestEffort);
if let Ok(region) = result {
if region.is_locked {
let new_vmlck = get_vmlck().unwrap_or(0);
assert!(
new_vmlck >= initial_vmlck,
"VmLck should not decrease after mlock"
);
}
}
}
}