use ascon_aead::aead::generic_array::sequence::Split;
use ascon_aead::aead::{AeadInPlace, KeyInit};
use ascon_aead::{Ascon128, Key, Nonce};
use ascon_hash::{AsconHash, Digest};
use rand_chacha::{
rand_core::{RngCore, SeedableRng},
ChaCha20Rng,
};
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::slice;
use zeroize::Zeroize;
use crate::pagevec::PageVec;
use crate::utils;
unsafe fn derive_secrets(region_ptr: *const u8) -> (Key<Ascon128>, Nonce<Ascon128>) {
let s1 = slice::from_raw_parts(region_ptr, *utils::PAGE_SIZE);
let region_p3 = region_ptr.add(*utils::PAGE_SIZE * 2);
let s2 = slice::from_raw_parts(region_p3, *utils::PAGE_SIZE);
let hash = <AsconHash as Digest>::new()
.chain_update(s1)
.chain_update(s2)
.finalize();
hash.split()
}
fn cleanup(region_ptr: *mut u8) {
unsafe {
utils::munlock(region_ptr, *utils::PAGE_SIZE * 3);
utils::memzero(region_ptr, *utils::PAGE_SIZE * 3);
utils::free(region_ptr, *utils::PAGE_SIZE * 3);
}
}
#[derive(Debug)]
pub struct CryptBox<T> {
data: PageVec,
_marker: std::marker::PhantomData<T>,
}
impl<T> CryptBox<T> {
#[must_use]
#[inline]
pub fn new(val: T) -> Self {
Self::try_new(val).expect("CryptBox memory allocation failed!")
}
#[must_use]
#[inline]
pub fn construct<F: Fn() -> T>(f: F) -> Self {
Self::try_construct(f).expect("CryptBox memory allocation failed!")
}
#[must_use]
pub fn try_new(val: T) -> Option<CryptBox<T>> {
if size_of::<T>() + 16 > *utils::PAGE_SIZE {
return None;
}
let region_ptr = utils::alloc(*utils::PAGE_SIZE * 3)?;
let (mut data, mut key, mut nonce) = unsafe {
let region_slice =
slice::from_raw_parts_mut(region_ptr.as_ptr(), *utils::PAGE_SIZE * 3);
let mut rng = ChaCha20Rng::from_os_rng();
rng.fill_bytes(region_slice);
let (key, nonce) = derive_secrets(region_ptr.as_ptr());
let region_data = region_ptr.add(*utils::PAGE_SIZE).as_ptr() as *mut T;
region_data.write_unaligned(val);
(
PageVec {
ptr: region_data as *mut u8,
len: size_of::<T>(),
},
key,
nonce,
)
};
let cipher = Ascon128::new(&key);
match cipher.encrypt_in_place(&nonce, b"", &mut data) {
Ok(_) => (),
Err(_) => {
cleanup(region_ptr.as_ptr());
return None;
}
}
unsafe {
utils::mlock(region_ptr.as_ptr(), *utils::PAGE_SIZE * 3);
utils::mprotect(
region_ptr.as_ptr(),
*utils::PAGE_SIZE * 3,
utils::Prot::NoAccess,
);
}
key.zeroize();
nonce.zeroize();
Some(Self {
data,
_marker: std::marker::PhantomData,
})
}
#[must_use]
pub fn try_construct<F: Fn() -> T>(f: F) -> Option<CryptBox<T>> {
if size_of::<T>() + 16 > *utils::PAGE_SIZE {
return None;
}
let region_ptr = utils::alloc(*utils::PAGE_SIZE * 3)?;
let (mut data, mut key, mut nonce) = unsafe {
let region_slice =
slice::from_raw_parts_mut(region_ptr.as_ptr(), *utils::PAGE_SIZE * 3);
let mut rng = ChaCha20Rng::from_os_rng();
rng.fill_bytes(region_slice);
let (key, nonce) = derive_secrets(region_ptr.as_ptr());
let region_data = region_ptr.add(*utils::PAGE_SIZE).as_ptr() as *mut T;
*region_data = f();
(
PageVec {
ptr: region_data as *mut u8,
len: size_of::<T>(),
},
key,
nonce,
)
};
let cipher = Ascon128::new(&key);
match cipher.encrypt_in_place(&nonce, b"", &mut data) {
Ok(_) => (),
Err(_) => {
cleanup(region_ptr.as_ptr());
return None;
}
}
unsafe {
utils::mlock(region_ptr.as_ptr(), *utils::PAGE_SIZE * 3);
utils::mprotect(
region_ptr.as_ptr(),
*utils::PAGE_SIZE * 3,
utils::Prot::NoAccess,
);
}
key.zeroize();
nonce.zeroize();
Some(Self {
data,
_marker: std::marker::PhantomData,
})
}
#[must_use]
pub fn decrypt(self) -> PlainBox<T> {
let mut s = ManuallyDrop::new(self);
unsafe {
let region_ptr = s.data.ptr.sub(*utils::PAGE_SIZE);
utils::mprotect(region_ptr, *utils::PAGE_SIZE * 3, utils::Prot::ReadWrite);
let (mut key, mut nonce) = derive_secrets(region_ptr);
let cipher = Ascon128::new(&key);
match cipher.decrypt_in_place(&nonce, b"", &mut s.data) {
Ok(_) => (),
Err(_) => {
cleanup(region_ptr);
panic!("CryptBox decryption failure!");
}
}
key.zeroize();
nonce.zeroize();
let region_data = s.data.ptr as *mut T;
PlainBox { ptr: region_data }
}
}
#[must_use]
pub fn encrypt(s: PlainBox<T>) -> CryptBox<T> {
let s = ManuallyDrop::new(s);
let mut data = PageVec {
ptr: s.ptr as *mut u8,
len: size_of::<T>(),
};
unsafe {
let region_ptr = data.ptr.sub(*utils::PAGE_SIZE);
let (mut key, mut nonce) = derive_secrets(region_ptr);
let cipher = Ascon128::new(&key);
match cipher.encrypt_in_place(&nonce, b"", &mut data) {
Ok(_) => (),
Err(_) => {
cleanup(region_ptr);
panic!("CryptBox decryption failure!");
}
}
utils::mprotect(region_ptr, *utils::PAGE_SIZE * 3, utils::Prot::NoAccess);
key.zeroize();
nonce.zeroize();
}
Self {
data,
_marker: std::marker::PhantomData,
}
}
}
impl<T> Drop for CryptBox<T> {
fn drop(&mut self) {
unsafe {
let region = self.data.ptr.sub(*utils::PAGE_SIZE);
utils::mprotect(region, *utils::PAGE_SIZE * 3, utils::Prot::ReadWrite);
utils::munlock(region, *utils::PAGE_SIZE * 3);
utils::memzero(region, *utils::PAGE_SIZE * 3);
utils::free(region, *utils::PAGE_SIZE * 3);
}
}
}
#[derive(Debug)]
pub struct PlainBox<T> {
ptr: *mut T,
}
impl<T> Drop for PlainBox<T> {
fn drop(&mut self) {
unsafe {
let region = self.ptr.byte_sub(*utils::PAGE_SIZE) as *mut u8;
utils::munlock(region, *utils::PAGE_SIZE * 3);
utils::memzero(region, *utils::PAGE_SIZE * 3);
utils::free(region, *utils::PAGE_SIZE * 3);
}
}
}
impl<T> Deref for PlainBox<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.ptr }
}
}
impl<T> DerefMut for PlainBox<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.ptr }
}
}
impl<T> AsRef<T> for PlainBox<T>
where
<PlainBox<T> as Deref>::Target: AsRef<T>,
{
fn as_ref(&self) -> &T {
self.deref().as_ref()
}
}
impl<T> AsMut<T> for PlainBox<T>
where
<PlainBox<T> as Deref>::Target: AsMut<T>,
{
fn as_mut(&mut self) -> &mut T {
self.deref_mut().as_mut()
}
}
impl<T> std::borrow::Borrow<T> for PlainBox<T> {
fn borrow(&self) -> &T {
&**self
}
}
impl<T> std::borrow::BorrowMut<T> for PlainBox<T> {
fn borrow_mut(&mut self) -> &mut T {
&mut **self
}
}