use crate::error::{Result, ShroudError};
use crate::policy::Policy;
use crate::sys::{self, MemoryRegion};
pub struct ProtectedAlloc {
region: MemoryRegion,
}
impl ProtectedAlloc {
pub fn new(size: usize, policy: Policy) -> Result<Self> {
Self::new_aligned(size, 1, policy)
}
pub fn new_aligned(size: usize, alignment: usize, policy: Policy) -> Result<Self> {
let region = sys::allocate_aligned(size, alignment, policy)?;
Ok(Self { region })
}
#[inline]
pub fn len(&self) -> usize {
self.region.len()
}
#[inline]
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.region.is_empty()
}
#[inline]
#[allow(dead_code)]
pub fn is_protected(&self) -> bool {
self.region.is_protected()
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
assert!(
!self.region.is_protected(),
"cannot read protected memory - use expose_guarded() for safe access"
);
unsafe { self.region.as_slice() }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
assert!(
!self.region.is_protected(),
"cannot write protected memory - use expose_guarded_mut() for safe access"
);
unsafe { self.region.as_mut_slice() }
}
#[allow(dead_code)]
pub fn make_readable(&self) -> Result<()> {
self.region.make_readable()
}
pub fn make_writable(&self) -> Result<()> {
self.region.make_writable()
}
#[allow(dead_code)]
pub fn make_inaccessible(&self) -> Result<()> {
self.region.make_inaccessible()
}
pub fn write_and_zeroize_source(&mut self, source: &mut [u8]) -> Result<()> {
if source.len() > self.len() {
return Err(ShroudError::CapacityOverflow {
requested: source.len(),
maximum: self.len(),
});
}
if self.region.is_protected() {
self.make_writable()?;
}
let dest = self.as_mut_slice();
dest[..source.len()].copy_from_slice(source);
zeroize_slice(source);
Ok(())
}
}
#[inline]
pub(crate) fn zeroize_slice(data: &mut [u8]) {
for byte in data.iter_mut() {
unsafe {
core::ptr::write_volatile(byte, 0);
}
}
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_allocation() {
let alloc = ProtectedAlloc::new(100, Policy::BestEffort).unwrap();
assert_eq!(alloc.len(), 100);
assert!(!alloc.is_empty());
}
#[test]
fn test_empty_allocation() {
let alloc = ProtectedAlloc::new(0, Policy::BestEffort).unwrap();
assert!(alloc.is_empty());
}
#[test]
fn test_write_and_read() {
let mut alloc = ProtectedAlloc::new(100, Policy::BestEffort).unwrap();
{
let data = alloc.as_mut_slice();
data[0] = 42;
data[99] = 255;
}
let data = alloc.as_slice();
assert_eq!(data[0], 42);
assert_eq!(data[99], 255);
}
#[test]
fn test_write_and_zeroize_source() {
let mut alloc = ProtectedAlloc::new(10, Policy::BestEffort).unwrap();
let mut source = vec![1, 2, 3, 4, 5];
alloc.write_and_zeroize_source(&mut source).unwrap();
assert_eq!(&alloc.as_slice()[..5], &[1, 2, 3, 4, 5]);
assert_eq!(source, vec![0, 0, 0, 0, 0]);
}
}