use super::aead::Aead;
use super::suite::AeadAlg;
use crate::tls::{ContentType, Error};
use alloc::vec::Vec;
#[allow(dead_code)]
const MAX_RECORDS_PER_KEY: u64 = 1 << 23;
#[allow(dead_code)]
pub(crate) struct RecordCrypter12 {
aead: Aead,
salt: [u8; 4],
seq: u64,
}
impl RecordCrypter12 {
#[allow(dead_code)]
pub(crate) fn new(alg: AeadAlg, key: &[u8], salt: [u8; 4]) -> Self {
RecordCrypter12 {
aead: Aead::from_key(alg, key),
salt,
seq: 0,
}
}
#[cfg(test)]
pub(crate) fn seq(&self) -> u64 {
self.seq
}
fn aead_nonce(&self, explicit_nonce: &[u8; 8]) -> [u8; 12] {
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.salt);
nonce[4..].copy_from_slice(explicit_nonce);
nonce
}
fn aad(seq: u64, content_type: ContentType, plaintext_len: u16) -> [u8; 13] {
let mut aad = [0u8; 13];
aad[..8].copy_from_slice(&seq.to_be_bytes());
aad[8] = content_type.as_u8();
aad[9] = 0x03;
aad[10] = 0x03;
aad[11..13].copy_from_slice(&plaintext_len.to_be_bytes());
aad
}
#[allow(dead_code)]
fn aad_dtls(seq_combined: u64, content_type: ContentType, plaintext_len: u16) -> [u8; 13] {
let mut aad = [0u8; 13];
aad[..8].copy_from_slice(&seq_combined.to_be_bytes());
aad[8] = content_type.as_u8();
aad[9] = 0xfe;
aad[10] = 0xfd;
aad[11..13].copy_from_slice(&plaintext_len.to_be_bytes());
aad
}
#[allow(dead_code)]
pub(crate) fn encrypt_dtls(
&self,
seq_combined: u64,
content_type: ContentType,
payload: &[u8],
) -> Result<Vec<u8>, Error> {
if payload.len() > (1usize << 14) {
return Err(Error::RecordOverflow);
}
let explicit_nonce = seq_combined.to_be_bytes();
let nonce = self.aead_nonce(&explicit_nonce);
let aad = Self::aad_dtls(seq_combined, content_type, payload.len() as u16);
let mut buf = payload.to_vec();
let tag = self.aead.encrypt(&nonce, &aad, &mut buf);
let mut out = Vec::with_capacity(8 + buf.len() + 16);
out.extend_from_slice(&explicit_nonce);
out.extend_from_slice(&buf);
out.extend_from_slice(&tag);
Ok(out)
}
#[allow(dead_code)]
pub(crate) fn decrypt_dtls(
&self,
seq_combined: u64,
content_type: ContentType,
fragment: &[u8],
) -> Result<Vec<u8>, Error> {
if fragment.len() < 8 + 16 {
return Err(Error::Decode);
}
let mut explicit_nonce = [0u8; 8];
explicit_nonce.copy_from_slice(&fragment[..8]);
let body = &fragment[8..];
let (ct_bytes, tag_bytes) = body.split_at(body.len() - 16);
let mut tag = [0u8; 16];
tag.copy_from_slice(tag_bytes);
let plaintext_len = ct_bytes.len();
if plaintext_len > (1usize << 14) {
return Err(Error::RecordOverflow);
}
let aad = Self::aad_dtls(seq_combined, content_type, plaintext_len as u16);
let nonce = self.aead_nonce(&explicit_nonce);
let mut buf = ct_bytes.to_vec();
if !self.aead.decrypt(&nonce, &aad, &mut buf, &tag) {
return Err(Error::BadRecordMac);
}
Ok(buf)
}
#[allow(dead_code)]
pub(crate) fn encrypt(
&mut self,
content_type: ContentType,
payload: &[u8],
) -> Result<Vec<u8>, Error> {
if payload.len() > (1usize << 14) {
return Err(Error::RecordOverflow);
}
if self.seq >= MAX_RECORDS_PER_KEY {
return Err(Error::TooManyRecords);
}
let explicit_nonce = self.seq.to_be_bytes();
let nonce = self.aead_nonce(&explicit_nonce);
let aad = Self::aad(self.seq, content_type, payload.len() as u16);
let mut buf = payload.to_vec();
let tag = self.aead.encrypt(&nonce, &aad, &mut buf);
let mut out = Vec::with_capacity(8 + buf.len() + 16);
out.extend_from_slice(&explicit_nonce);
out.extend_from_slice(&buf);
out.extend_from_slice(&tag);
self.seq += 1;
Ok(out)
}
#[allow(dead_code)]
pub(crate) fn decrypt(
&mut self,
record_header: &[u8; 5],
fragment: &[u8],
) -> Result<(ContentType, Vec<u8>), Error> {
if fragment.len() < 8 + 16 {
return Err(Error::Decode);
}
if self.seq >= MAX_RECORDS_PER_KEY {
return Err(Error::TooManyRecords);
}
let mut explicit_nonce = [0u8; 8];
explicit_nonce.copy_from_slice(&fragment[..8]);
let body = &fragment[8..];
let (ct_bytes, tag_bytes) = body.split_at(body.len() - 16);
let mut tag = [0u8; 16];
tag.copy_from_slice(tag_bytes);
let plaintext_len = ct_bytes.len();
if plaintext_len > (1usize << 14) {
return Err(Error::RecordOverflow);
}
let content_type = ContentType::from_u8(record_header[0]);
let aad = Self::aad(self.seq, content_type, plaintext_len as u16);
let nonce = self.aead_nonce(&explicit_nonce);
let mut buf = ct_bytes.to_vec();
if !self.aead.decrypt(&nonce, &aad, &mut buf, &tag) {
return Err(Error::BadRecordMac);
}
self.seq += 1;
Ok((content_type, buf))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pair(alg: AeadAlg, key: &[u8], salt: [u8; 4]) -> (RecordCrypter12, RecordCrypter12) {
(
RecordCrypter12::new(alg, key, salt),
RecordCrypter12::new(alg, key, salt),
)
}
#[test]
fn round_trip_application_data() {
let payload = (0..100u8).collect::<Vec<u8>>();
let salt = [0xa1, 0xa2, 0xa3, 0xa4];
for (alg, key_len) in [
(AeadAlg::Aes128Gcm, 16usize),
(AeadAlg::Aes256Gcm, 32),
(AeadAlg::ChaCha20Poly1305, 32),
] {
let key: Vec<u8> = (0..key_len as u8).collect();
let (mut enc, mut dec) = pair(alg, &key, salt);
let wire = enc.encrypt(ContentType::ApplicationData, &payload).unwrap();
assert_eq!(wire.len(), 8 + payload.len() + 16);
let mut header = [0u8; 5];
header[0] = ContentType::ApplicationData.as_u8();
header[1] = 0x03;
header[2] = 0x03;
let frag_len = wire.len() as u16;
header[3..5].copy_from_slice(&frag_len.to_be_bytes());
let (ct, plain) = dec.decrypt(&header, &wire).unwrap();
assert_eq!(ct, ContentType::ApplicationData);
assert_eq!(plain, payload);
}
}
#[test]
fn tampering_is_rejected() {
let payload = alloc::vec![0x42u8; 100];
let salt = [0xa1, 0xa2, 0xa3, 0xa4];
let key = alloc::vec![0x33u8; 16];
{
let (mut enc, mut dec) = pair(AeadAlg::Aes128Gcm, &key, salt);
let mut wire = enc.encrypt(ContentType::ApplicationData, &payload).unwrap();
wire[0] ^= 0x01;
let mut header = [0u8; 5];
header[0] = ContentType::ApplicationData.as_u8();
header[1] = 0x03;
header[2] = 0x03;
let frag_len = wire.len() as u16;
header[3..5].copy_from_slice(&frag_len.to_be_bytes());
assert!(matches!(
dec.decrypt(&header, &wire),
Err(Error::BadRecordMac)
));
}
{
let (mut enc, mut dec) = pair(AeadAlg::Aes128Gcm, &key, salt);
let mut wire = enc.encrypt(ContentType::ApplicationData, &payload).unwrap();
wire[20] ^= 0x80;
let mut header = [0u8; 5];
header[0] = ContentType::ApplicationData.as_u8();
header[1] = 0x03;
header[2] = 0x03;
let frag_len = wire.len() as u16;
header[3..5].copy_from_slice(&frag_len.to_be_bytes());
assert!(matches!(
dec.decrypt(&header, &wire),
Err(Error::BadRecordMac)
));
}
{
let (mut enc, mut dec) = pair(AeadAlg::Aes128Gcm, &key, salt);
let mut wire = enc.encrypt(ContentType::ApplicationData, &payload).unwrap();
let last = wire.len() - 1;
wire[last] ^= 0x01;
let mut header = [0u8; 5];
header[0] = ContentType::ApplicationData.as_u8();
header[1] = 0x03;
header[2] = 0x03;
let frag_len = wire.len() as u16;
header[3..5].copy_from_slice(&frag_len.to_be_bytes());
assert!(matches!(
dec.decrypt(&header, &wire),
Err(Error::BadRecordMac)
));
}
}
#[test]
fn explicit_nonce_matches_seq_counter() {
let payload = alloc::vec![0u8; 4];
let salt = [0; 4];
let key = alloc::vec![0u8; 16];
let mut enc = RecordCrypter12::new(AeadAlg::Aes128Gcm, &key, salt);
for expected_seq in 0u64..5 {
assert_eq!(enc.seq(), expected_seq);
let wire = enc.encrypt(ContentType::ApplicationData, &payload).unwrap();
let mut got = [0u8; 8];
got.copy_from_slice(&wire[..8]);
assert_eq!(got, expected_seq.to_be_bytes());
}
assert_eq!(enc.seq(), 5);
}
#[test]
fn short_fragment_rejected() {
let salt = [0; 4];
let key = alloc::vec![0u8; 16];
let mut dec = RecordCrypter12::new(AeadAlg::Aes128Gcm, &key, salt);
let header = [ContentType::ApplicationData.as_u8(), 0x03, 0x03, 0x00, 0x10];
let short = [0u8; 16];
assert!(matches!(dec.decrypt(&header, &short), Err(Error::Decode)));
}
}