use std::cmp::{min, Ordering};
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::{Deref, DerefMut};
const TRACE: bool = false;
pub struct Protected(*mut [u8]);
unsafe impl Send for Protected {}
unsafe impl Sync for Protected {}
impl Clone for Protected {
fn clone(&self) -> Self {
let mut p = Vec::with_capacity(self.len());
p.extend_from_slice(self);
p.into_boxed_slice().into()
}
}
impl PartialEq for Protected {
fn eq(&self, other: &Self) -> bool {
secure_cmp(self, other) == Ordering::Equal
}
}
impl Eq for Protected {}
impl Hash for Protected {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state);
}
}
impl Protected {
pub fn new(size: usize) -> Protected {
vec![0; size].into_boxed_slice().into()
}
pub(crate) fn expose_into_unprotected_vec(self) -> Vec<u8> {
let mut p = Vec::with_capacity(self.len());
p.extend_from_slice(&self);
p
}
}
impl Deref for Protected {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl AsRef<[u8]> for Protected {
fn as_ref(&self) -> &[u8] {
unsafe { &*self.0 }
}
}
impl AsMut<[u8]> for Protected {
fn as_mut(&mut self) -> &mut [u8] {
unsafe { &mut *self.0 }
}
}
impl DerefMut for Protected {
fn deref_mut(&mut self) -> &mut [u8] {
self.as_mut()
}
}
impl From<Vec<u8>> for Protected {
fn from(mut v: Vec<u8>) -> Self {
let p = Protected::from(&v[..]);
let capacity = v.capacity();
unsafe {
v.set_len(capacity);
memsec::memzero(v.as_mut_ptr(), capacity);
}
p
}
}
#[allow(dead_code)]
#[inline(never)]
pub(crate) fn zero_stack_after<const N: usize, T>(fun: impl FnOnce() -> T) -> T
{
zero_stack::<N, T>(fun())
}
#[allow(dead_code)]
#[inline(never)]
pub(crate) fn zero_stack<const N: usize, T>(v: T) -> T {
tracer!(TRACE, "zero_stack");
let mut a = [0xffu8; N];
t!("zeroing {:?}..{:?}", a.as_ptr(), unsafe { a.as_ptr().offset(N as _) });
unsafe {
memsec::memzero(a.as_mut_ptr(), a.len());
}
std::hint::black_box(a);
v
}
pub(crate) fn careful_memcpy(from: &[u8], to: &mut [u8]) {
from.iter().zip(to.iter_mut()).for_each(|(f, t)| *t = *f);
}
impl From<Box<[u8]>> for Protected {
fn from(v: Box<[u8]>) -> Self {
Protected(Box::leak(v))
}
}
impl From<&[u8]> for Protected {
fn from(v: &[u8]) -> Self {
let mut p = Protected::new(v.len());
careful_memcpy(v, &mut p);
p
}
}
impl<const N: usize> From<[u8; N]> for Protected {
fn from(mut v: [u8; N]) -> Self {
let mut p = Protected::new(v.len());
careful_memcpy(&v, &mut p);
unsafe {
memsec::memzero(v.as_mut_ptr(), v.len());
}
p
}
}
impl Drop for Protected {
fn drop(&mut self) {
unsafe {
let len = self.len();
memsec::memzero(self.as_mut().as_mut_ptr(), len);
drop(Box::from_raw(self.0));
}
}
}
impl fmt::Debug for Protected {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if cfg!(debug_assertions) {
write!(f, "{:?}", self.0)
} else {
f.write_str("[<Redacted>]")
}
}
}
#[derive(Clone, Debug)]
pub struct Encrypted {
ciphertext: Protected,
salt: [u8; 32],
plaintext_len: usize,
}
assert_send_and_sync!(Encrypted);
impl PartialEq for Encrypted {
fn eq(&self, other: &Self) -> bool {
self.map(|a| other.map(|b| a == b))
}
}
impl Eq for Encrypted {}
impl Hash for Encrypted {
fn hash<H: Hasher>(&self, state: &mut H) {
self.map(|k| Hash::hash(k, state));
}
}
const DANGER_DISABLE_ENCRYPTED_MEMORY: bool = false;
const ENCRYPTED_MEMORY_PREKEY_PAGES: usize = 4;
const ENCRYPTED_MEMORY_PAGE_SIZE: usize = 4096;
mod has_access_to_prekey {
use crate::Result;
use crate::types::{AEADAlgorithm, HashAlgorithm, SymmetricAlgorithm};
use crate::crypto::{aead, SessionKey};
use super::*;
fn prekey() -> Result<&'static Box<[Box<[u8]>]>> {
use std::sync::OnceLock;
static PREKEY: OnceLock<Result<Box<[Box<[u8]>]>>>
= OnceLock::new();
PREKEY.get_or_init(|| -> Result<Box<[Box<[u8]>]>> {
let mut pages = Vec::new();
for _ in 0..ENCRYPTED_MEMORY_PREKEY_PAGES {
let mut page = vec![0; ENCRYPTED_MEMORY_PAGE_SIZE];
crate::crypto::random(&mut page)?;
pages.push(page.into());
}
Ok(pages.into())
}).as_ref().map_err(|e| anyhow::anyhow!("{}", e))
}
const HASH_ALGO: HashAlgorithm = HashAlgorithm::SHA256;
const SYMMETRIC_ALGO: SymmetricAlgorithm = SymmetricAlgorithm::AES256;
impl Encrypted {
fn sealing_key(salt: &[u8; 32]) -> Result<SessionKey> {
let mut ctx = HASH_ALGO.context()
.expect("Mandatory algorithm unsupported")
.for_digest();
ctx.update(salt);
prekey()?
.iter().for_each(|page| ctx.update(page));
let mut sk: SessionKey = Protected::new(256/8).into();
let _ = ctx.digest(&mut sk);
Ok(sk)
}
fn nonce(aead_algo: AEADAlgorithm) -> &'static [u8] {
const NONCE_STORE: [u8; aead::MAX_NONCE_LEN] =
[0u8; aead::MAX_NONCE_LEN];
let nonce_len = aead_algo.nonce_size()
.expect("Mandatory algorithm unsupported");
debug_assert!(nonce_len >= 8 && nonce_len <= aead::MAX_NONCE_LEN);
&NONCE_STORE[..nonce_len]
}
pub fn new(p: Protected) -> Result<Self> {
if DANGER_DISABLE_ENCRYPTED_MEMORY {
return Ok(Encrypted {
plaintext_len: p.len(),
ciphertext: p,
salt: Default::default(),
});
}
let aead_algo = AEADAlgorithm::default();
let mut salt = [0; 32];
crate::crypto::random(&mut salt)?;
let mut ciphertext = Protected::new(
p.len() + aead_algo.digest_size().expect("supported"));
aead_algo.context(SYMMETRIC_ALGO,
&Self::sealing_key(&salt)?,
&[],
Self::nonce(aead_algo))?
.for_encryption()?
.encrypt_seal(&mut ciphertext, &p)?;
Ok(Encrypted {
plaintext_len: p.len(),
ciphertext,
salt,
})
}
pub fn map<F, T>(&self, mut fun: F) -> T
where F: FnMut(&Protected) -> T
{
if DANGER_DISABLE_ENCRYPTED_MEMORY {
return fun(&self.ciphertext);
}
let aead_algo = AEADAlgorithm::default();
let mut plaintext = Protected::new(self.plaintext_len);
let r = aead_algo.context(SYMMETRIC_ALGO,
&Self::sealing_key(&self.salt).unwrap(),
&[],
Self::nonce(aead_algo)).unwrap()
.for_decryption().unwrap()
.decrypt_verify(&mut plaintext, &self.ciphertext);
if r.is_err() {
drop(plaintext); panic!("Encrypted memory modified or corrupted");
}
fun(&plaintext)
}
}
}
pub fn secure_cmp(a: &[u8], b: &[u8]) -> Ordering {
let ord1 = a.len().cmp(&b.len());
let ord2 = unsafe {
memsec::memcmp(a.as_ptr(), b.as_ptr(), min(a.len(), b.len()))
};
let ord2 = match ord2 {
1..=std::i32::MAX => Ordering::Greater,
0 => Ordering::Equal,
std::i32::MIN..=-1 => Ordering::Less,
};
if ord1 == Ordering::Equal { ord2 } else { ord1 }
}