use core::fmt;
use std::borrow::{Borrow, BorrowMut};
use std::mem::MaybeUninit;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use crate::secure_utils::memlock;
pub struct SecureBox<T>
where
T: Copy,
{
content: Option<Box<T>>,
}
impl<T> SecureBox<T>
where
T: Copy,
{
#[must_use]
pub fn new(mut cont: Box<T>) -> Self {
memlock::mlock(&raw mut *cont, 1);
SecureBox {
content: Some(cont),
}
}
#[must_use]
pub fn unsecure(&self) -> &T {
self.content
.as_deref()
.expect("SecureBox content accessed after drop")
}
pub fn unsecure_mut(&mut self) -> &mut T {
self.content
.as_deref_mut()
.expect("SecureBox content accessed after drop")
}
}
impl<T: Copy> Clone for SecureBox<T> {
fn clone(&self) -> Self {
Self::new(Box::new(*self.unsecure()))
}
}
impl<T: Copy + ConstantTimeEq> ConstantTimeEq for SecureBox<T> {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.unsecure().ct_eq(other.unsecure())
}
}
impl<T: Copy + ConstantTimeEq> PartialEq for SecureBox<T> {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl<T: Copy + ConstantTimeEq> Eq for SecureBox<T> {}
impl<T, U> std::ops::Index<U> for SecureBox<T>
where
T: std::ops::Index<U> + Copy,
{
type Output = <T as std::ops::Index<U>>::Output;
fn index(&self, index: U) -> &Self::Output {
std::ops::Index::index(self.unsecure(), index)
}
}
impl<T> Borrow<T> for SecureBox<T>
where
T: Copy,
{
fn borrow(&self) -> &T {
self.unsecure()
}
}
impl<T> BorrowMut<T> for SecureBox<T>
where
T: Copy,
{
fn borrow_mut(&mut self) -> &mut T {
self.unsecure_mut()
}
}
impl<T> Drop for SecureBox<T>
where
T: Copy,
{
fn drop(&mut self) {
let ptr = Box::into_raw(self.content.take().expect("SecureBox dropped twice"));
unsafe {
std::slice::from_raw_parts_mut::<MaybeUninit<u8>>(
ptr.cast::<MaybeUninit<u8>>(),
std::mem::size_of::<T>(),
)
.zeroize();
}
memlock::munlock(ptr, 1);
if std::mem::size_of::<T>() != 0 {
unsafe { std::alloc::dealloc(ptr.cast::<u8>(), std::alloc::Layout::new::<T>()) };
}
}
}
impl<T> fmt::Debug for SecureBox<T>
where
T: Copy,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SecureBox").finish_non_exhaustive()
}
}
impl<T> fmt::Display for SecureBox<T>
where
T: Copy,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("***SECRET***").map_err(|_| fmt::Error)
}
}
#[cfg(test)]
mod tests {
use std::mem::MaybeUninit;
use zeroize::Zeroize;
use super::SecureBox;
use crate::test_utils::{PRIVATE_KEY_1, PRIVATE_KEY_2, Packed, Padded};
unsafe fn zero_out_secure_box<T>(secure_box: &mut SecureBox<T>)
where
T: Copy,
{
unsafe {
std::slice::from_raw_parts_mut::<MaybeUninit<u8>>(
std::ptr::from_mut::<T>(secure_box.unsecure_mut()).cast::<MaybeUninit<u8>>(),
std::mem::size_of::<T>(),
)
.zeroize();
}
}
#[test]
fn test_secure_box() {
let key_1 = SecureBox::new(Box::new(PRIVATE_KEY_1));
let key_2 = SecureBox::new(Box::new(PRIVATE_KEY_2));
let key_3 = SecureBox::new(Box::new(PRIVATE_KEY_1));
assert_eq!(key_1, key_1);
assert_ne!(key_1, key_2);
assert_ne!(key_2, key_3);
assert_eq!(key_1, key_3);
let mut final_key = key_1.clone();
unsafe {
zero_out_secure_box(&mut final_key);
}
assert_eq!(final_key.unsecure().0, [0; 32]);
}
#[test]
fn test_repr_c_with_padding() {
assert_eq!(std::mem::size_of::<Padded>(), 4);
let sec_a = SecureBox::new(Box::new(Padded { x: 1, y: 2 }));
let sec_b = SecureBox::new(Box::new(Padded { x: 1, y: 2 }));
assert_eq!(sec_a, sec_b);
let sec_c = SecureBox::new(Box::new(Padded { x: 1, y: 3 }));
assert_ne!(sec_a, sec_c);
let sec_d = SecureBox::new(Box::new(Padded { x: 2, y: 2 }));
assert_ne!(sec_a, sec_d);
}
#[test]
fn test_repr_c_packed() {
assert_eq!(std::mem::size_of::<Packed>(), 3);
let sec_a = SecureBox::new(Box::new(Packed { x: 42, y: 1000 }));
let sec_b = SecureBox::new(Box::new(Packed { x: 42, y: 1000 }));
let sec_c = SecureBox::new(Box::new(Packed { x: 42, y: 1001 }));
let sec_d = SecureBox::new(Box::new(Packed { x: 43, y: 1000 }));
assert_eq!(sec_a, sec_b);
assert_ne!(sec_a, sec_c);
assert_ne!(sec_a, sec_d);
}
}