use crate::common::{SerdeDeserialize, SerdeSerialize};
use aes::{
cipher::{block_padding::Pkcs7, BlockDecryptMut, BlockEncryptMut, KeyIvInit},
Aes256,
};
use base64::Engine;
use hmac::Hmac;
use rand::Rng;
use serde::{Deserializer, Serializer};
use std::str::FromStr;
use thiserror::Error;
type CipherC = cbc::Encryptor<Aes256>;
type CipherD = cbc::Decryptor<Aes256>;
pub const AES_BLOCK_SIZE: usize = 16;
pub struct Password {
password: String,
}
impl From<String> for Password {
fn from(password: String) -> Self {
Password { password }
}
}
impl FromStr for Password {
type Err = <String as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Password { password: s.into() })
}
}
fn as_base64<A: AsRef<[u8]>, S>(key: &A, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&base64::engine::general_purpose::STANDARD.encode(key.as_ref()))
}
fn from_base64<'de, D: Deserializer<'de>, X: TryFrom<Vec<u8>>>(des: D) -> Result<X, D::Error> {
use serde::de::Error;
let data = String::deserialize(des)?;
let decoded = base64::engine::general_purpose::STANDARD
.decode(data)
.map_err(|err| Error::custom(err.to_string()))?;
X::try_from(decoded).map_err(|_| Error::custom("Data of incorrect length."))
}
#[derive(SerdeSerialize, SerdeDeserialize)]
pub enum EncryptionMethod {
#[serde(rename = "AES-256")]
Aes256,
}
#[derive(SerdeSerialize, SerdeDeserialize)]
pub enum KeyDerivationMethod {
#[serde(rename = "PBKDF2WithHmacSHA256")]
Pbkdf2Sha256,
}
#[derive(SerdeSerialize, SerdeDeserialize)]
pub struct EncryptionMetadata {
#[serde(rename = "encryptionMethod")]
encryption_method: EncryptionMethod,
#[serde(rename = "keyDerivationMethod")]
key_derivation_method: KeyDerivationMethod,
#[serde(rename = "iterations")]
iterations: u32,
#[serde(
rename = "salt",
serialize_with = "as_base64",
deserialize_with = "from_base64"
)]
salt: Vec<u8>,
#[serde(
rename = "initializationVector",
serialize_with = "as_base64",
deserialize_with = "from_base64"
)]
initialization_vector: [u8; AES_BLOCK_SIZE],
}
#[derive(SerdeSerialize, SerdeDeserialize)]
#[serde(transparent)]
pub struct CipherText {
#[serde(serialize_with = "as_base64", deserialize_with = "from_base64")]
ct: Vec<u8>,
}
#[derive(SerdeSerialize, SerdeDeserialize)]
pub struct EncryptedData {
#[serde(rename = "metadata")]
metadata: EncryptionMetadata,
#[serde(rename = "cipherText")]
cipher_text: CipherText,
}
pub const NUM_ROUNDS: u32 = 100000;
pub fn encrypt<A: AsRef<[u8]>, R: Rng>(
pass: &Password,
plaintext: &A,
csprng: &mut R,
) -> EncryptedData {
let mut key = [0u8; 32];
let salt: [u8; 16] = csprng.gen();
pbkdf2::pbkdf2::<Hmac<sha2::Sha256>>(pass.password.as_bytes(), &salt, NUM_ROUNDS, &mut key);
let initialization_vector: [u8; AES_BLOCK_SIZE] = csprng.gen();
let cipher = CipherC::new((&key).into(), (&initialization_vector).into());
let cipher_text = CipherText {
ct: cipher.encrypt_padded_vec_mut::<Pkcs7>(plaintext.as_ref()),
};
let metadata = EncryptionMetadata {
encryption_method: EncryptionMethod::Aes256,
key_derivation_method: KeyDerivationMethod::Pbkdf2Sha256,
iterations: NUM_ROUNDS,
salt: salt.into(),
initialization_vector,
};
EncryptedData {
metadata,
cipher_text,
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Error)]
pub enum DecryptionError {
#[error("Decryption error.")]
BlockMode,
}
pub fn decrypt(pass: &Password, et: &EncryptedData) -> Result<Vec<u8>, DecryptionError> {
let mut key = [0u8; 32];
pbkdf2::pbkdf2::<Hmac<sha2::Sha256>>(
pass.password.as_bytes(),
&et.metadata.salt,
et.metadata.iterations,
&mut key,
);
let cipher = CipherD::new((&key).into(), (&et.metadata.initialization_vector).into());
cipher
.decrypt_padded_vec_mut::<Pkcs7>(&et.cipher_text.ct)
.map_err(|_| DecryptionError::BlockMode)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_decrypt_success() {
let pass = Password {
password: "hello".into(),
};
let mut rng = rand::thread_rng();
let plaintext = (&mut rng)
.sample_iter(rand::distributions::Uniform::new_inclusive(
u8::MIN,
u8::MAX,
))
.take(1000)
.collect::<Vec<u8>>();
let et = encrypt(&pass, &plaintext, &mut rng);
let decrypted = decrypt(&pass, &et);
assert_eq!(Ok(plaintext), decrypted, "Decryption failed.");
}
}