use crate::alloc::ProtectedAlloc;
use crate::error::{Result, ShroudError};
use crate::policy::Policy;
use crate::traits::{
Expose, ExposeGuard, ExposeGuardMut, ExposeGuarded, ExposeGuardedMut, ExposeMut,
};
use core::fmt;
use core::mem;
use zeroize::Zeroize;
pub struct Shroud<T: Zeroize> {
alloc: ProtectedAlloc,
#[allow(dead_code)]
policy: Policy,
_marker: core::marker::PhantomData<T>,
}
impl<T: Zeroize> Shroud<T> {
pub fn new(value: T) -> Result<Self> {
Self::new_with_policy(value, Policy::default())
}
pub fn new_with_policy(value: T, policy: Policy) -> Result<Self> {
let size = mem::size_of::<T>();
let align = mem::align_of::<T>();
if size == 0 {
return Err(ShroudError::AllocationFailed(
"cannot shroud zero-sized types".to_string(),
));
}
let mut alloc = ProtectedAlloc::new_aligned(size, align, policy)?;
unsafe {
let ptr = alloc.as_mut_slice().as_mut_ptr() as *mut T;
ptr.write(value);
}
Ok(Self {
alloc,
policy,
_marker: core::marker::PhantomData,
})
}
pub fn new_with<F>(f: F) -> Result<Self>
where
F: FnOnce() -> T,
{
Self::new_with_policy_and_init(Policy::default(), f)
}
pub fn new_with_policy_and_init<F>(policy: Policy, f: F) -> Result<Self>
where
F: FnOnce() -> T,
{
let size = mem::size_of::<T>();
let align = mem::align_of::<T>();
if size == 0 {
return Err(ShroudError::AllocationFailed(
"cannot shroud zero-sized types".to_string(),
));
}
let mut alloc = ProtectedAlloc::new_aligned(size, align, policy)?;
let value = f();
unsafe {
let ptr = alloc.as_mut_slice().as_mut_ptr() as *mut T;
ptr.write(value);
}
Ok(Self {
alloc,
policy,
_marker: core::marker::PhantomData,
})
}
#[inline]
pub const fn size(&self) -> usize {
mem::size_of::<T>()
}
pub fn ct_eq(&self, other: &Self) -> subtle::Choice {
use subtle::ConstantTimeEq;
self.alloc.as_slice().ct_eq(other.alloc.as_slice())
}
}
impl<T: Zeroize> Expose for Shroud<T> {
type Target = T;
#[inline]
fn expose(&self) -> &T {
unsafe { &*(self.alloc.as_slice().as_ptr() as *const T) }
}
}
impl<T: Zeroize> ExposeMut for Shroud<T> {
#[inline]
fn expose_mut(&mut self) -> &mut T {
unsafe { &mut *(self.alloc.as_mut_slice().as_mut_ptr() as *mut T) }
}
}
impl<T: Zeroize> ExposeGuarded for Shroud<T> {
fn expose_guarded(&self) -> Result<ExposeGuard<'_, T>> {
if self.policy.protection_enabled() {
self.alloc.make_readable()?;
let alloc_ref = &self.alloc;
let value_ref = unsafe { &*(self.alloc.as_slice().as_ptr() as *const T) };
Ok(ExposeGuard::new(value_ref, move || {
let _ = alloc_ref.make_inaccessible();
}))
} else {
let value_ref = unsafe { &*(self.alloc.as_slice().as_ptr() as *const T) };
Ok(ExposeGuard::unguarded(value_ref))
}
}
}
impl<T: Zeroize> ExposeGuardedMut for Shroud<T> {
fn expose_guarded_mut(&mut self) -> Result<ExposeGuardMut<'_, T>> {
if self.policy.protection_enabled() {
self.alloc.make_writable()?;
let alloc_ptr = &self.alloc as *const ProtectedAlloc;
let value_ref = unsafe { &mut *(self.alloc.as_mut_slice().as_mut_ptr() as *mut T) };
Ok(ExposeGuardMut::new(value_ref, move || {
unsafe {
let _ = (*alloc_ptr).make_inaccessible();
}
}))
} else {
let value_ref = unsafe { &mut *(self.alloc.as_mut_slice().as_mut_ptr() as *mut T) };
Ok(ExposeGuardMut::unguarded(value_ref))
}
}
}
impl<T: Zeroize> Drop for Shroud<T> {
fn drop(&mut self) {
unsafe {
let ptr = self.alloc.as_mut_slice().as_mut_ptr() as *mut T;
(*ptr).zeroize();
}
}
}
impl<T: Zeroize> fmt::Debug for Shroud<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Shroud")
.field("type", &core::any::type_name::<T>())
.field("size", &mem::size_of::<T>())
.field("data", &"[REDACTED]")
.finish()
}
}
impl<T: Zeroize + PartialEq> PartialEq for Shroud<T> {
fn eq(&self, other: &Self) -> bool {
self.expose() == other.expose()
}
}
impl<T: Zeroize + Eq> Eq for Shroud<T> {}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Zeroize, PartialEq, Eq, Debug)]
struct TestKey {
data: [u8; 32],
}
#[test]
fn test_new() {
let key = TestKey { data: [0x42; 32] };
let secret = Shroud::new(key).unwrap();
assert_eq!(secret.expose().data[0], 0x42);
assert_eq!(secret.size(), mem::size_of::<TestKey>());
}
#[test]
fn test_new_with() {
let secret = Shroud::new_with(|| TestKey { data: [0x42; 32] }).unwrap();
assert_eq!(secret.expose().data[0], 0x42);
}
#[test]
fn test_expose_mut() {
let mut secret = Shroud::new(TestKey { data: [0x00; 32] }).unwrap();
secret.expose_mut().data[0] = 0x99;
assert_eq!(secret.expose().data[0], 0x99);
}
#[test]
fn test_debug_redacted() {
let secret = Shroud::new(TestKey { data: [0x42; 32] }).unwrap();
let debug_str = format!("{:?}", secret);
assert!(debug_str.contains("[REDACTED]"));
assert!(debug_str.contains("TestKey"));
}
#[test]
fn test_equality() {
let a = Shroud::new(TestKey { data: [0x42; 32] }).unwrap();
let b = Shroud::new(TestKey { data: [0x42; 32] }).unwrap();
let c = Shroud::new(TestKey { data: [0x00; 32] }).unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_primitive_array() {
let secret: Shroud<[u8; 16]> = Shroud::new([0x42u8; 16]).unwrap();
assert_eq!(secret.expose()[0], 0x42);
}
}