use std::{fmt, num::NonZeroU32};
use ring::pbkdf2;
use zeroize::Zeroize;
use crate::{
aes::{self, AesMasterKey},
rng::Crng,
};
static PBKDF2_ALGORITHM: pbkdf2::Algorithm = pbkdf2::PBKDF2_HMAC_SHA256;
const PBKDF2_ITERATIONS: NonZeroU32 = NonZeroU32::new(600_000).unwrap();
const AES_KEY_LEN: usize = ring::digest::SHA256_OUTPUT_LEN;
pub const MIN_PASSWORD_LENGTH: usize = 12;
pub const MAX_PASSWORD_LENGTH: usize = 512;
lexe_std::const_assert!(MIN_PASSWORD_LENGTH < MAX_PASSWORD_LENGTH);
#[derive(Clone, Debug)]
pub enum Error {
PasswordTooShort,
PasswordTooLong,
AesDecrypt(aes::DecryptError),
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::PasswordTooShort => write!(
f,
"Password must have at least {MIN_PASSWORD_LENGTH} characters"
),
Self::PasswordTooLong => write!(
f,
"Password cannot have more than {MAX_PASSWORD_LENGTH} characters"
),
Self::AesDecrypt(err) => err.fmt(f),
}
}
}
impl From<aes::DecryptError> for Error {
fn from(err: aes::DecryptError) -> Self {
Self::AesDecrypt(err)
}
}
pub fn encrypt(
rng: &mut impl Crng,
password: &str,
salt: &[u8; 32],
data: &[u8],
) -> Result<Vec<u8>, Error> {
validate_password_len(password)?;
let aes_key = derive_aes_key(password, salt);
let aad = &[salt.as_slice()];
let data_size_hint = Some(data.len());
let write_data_cb = |buf: &mut Vec<u8>| buf.extend_from_slice(data);
let ciphertext = aes_key.encrypt(rng, aad, data_size_hint, &write_data_cb);
Ok(ciphertext)
}
pub fn decrypt(
password: &str,
salt: &[u8; 32],
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, Error> {
validate_password_len(password)?;
let aes_key = derive_aes_key(password, salt);
let aad = &[salt.as_slice()];
let data = aes_key.decrypt(aad, ciphertext)?;
Ok(data)
}
pub fn validate_password_len(password: &str) -> Result<(), Error> {
let password_length = password.chars().count();
if password_length < MIN_PASSWORD_LENGTH {
return Err(Error::PasswordTooShort);
}
if password_length > MAX_PASSWORD_LENGTH {
return Err(Error::PasswordTooLong);
}
Ok(())
}
fn derive_aes_key(password: &str, salt: &[u8; 32]) -> AesMasterKey {
let mut aes_key_buf = [0u8; AES_KEY_LEN];
pbkdf2::derive(
PBKDF2_ALGORITHM,
PBKDF2_ITERATIONS,
salt,
password.as_bytes(),
&mut aes_key_buf,
);
let aes_key = AesMasterKey::new(&aes_key_buf);
aes_key_buf.zeroize();
aes_key
}
#[cfg(test)]
mod test {
use lexe_hex::hex;
use proptest::{
arbitrary::any, proptest, strategy::Strategy, test_runner::Config,
};
use super::*;
use crate::rng::FastRng;
#[test]
fn encryption_roundtrip() {
let config = Config::with_cases(4);
let password_length_range = MIN_PASSWORD_LENGTH..MAX_PASSWORD_LENGTH;
let any_valid_password =
proptest::collection::vec(any::<char>(), password_length_range)
.prop_map(String::from_iter);
proptest!(config, |(
mut rng in any::<FastRng>(),
password in any_valid_password,
salt in any::<[u8; 32]>(),
data1 in any::<Vec<u8>>(),
)| {
let ciphertext =
encrypt(&mut rng, &password, &salt, &data1).unwrap();
let data2 = decrypt(&password, &salt, ciphertext).unwrap();
assert_eq!(data1, data2);
})
}
#[test]
fn decryption_compatibility() {
struct TestCase {
password: String,
salt: [u8; 32],
data1: &'static [u8],
maybe_ciphertext: Option<&'static str>,
}
let case0 = TestCase {
password: "medium-length!123123".to_owned(),
salt: [0u8; 32],
data1: b"",
maybe_ciphertext: Some(
"00a9ebf955ed070fe7acefe66e5a007b2c4165d3c2c23efc6a91d60a37e3a7b6181e4156d15d513cb9cee00739a226466e",
),
};
let case1 = TestCase {
password: "passwordword".to_owned(),
salt: [69; 32],
data1: b"*jaw drops* awooga! hummina hummina bazooing!",
maybe_ciphertext: Some(
"00a9ebf955ed070fe7acefe66e5a007b2c4165d3c2c23efc6a91d60a37e3a7b6180c0d3cd90616335f13f5de7c9df0a1d89a7aec282b8083089c2360962e22db1a57685e82aea236c053b88495021767e0c17e05b3f72a86cfbbffc3724a",
),
};
let password = (0u32..512)
.map(|i| char::from_u32(i).unwrap())
.collect::<String>();
let case2 = TestCase {
password,
salt: [69; 32],
data1: b"*jaw drops* awooga! hummina hummina bazooing!",
maybe_ciphertext: Some(
"00a9ebf955ed070fe7acefe66e5a007b2c4165d3c2c23efc6a91d60a37e3a7b618cf7a8ff3ea628ed33fb32428930340557454454258dedc67c9a3a5e350c2408ad82e6a8ac02779fd9df3f513364b6351301271cfd2c515fdca0cd15de0",
),
};
for (i, case) in [case0, case1, case2].into_iter().enumerate() {
let TestCase {
password,
salt,
data1,
maybe_ciphertext,
} = case;
match maybe_ciphertext {
Some(cipherhext) => {
println!("Testing case {i}");
let ciphertext = hex::decode(cipherhext).unwrap();
let data2 = decrypt(&password, &salt, ciphertext).unwrap();
assert_eq!(data1, data2.as_slice());
}
None => {
let mut rng = FastRng::from_u64(20231016);
let ciphertext =
encrypt(&mut rng, &password, &salt, data1).unwrap();
let cipherhext = hex::display(&ciphertext);
println!("Case {i} ciphertext: {cipherhext}");
}
}
}
}
}