use crate::{os, util, Protection, QueryIter, Region, Result};
#[inline]
pub unsafe fn protect<T>(address: *const T, size: usize, protection: Protection) -> Result<()> {
let (address, size) = util::round_to_page_boundaries(address, size)?;
os::protect(address.cast(), size, protection)
}
#[allow(clippy::missing_inline_in_public_items)]
pub unsafe fn protect_with_handle<T>(
address: *const T,
size: usize,
protection: Protection,
) -> Result<ProtectGuard> {
let (address, size) = util::round_to_page_boundaries(address, size)?;
let mut regions = QueryIter::new(address, size)?.collect::<Result<Vec<_>>>()?;
protect(address, size, protection)?;
if let Some(region) = regions.first_mut() {
region.base = address.cast();
region.size -= address as usize - region.as_range().start;
}
if let Some(region) = regions.last_mut() {
let protect_end = address as usize + size;
region.size -= region.as_range().end - protect_end;
}
Ok(ProtectGuard::new(regions))
}
#[must_use]
pub struct ProtectGuard {
regions: Vec<Region>,
}
impl ProtectGuard {
#[inline(always)]
fn new(regions: Vec<Region>) -> Self {
Self { regions }
}
}
impl Drop for ProtectGuard {
#[inline]
fn drop(&mut self) {
let result = self
.regions
.iter()
.try_for_each(|region| unsafe { protect(region.base, region.size, region.protection) });
debug_assert!(result.is_ok(), "restoring region protection: {:?}", result);
}
}
unsafe impl Send for ProtectGuard {}
unsafe impl Sync for ProtectGuard {}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::util::alloc_pages;
use crate::{page, query, query_range};
#[test]
fn protect_null_fails() {
assert!(unsafe { protect(std::ptr::null::<()>(), 0, Protection::NONE) }.is_err());
}
#[test]
#[cfg(not(target_os = "openbsd"))]
fn protect_can_alter_text_segments() {
#[allow(clippy::ptr_as_ptr)]
let address = &mut protect_can_alter_text_segments as *mut _ as *mut u8;
unsafe {
protect(address, 1, Protection::READ_WRITE_EXECUTE).unwrap();
*address = 0x90;
}
}
#[test]
fn protect_updates_both_pages_for_straddling_range() -> Result<()> {
let pz = page::size();
let map = alloc_pages(&[
Protection::READ,
Protection::READ_EXECUTE,
Protection::READ_WRITE,
Protection::READ,
]);
let exec_page = unsafe { map.as_ptr().add(pz) };
let exec_page_end = unsafe { exec_page.add(pz - 1) };
unsafe {
protect(exec_page_end, 2, Protection::NONE)?;
}
let result = query_range(exec_page, pz * 2)?.collect::<Result<Vec<_>>>()?;
assert!(matches!(result.len(), 1 | 2));
assert_eq!(result.iter().map(Region::len).sum::<usize>(), pz * 2);
assert_eq!(result[0].protection(), Protection::NONE);
Ok(())
}
#[test]
fn protect_has_inclusive_lower_and_exclusive_upper_bound() -> Result<()> {
let map = alloc_pages(&[
Protection::READ_WRITE,
Protection::READ,
Protection::READ_WRITE,
Protection::READ,
]);
let second_page = unsafe { map.as_ptr().add(page::size()) };
unsafe {
let second_page_end = second_page.offset(page::size() as isize - 1);
protect(second_page_end, 1, Protection::NONE)?;
}
let regions = query_range(map.as_ptr(), page::size() * 3)?.collect::<Result<Vec<_>>>()?;
assert_eq!(regions.len(), 3);
assert_eq!(regions[0].protection(), Protection::READ_WRITE);
assert_eq!(regions[1].protection(), Protection::NONE);
assert_eq!(regions[2].protection(), Protection::READ_WRITE);
unsafe {
protect(second_page, page::size() + 1, Protection::READ_EXECUTE)?;
}
let regions = query_range(map.as_ptr(), page::size() * 3)?.collect::<Result<Vec<_>>>()?;
assert!(regions.len() >= 2);
assert_eq!(regions[0].protection(), Protection::READ_WRITE);
assert_eq!(regions[1].protection(), Protection::READ_EXECUTE);
assert!(regions[1].len() >= page::size());
Ok(())
}
#[test]
fn protect_with_handle_resets_protection() -> Result<()> {
let map = alloc_pages(&[Protection::READ]);
unsafe {
let _handle = protect_with_handle(map.as_ptr(), page::size(), Protection::READ_WRITE)?;
assert_eq!(query(map.as_ptr())?.protection(), Protection::READ_WRITE);
};
assert_eq!(query(map.as_ptr())?.protection(), Protection::READ);
Ok(())
}
#[test]
fn protect_with_handle_only_alters_protection_of_affected_pages() -> Result<()> {
let pages = [
Protection::READ_WRITE,
Protection::READ,
Protection::READ_WRITE,
Protection::READ_EXECUTE,
Protection::NONE,
];
let map = alloc_pages(&pages);
let second_page = unsafe { map.as_ptr().add(page::size()) };
let region_size = page::size() * 3;
unsafe {
let _handle = protect_with_handle(second_page, region_size, Protection::NONE)?;
let region = query(second_page)?;
assert_eq!(region.protection(), Protection::NONE);
assert_eq!(region.as_ptr(), second_page);
}
let regions =
query_range(map.as_ptr(), page::size() * pages.len())?.collect::<Result<Vec<_>>>()?;
assert_eq!(regions.len(), 5);
assert_eq!(regions[0].as_ptr(), map.as_ptr());
for i in 0..pages.len() {
assert_eq!(regions[i].protection(), pages[i]);
}
Ok(())
}
}