use std::fmt;
use std::ops::{ Deref, DerefMut };
use std::ptr::copy;
use std::mem::size_of;
use std::cell::Cell;
use memsec::{
memzero,
malloc, free,
Prot, mprotect
};
pub struct SecKey<T: Sized> {
ptr: *mut T,
count: Cell<usize>
}
impl<T> SecKey<T> where T: Sized {
pub fn new(t: &T) -> Option<SecKey<T>> {
let memptr = match unsafe { malloc(size_of::<T>()) } {
Some(memptr) => memptr,
None => return None
};
unsafe {
copy(t, memptr, 1);
mprotect(memptr, Prot::NoAccess);
}
Some(SecKey {
ptr: memptr,
count: Cell::new(0)
})
}
fn read_unlock(&self) {
let count = self.count.get();
self.count.set(count + 1);
if count == 0 {
unsafe { mprotect(self.ptr, Prot::ReadOnly) };
}
}
fn write_unlock(&self) {
let count = self.count.get();
self.count.set(count + 1);
if count == 0 {
unsafe { mprotect(self.ptr, Prot::ReadWrite) };
}
}
fn lock(&self) {
let count = self.count.get();
self.count.set(count - 1);
if count == 1 {
unsafe { mprotect(self.ptr, Prot::NoAccess) };
}
}
#[inline]
pub fn read(&self) -> SecReadGuard<T> {
self.read_unlock();
SecReadGuard(self)
}
#[inline]
pub fn write(&mut self) -> SecWriteGuard<T> {
self.write_unlock();
SecWriteGuard(self)
}
}
impl<T> From<T> for SecKey<T> {
fn from(mut t: T) -> SecKey<T> {
let output = SecKey::new(&t).unwrap();
unsafe { memzero(&mut t, size_of::<T>()) }; output
}
}
impl<T> fmt::Debug for SecKey<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", "** sec key **")
}
}
impl<T> Drop for SecKey<T> {
fn drop(&mut self) {
unsafe { free(self.ptr) }
}
}
pub struct SecReadGuard<'a, T: Sized + 'a>(&'a SecKey<T>);
impl<'a, T: Sized + 'a> Deref for SecReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.ptr }
}
}
impl<'a, T: Sized + 'a> Drop for SecReadGuard<'a, T> {
fn drop(&mut self) {
self.0.lock();
}
}
pub struct SecWriteGuard<'a, T: Sized + 'a>(&'a mut SecKey<T>);
impl<'a, T: Sized + 'a> Deref for SecWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.ptr }
}
}
impl<'a, T: Sized + 'a> DerefMut for SecWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.0.ptr }
}
}
impl<'a, T: Sized + 'a> Drop for SecWriteGuard<'a, T> {
fn drop(&mut self) {
self.0.lock();
}
}