use std::{fmt, num::NonZeroU32};
use rand::{self, CryptoRng, RngCore};
use ring::{aead, aead::BoundKey, digest, error::Unspecified, pbkdf2};
use crate::internal::{take_lock, IronOxideErr};
use std::{convert::TryFrom, ops::DerefMut, sync::Mutex};
const PBKDF2_ITERATIONS: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(250_000) };
const PBKDF2_SALT_LEN: usize = 32;
const AES_GCM_TAG_LEN: usize = 16;
const AES_IV_LEN: usize = 12;
const AES_KEY_LEN: usize = 32;
const ENCRYPTED_KEY_AND_GCM_TAG_LEN: usize = AES_KEY_LEN + AES_GCM_TAG_LEN;
pub struct EncryptedMasterKey {
pbkdf2_salt: [u8; PBKDF2_SALT_LEN],
aes_iv: [u8; AES_IV_LEN],
encrypted_key: [u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN],
}
impl fmt::Debug for EncryptedMasterKey {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter
.debug_struct(stringify!(EncryptedMasterKey))
.field("pbkdf2_salt", &&self.pbkdf2_salt)
.field("aes_iv", &&self.aes_iv)
.field("encrypted_key", &&self.encrypted_key[..])
.finish()
}
}
impl EncryptedMasterKey {
pub const SIZE_BYTES: usize = PBKDF2_SALT_LEN + AES_IV_LEN + ENCRYPTED_KEY_AND_GCM_TAG_LEN;
pub fn new(
pbkdf2_salt: [u8; PBKDF2_SALT_LEN],
aes_iv: [u8; AES_IV_LEN],
encrypted_key: [u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN],
) -> EncryptedMasterKey {
EncryptedMasterKey {
pbkdf2_salt,
aes_iv,
encrypted_key,
}
}
pub fn new_from_slice(bytes: &[u8]) -> Result<EncryptedMasterKey, IronOxideErr> {
if bytes.len() == EncryptedMasterKey::SIZE_BYTES {
let mut pbkdf2_salt = [0u8; PBKDF2_SALT_LEN];
let mut aes_iv = [0u8; AES_IV_LEN];
let mut encrypted_key = [0u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN];
pbkdf2_salt.copy_from_slice(&bytes[..PBKDF2_SALT_LEN]);
aes_iv.copy_from_slice(&bytes[PBKDF2_SALT_LEN..(PBKDF2_SALT_LEN + AES_IV_LEN)]);
encrypted_key.copy_from_slice(&bytes[(PBKDF2_SALT_LEN + AES_IV_LEN)..]);
Ok(EncryptedMasterKey::new(pbkdf2_salt, aes_iv, encrypted_key))
} else {
Err(IronOxideErr::WrongSizeError(
Some(bytes.len()),
Some(EncryptedMasterKey::SIZE_BYTES),
))
}
}
pub fn bytes(&self) -> [u8; EncryptedMasterKey::SIZE_BYTES] {
let mut dest = [0u8; EncryptedMasterKey::SIZE_BYTES];
let vec = [
&self.pbkdf2_salt[..],
&self.aes_iv[..],
&self.encrypted_key[..],
]
.concat();
debug_assert!(dest.len() == vec.len());
dest.copy_from_slice(&vec[..]);
dest
}
}
#[derive(Clone, Debug)]
pub struct AesEncryptedValue {
aes_iv: [u8; AES_IV_LEN],
ciphertext: Vec<u8>,
}
impl AesEncryptedValue {
pub fn bytes(&self) -> Vec<u8> {
[&self.aes_iv[..], &self.ciphertext].concat()
}
}
impl TryFrom<&[u8]> for AesEncryptedValue {
type Error = IronOxideErr;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() < (AES_IV_LEN + AES_GCM_TAG_LEN) {
Err(IronOxideErr::AesEncryptedDocSizeError)
} else {
let mut iv: [u8; AES_IV_LEN] = [0u8; AES_IV_LEN];
iv.copy_from_slice(&bytes[..AES_IV_LEN]);
Ok(AesEncryptedValue {
aes_iv: iv,
ciphertext: bytes[AES_IV_LEN..].to_vec(),
})
}
}
}
impl From<ring::error::Unspecified> for IronOxideErr {
fn from(ring_err: Unspecified) -> Self {
IronOxideErr::AesError(ring_err)
}
}
fn derive_key_from_password(password: &str, salt: [u8; PBKDF2_SALT_LEN]) -> [u8; AES_KEY_LEN] {
let mut derived_key = [0u8; digest::SHA256_OUTPUT_LEN];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
PBKDF2_ITERATIONS,
&salt,
password.as_bytes(),
&mut derived_key,
);
derived_key
}
pub fn encrypt_user_master_key<R: CryptoRng + RngCore>(
rng: &Mutex<R>,
password: &str,
user_master_key: &[u8; 32],
) -> Result<EncryptedMasterKey, Unspecified> {
let mut salt = [0u8; PBKDF2_SALT_LEN];
take_lock(rng).deref_mut().fill_bytes(&mut salt);
let derived_key = derive_key_from_password(password, salt);
let encrypted_key = encrypt(rng, &user_master_key.to_vec(), derived_key)?;
let mut master_key_ciphertext = [0u8; ENCRYPTED_KEY_AND_GCM_TAG_LEN];
master_key_ciphertext[..].copy_from_slice(&encrypted_key.ciphertext[..]);
Ok(EncryptedMasterKey {
pbkdf2_salt: salt,
aes_iv: encrypted_key.aes_iv,
encrypted_key: master_key_ciphertext,
})
}
pub fn decrypt_user_master_key(
password: &str,
encrypted_master_key: &EncryptedMasterKey,
) -> Result<[u8; 32], Unspecified> {
let derived_key = derive_key_from_password(password, encrypted_master_key.pbkdf2_salt);
let mut fixed_decrypted_master_key = [0u8; 32];
let mut encrypted_key = AesEncryptedValue {
aes_iv: encrypted_master_key.aes_iv,
ciphertext: encrypted_master_key.encrypted_key.to_vec(),
};
let decrypted_master_key = decrypt(&mut encrypted_key, derived_key)?;
fixed_decrypted_master_key[..].copy_from_slice(decrypted_master_key);
Ok(fixed_decrypted_master_key)
}
struct SingleUseNonceGenerator {
iv: Option<[u8; aead::NONCE_LEN]>,
}
impl SingleUseNonceGenerator {
fn new(iv: [u8; aead::NONCE_LEN]) -> SingleUseNonceGenerator {
SingleUseNonceGenerator { iv: Some(iv) }
}
}
impl aead::NonceSequence for SingleUseNonceGenerator {
fn advance(&mut self) -> Result<aead::Nonce, Unspecified> {
self.iv
.take()
.map_or_else(
|| Err(Unspecified),
|iv| Ok(aead::Nonce::assume_unique_for_key(iv)),
)
}
}
pub fn encrypt<R: CryptoRng + RngCore>(
rng: &Mutex<R>,
plaintext: &[u8],
key: [u8; AES_KEY_LEN],
) -> Result<AesEncryptedValue, Unspecified> {
let algorithm = &aead::AES_256_GCM;
let mut iv = [0u8; aead::NONCE_LEN];
take_lock(rng).deref_mut().fill_bytes(&mut iv);
let mut aes_key = aead::SealingKey::new(
aead::UnboundKey::new(algorithm, &key[..])?,
SingleUseNonceGenerator::new(iv),
);
let mut ciphertext = plaintext.to_owned();
aes_key.seal_in_place_append_tag(aead::Aad::empty(), &mut ciphertext)?;
Ok(AesEncryptedValue {
ciphertext,
aes_iv: iv,
})
}
pub async fn encrypt_async<R: CryptoRng + RngCore>(
rng: &Mutex<R>,
plaintext: &[u8],
key: [u8; AES_KEY_LEN],
) -> Result<AesEncryptedValue, IronOxideErr> {
async { encrypt(rng, plaintext, key).map_err(IronOxideErr::from) }.await
}
pub fn decrypt(
encrypted_doc: &mut AesEncryptedValue,
key: [u8; AES_KEY_LEN],
) -> Result<&mut [u8], Unspecified> {
let mut aes_key = aead::OpeningKey::new(
aead::UnboundKey::new(&aead::AES_256_GCM, &key[..])?,
SingleUseNonceGenerator::new(encrypted_doc.aes_iv),
);
let plaintext = aes_key.open_in_place(aead::Aad::empty(), &mut encrypted_doc.ciphertext[..])?;
Ok(plaintext)
}
#[cfg(test)]
mod tests {
use super::*;
use std::{convert::TryInto, sync::Arc};
#[test]
fn test_encrypt_user_master_key() {
let user_master_key = [0u8; 32];
let password = "MyPassword";
let rng = rand::thread_rng();
let encrypted_master_key =
encrypt_user_master_key(&Mutex::new(rng), password, &user_master_key).unwrap();
assert_eq!(encrypted_master_key.pbkdf2_salt.len(), 32);
assert_eq!(encrypted_master_key.aes_iv.len(), 12);
assert_eq!(encrypted_master_key.encrypted_key.len(), 48);
}
#[test]
fn test_decrypt_user_master_key() {
let user_master_key = [0u8; 32];
let password = "MyPassword";
let rng = rand::thread_rng();
let encrypted_master_key =
encrypt_user_master_key(&Mutex::new(rng), password, &user_master_key).unwrap();
let decrypted_master_key =
decrypt_user_master_key(password, &encrypted_master_key).unwrap();
assert_eq!(decrypted_master_key, user_master_key);
}
#[test]
fn test_encrypt() {
let plaintext = vec![1, 2, 3, 4, 5, 6, 7];
let mut key = [0u8; 32];
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut key);
let res = encrypt(&Mutex::new(rng), &plaintext, key).unwrap();
assert_eq!(res.aes_iv.len(), 12);
assert_eq!(
res.ciphertext.len(),
plaintext.len() + aead::AES_256_GCM.tag_len()
);
}
#[test]
fn test_decrypt() {
let plaintext = vec![1, 2, 3, 4, 5, 6, 7];
let mut key = [0u8; 32];
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut key);
let mut encrypted_result = encrypt(&Mutex::new(rng), &plaintext, key).unwrap();
let decrypted_plaintext = decrypt(&mut encrypted_result, key).unwrap();
assert_eq!(*decrypted_plaintext, plaintext[..]);
}
#[test]
fn test_roundtrip_aesencryptedvalue_zero_one_bytes() -> Result<(), IronOxideErr> {
let encrypted_bytes = [0u8; 1 + AES_IV_LEN + AES_GCM_TAG_LEN];
let round_tripped_aes_encrypted_value: AesEncryptedValue =
encrypted_bytes.as_ref().try_into()?;
assert_eq!(round_tripped_aes_encrypted_value.bytes(), encrypted_bytes);
let encrypted_bytes2 = [0u8; AES_IV_LEN + AES_GCM_TAG_LEN];
let round_tripped_aes_encrypted_value: AesEncryptedValue =
encrypted_bytes2.as_ref().try_into()?;
assert_eq!(round_tripped_aes_encrypted_value.bytes(), encrypted_bytes2);
Ok(())
}
#[test]
fn test_parallel_encrypt() {
use rand::SeedableRng;
let plaintext = vec![1, 2, 3, 4, 5, 6, 7];
let mut key = [0u8; 32];
let rng = Mutex::new(rand_chacha::ChaChaRng::from_entropy());
take_lock(&rng).deref_mut().fill_bytes(&mut key);
let a_rng = Arc::new(rng);
let mut threads = vec![];
for _i in 0..100 {
let rng_ref = a_rng.clone();
let pt = plaintext.clone();
threads.push(std::thread::spawn(move || {
let _res = encrypt(&rng_ref, &pt, key).unwrap();
}));
}
let mut joined_count = 0;
for t in threads {
t.join().expect("join failed");
joined_count += 1;
}
assert_eq!(joined_count, 100);
}
}