use std::fmt;
use lexe_std::array;
use ref_cast::RefCast;
use ring::{
aead::{self, BoundKey},
hkdf,
};
use serde_core::ser::{Serialize, SerializeStruct, Serializer};
use crate::rng::{Crng, RngExt};
const VERSION_LEN: usize = 1;
const KEY_ID_LEN: usize = 32;
const TAG_LEN: usize = 16;
pub const fn encrypted_len(plaintext_len: usize) -> usize {
VERSION_LEN + KEY_ID_LEN + plaintext_len + TAG_LEN
}
pub struct AesMasterKey(hkdf::Prk);
#[derive(RefCast)]
#[repr(transparent)]
struct KeyId([u8; 32]);
struct Aad<'data, 'aad> {
version: u8,
key_id: &'data KeyId,
aad: &'aad [&'aad [u8]],
}
struct EncryptKey(aead::SealingKey<ZeroNonce>);
struct DecryptKey(aead::OpeningKey<ZeroNonce>);
struct ZeroNonce(Option<aead::Nonce>);
#[derive(Clone, Debug)]
pub struct DecryptError;
impl std::error::Error for DecryptError {}
impl fmt::Display for DecryptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("decrypt error: ciphertext or metadata may be corrupted")
}
}
impl fmt::Debug for AesMasterKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("AesMasterKey(..)")
}
}
impl AesMasterKey {
const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::AesMasterKey");
pub fn new(root_seed_derived_secret: &[u8; 32]) -> Self {
Self(
hkdf::Salt::new(hkdf::HKDF_SHA256, &Self::HKDF_SALT)
.extract(root_seed_derived_secret),
)
}
fn derive_unbound_key(&self, key_id: &KeyId) -> aead::UnboundKey {
aead::UnboundKey::from(
self.0
.expand(&[key_id.as_slice()], &aead::AES_256_GCM)
.expect("This should never fail"),
)
}
fn derive_encrypt_key(&self, key_id: &KeyId) -> EncryptKey {
let nonce = ZeroNonce::new();
let key = aead::SealingKey::new(self.derive_unbound_key(key_id), nonce);
EncryptKey(key)
}
fn derive_decrypt_key(&self, key_id: &KeyId) -> DecryptKey {
let nonce = ZeroNonce::new();
let key = aead::OpeningKey::new(self.derive_unbound_key(key_id), nonce);
DecryptKey(key)
}
pub fn encrypt<R: Crng>(
&self,
rng: &mut R,
aad: &[&[u8]],
data_size_hint: Option<usize>,
write_data_cb: &dyn Fn(&mut Vec<u8>),
) -> Vec<u8> {
let version = 0;
let key_id = KeyId::from_rng(rng);
let aad = Aad {
version,
key_id: &key_id,
aad,
}
.serialize();
let approx_encrypted_len = encrypted_len(data_size_hint.unwrap_or(0));
let mut data = Vec::with_capacity(approx_encrypted_len);
data.push(version);
data.extend_from_slice(key_id.as_slice());
let plaintext_offset = data.len();
write_data_cb(&mut data);
self.derive_encrypt_key(&key_id).encrypt_in_place(
aad.as_slice(),
&mut data,
plaintext_offset,
);
data
}
pub fn decrypt(
&self,
aad: &[&[u8]],
mut data: Vec<u8>,
) -> Result<Vec<u8>, DecryptError> {
const MIN_DATA_LEN: usize = encrypted_len(0 );
if data.len() < MIN_DATA_LEN {
return Err(DecryptError);
}
let (version, key_id) = {
let (version, data) = data
.split_first_chunk::<VERSION_LEN>()
.expect("data.len() checked above");
let (key_id, _) = data
.split_first_chunk::<KEY_ID_LEN>()
.expect("data.len() checked above");
(version[0], key_id)
};
if version != 0 {
return Err(DecryptError);
}
let key_id = KeyId::from_ref(key_id);
let decrypt_key = self.derive_decrypt_key(key_id);
let aad = Aad {
version,
key_id,
aad,
}
.serialize();
let ciphertext_and_tag_offset = VERSION_LEN + KEY_ID_LEN;
decrypt_key.decrypt_in_place(
&aad,
&mut data,
ciphertext_and_tag_offset,
)?;
Ok(data)
}
}
impl EncryptKey {
fn encrypt_in_place(
mut self,
aad: &[u8],
data: &mut Vec<u8>,
plaintext_offset: usize,
) {
assert!(plaintext_offset <= data.len());
let aad = aead::Aad::from(aad);
let tag = self
.0
.seal_in_place_separate_tag(aad, &mut data[plaintext_offset..])
.expect(
"Cannot encrypt more than ~4 GiB at once (should never happen)",
);
data.extend_from_slice(tag.as_ref());
}
}
impl DecryptKey {
fn decrypt_in_place(
mut self,
aad: &[u8],
data: &mut Vec<u8>,
ciphertext_and_tag_offset: usize,
) -> Result<(), DecryptError> {
let aad = aead::Aad::from(aad);
let plaintext_ref = self
.0
.open_within(aad, data, ciphertext_and_tag_offset..)
.map_err(|_| DecryptError)?;
let plaintext_len = plaintext_ref.len();
data.truncate(plaintext_len);
Ok(())
}
}
impl KeyId {
#[inline]
const fn from_ref(arr: &[u8; 32]) -> &Self {
lexe_std::const_utils::const_ref_cast(arr)
}
#[inline]
fn as_slice(&self) -> &[u8] {
self.0.as_slice()
}
fn from_rng<R: Crng>(rng: &mut R) -> Self {
Self(rng.gen_bytes())
}
}
impl Serialize for KeyId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.0.serialize(serializer)
}
}
impl Aad<'_, '_> {
fn serialize(&self) -> Vec<u8> {
let len = bcs::serialized_size(self)
.expect("Serializing the AAD should never fail");
let mut out = Vec::with_capacity(len);
bcs::serialize_into(&mut out, self)
.expect("Serializing the AAD should never fail");
out
}
}
impl Serialize for Aad<'_, '_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut fields = serializer.serialize_struct("Aad", 3)?;
fields.serialize_field("version", &self.version)?;
fields.serialize_field("key_id", self.key_id)?;
fields.serialize_field("aad", self.aad)?;
fields.end()
}
}
impl ZeroNonce {
fn new() -> Self {
Self(Some(aead::Nonce::assume_unique_for_key([0u8; 12])))
}
}
impl aead::NonceSequence for ZeroNonce {
fn advance(&mut self) -> Result<aead::Nonce, ring::error::Unspecified> {
Ok(self.0.take().expect(
"We somehow encrypted / decrypted more than once with the same key",
))
}
}
#[cfg(any(test, feature = "test-utils"))]
pub(crate) fn derive_key(rng: &mut crate::rng::FastRng) -> AesMasterKey {
struct OkmLength;
impl hkdf::KeyType for OkmLength {
fn len(&self) -> usize {
32
}
}
const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::RootSeed");
let seed: [u8; 32] = rng.gen_bytes();
let mut key_seed = [0u8; 32];
hkdf::Salt::new(hkdf::HKDF_SHA256, HKDF_SALT.as_slice())
.extract(&seed)
.expand(&[b"vfs master key"], OkmLength)
.unwrap()
.fill(key_seed.as_mut_slice())
.unwrap();
AesMasterKey::new(&key_seed)
}
#[cfg(any(test, feature = "test-utils"))]
mod arbitrary_impl {
use proptest::{
arbitrary::{Arbitrary, any},
strategy::{BoxedStrategy, Strategy},
};
use super::*;
use crate::rng::FastRng;
impl Arbitrary for AesMasterKey {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<FastRng>()
.prop_map(|mut rng| derive_key(&mut rng))
.boxed()
}
}
}
#[cfg(test)]
mod test {
use lexe_hex::hex;
use proptest::{
arbitrary::any, collection::vec, prop_assert, prop_assert_eq, proptest,
};
use super::*;
use crate::rng::FastRng;
#[test]
fn test_aad_compat() {
let aad = Aad {
version: 0,
key_id: KeyId::from_ref(&[0x69; 32]),
aad: &[],
}
.serialize();
let expected_aad = hex::decode(
"00\
6969696969696969696969696969696969696969696969696969696969696969\
00",
)
.unwrap();
assert_eq!(&aad, &expected_aad);
let aad = Aad {
version: 0,
key_id: KeyId::from_ref(&[0x42; 32]),
aad: &[b"aaaaaaaa".as_slice(), b"0123456789".as_slice()],
}
.serialize();
let expected_aad = hex::decode(
"00\
4242424242424242424242424242424242424242424242424242424242424242\
02\
08\
6161616161616161\
0a\
30313233343536373839",
)
.unwrap();
assert_eq!(&aad, &expected_aad);
}
#[test]
fn test_decrypt_compat() {
let mut rng = FastRng::from_u64(123);
let vfs_key = derive_key(&mut rng);
let encrypted = hex::decode(
"00\
b0abd2beab31c1d925c5d8059cf90068eece2c41a3a6e4454d84e36ad6858a01\
\
0e2d1f6d16e9bb5738de28b4f180f07f",
)
.unwrap();
let decrypted = vfs_key.decrypt(&[], encrypted).unwrap();
assert_eq!(decrypted.as_slice(), b"");
let aad = b"my context".as_slice();
let plaintext = b"my cool message".as_slice();
let encrypted = hex::decode(
"00\
c87fea5c4db8c16d3dae5a6ead5ee5985fa7c38721b9624e37772adea6a48aae\
22f52c6f08440092338d16e3402eaf\
c3972d357e56dad4cc42c6a80da4ac35",
)
.unwrap();
let decrypted = vfs_key.decrypt(&[aad], encrypted).unwrap();
assert_eq!(decrypted.as_slice(), plaintext);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
proptest!(|(
mut rng in any::<FastRng>(),
aad in vec(vec(any::<u8>(), 0..=16), 0..=4),
plaintext in vec(any::<u8>(), 0..=256),
)| {
let vfs_key = derive_key(&mut rng);
let aad_ref = aad
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>();
let encrypted = vfs_key.encrypt(&mut rng, &aad_ref, Some(plaintext.len()), &|out: &mut Vec<u8>| {
out.extend_from_slice(&plaintext);
});
let decrypted = vfs_key.decrypt(&aad_ref, encrypted.clone()).unwrap();
prop_assert_eq!(&plaintext, &decrypted);
let encrypted2 = vfs_key.encrypt(&mut rng, &aad_ref, None, &|out: &mut Vec<u8>| {
out.extend_from_slice(&plaintext);
});
prop_assert!(encrypted != encrypted2);
});
}
}