use error::*;
use std::mem;
use {lock, page, protect, query_range, Protection, Region};
#[derive(Debug, Copy, Clone)]
struct RegionMeta {
region: Region,
previous: Protection,
initial: Protection,
}
#[derive(Debug, Copy, Clone)]
pub enum Access {
Type(Protection),
Previous,
Initial,
}
impl From<Protection> for Access {
fn from(protection: Protection) -> Access {
Access::Type(protection)
}
}
#[derive(Debug, Clone)]
pub struct View {
regions: Vec<RegionMeta>,
}
impl View {
pub fn new(address: *const u8, size: usize) -> Result<Self> {
let mut regions = query_range(address, size)?;
let lower = page::floor(address as usize);
let upper = page::ceil(address as usize + size);
if let Some(ref mut region) = regions.first_mut() {
let delta = lower - region.base as usize;
region.base = (region.base as usize + delta) as *mut u8;
region.size -= delta;
}
if let Some(ref mut region) = regions.last_mut() {
let delta = region.upper() - upper;
region.size -= delta;
}
Ok(View {
regions: regions
.iter()
.map(|region| RegionMeta {
region: *region,
previous: region.protection,
initial: region.protection,
})
.collect::<Vec<_>>(),
})
}
pub fn get_prot(&self) -> Option<Protection> {
let prot = self
.regions
.iter()
.fold(Protection::None, |prot, meta| prot | meta.region.protection);
if self
.regions
.iter()
.all(|meta| meta.region.protection == prot)
{
Some(prot)
} else {
None
}
}
pub unsafe fn set_prot<A: Into<Access>>(&mut self, access: A) -> Result<()> {
match access.into() {
Access::Type(protection) => {
protect(self.as_ptr(), self.len(), protection)?;
for meta in &mut self.regions {
meta.previous = meta.region.protection;
meta.region.protection = protection;
}
},
Access::Previous => {
for meta in &mut self.regions {
protect(meta.region.base, meta.region.size, meta.previous)?;
mem::swap(&mut meta.region.protection, &mut meta.previous);
}
},
Access::Initial => {
for meta in &mut self.regions {
protect(meta.region.base, meta.region.size, meta.initial)?;
meta.previous = meta.region.protection;
meta.region.protection = meta.initial;
}
},
}
Ok(())
}
pub unsafe fn exec_with_prot<Ret, T: FnOnce() -> Ret>(
&mut self,
prot: Protection,
callback: T,
) -> Result<Ret> {
self.set_prot(prot)?;
let result = callback();
self.set_prot(Access::Previous)?;
Ok(result)
}
pub fn lock(&mut self) -> Result<::LockGuard> {
lock(self.as_ptr(), self.len())
}
pub fn as_ptr(&self) -> *const u8 {
self.regions.first().unwrap().region.base
}
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.regions.first().unwrap().region.base as *mut _
}
pub fn lower(&self) -> usize {
self.regions.first().unwrap().region.lower()
}
pub fn upper(&self) -> usize {
self.regions.last().unwrap().region.upper()
}
pub fn len(&self) -> usize {
self.upper() - self.lower()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use Protection;
use tests::alloc_pages;
#[test]
fn view_check_size() {
let pz = page::size();
let map = alloc_pages(&[Protection::Read, Protection::Read, Protection::Read]);
let base = unsafe { map.as_ptr().offset(pz as isize) };
let view = View::new(base, pz).unwrap();
assert_eq!(view.as_ptr(), base);
assert_eq!(view.len(), pz);
let base = unsafe { map.as_ptr().offset(pz as isize - 1) };
let view = View::new(base, 2).unwrap();
assert_eq!(view.as_ptr(), map.as_ptr());
assert_eq!(view.len(), pz * 2);
}
#[test]
fn view_exec_prot() {
let pz = page::size();
let mut map = alloc_pages(&[Protection::Read]);
let mut view = View::new(map.as_ptr(), pz).unwrap();
unsafe {
let val = view
.exec_with_prot(Protection::ReadWrite, || {
*map.as_mut_ptr() = 0x10;
1337
})
.unwrap();
assert_eq!(val, 1337);
}
let region = ::query(view.as_ptr()).unwrap();
assert_eq!(region.protection, Protection::Read);
}
#[test]
fn view_prot_prev() {
let pz = page::size();
let map = alloc_pages(&[Protection::Read]);
let mut view = View::new(map.as_ptr(), pz).unwrap();
unsafe {
view.set_prot(Protection::ReadWrite).unwrap();
view.set_prot(Access::Previous).unwrap();
}
let region = ::query(view.as_ptr()).unwrap();
assert_eq!(region.protection, Protection::Read);
}
#[test]
fn view_prot_initial() {
let pz = page::size();
let map = alloc_pages(&[Protection::Read]);
let mut view = View::new(map.as_ptr(), pz).unwrap();
unsafe {
view.set_prot(Protection::ReadWrite).unwrap();
view.set_prot(Protection::ReadWriteExecute).unwrap();
view.set_prot(Access::Initial).unwrap();
}
let region = ::query(view.as_ptr()).unwrap();
assert_eq!(region.protection, Protection::Read);
}
#[test]
fn view_get_prot() {
let pz = page::size();
let map = alloc_pages(&[Protection::Read, Protection::ReadWrite]);
let mut view = View::new(map.as_ptr(), pz * 2).unwrap();
assert_eq!(view.len(), pz * 2);
assert_eq!(view.get_prot(), None);
unsafe { view.set_prot(Protection::Read).unwrap() };
assert_eq!(view.get_prot(), Some(Protection::Read));
}
}