use std::{fmt, num::NonZeroU32};
use aws_lc_rs::{aead, aead::BoundKey, digest, error::Unspecified, pbkdf2};
use rand::CryptoRng;
use crate::internal::{IronOxideErr, take_lock};
use std::{convert::TryFrom, ops::DerefMut, sync::Mutex};
const PBKDF2_ITERATIONS: NonZeroU32 = NonZeroU32::new(250_000).unwrap();
const PBKDF2_SALT_LEN: usize = 32;
pub(crate) const AES_GCM_TAG_LEN: usize = 16;
pub(crate) const AES_IV_LEN: usize = 12;
pub(crate) const AES_KEY_LEN: usize = 32;
pub(crate) const AES_BLOCK_SIZE: usize = 16; const ENCRYPTED_KEY_AND_GCM_TAG_LEN: usize = AES_KEY_LEN + AES_GCM_TAG_LEN;
pub struct EncryptedMasterKey {
pbkdf2_salt: [u8; PBKDF2_SALT_LEN],
aes_iv: [u8; AES_IV_LEN],
encrypted_key: [u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN],
}
impl fmt::Debug for EncryptedMasterKey {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter
.debug_struct(stringify!(EncryptedMasterKey))
.field("pbkdf2_salt", &&self.pbkdf2_salt)
.field("aes_iv", &&self.aes_iv)
.field("encrypted_key", &&self.encrypted_key[..])
.finish()
}
}
impl EncryptedMasterKey {
pub const SIZE_BYTES: usize = PBKDF2_SALT_LEN + AES_IV_LEN + ENCRYPTED_KEY_AND_GCM_TAG_LEN;
pub fn new(
pbkdf2_salt: [u8; PBKDF2_SALT_LEN],
aes_iv: [u8; AES_IV_LEN],
encrypted_key: [u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN],
) -> EncryptedMasterKey {
EncryptedMasterKey {
pbkdf2_salt,
aes_iv,
encrypted_key,
}
}
pub fn new_from_slice(bytes: &[u8]) -> Result<EncryptedMasterKey, IronOxideErr> {
if bytes.len() == EncryptedMasterKey::SIZE_BYTES {
let mut pbkdf2_salt = [0u8; PBKDF2_SALT_LEN];
let mut aes_iv = [0u8; AES_IV_LEN];
let mut encrypted_key = [0u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN];
pbkdf2_salt.copy_from_slice(&bytes[..PBKDF2_SALT_LEN]);
aes_iv.copy_from_slice(&bytes[PBKDF2_SALT_LEN..(PBKDF2_SALT_LEN + AES_IV_LEN)]);
encrypted_key.copy_from_slice(&bytes[(PBKDF2_SALT_LEN + AES_IV_LEN)..]);
Ok(EncryptedMasterKey::new(pbkdf2_salt, aes_iv, encrypted_key))
} else {
Err(IronOxideErr::WrongSizeError(
Some(bytes.len()),
Some(EncryptedMasterKey::SIZE_BYTES),
))
}
}
pub fn bytes(&self) -> [u8; EncryptedMasterKey::SIZE_BYTES] {
let mut dest = [0u8; EncryptedMasterKey::SIZE_BYTES];
let vec = [
&self.pbkdf2_salt[..],
&self.aes_iv[..],
&self.encrypted_key[..],
]
.concat();
debug_assert!(dest.len() == vec.len());
dest.copy_from_slice(&vec[..]);
dest
}
}
#[derive(Clone, Debug)]
pub struct AesEncryptedValue {
aes_iv: [u8; AES_IV_LEN],
ciphertext: Vec<u8>,
}
impl AesEncryptedValue {
pub fn bytes(&self) -> Vec<u8> {
[&self.aes_iv[..], &self.ciphertext].concat()
}
}
impl TryFrom<&[u8]> for AesEncryptedValue {
type Error = IronOxideErr;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() < (AES_IV_LEN + AES_GCM_TAG_LEN) {
Err(IronOxideErr::AesEncryptedDocSizeError)
} else {
let mut iv: [u8; AES_IV_LEN] = [0u8; AES_IV_LEN];
iv.copy_from_slice(&bytes[..AES_IV_LEN]);
Ok(AesEncryptedValue {
aes_iv: iv,
ciphertext: bytes[AES_IV_LEN..].to_vec(),
})
}
}
}
impl From<aws_lc_rs::error::Unspecified> for IronOxideErr {
fn from(ring_err: Unspecified) -> Self {
IronOxideErr::AesError(ring_err)
}
}
fn derive_key_from_password(password: &str, salt: [u8; PBKDF2_SALT_LEN]) -> [u8; AES_KEY_LEN] {
let mut derived_key = [0u8; digest::SHA256_OUTPUT_LEN];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
PBKDF2_ITERATIONS,
&salt,
password.as_bytes(),
&mut derived_key,
);
derived_key
}
pub fn encrypt_user_master_key<R: CryptoRng>(
rng: &Mutex<R>,
password: &str,
user_master_key: &[u8; 32],
) -> Result<EncryptedMasterKey, Unspecified> {
let mut salt = [0u8; PBKDF2_SALT_LEN];
take_lock(rng).deref_mut().fill_bytes(&mut salt);
let derived_key = derive_key_from_password(password, salt);
let encrypted_key = encrypt(rng, user_master_key.to_vec(), derived_key)?;
let mut master_key_ciphertext = [0u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN];
master_key_ciphertext[..].copy_from_slice(&encrypted_key.ciphertext[..]);
Ok(EncryptedMasterKey {
pbkdf2_salt: salt,
aes_iv: encrypted_key.aes_iv,
encrypted_key: master_key_ciphertext,
})
}
pub fn decrypt_user_master_key(
password: &str,
encrypted_master_key: &EncryptedMasterKey,
) -> Result<[u8; 32], Unspecified> {
let derived_key = derive_key_from_password(password, encrypted_master_key.pbkdf2_salt);
let mut fixed_decrypted_master_key = [0u8; 32];
let mut encrypted_key = AesEncryptedValue {
aes_iv: encrypted_master_key.aes_iv,
ciphertext: encrypted_master_key.encrypted_key.to_vec(),
};
let decrypted_master_key = decrypt(&mut encrypted_key, derived_key)?;
fixed_decrypted_master_key[..].copy_from_slice(decrypted_master_key);
Ok(fixed_decrypted_master_key)
}
struct SingleUseNonceGenerator {
iv: Option<[u8; aead::NONCE_LEN]>,
}
impl SingleUseNonceGenerator {
fn new(iv: [u8; aead::NONCE_LEN]) -> SingleUseNonceGenerator {
SingleUseNonceGenerator { iv: Some(iv) }
}
}
impl aead::NonceSequence for SingleUseNonceGenerator {
fn advance(&mut self) -> Result<aead::Nonce, Unspecified> {
self.iv
.take() .map_or_else(
|| Err(Unspecified),
|iv| Ok(aead::Nonce::assume_unique_for_key(iv)),
)
}
}
pub fn encrypt<R: CryptoRng>(
rng: &Mutex<R>,
mut plaintext: Vec<u8>,
key: [u8; AES_KEY_LEN],
) -> Result<AesEncryptedValue, Unspecified> {
let algorithm = &aead::AES_256_GCM;
let mut iv = [0u8; aead::NONCE_LEN];
take_lock(rng).deref_mut().fill_bytes(&mut iv);
let mut aes_key = aead::SealingKey::new(
aead::UnboundKey::new(algorithm, &key[..])?,
SingleUseNonceGenerator::new(iv),
);
aes_key.seal_in_place_append_tag(aead::Aad::empty(), &mut plaintext)?;
Ok(AesEncryptedValue {
ciphertext: plaintext,
aes_iv: iv,
})
}
pub async fn encrypt_async<R: CryptoRng>(
rng: &Mutex<R>,
plaintext: Vec<u8>,
key: [u8; AES_KEY_LEN],
) -> Result<AesEncryptedValue, IronOxideErr> {
async { encrypt(rng, plaintext, key).map_err(IronOxideErr::from) }.await
}
pub fn decrypt(
encrypted_doc: &mut AesEncryptedValue,
key: [u8; AES_KEY_LEN],
) -> Result<&mut [u8], Unspecified> {
let mut aes_key = aead::OpeningKey::new(
aead::UnboundKey::new(&aead::AES_256_GCM, &key)?,
SingleUseNonceGenerator::new(encrypted_doc.aes_iv),
);
let plaintext = aes_key.open_in_place(aead::Aad::empty(), &mut encrypted_doc.ciphertext)?;
Ok(plaintext)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::streaming::tests::{generate_test_key, test_rng};
use proptest::prelude::*;
use rand::Rng;
use std::{convert::TryInto, sync::Arc};
#[test]
fn test_encrypt_user_master_key() {
let user_master_key = [0u8; 32];
let password = "MyPassword";
let rng = rand::rng();
let encrypted_master_key =
encrypt_user_master_key(&Mutex::new(rng), password, &user_master_key).unwrap();
assert_eq!(encrypted_master_key.pbkdf2_salt.len(), 32);
assert_eq!(encrypted_master_key.aes_iv.len(), 12);
assert_eq!(encrypted_master_key.encrypted_key.len(), 48);
}
#[test]
fn test_decrypt_user_master_key() {
let user_master_key = [0u8; 32];
let password = "MyPassword";
let rng = rand::rng();
let encrypted_master_key =
encrypt_user_master_key(&Mutex::new(rng), password, &user_master_key).unwrap();
let decrypted_master_key =
decrypt_user_master_key(password, &encrypted_master_key).unwrap();
assert_eq!(decrypted_master_key, user_master_key);
}
#[test]
fn test_encrypt() {
let plaintext = vec![1, 2, 3, 4, 5, 6, 7];
let key = generate_test_key();
let rng = test_rng();
let res = encrypt(&rng, plaintext.clone(), key).unwrap();
assert_eq!(res.aes_iv.len(), 12);
assert_eq!(
res.ciphertext.len(),
plaintext.len() + aead::AES_256_GCM.tag_len()
);
}
#[test]
fn test_decrypt() {
let plaintext = vec![1, 2, 3, 4, 5, 6, 7];
let key = generate_test_key();
let rng = test_rng();
let mut encrypted_result = encrypt(&rng, plaintext.clone(), key).unwrap();
let decrypted_plaintext = decrypt(&mut encrypted_result, key).unwrap();
assert_eq!(*decrypted_plaintext, plaintext[..]);
}
#[test]
fn test_roundtrip_aesencryptedvalue_zero_one_bytes() -> Result<(), IronOxideErr> {
let encrypted_bytes = [0u8; 1 + AES_IV_LEN + AES_GCM_TAG_LEN];
let round_tripped_aes_encrypted_value: AesEncryptedValue =
encrypted_bytes.as_ref().try_into()?;
assert_eq!(round_tripped_aes_encrypted_value.bytes(), encrypted_bytes);
let encrypted_bytes2 = [0u8; AES_IV_LEN + AES_GCM_TAG_LEN];
let round_tripped_aes_encrypted_value: AesEncryptedValue =
encrypted_bytes2.as_ref().try_into()?;
assert_eq!(round_tripped_aes_encrypted_value.bytes(), encrypted_bytes2);
Ok(())
}
#[test]
fn test_parallel_encrypt() {
let plaintext = vec![1, 2, 3, 4, 5, 6, 7];
let mut key = [0u8; 32];
let rng = test_rng();
take_lock(&rng).deref_mut().fill_bytes(&mut key);
let a_rng = Arc::new(rng);
let mut threads = vec![];
for _i in 0..100 {
let rng_ref = a_rng.clone();
let pt = plaintext.clone();
threads.push(std::thread::spawn(move || {
let _res = encrypt(&rng_ref, pt, key).unwrap();
}));
}
let mut joined_count = 0;
for t in threads {
t.join().expect("join failed");
joined_count += 1;
}
assert_eq!(joined_count, 100);
}
fn uniform48<S: Strategy>(
strategy: S,
) -> proptest::array::UniformArrayStrategy<S, [S::Value; 48]> {
proptest::array::UniformArrayStrategy::new(strategy)
}
proptest! {
#[test]
fn prop_encrypt_decrypt_roundtrip(plaintext in prop::collection::vec(any::<u8>(), 0..10000)) {
let key = generate_test_key();
let rng = test_rng();
let mut encrypted = encrypt(&rng, plaintext.clone(), key).unwrap();
let decrypted = decrypt(&mut encrypted, key).unwrap();
prop_assert_eq!(&plaintext[..], decrypted);
}
#[test]
fn prop_aes_encrypted_value_roundtrip_bytes(
iv in prop::array::uniform12(any::<u8>()),
ciphertext in prop::collection::vec(any::<u8>(), AES_GCM_TAG_LEN..1000)
) {
let value = AesEncryptedValue { aes_iv: iv, ciphertext };
let bytes = value.bytes();
let restored: AesEncryptedValue = bytes.as_slice().try_into().unwrap();
prop_assert_eq!(value.aes_iv, restored.aes_iv);
prop_assert_eq!(value.ciphertext, restored.ciphertext);
}
#[test]
fn prop_encrypted_master_key_roundtrip_bytes(
salt in prop::array::uniform32(any::<u8>()),
iv in prop::array::uniform12(any::<u8>()),
encrypted_key in uniform48(any::<u8>())
) {
let key = EncryptedMasterKey::new(salt, iv, encrypted_key);
let bytes = key.bytes();
let restored = EncryptedMasterKey::new_from_slice(&bytes).unwrap();
prop_assert_eq!(key.pbkdf2_salt, restored.pbkdf2_salt);
prop_assert_eq!(key.aes_iv, restored.aes_iv);
prop_assert_eq!(key.encrypted_key, restored.encrypted_key);
}
}
}