#![allow(dead_code)]
use crate::cipher::{Aes128, Aes256, BlockCipher};
use crate::tls::Error;
use alloc::vec::Vec;
const UNIFIED_HDR_PREFIX: u8 = 0b0010_0000;
const UNIFIED_HDR_PREFIX_MASK: u8 = 0b1110_0000;
const FLAG_CID: u8 = 0b0001_0000;
const FLAG_SEQ_16: u8 = 0b0000_1000;
const FLAG_LENGTH: u8 = 0b0000_0100;
const FLAG_EPOCH_LO2: u8 = 0b0000_0011;
const SEQ_MASK_48: u64 = (1u64 << 48) - 1;
#[derive(Debug, Clone, Copy)]
pub(crate) struct UnifiedHeader {
pub(crate) is_ciphertext: bool,
pub(crate) epoch_low2: u8,
pub(crate) seq_low: u16,
pub(crate) seq_is_16bit: bool,
pub(crate) has_length: bool,
pub(crate) has_cid: bool,
pub(crate) header_len: usize,
}
pub(crate) fn reconstruct_seq(seq_low: u16, seq_is_16bit: bool, expected_seq: u64) -> u64 {
let (modulus_bits, mask) = if seq_is_16bit {
(16u32, 0xFFFFu64)
} else {
(8u32, 0xFFu64)
};
let modulus = 1u64 << modulus_bits;
let low = (seq_low as u64) & mask;
let expected_high = expected_seq & !mask;
let base = expected_high | low;
let candidates = [
base.checked_sub(modulus),
Some(base),
base.checked_add(modulus),
];
let mut best = base;
let mut best_dist = abs_diff(base, expected_seq);
for c in candidates.iter().flatten() {
if *c > SEQ_MASK_48 {
continue;
}
let d = abs_diff(*c, expected_seq);
if d < best_dist {
best = *c;
best_dist = d;
}
}
best
}
#[inline]
fn abs_diff(a: u64, b: u64) -> u64 {
a.abs_diff(b)
}
pub(crate) fn encode_record(
out: &mut Vec<u8>,
epoch: u16,
seq: u64,
seq_is_16bit: bool,
omit_length: bool,
encrypted_payload: &[u8],
sn_mask: &[u8],
) {
debug_assert!(seq <= SEQ_MASK_48, "DTLS seq must fit in 48 bits");
let expected_mask_len = if seq_is_16bit { 2 } else { 1 };
debug_assert_eq!(
sn_mask.len(),
expected_mask_len,
"sn_mask length must match seq_is_16bit",
);
let mut first = UNIFIED_HDR_PREFIX;
if seq_is_16bit {
first |= FLAG_SEQ_16;
}
if !omit_length {
first |= FLAG_LENGTH;
}
first |= (epoch as u8) & FLAG_EPOCH_LO2;
out.push(first);
if seq_is_16bit {
let seq_bytes = (seq as u16).to_be_bytes();
out.push(seq_bytes[0] ^ sn_mask[0]);
out.push(seq_bytes[1] ^ sn_mask[1]);
} else {
out.push((seq as u8) ^ sn_mask[0]);
}
if !omit_length {
out.extend_from_slice(&(encrypted_payload.len() as u16).to_be_bytes());
}
out.extend_from_slice(encrypted_payload);
}
pub(crate) fn decode_record<'a>(
buf: &'a [u8],
sn_mask: &[u8],
) -> Result<(UnifiedHeader, &'a [u8]), Error> {
if buf.is_empty() {
return Err(Error::Decode);
}
let first = buf[0];
if (first & UNIFIED_HDR_PREFIX_MASK) != UNIFIED_HDR_PREFIX {
return Err(Error::Decode);
}
let has_cid = (first & FLAG_CID) != 0;
if has_cid {
return Err(Error::IllegalParameter);
}
let seq_is_16bit = (first & FLAG_SEQ_16) != 0;
let has_length = (first & FLAG_LENGTH) != 0;
let epoch_low2 = first & FLAG_EPOCH_LO2;
let seq_bytes = if seq_is_16bit { 2usize } else { 1usize };
if sn_mask.len() != seq_bytes {
return Err(Error::Decode);
}
let len_bytes = if has_length { 2usize } else { 0usize };
let header_len = 1 + seq_bytes + len_bytes;
if buf.len() < header_len {
return Err(Error::Decode);
}
let seq_low = if seq_is_16bit {
let hi = buf[1] ^ sn_mask[0];
let lo = buf[2] ^ sn_mask[1];
((hi as u16) << 8) | (lo as u16)
} else {
(buf[1] ^ sn_mask[0]) as u16
};
let body_start = header_len;
let body_end = if has_length {
let off = 1 + seq_bytes;
let len = u16::from_be_bytes([buf[off], buf[off + 1]]) as usize;
let end = body_start + len;
if end > buf.len() {
return Err(Error::Decode);
}
end
} else {
buf.len()
};
Ok((
UnifiedHeader {
is_ciphertext: true,
epoch_low2,
seq_low,
seq_is_16bit,
has_length,
has_cid,
header_len,
},
&buf[body_start..body_end],
))
}
pub(crate) fn peek_header_layout(buf: &[u8]) -> Result<(usize, usize), Error> {
if buf.is_empty() {
return Err(Error::Decode);
}
let first = buf[0];
if (first & UNIFIED_HDR_PREFIX_MASK) != UNIFIED_HDR_PREFIX {
return Err(Error::Decode);
}
if (first & FLAG_CID) != 0 {
return Err(Error::IllegalParameter);
}
let seq_is_16bit = (first & FLAG_SEQ_16) != 0;
let has_length = (first & FLAG_LENGTH) != 0;
let seq_bytes = if seq_is_16bit { 2 } else { 1 };
let len_bytes = if has_length { 2 } else { 0 };
let header_len = 1 + seq_bytes + len_bytes;
if buf.len() < header_len {
return Err(Error::Decode);
}
let body_len = if has_length {
let off = 1 + seq_bytes;
let len = u16::from_be_bytes([buf[off], buf[off + 1]]) as usize;
if header_len + len > buf.len() {
return Err(Error::Decode);
}
len
} else {
buf.len() - header_len
};
Ok((header_len, body_len))
}
pub(crate) fn sn_mask_aes128(sn_key: &[u8; 16], ciphertext: &[u8]) -> [u8; 2] {
let cipher = Aes128::new(sn_key);
sn_mask_block(&cipher, ciphertext)
}
pub(crate) fn sn_mask_aes256(sn_key: &[u8; 32], ciphertext: &[u8]) -> [u8; 2] {
let cipher = Aes256::new(sn_key);
sn_mask_block(&cipher, ciphertext)
}
#[inline]
fn sn_mask_block<C: BlockCipher>(cipher: &C, ciphertext: &[u8]) -> [u8; 2] {
let mut block = [0u8; 16];
let take = ciphertext.len().min(16);
block[..take].copy_from_slice(&ciphertext[..take]);
cipher.encrypt_block(&mut block);
[block[0], block[1]]
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
fn dummy_ct() -> Vec<u8> {
(0u8..32).collect()
}
#[test]
fn header_roundtrip_16bit_seq_with_length() {
let mut out = Vec::new();
let mask = [0u8; 2];
let ct = dummy_ct();
encode_record(&mut out, 1, 42, true, false, &ct, &mask);
assert_eq!(out[0], 0b0010_1101);
assert_eq!(&out[1..3], &[0x00, 0x2A]);
assert_eq!(&out[3..5], &[0x00, 0x20]);
assert_eq!(&out[5..], ct.as_slice());
let (hdr, body) = decode_record(&out, &mask).unwrap();
assert!(hdr.is_ciphertext);
assert_eq!(hdr.epoch_low2, 0b01);
assert!(hdr.seq_is_16bit);
assert_eq!(hdr.seq_low, 42);
assert!(hdr.has_length);
assert!(!hdr.has_cid);
assert_eq!(hdr.header_len, 1 + 2 + 2);
assert_eq!(body, ct.as_slice());
}
#[test]
fn header_roundtrip_8bit_seq() {
let mut out = Vec::new();
let mask = [0u8; 1];
let ct = dummy_ct();
encode_record(&mut out, 2, 0x0055, false, false, &ct, &mask);
assert_eq!(out[0], 0b0010_0110);
assert_eq!(out[1], 0x55);
assert_eq!(&out[2..4], &[0x00, 0x20]);
assert_eq!(&out[4..], ct.as_slice());
let (hdr, body) = decode_record(&out, &mask).unwrap();
assert_eq!(hdr.epoch_low2, 0b10);
assert!(!hdr.seq_is_16bit);
assert_eq!(hdr.seq_low, 0x55);
assert!(hdr.has_length);
assert_eq!(hdr.header_len, 1 + 1 + 2);
assert_eq!(body, ct.as_slice());
}
#[test]
fn header_roundtrip_length_omitted() {
let mut out = Vec::new();
let mask = [0u8; 2];
let ct = dummy_ct();
encode_record(&mut out, 0, 7, true, true, &ct, &mask);
assert_eq!(out[0], 0b0010_1000);
assert_eq!(&out[1..3], &[0x00, 0x07]);
assert_eq!(&out[3..], ct.as_slice());
assert_eq!(out.len(), 1 + 2 + ct.len());
let (hdr, body) = decode_record(&out, &mask).unwrap();
assert!(!hdr.has_length);
assert!(hdr.seq_is_16bit);
assert_eq!(hdr.seq_low, 7);
assert_eq!(hdr.header_len, 1 + 2);
assert_eq!(body, ct.as_slice());
}
#[test]
fn cid_bit_rejected() {
let bad = vec![0b0011_0101u8, 0, 0, 0, 0];
match decode_record(&bad, &[0u8; 2]) {
Err(Error::IllegalParameter) => {}
other => panic!("expected IllegalParameter, got {other:?}"),
}
match peek_header_layout(&bad) {
Err(Error::IllegalParameter) => {}
other => panic!("expected IllegalParameter, got {other:?}"),
}
}
#[test]
fn non_dtls13_prefix_rejected() {
let bad = vec![0b1010_0101u8, 0, 0, 0, 0];
match decode_record(&bad, &[0u8; 2]) {
Err(Error::Decode) => {}
other => panic!("expected Decode, got {other:?}"),
}
}
#[test]
fn truncated_buffer_rejected() {
let mut out = Vec::new();
let mask = [0u8; 2];
let ct = dummy_ct();
encode_record(&mut out, 0, 1, true, false, &ct, &mask);
out.pop();
match decode_record(&out, &mask) {
Err(Error::Decode) => {}
other => panic!("expected Decode, got {other:?}"),
}
}
#[test]
fn reconstruct_seq_simple() {
let got = reconstruct_seq(0x0145, true, 300);
assert_eq!(got, 0x0145);
}
#[test]
fn reconstruct_seq_wraparound() {
let got = reconstruct_seq(0xFFFF, true, 0x10000);
assert_eq!(got, 0x0FFFF);
}
#[test]
fn reconstruct_seq_forward_wrap() {
let got = reconstruct_seq(0x0001, true, 0xFFFE);
assert_eq!(got, 0x10001);
}
#[test]
fn reconstruct_seq_8bit() {
let got = reconstruct_seq(0x05, false, 0x200);
assert_eq!(got, 0x205);
}
#[test]
fn sn_mask_aes128_known_vector() {
let key = [0u8; 16];
let ct = [0u8; 16];
let mask = sn_mask_aes128(&key, &ct);
assert_eq!(mask, [0x66, 0xe9]);
}
#[test]
fn sn_mask_aes128_short_ciphertext_zero_padded() {
let key = [0u8; 16];
let short = [0u8; 4];
let mask_short = sn_mask_aes128(&key, &short);
let zero = [0u8; 16];
let mask_zero = sn_mask_aes128(&key, &zero);
assert_eq!(mask_short, mask_zero);
}
#[test]
fn sn_mask_aes256_known_vector() {
let key = [0u8; 32];
let ct = [0u8; 16];
let mask = sn_mask_aes256(&key, &ct);
assert_eq!(mask, [0xdc, 0x95]);
}
#[test]
fn encode_decode_with_real_mask() {
let mut out = Vec::new();
let mask = [0xAA, 0x55];
let ct = dummy_ct();
encode_record(&mut out, 3, 0x1234, true, false, &ct, &mask);
assert_eq!(&out[1..3], &[0xB8, 0x61]);
let (hdr, body) = decode_record(&out, &mask).unwrap();
assert_eq!(hdr.seq_low, 0x1234);
assert_eq!(body, ct.as_slice());
}
#[test]
fn peek_header_layout_matches_decode() {
let mut out = Vec::new();
let mask = [0u8; 2];
let ct = dummy_ct();
encode_record(&mut out, 0, 9, true, false, &ct, &mask);
let (hdr_len, body_len) = peek_header_layout(&out).unwrap();
assert_eq!(hdr_len, 5);
assert_eq!(body_len, ct.len());
let mut out2 = Vec::new();
encode_record(&mut out2, 0, 9, true, true, &ct, &mask);
let (hdr_len2, body_len2) = peek_header_layout(&out2).unwrap();
assert_eq!(hdr_len2, 3);
assert_eq!(body_len2, ct.len());
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "DTLS seq must fit in 48 bits")]
fn encode_panics_on_oversized_seq() {
let mut out = Vec::new();
let mask = [0u8; 2];
encode_record(&mut out, 0, 1u64 << 48, true, false, b"", &mask);
}
}