use alloc::boxed::Box;
use alloc::vec::Vec;
use core::fmt;
use core::fmt::{Debug, Formatter};
use core::sync::atomic::{AtomicUsize, Ordering};
use core::time::Duration;
use aws_lc_rs::cipher::{
AES_256, AES_256_KEY_LEN, AES_CBC_IV_LEN, DecryptionContext, PaddedBlockDecryptingKey,
PaddedBlockEncryptingKey, UnboundCipherKey,
};
use aws_lc_rs::rand::{SecureRandom, SystemRandom};
use aws_lc_rs::{hmac, iv};
use rustls::crypto::{GetRandomFailed, TicketProducer};
use rustls::error::Error;
use super::unspecified_err;
pub(super) struct Rfc5077Ticketer {
aes_encrypt_key: PaddedBlockEncryptingKey,
aes_decrypt_key: PaddedBlockDecryptingKey,
hmac_key: hmac::Key,
key_name: [u8; 16],
maximum_ciphertext_len: AtomicUsize,
}
impl Rfc5077Ticketer {
#[expect(clippy::new_ret_no_self)]
pub(super) fn new() -> Result<Box<dyn TicketProducer>, Error> {
let rand = SystemRandom::new();
let mut aes_key = [0u8; AES_256_KEY_LEN];
rand.fill(&mut aes_key)
.map_err(|_| GetRandomFailed)?;
let aes_encrypt_key =
UnboundCipherKey::new(&AES_256, &aes_key[..]).map_err(unspecified_err)?;
let aes_encrypt_key =
PaddedBlockEncryptingKey::cbc_pkcs7(aes_encrypt_key).map_err(unspecified_err)?;
let aes_decrypt_key =
UnboundCipherKey::new(&AES_256, &aes_key[..]).map_err(unspecified_err)?;
let aes_decrypt_key =
PaddedBlockDecryptingKey::cbc_pkcs7(aes_decrypt_key).map_err(unspecified_err)?;
let hmac_key = hmac::Key::generate(hmac::HMAC_SHA256, &rand).map_err(unspecified_err)?;
let mut key_name = [0u8; 16];
rand.fill(&mut key_name)
.map_err(|_| GetRandomFailed)?;
Ok(Box::new(Self {
aes_encrypt_key,
aes_decrypt_key,
hmac_key,
key_name,
maximum_ciphertext_len: AtomicUsize::new(0),
}))
}
}
impl TicketProducer for Rfc5077Ticketer {
fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
let mut encrypted_state = Vec::from(message);
let dec_ctx = self
.aes_encrypt_key
.encrypt(&mut encrypted_state)
.ok()?;
let iv: &[u8] = (&dec_ctx).try_into().ok()?;
let mut hmac_data =
Vec::with_capacity(self.key_name.len() + iv.len() + 2 + encrypted_state.len());
hmac_data.extend(&self.key_name);
hmac_data.extend(iv);
hmac_data.extend(
u16::try_from(encrypted_state.len())
.ok()?
.to_be_bytes(),
);
hmac_data.extend(&encrypted_state);
let tag = hmac::sign(&self.hmac_key, &hmac_data);
let tag = tag.as_ref();
let mut ciphertext =
Vec::with_capacity(self.key_name.len() + iv.len() + encrypted_state.len() + tag.len());
ciphertext.extend(self.key_name);
ciphertext.extend(iv);
ciphertext.extend(encrypted_state);
ciphertext.extend(tag);
self.maximum_ciphertext_len
.fetch_max(ciphertext.len(), Ordering::SeqCst);
Some(ciphertext)
}
fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
if ciphertext.len()
> self
.maximum_ciphertext_len
.load(Ordering::SeqCst)
{
return None;
}
let (alleged_key_name, ciphertext) = ciphertext.split_at_checked(self.key_name.len())?;
let (iv, ciphertext) = ciphertext.split_at_checked(AES_CBC_IV_LEN)?;
let tag_len = self
.hmac_key
.algorithm()
.digest_algorithm()
.output_len();
let (enc_state, mac) = ciphertext.split_at_checked(ciphertext.len() - tag_len)?;
let mut hmac_data =
Vec::with_capacity(alleged_key_name.len() + iv.len() + 2 + enc_state.len());
hmac_data.extend(alleged_key_name);
hmac_data.extend(iv);
hmac_data.extend(
u16::try_from(enc_state.len())
.ok()?
.to_be_bytes(),
);
hmac_data.extend(enc_state);
hmac::verify(&self.hmac_key, &hmac_data, mac).ok()?;
let iv = iv::FixedLength::try_from(iv).ok()?;
let dec_context = DecryptionContext::Iv128(iv);
let mut out = Vec::from(enc_state);
let plaintext = self
.aes_decrypt_key
.decrypt(&mut out, dec_context)
.ok()?;
Some(plaintext.into())
}
fn lifetime(&self) -> Duration {
Duration::ZERO
}
}
impl Debug for Rfc5077Ticketer {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Rfc5077Ticketer")
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use rustls::crypto::TicketerFactory;
use crate::AwsLcRs;
#[test]
fn basic_pairwise_test() {
let t = AwsLcRs.ticketer().unwrap();
let cipher = t.encrypt(b"hello world").unwrap();
let plain = t.decrypt(&cipher).unwrap();
assert_eq!(plain, b"hello world");
}
#[test]
fn refuses_decrypt_before_encrypt() {
let t = AwsLcRs.ticketer().unwrap();
assert_eq!(t.decrypt(b"hello"), None);
}
#[test]
fn refuses_decrypt_larger_than_largest_encryption() {
let t = AwsLcRs.ticketer().unwrap();
let mut cipher = t.encrypt(b"hello world").unwrap();
assert_eq!(t.decrypt(&cipher), Some(b"hello world".to_vec()));
cipher.push(0);
assert_eq!(t.decrypt(&cipher), None);
}
#[test]
fn rfc5077ticketer_is_debug_and_producestickets() {
use alloc::format;
use super::*;
let t = Rfc5077Ticketer::new().unwrap();
assert_eq!(format!("{t:?}"), "Rfc5077Ticketer { .. }");
assert_eq!(t.lifetime(), Duration::ZERO);
}
}