use aes::Aes256;
use cbc::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit};
use hmac::{Hmac, Mac};
use rand_core::RngCore;
use sha2::Sha512;
use subtle::ConstantTimeEq;
use crate::error::CryptoError;
type Aes256CbcEnc = cbc::Encryptor<Aes256>;
type Aes256CbcDec = cbc::Decryptor<Aes256>;
pub const CEK_SIZE: usize = 64;
pub const IV_SIZE: usize = 16;
pub const TAG_SIZE: usize = 32;
pub fn generate_cek() -> [u8; CEK_SIZE] {
let mut cek = [0u8; CEK_SIZE];
rand_core::OsRng.fill_bytes(&mut cek);
cek
}
pub fn generate_iv() -> [u8; IV_SIZE] {
let mut iv = [0u8; IV_SIZE];
rand_core::OsRng.fill_bytes(&mut iv);
iv
}
pub fn encrypt(
plaintext: &[u8],
cek: &[u8; CEK_SIZE],
iv: &[u8; IV_SIZE],
aad: &[u8],
) -> Result<(Vec<u8>, [u8; TAG_SIZE]), CryptoError> {
let mac_key = &cek[..32];
let enc_key = &cek[32..];
let pad_len = 16 - (plaintext.len() % 16);
let mut padded = Vec::with_capacity(plaintext.len() + pad_len);
padded.extend_from_slice(plaintext);
padded.resize(plaintext.len() + pad_len, pad_len as u8);
let enc_key_arr: [u8; 32] = enc_key.try_into().unwrap();
let encryptor = Aes256CbcEnc::new(&enc_key_arr.into(), iv.into());
let ciphertext =
encryptor.encrypt_padded_vec_mut::<cbc::cipher::block_padding::NoPadding>(&padded);
let aad_len_bits = (aad.len() as u64) * 8;
let mut hmac = <Hmac<Sha512>>::new_from_slice(mac_key)
.map_err(|e| CryptoError::ContentEncryption(format!("HMAC init failed: {e}")))?;
hmac.update(aad);
hmac.update(iv);
hmac.update(&ciphertext);
hmac.update(&aad_len_bits.to_be_bytes());
let full_tag = hmac.finalize().into_bytes();
let mut tag = [0u8; TAG_SIZE];
tag.copy_from_slice(&full_tag[..TAG_SIZE]);
Ok((ciphertext, tag))
}
pub fn decrypt(
ciphertext: &[u8],
cek: &[u8; CEK_SIZE],
iv: &[u8; IV_SIZE],
aad: &[u8],
tag: &[u8; TAG_SIZE],
) -> Result<Vec<u8>, CryptoError> {
let mac_key = &cek[..32];
let enc_key = &cek[32..];
let aad_len_bits = (aad.len() as u64) * 8;
let mut hmac = <Hmac<Sha512>>::new_from_slice(mac_key)
.map_err(|e| CryptoError::ContentEncryption(format!("HMAC init failed: {e}")))?;
hmac.update(aad);
hmac.update(iv);
hmac.update(ciphertext);
hmac.update(&aad_len_bits.to_be_bytes());
let full_tag = hmac.finalize().into_bytes();
if full_tag[..TAG_SIZE].ct_eq(tag).unwrap_u8() != 1 {
return Err(CryptoError::ContentEncryption(
"authentication tag mismatch".into(),
));
}
let enc_key_arr: [u8; 32] = enc_key.try_into().unwrap();
let decryptor = Aes256CbcDec::new(&enc_key_arr.into(), iv.into());
let buf = ciphertext.to_vec();
let decrypted = decryptor
.decrypt_padded_vec_mut::<cbc::cipher::block_padding::NoPadding>(&buf)
.map_err(|e| CryptoError::ContentEncryption(format!("AES-CBC decrypt failed: {e}")))?;
if decrypted.is_empty() {
return Err(CryptoError::ContentEncryption(
"decrypted data is empty".into(),
));
}
let pad_len = *decrypted.last().unwrap() as usize;
if pad_len == 0 || pad_len > 16 || pad_len > decrypted.len() {
return Err(CryptoError::ContentEncryption(
"invalid PKCS7 padding".into(),
));
}
for &b in &decrypted[decrypted.len() - pad_len..] {
if b as usize != pad_len {
return Err(CryptoError::ContentEncryption(
"invalid PKCS7 padding".into(),
));
}
}
let plaintext = &decrypted[..decrypted.len() - pad_len];
Ok(plaintext.to_vec())
}