use {os, page, query_range, Error, Region, Result};
pub unsafe fn protect(address: *const u8, size: usize, protection: Protection) -> Result<()> {
if address.is_null() {
Err(Error::NullAddress)?;
}
if size == 0 {
Err(Error::EmptyRange)?;
}
os::set_protection(
page::floor(address as usize) as *const u8,
page::size_from_range(address, size),
protection,
)
}
pub unsafe fn protect_with_handle(
address: *const u8,
size: usize,
protection: Protection,
) -> Result<ProtectGuard> {
let mut regions = query_range(address, size)?;
protect(address, size, protection)?;
let lower = page::floor(address as usize);
let upper = page::ceil(address as usize + size);
if let Some(ref mut region) = regions.first_mut() {
let delta = lower - region.base as usize;
region.base = (region.base as usize + delta) as *mut u8;
region.size -= delta;
}
if let Some(ref mut region) = regions.last_mut() {
let delta = region.upper() - upper;
region.size -= delta;
}
Ok(ProtectGuard::new(regions))
}
#[must_use]
pub struct ProtectGuard {
regions: Vec<Region>,
}
impl ProtectGuard {
fn new(regions: Vec<Region>) -> Self {
ProtectGuard { regions }
}
pub unsafe fn release(self) {
::std::mem::forget(self);
}
}
impl Drop for ProtectGuard {
fn drop(&mut self) {
let result = unsafe {
self
.regions
.iter()
.try_for_each(|region| protect(region.base, region.size, region.protection))
};
debug_assert!(result.is_ok(), "restoring region protection");
}
}
unsafe impl Send for ProtectGuard {}
unsafe impl Sync for ProtectGuard {}
bitflags! {
pub struct Protection: usize {
const None = 0;
const Read = (1 << 1);
const Write = (1 << 2);
const Execute = (1 << 3);
const ReadExecute = (Self::Read.bits | Self::Execute.bits);
const ReadWrite = (Self::Read.bits | Self::Write.bits);
const ReadWriteExecute = (Self::Read.bits | Self::Write.bits | Self::Execute.bits);
const WriteExecute = (Self::Write.bits | Self::Execute.bits);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tests::alloc_pages;
#[test]
fn protect_null() {
assert!(unsafe { protect(::std::ptr::null(), 0, Protection::None) }.is_err());
}
#[test]
fn protect_code() {
let address = &mut protect_code as *mut _ as *mut u8;
unsafe {
protect(address, 0x10, Protection::ReadWriteExecute).unwrap();
*address = 0x90;
}
}
#[test]
fn protect_alloc() {
let mut map = alloc_pages(&[Protection::Read]);
unsafe {
protect(map.as_ptr(), page::size(), Protection::ReadWrite).unwrap();
*map.as_mut_ptr() = 0x1;
}
}
#[test]
fn protect_overlap() {
let pz = page::size();
let prots = [
Protection::Read,
Protection::ReadExecute,
Protection::ReadWrite,
Protection::Read,
];
let map = alloc_pages(&prots);
let base_exec = unsafe { map.as_ptr().offset(pz as isize) };
let straddle = unsafe { base_exec.offset(pz as isize - 1) };
unsafe { protect(straddle, 2, Protection::ReadWriteExecute).unwrap() };
let result = query_range(base_exec, pz * 2).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].protection, Protection::ReadWriteExecute);
assert_eq!(result[0].size, pz * 2);
}
#[test]
fn protect_handle() {
let map = alloc_pages(&[Protection::Read]);
unsafe {
let _handle = protect_with_handle(map.as_ptr(), page::size(), Protection::ReadWrite).unwrap();
assert_eq!(
::query(map.as_ptr()).unwrap().protection,
Protection::ReadWrite
);
};
assert_eq!(::query(map.as_ptr()).unwrap().protection, Protection::Read);
}
}