pub(crate) mod key;
use aes::{
Aes256,
cipher::{
BlockDecryptMut, BlockEncryptMut, KeyIvInit,
block_padding::{Pkcs7, UnpadError},
},
};
use hmac::{Hmac, Mac as MacT, digest::MacError};
use key::CipherKeys;
use sha2::Sha256;
use thiserror::Error;
pub(crate) type Aes256CbcEnc = cbc::Encryptor<Aes256>;
pub(crate) type Aes256CbcDec = cbc::Decryptor<Aes256>;
pub(crate) type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Mac(pub(crate) [u8; Self::LENGTH]);
impl Mac {
pub const LENGTH: usize = 32;
pub const TRUNCATED_LEN: usize = 8;
pub fn truncate(&self) -> [u8; Self::TRUNCATED_LEN] {
let mut truncated = [0u8; Self::TRUNCATED_LEN];
truncated.copy_from_slice(&self.0[0..Self::TRUNCATED_LEN]);
truncated
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_ref()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum MessageMac {
Truncated([u8; Mac::TRUNCATED_LEN]),
Full(Mac),
}
impl MessageMac {
pub fn as_bytes(&self) -> &[u8] {
match self {
MessageMac::Truncated(m) => m.as_ref(),
MessageMac::Full(m) => m.as_bytes(),
}
}
}
impl From<Mac> for MessageMac {
fn from(m: Mac) -> Self {
Self::Full(m)
}
}
impl From<[u8; Mac::TRUNCATED_LEN]> for MessageMac {
fn from(m: [u8; Mac::TRUNCATED_LEN]) -> Self {
Self::Truncated(m)
}
}
#[derive(Debug, Error)]
pub enum DecryptionError {
#[error("Failed decrypting, invalid padding")]
InvalidPadding(#[from] UnpadError),
#[error("The MAC of the ciphertext didn't pass validation {0}")]
Mac(#[from] MacError),
#[error("The ciphertext didn't contain a valid MAC")]
MacMissing,
}
pub struct Cipher {
keys: CipherKeys,
}
impl Cipher {
pub fn new(key: &[u8; 32]) -> Self {
let keys = CipherKeys::new(key);
Self { keys }
}
pub fn new_megolm(&key: &[u8; 128]) -> Self {
let keys = CipherKeys::new_megolm(&key);
Self { keys }
}
pub fn new_pickle(key: &[u8]) -> Self {
let keys = CipherKeys::new_pickle(key);
Self { keys }
}
fn get_hmac(&self) -> HmacSha256 {
#[allow(clippy::expect_used)]
HmacSha256::new_from_slice(self.keys.mac_key())
.expect("We should be able to create a HmacSha256 from a 32 byte key")
}
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let cipher = Aes256CbcEnc::new(self.keys.aes_key(), self.keys.iv());
cipher.encrypt_padded_vec_mut::<Pkcs7>(plaintext)
}
pub fn mac(&self, message: &[u8]) -> Mac {
let mut hmac = self.get_hmac();
hmac.update(message);
let mac_bytes = hmac.finalize().into_bytes();
let mut mac = [0u8; 32];
mac.copy_from_slice(&mac_bytes);
Mac(mac)
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, UnpadError> {
let cipher = Aes256CbcDec::new(self.keys.aes_key(), self.keys.iv());
cipher.decrypt_padded_vec_mut::<Pkcs7>(ciphertext)
}
#[cfg(all(not(fuzzing), feature = "experimental-session-config"))]
pub fn verify_mac(&self, message: &[u8], tag: &Mac) -> Result<(), MacError> {
let mut hmac = self.get_hmac();
hmac.update(message);
hmac.verify_slice(tag.as_bytes())
}
#[cfg(not(fuzzing))]
pub fn verify_truncated_mac(&self, message: &[u8], tag: &[u8]) -> Result<(), MacError> {
let mut hmac = self.get_hmac();
hmac.update(message);
hmac.verify_truncated_left(tag)
}
#[cfg(fuzzing)]
pub fn verify_mac(&self, _: &[u8], _: &Mac) -> Result<(), MacError> {
Ok(())
}
#[cfg(fuzzing)]
pub fn verify_truncated_mac(&self, _: &[u8], _: &[u8]) -> Result<(), MacError> {
Ok(())
}
pub fn encrypt_pickle(&self, plaintext: &[u8]) -> Vec<u8> {
let mut ciphertext = self.encrypt(plaintext);
let mac = self.mac(&ciphertext);
ciphertext.extend(mac.truncate());
ciphertext
}
pub fn decrypt_pickle(&self, ciphertext: &[u8]) -> Result<Vec<u8>, DecryptionError> {
if ciphertext.len() < Mac::TRUNCATED_LEN + 1 {
Err(DecryptionError::MacMissing)
} else {
let (ciphertext, mac) = ciphertext.split_at(ciphertext.len() - Mac::TRUNCATED_LEN);
self.verify_truncated_mac(ciphertext, mac)?;
Ok(self.decrypt(ciphertext)?)
}
}
}
#[cfg(test)]
mod test {
use assert_matches2::assert_matches;
use super::{Cipher, Mac};
use crate::cipher::DecryptionError;
#[test]
fn decrypt_pickle_mac_missing() {
let cipher = Cipher::new(&[1u8; 32]);
assert_matches!(
cipher.decrypt_pickle(&[2u8; Mac::TRUNCATED_LEN]),
Err(DecryptionError::MacMissing)
);
assert_matches!(
cipher.decrypt_pickle(&[0u8; Mac::TRUNCATED_LEN + 1]),
Err(DecryptionError::Mac(_))
);
}
}