use super::schedule::{HashAlg, Secret, traffic_key_iv};
use super::suite::AeadAlg;
use crate::cipher::{Aes128, Aes256, ChaCha20Poly1305, Gcm};
use crate::ct::{Choice, ConditionallySelectable, ConstantTimeEq};
use crate::tls::{ContentType, Error};
use alloc::vec::Vec;
pub(crate) enum Aead {
Aes128(Gcm<Aes128>),
Aes256(Gcm<Aes256>),
ChaCha20Poly1305(ChaCha20Poly1305),
}
impl Aead {
pub(crate) fn encrypt(&self, nonce: &[u8; 12], aad: &[u8], buf: &mut [u8]) -> [u8; 16] {
match self {
Aead::Aes128(g) => g.encrypt(nonce, aad, buf),
Aead::Aes256(g) => g.encrypt(nonce, aad, buf),
Aead::ChaCha20Poly1305(c) => c.encrypt(nonce, aad, buf),
}
}
pub(crate) fn decrypt(
&self,
nonce: &[u8; 12],
aad: &[u8],
buf: &mut [u8],
tag: &[u8; 16],
) -> bool {
let r = match self {
Aead::Aes128(g) => g.decrypt(nonce, aad, buf, tag),
Aead::Aes256(g) => g.decrypt(nonce, aad, buf, tag),
Aead::ChaCha20Poly1305(c) => c.decrypt(nonce, aad, buf, tag),
};
r.is_ok()
}
#[allow(dead_code)]
pub(crate) fn from_key(alg: AeadAlg, key: &[u8]) -> Self {
match alg {
AeadAlg::Aes128Gcm => {
let mut k = [0u8; 16];
k.copy_from_slice(&key[..16]);
Aead::Aes128(Gcm::new(Aes128::new(&k)))
}
AeadAlg::Aes256Gcm => {
let mut k = [0u8; 32];
k.copy_from_slice(&key[..32]);
Aead::Aes256(Gcm::new(Aes256::new(&k)))
}
AeadAlg::ChaCha20Poly1305 => {
let mut k = [0u8; 32];
k.copy_from_slice(&key[..32]);
Aead::ChaCha20Poly1305(ChaCha20Poly1305::new(&k))
}
}
}
}
const MAX_RECORDS_PER_KEY: u64 = 1 << 23;
pub(crate) struct RecordCrypter {
aead: Aead,
iv: [u8; 12],
seq: u64,
}
impl RecordCrypter {
pub(crate) fn new(hash: HashAlg, alg: AeadAlg, key_len: usize, secret: &Secret) -> Self {
let (key, iv) = traffic_key_iv(hash, secret, key_len);
let aead = match alg {
AeadAlg::Aes128Gcm => {
let mut k = [0u8; 16];
k.copy_from_slice(&key[..16]);
Aead::Aes128(Gcm::new(Aes128::new(&k)))
}
AeadAlg::Aes256Gcm => {
let mut k = [0u8; 32];
k.copy_from_slice(&key[..32]);
Aead::Aes256(Gcm::new(Aes256::new(&k)))
}
AeadAlg::ChaCha20Poly1305 => {
let mut k = [0u8; 32];
k.copy_from_slice(&key[..32]);
Aead::ChaCha20Poly1305(ChaCha20Poly1305::new(&k))
}
};
RecordCrypter { aead, iv, seq: 0 }
}
fn next_nonce(&mut self) -> Result<[u8; 12], Error> {
if self.seq >= MAX_RECORDS_PER_KEY {
return Err(Error::TooManyRecords);
}
let mut nonce = self.iv;
let seq = self.seq.to_be_bytes();
for i in 0..8 {
nonce[4 + i] ^= seq[i];
}
self.seq += 1;
Ok(nonce)
}
pub(crate) fn encrypt(
&mut self,
content_type: ContentType,
content: &[u8],
) -> Result<Vec<u8>, Error> {
if content.len() > (1usize << 14) {
return Err(Error::RecordOverflow);
}
let fragment_len = content.len() + 1 + 16; let mut header = [0u8; 5];
header[0] = ContentType::ApplicationData.as_u8();
header[1] = 0x03;
header[2] = 0x03;
header[3..5].copy_from_slice(&(fragment_len as u16).to_be_bytes());
let mut inner = Vec::with_capacity(content.len() + 1);
inner.extend_from_slice(content);
inner.push(content_type.as_u8());
let nonce = self.next_nonce()?;
let tag = self.aead.encrypt(&nonce, &header, &mut inner);
let mut out = Vec::with_capacity(5 + fragment_len);
out.extend_from_slice(&header);
out.extend_from_slice(&inner);
out.extend_from_slice(&tag);
Ok(out)
}
fn nonce_for(&self, seq: u64) -> [u8; 12] {
let mut nonce = self.iv;
let s = seq.to_be_bytes();
for i in 0..8 {
nonce[4 + i] ^= s[i];
}
nonce
}
pub(crate) fn encrypt_raw(
&mut self,
seq: u64,
aad: &[u8],
buf: &mut [u8],
) -> Result<[u8; 16], Error> {
let nonce = self.nonce_for(seq);
Ok(self.aead.encrypt(&nonce, aad, buf))
}
pub(crate) fn decrypt_raw(
&mut self,
seq: u64,
aad: &[u8],
buf: &mut [u8],
tag: &[u8; 16],
) -> Result<(), Error> {
let nonce = self.nonce_for(seq);
if !self.aead.decrypt(&nonce, aad, buf, tag) {
return Err(Error::BadRecordMac);
}
Ok(())
}
pub(crate) fn decrypt(
&mut self,
header: &[u8; 5],
fragment: &[u8],
) -> Result<(ContentType, Vec<u8>), Error> {
if fragment.len() < 16 {
return Err(Error::Decode);
}
let (ct, tag_bytes) = fragment.split_at(fragment.len() - 16);
let mut tag = [0u8; 16];
tag.copy_from_slice(tag_bytes);
let mut buf = ct.to_vec();
let nonce = self.next_nonce()?;
if !self.aead.decrypt(&nonce, header, &mut buf, &tag) {
return Err(Error::BadRecordMac);
}
let (content_type_byte, end) = ct_find_last_nonzero(&buf)?;
let content_type = ContentType::from_u8(content_type_byte);
buf.truncate(end);
if buf.len() > (1usize << 14) {
return Err(Error::RecordOverflow);
}
Ok((content_type, buf))
}
}
fn ct_find_last_nonzero(buf: &[u8]) -> Result<(u8, usize), Error> {
if buf.is_empty() {
return Err(Error::PeerMisbehaved);
}
let mut found_any = Choice::from(0);
let mut cur_byte: u8 = 0;
let mut cur_end: usize = 0;
for (i, &b) in buf.iter().enumerate() {
let nonzero = !b.ct_eq(&0u8);
cur_byte = u8::conditional_select(&b, &cur_byte, nonzero);
cur_end = usize::conditional_select(&(i + 1), &cur_end, nonzero);
found_any |= nonzero;
}
if !bool::from(found_any) {
return Err(Error::PeerMisbehaved);
}
Ok((cur_byte, cur_end - 1))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::from_hex_vec;
fn server_hs_secret() -> Secret {
Secret::new(&from_hex_vec(
"b67b7d690cc16c4e75e54213cb2d37b4e9c912bcded9105d42befd59d391ad38",
))
}
#[test]
fn rfc8448_server_flight_encrypt() {
let payload = from_hex_vec(include_str!(
"../../../testdata/rfc8448_server_flight_payload.hex"
));
let record = from_hex_vec(include_str!(
"../../../testdata/rfc8448_server_flight_record.hex"
));
let mut c =
RecordCrypter::new(HashAlg::Sha256, AeadAlg::Aes128Gcm, 16, &server_hs_secret());
let out = c.encrypt(ContentType::Handshake, &payload).unwrap();
assert_eq!(out, record);
}
#[test]
fn rfc8448_server_flight_decrypt() {
let payload = from_hex_vec(include_str!(
"../../../testdata/rfc8448_server_flight_payload.hex"
));
let record = from_hex_vec(include_str!(
"../../../testdata/rfc8448_server_flight_record.hex"
));
let mut c =
RecordCrypter::new(HashAlg::Sha256, AeadAlg::Aes128Gcm, 16, &server_hs_secret());
let mut header = [0u8; 5];
header.copy_from_slice(&record[..5]);
let (ct, content) = c.decrypt(&header, &record[5..]).unwrap();
assert_eq!(ct, ContentType::Handshake);
assert_eq!(content, payload);
}
#[test]
fn tampered_tag_is_rejected() {
let record = from_hex_vec(include_str!(
"../../../testdata/rfc8448_server_flight_record.hex"
));
let mut bad = record.clone();
*bad.last_mut().unwrap() ^= 0x01;
let mut c =
RecordCrypter::new(HashAlg::Sha256, AeadAlg::Aes128Gcm, 16, &server_hs_secret());
let mut header = [0u8; 5];
header.copy_from_slice(&bad[..5]);
assert!(matches!(
c.decrypt(&header, &bad[5..]),
Err(Error::BadRecordMac)
));
}
#[test]
fn ct_padding_strip_no_padding() {
let buf = alloc::vec![0xAA, 0xBB, 0xCC, 22u8];
let (ty, end) = super::ct_find_last_nonzero(&buf).expect("nonzero present");
assert_eq!(ty, 22);
assert_eq!(end, 3);
}
#[test]
fn ct_padding_strip_with_padding() {
let mut buf = alloc::vec![0x11, 0x22, 23u8];
buf.extend(core::iter::repeat_n(0u8, 10));
let (ty, end) = super::ct_find_last_nonzero(&buf).expect("nonzero present");
assert_eq!(ty, 23);
assert_eq!(end, 2);
}
#[test]
fn ct_padding_strip_all_zero_signals_error() {
let buf = alloc::vec![0u8; 32];
assert!(matches!(
super::ct_find_last_nonzero(&buf),
Err(Error::PeerMisbehaved)
));
}
#[test]
fn ct_padding_strip_empty_signals_error() {
let buf: alloc::vec::Vec<u8> = alloc::vec::Vec::new();
assert!(matches!(
super::ct_find_last_nonzero(&buf),
Err(Error::PeerMisbehaved)
));
}
#[test]
fn ct_padding_strip_zero_byte_in_content_still_finds_last_nonzero() {
let buf = alloc::vec![0xAA, 0u8, 0xBB, 0u8, 23u8, 0u8, 0u8];
let (ty, end) = super::ct_find_last_nonzero(&buf).expect("nonzero present");
assert_eq!(ty, 23);
assert_eq!(end, 4);
}
}