use crate::boxed::Box;
use crate::traits::*;
use std::fmt::{self, Debug, Formatter};
use std::ops::{Deref, DerefMut};
#[derive(Clone, Eq)]
pub struct SecretBox<T: Bytes> {
boxed: Box<T>,
}
pub struct Ref<'a, T: Bytes> {
boxed: &'a Box<T>,
}
pub struct RefMut<'a, T: Bytes> {
boxed: &'a mut Box<T>,
}
impl<T: Bytes> SecretBox<T> {
pub fn new<F>(f: F) -> Self
where
F: FnOnce(&mut T),
{
Self {
boxed: Box::new(1, |b| f(b.as_mut())),
}
}
pub fn try_new<U, E, F>(f: F) -> Result<Self, E>
where
F: FnOnce(&mut T) -> Result<U, E>,
{
Box::try_new(1, |b| f(b.as_mut()))
.map(|b| Self { boxed: b })
}
pub fn size(&self) -> usize {
self.boxed.size()
}
pub fn borrow(&self) -> Ref<'_, T> {
Ref::new(&self.boxed)
}
pub fn borrow_mut(&mut self) -> RefMut<'_, T> {
RefMut::new(&mut self.boxed)
}
}
impl<T: Bytes + Randomizable> SecretBox<T> {
pub fn random() -> Self {
Self {
boxed: Box::random(1),
}
}
}
impl<T: Bytes + Zeroable> SecretBox<T> {
pub fn zero() -> Self {
Self {
boxed: Box::zero(1),
}
}
}
impl<T: Bytes + Zeroable> From<&mut T> for SecretBox<T> {
fn from(data: &mut T) -> Self {
Self { boxed: data.into() }
}
}
impl<T: Bytes> Debug for SecretBox<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.boxed.fmt(f)
}
}
impl<T: Bytes + ConstantEq> PartialEq for SecretBox<T> {
fn eq(&self, rhs: &Self) -> bool {
self.boxed.eq(&rhs.boxed)
}
}
impl<'a, T: Bytes> Ref<'a, T> {
fn new(boxed: &'a Box<T>) -> Self {
proven!(boxed.len() == 1,
"secrets: attempted to dereference a box with zero length");
Self {
boxed: boxed.unlock(),
}
}
}
impl<T: Bytes> Clone for Ref<'_, T> {
fn clone(&self) -> Self {
Self {
boxed: self.boxed.unlock(),
}
}
}
impl<T: Bytes> Drop for Ref<'_, T> {
fn drop(&mut self) {
self.boxed.lock();
}
}
impl<T: Bytes> Deref for Ref<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.boxed.as_ref()
}
}
impl<T: Bytes> Debug for Ref<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.boxed.fmt(f)
}
}
impl<T: Bytes> PartialEq for Ref<'_, T> {
fn eq(&self, rhs: &Self) -> bool {
self.constant_eq(rhs)
}
}
impl<T: Bytes> PartialEq<RefMut<'_, T>> for Ref<'_, T> {
fn eq(&self, rhs: &RefMut<'_, T>) -> bool {
self.constant_eq(rhs)
}
}
impl<T: Bytes> Eq for Ref<'_, T> {}
impl<'a, T: Bytes> RefMut<'a, T> {
fn new(boxed: &'a mut Box<T>) -> Self {
proven!(boxed.len() == 1,
"secrets: attempted to dereference a box with zero length");
Self {
boxed: boxed.unlock_mut(),
}
}
}
impl<T: Bytes> Drop for RefMut<'_, T> {
fn drop(&mut self) {
self.boxed.lock();
}
}
impl<T: Bytes> Deref for RefMut<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.boxed.as_ref()
}
}
impl<T: Bytes> DerefMut for RefMut<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.boxed.as_mut()
}
}
impl<T: Bytes> Debug for RefMut<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.boxed.fmt(f)
}
}
impl<T: Bytes> PartialEq for RefMut<'_, T> {
fn eq(&self, rhs: &Self) -> bool {
self.constant_eq(rhs)
}
}
impl<T: Bytes> PartialEq<Ref<'_, T>> for RefMut<'_, T> {
fn eq(&self, rhs: &Ref<'_, T>) -> bool {
self.constant_eq(rhs)
}
}
impl<T: Bytes> Eq for RefMut<'_, T> {}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn it_allows_custom_initialization() {
let _ = SecretBox::<u64>::new(|s| {
*s = 0x8f1a;
assert_eq!(*s, 0x8f1a);
});
}
#[test]
fn it_allows_failing_initialization() {
assert!(SecretBox::<u8>::try_new(|_| Ok::<(), ()>(())).is_ok());
}
#[test]
fn it_allows_borrowing_immutably() {
let secret = SecretBox::<u64>::zero();
let s = secret.borrow();
assert_eq!(*s, 0);
}
#[test]
fn it_allows_borrowing_mutably() {
let mut secret = SecretBox::<u64>::zero();
let mut s = secret.borrow_mut();
*s = 0x01ab_cdef;
assert_eq!(*s, 0x01ab_cdef);
}
#[test]
fn it_allows_storing_fixed_size_arrays() {
let secret = SecretBox::<[u8; 2]>::new(|s| {
*s = [0xdd, 0xa1];
});
assert_eq!(*secret.borrow(), [0xdd, 0xa1]);
}
#[test]
fn it_provides_its_size() {
let secret = SecretBox::<[u128; 4]>::zero();
assert_eq!(secret.size(), 64);
}
#[test]
fn it_preserves_secrecy() {
let mut secret = SecretBox::<u64>::random();
assert_eq!(
format!("{{ {} bytes redacted }}", 8),
format!("{:?}", secret),
);
assert_eq!(
format!("{{ {} bytes redacted }}", 8),
format!("{:?}", secret.borrow()),
);
assert_eq!(
format!("{{ {} bytes redacted }}", 8),
format!("{:?}", secret.borrow_mut()),
);
}
#[test]
fn it_moves_safely() {
let secret_1 = SecretBox::<u8>::zero();
let secret_2 = secret_1;
assert_eq!(*secret_2.borrow(), 0);
}
#[test]
fn it_safely_clones_immutable_references() {
let secret = SecretBox::<u8>::random();
let borrow_1 = secret.borrow();
let borrow_2 = borrow_1.clone();
assert_eq!(borrow_1, borrow_2);
}
#[test]
fn it_compares_equality() {
let secret_1 = SecretBox::<u8>::from(&mut 0xaf);
let secret_2 = secret_1.clone();
assert_eq!(secret_1, secret_2);
}
#[test]
fn it_compares_inequality() {
let secret_1 = SecretBox::<[u128; 8]>::random();
let secret_2 = SecretBox::<[u128; 8]>::random();
assert_ne!(secret_1, secret_2);
}
#[test]
fn it_compares_equality_immutably_on_refs() {
let secret_1 = SecretBox::<u8>::from(&mut 0xaf);
let secret_2 = secret_1.clone();
assert_eq!(secret_1.borrow(), secret_2.borrow());
}
#[test]
fn it_compares_equality_immutably_on_ref_muts() {
let mut secret_1 = SecretBox::<u8>::from(&mut 0xaf);
let mut secret_2 = secret_1.clone();
assert_eq!(secret_1.borrow_mut(), secret_2.borrow_mut());
}
#[test]
fn it_compares_equality_immutably_regardless_of_mut() {
let mut secret_1 = SecretBox::<u8>::from(&mut 0xaf);
let mut secret_2 = secret_1.clone();
assert_eq!(secret_1.borrow_mut(), secret_2.borrow());
assert_eq!(secret_2.borrow_mut(), secret_1.borrow());
}
}
#[cfg(all(test, profile = "debug"))]
mod tests_proven_statements {
#![allow(unsafe_code)]
use super::*;
#[test]
#[should_panic(expected = "secrets: attempted to dereference a zero-length pointer")]
fn it_doesnt_allow_borrowing_zero_length() {
let boxed = Box::<u8>::zero(0);
let _ = boxed.as_ref();
}
#[test]
#[should_panic(expected = "secrets: attempted to dereference a zero-length pointer")]
fn it_doesnt_allow_mutably_borrowing_zero_length() {
let mut boxed = Box::<u8>::zero(0);
let _ = boxed.as_mut();
}
}