use core::any::type_name;
use core::fmt;
use core::mem::MaybeUninit;
use core::ops;
use core::ptr;
use rand_core::RngCore;
use zeroize::{DefaultIsZeroes, Zeroize};
use crate::boxed::SecureBox;
use crate::bytes::FillBytes;
pub trait ExposeProtected {
type Target: ?Sized;
fn expose_read<F>(&self, f: F)
where
F: FnOnce(SecureRef<&Self::Target>);
fn expose_write<F>(&mut self, f: F)
where
F: FnOnce(SecureRef<&mut Self::Target>);
fn unprotect(self) -> SecureBox<Self::Target>;
}
pub trait ProtectedInit: ExposeProtected + From<SecureBox<Self::Target>> + Sized {
fn init<F>(f: F) -> Self
where
F: FnOnce(SecureRef<&mut Self::Target>),
Self::Target: Copy + FillBytes,
{
let mut boxed = unsafe { SecureBox::<Self::Target>::new_uninit().assume_init() };
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn init_default<F>(f: F) -> Self
where
F: FnOnce(SecureRef<&mut Self::Target>),
Self::Target: Default,
{
let mut boxed = SecureBox::<Self::Target>::default();
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn init_random<F>(rng: impl RngCore, f: F) -> Self
where
F: FnOnce(SecureRef<&mut Self::Target>),
Self::Target: Copy + FillBytes,
{
let mut boxed = SecureBox::<Self::Target>::new_uninit();
boxed.fill_random(rng);
let mut boxed = unsafe { boxed.assume_init() };
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn init_take<F>(from: &mut Self::Target, f: F) -> Self
where
F: FnOnce(SecureRef<&mut Self::Target>),
Self::Target: DefaultIsZeroes,
{
let boxed = SecureBox::new_uninit();
let mut boxed = SecureBox::write(boxed, *from);
from.zeroize();
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
#[inline(always)]
fn init_with<F>(f: F) -> Self
where
F: FnOnce() -> Self::Target,
Self::Target: Sized,
{
let boxed = SecureBox::new_uninit();
SecureBox::write(boxed, f()).into()
}
#[inline(always)]
fn try_init_with<F, E>(f: F) -> Result<Self, E>
where
F: FnOnce() -> Result<Self::Target, E>,
Self::Target: Sized,
{
let boxed = SecureBox::new_uninit();
Ok(SecureBox::write(boxed, f()?).into())
}
fn random(rng: impl RngCore) -> Self
where
Self::Target: Copy + FillBytes,
{
Self::init_random(rng, |_| ())
}
fn take(from: &mut Self::Target) -> Self
where
Self::Target: DefaultIsZeroes,
{
Self::init_take(from, |_| ())
}
}
impl<T, W> ProtectedInit for W where W: From<SecureBox<T>> + ExposeProtected<Target = T> {}
pub trait ProtectedInitSlice:
ExposeProtected<Target = [Self::Item]> + From<SecureBox<[Self::Item]>> + Sized
{
type Item;
fn init_slice<F>(len: usize, f: F) -> Self
where
F: FnOnce(SecureRef<&mut [Self::Item]>),
Self::Item: Copy + FillBytes,
{
let mut boxed = unsafe { SecureBox::<[Self::Item]>::new_uninit_slice(len).assume_init() };
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn init_default_slice<F>(len: usize, f: F) -> Self
where
F: FnOnce(SecureRef<&mut [Self::Item]>),
Self::Item: Default,
{
let mut boxed = SecureBox::<[Self::Item]>::new_uninit_slice(len);
boxed
.as_mut()
.fill_with(|| MaybeUninit::new(Default::default()));
let mut boxed = unsafe { boxed.assume_init() };
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn init_random_slice<F>(len: usize, rng: impl RngCore, f: F) -> Self
where
F: FnOnce(SecureRef<&mut [Self::Item]>),
Self::Item: Copy + FillBytes,
{
let mut boxed = SecureBox::<[Self::Item]>::new_uninit_slice(len);
boxed.fill_random(rng);
let mut boxed = unsafe { boxed.assume_init() };
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn init_take_slice<F>(from: &mut [Self::Item], f: F) -> Self
where
F: FnOnce(SecureRef<&mut [Self::Item]>),
Self::Item: DefaultIsZeroes,
{
let len = from.len();
let mut boxed = SecureBox::<[Self::Item]>::new_uninit_slice(len);
unsafe {
ptr::copy_nonoverlapping(
from.as_ptr(),
boxed.as_mut().as_mut_ptr() as *mut Self::Item,
len,
)
};
from.zeroize();
let mut boxed = unsafe { boxed.assume_init() };
f(SecureRef::new_mut(boxed.as_mut()));
boxed.into()
}
fn random_slice(len: usize, rng: impl RngCore) -> Self
where
Self::Item: Copy + FillBytes,
{
Self::init_random_slice(len, rng, |_| ())
}
fn take_slice(from: &mut [Self::Item]) -> Self
where
Self::Item: DefaultIsZeroes,
{
Self::init_take_slice(from, |_| ())
}
}
impl<T, W> ProtectedInitSlice for W
where
W: From<SecureBox<[T]>> + ExposeProtected<Target = [T]>,
{
type Item = T;
}
pub struct SecureRef<T: ?Sized>(T);
impl<'a, T: ?Sized> SecureRef<&'a T> {
pub(crate) fn new(inner: &'a T) -> Self {
Self(inner)
}
}
impl<'a, T: ?Sized> SecureRef<&'a mut T> {
pub(crate) fn new_mut(inner: &'a mut T) -> Self {
Self(inner)
}
}
impl<'a, T> SecureRef<&'a mut MaybeUninit<T>> {
#[inline]
pub unsafe fn assume_init(self) -> SecureRef<&'a mut T> {
SecureRef(self.0.assume_init_mut())
}
#[inline(always)]
pub fn write(slf: Self, value: T) -> SecureRef<&'a mut T> {
slf.0.write(value);
SecureRef(unsafe { slf.0.assume_init_mut() })
}
}
impl<T: ?Sized> AsRef<T> for SecureRef<&'_ T> {
#[inline]
fn as_ref(&self) -> &T {
self.0
}
}
impl<T: ?Sized> AsRef<T> for SecureRef<&'_ mut T> {
#[inline]
fn as_ref(&self) -> &T {
self.0
}
}
impl<T: ?Sized> AsMut<T> for SecureRef<&'_ mut T> {
#[inline]
fn as_mut(&mut self) -> &mut T {
self.0
}
}
impl<T: ?Sized> fmt::Debug for SecureRef<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("ProtectedRef<{}>", type_name::<T>()))
}
}
impl<T: ?Sized> ops::Deref for SecureRef<&'_ T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
self.0
}
}
impl<T: ?Sized> ops::Deref for SecureRef<&'_ mut T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<T: ?Sized> ops::DerefMut for SecureRef<&'_ mut T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
}
}