use crate::alloc::ProtectedAlloc;
use crate::error::Result;
use crate::policy::Policy;
use crate::traits::{
Expose, ExposeGuard, ExposeGuardMut, ExposeGuarded, ExposeGuardedMut, ExposeMut,
};
use core::fmt;
pub struct ShroudedArray<const N: usize> {
alloc: ProtectedAlloc,
policy: Policy,
}
impl<const N: usize> ShroudedArray<N> {
pub fn new() -> Result<Self> {
Self::new_with_policy(Policy::default())
}
pub fn new_with_policy(policy: Policy) -> Result<Self> {
let alloc = ProtectedAlloc::new(N, policy)?;
Ok(Self { alloc, policy })
}
pub fn from_array(source: [u8; N]) -> Result<Self> {
Self::from_array_with_policy(source, Policy::default())
}
pub fn from_array_with_policy(mut source: [u8; N], policy: Policy) -> Result<Self> {
let mut alloc = ProtectedAlloc::new(N, policy)?;
alloc.write_and_zeroize_source(&mut source)?;
Ok(Self { alloc, policy })
}
pub fn new_with<F>(f: F) -> Result<Self>
where
F: FnOnce(&mut [u8; N]),
{
Self::new_with_policy_and_init(Policy::default(), f)
}
pub fn new_with_policy_and_init<F>(policy: Policy, f: F) -> Result<Self>
where
F: FnOnce(&mut [u8; N]),
{
let mut alloc = ProtectedAlloc::new(N, policy)?;
let slice = alloc.as_mut_slice();
let array_ref: &mut [u8; N] = slice.try_into().expect("allocation size mismatch");
f(array_ref);
Ok(Self { alloc, policy })
}
#[inline]
pub const fn len(&self) -> usize {
N
}
#[inline]
pub const fn is_empty(&self) -> bool {
N == 0
}
pub fn try_clone(&self) -> Result<Self> {
if self.alloc.is_protected() {
return Err(crate::error::ShroudError::RegionLocked);
}
let mut alloc = ProtectedAlloc::new(N, self.policy)?;
alloc.as_mut_slice().copy_from_slice(self.alloc.as_slice());
Ok(Self {
alloc,
policy: self.policy,
})
}
}
impl<const N: usize> Expose for ShroudedArray<N> {
type Target = [u8; N];
#[inline]
fn expose(&self) -> &[u8; N] {
self.alloc
.as_slice()
.try_into()
.expect("allocation size mismatch")
}
}
impl<const N: usize> ExposeMut for ShroudedArray<N> {
#[inline]
fn expose_mut(&mut self) -> &mut [u8; N] {
self.alloc
.as_mut_slice()
.try_into()
.expect("allocation size mismatch")
}
}
impl<const N: usize> ExposeGuarded for ShroudedArray<N> {
fn expose_guarded(&self) -> Result<ExposeGuard<'_, [u8; N]>> {
if self.policy.protection_enabled() {
self.alloc.make_readable()?;
let alloc_ref = &self.alloc;
let array_ref: &[u8; N] = self
.alloc
.as_slice()
.try_into()
.expect("allocation size mismatch");
Ok(ExposeGuard::new(array_ref, move || {
let _ = alloc_ref.make_inaccessible();
}))
} else {
let array_ref: &[u8; N] = self
.alloc
.as_slice()
.try_into()
.expect("allocation size mismatch");
Ok(ExposeGuard::unguarded(array_ref))
}
}
}
impl<const N: usize> ExposeGuardedMut for ShroudedArray<N> {
fn expose_guarded_mut(&mut self) -> Result<ExposeGuardMut<'_, [u8; N]>> {
if self.policy.protection_enabled() {
self.alloc.make_writable()?;
let alloc_ptr = &self.alloc as *const ProtectedAlloc;
let array_ref: &mut [u8; N] = self
.alloc
.as_mut_slice()
.try_into()
.expect("allocation size mismatch");
Ok(ExposeGuardMut::new(array_ref, move || {
unsafe {
let _ = (*alloc_ptr).make_inaccessible();
}
}))
} else {
let array_ref: &mut [u8; N] = self
.alloc
.as_mut_slice()
.try_into()
.expect("allocation size mismatch");
Ok(ExposeGuardMut::unguarded(array_ref))
}
}
}
impl<const N: usize> fmt::Debug for ShroudedArray<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ShroudedArray")
.field("size", &N)
.field("data", &"[REDACTED]")
.finish()
}
}
impl<const N: usize> PartialEq for ShroudedArray<N> {
fn eq(&self, other: &Self) -> bool {
use subtle::ConstantTimeEq;
self.expose()
.as_slice()
.ct_eq(other.expose().as_slice())
.into()
}
}
impl<const N: usize> Eq for ShroudedArray<N> {}
impl<const N: usize> Default for ShroudedArray<N> {
fn default() -> Self {
Self::new().expect("failed to allocate ShroudedArray")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let arr: ShroudedArray<32> = ShroudedArray::new().unwrap();
assert_eq!(arr.len(), 32);
assert_eq!(arr.expose(), &[0u8; 32]);
}
#[test]
fn test_from_array() {
let source = [1u8, 2, 3, 4, 5, 6, 7, 8];
let arr = ShroudedArray::from_array(source).unwrap();
assert_eq!(arr.expose(), &[1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn test_new_with() {
let arr: ShroudedArray<16> = ShroudedArray::new_with(|buf| {
for (i, byte) in buf.iter_mut().enumerate() {
*byte = i as u8;
}
})
.unwrap();
let expected: [u8; 16] = core::array::from_fn(|i| i as u8);
assert_eq!(arr.expose(), &expected);
}
#[test]
fn test_try_clone() {
let arr: ShroudedArray<8> = ShroudedArray::new_with(|buf| buf.fill(0x42)).unwrap();
let cloned = arr.try_clone().unwrap();
assert_eq!(arr.expose(), cloned.expose());
}
#[test]
fn test_try_clone_fails_on_protected_memory() {
let arr: ShroudedArray<8> = ShroudedArray::new_with(|buf| buf.fill(0x42)).unwrap();
{
let _guard = arr.expose_guarded().unwrap();
}
let result = arr.try_clone();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::error::ShroudError::RegionLocked
));
}
#[test]
fn test_debug_redacted() {
let arr: ShroudedArray<32> = ShroudedArray::new_with(|buf| buf.fill(0x42)).unwrap();
let debug_str = format!("{:?}", arr);
assert!(debug_str.contains("[REDACTED]"));
assert!(debug_str.contains("32"));
}
#[test]
fn test_expose_mut() {
let mut arr: ShroudedArray<4> = ShroudedArray::new().unwrap();
arr.expose_mut()[0] = 99;
assert_eq!(arr.expose()[0], 99);
}
#[test]
fn test_zero_size() {
let arr: ShroudedArray<0> = ShroudedArray::new().unwrap();
assert!(arr.is_empty());
assert_eq!(arr.len(), 0);
}
}