#![allow(dead_code)]
use crate::tls::{ContentType, Error, ProtocolVersion};
use alloc::vec::Vec;
pub(crate) const MAX_FRAGMENT: usize = (1 << 14) + 2048;
pub(crate) const HEADER_LEN: usize = 13;
const SEQ_MASK_48: u64 = (1u64 << 48) - 1;
pub(crate) const MAX_RECORDS_PER_EPOCH: u64 = 1 << 23;
pub(crate) fn check_seq_cap(seq: u64) -> Result<(), Error> {
if seq >= MAX_RECORDS_PER_EPOCH {
return Err(Error::TooManyRecords);
}
Ok(())
}
pub(crate) struct ParsedDtlsRecord<'a> {
pub(crate) content_type: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) epoch: u16,
pub(crate) seq: u64,
pub(crate) fragment: &'a [u8],
pub(crate) len: usize,
}
pub(crate) fn read_record(buf: &[u8]) -> Result<Option<ParsedDtlsRecord<'_>>, Error> {
if buf.len() < HEADER_LEN {
return Ok(None);
}
let content_type = ContentType::from_u8(buf[0]);
let version = ProtocolVersion::from_u16(u16::from_be_bytes([buf[1], buf[2]]));
let epoch = u16::from_be_bytes([buf[3], buf[4]]);
let seq = ((buf[5] as u64) << 40)
| ((buf[6] as u64) << 32)
| ((buf[7] as u64) << 24)
| ((buf[8] as u64) << 16)
| ((buf[9] as u64) << 8)
| (buf[10] as u64);
let frag_len = u16::from_be_bytes([buf[11], buf[12]]) as usize;
if frag_len > MAX_FRAGMENT {
return Err(Error::RecordOverflow);
}
let total = HEADER_LEN + frag_len;
if buf.len() < total {
return Ok(None);
}
Ok(Some(ParsedDtlsRecord {
content_type,
version,
epoch,
seq,
fragment: &buf[HEADER_LEN..total],
len: total,
}))
}
pub(crate) fn write_record(
out: &mut Vec<u8>,
ct: ContentType,
version: ProtocolVersion,
epoch: u16,
seq: u64,
fragment: &[u8],
) {
debug_assert!(
seq <= SEQ_MASK_48,
"DTLS sequence numbers are 48-bit; caller must rekey before overflow",
);
debug_assert!(
fragment.len() <= MAX_FRAGMENT,
"DTLS record fragment exceeds RFC 6347 §4.1.1.1 maximum",
);
let seq = seq & SEQ_MASK_48;
out.push(ct.as_u8());
out.extend_from_slice(&version.as_u16().to_be_bytes());
out.extend_from_slice(&epoch.to_be_bytes());
out.push((seq >> 40) as u8);
out.push((seq >> 32) as u8);
out.push((seq >> 24) as u8);
out.push((seq >> 16) as u8);
out.push((seq >> 8) as u8);
out.push(seq as u8);
out.extend_from_slice(&(fragment.len() as u16).to_be_bytes());
out.extend_from_slice(fragment);
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn record_roundtrip_known_header() {
let mut out = Vec::new();
write_record(
&mut out,
ContentType::Handshake,
ProtocolVersion::DTLSv1_2,
0,
42,
b"hi",
);
assert_eq!(out.len(), HEADER_LEN + 2);
let expected: Vec<u8> = vec![
22, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x02, b'h', b'i',
];
assert_eq!(out, expected);
let rec = read_record(&out).unwrap().unwrap();
assert_eq!(rec.content_type, ContentType::Handshake);
assert_eq!(rec.version, ProtocolVersion::DTLSv1_2);
assert_eq!(rec.epoch, 0);
assert_eq!(rec.seq, 42);
assert_eq!(rec.fragment, b"hi");
assert_eq!(rec.len, HEADER_LEN + 2);
}
#[test]
fn record_roundtrip_full_48bit_seq() {
let max_seq = (1u64 << 48) - 1;
let mut out = Vec::new();
write_record(
&mut out,
ContentType::ApplicationData,
ProtocolVersion::DTLSv1_2,
7,
max_seq,
b"x",
);
let rec = read_record(&out).unwrap().unwrap();
assert_eq!(rec.content_type, ContentType::ApplicationData);
assert_eq!(rec.epoch, 7);
assert_eq!(rec.seq, max_seq);
assert_eq!(rec.fragment, b"x");
}
#[test]
fn partial_buffer_returns_none() {
let mut out = Vec::new();
write_record(
&mut out,
ContentType::Handshake,
ProtocolVersion::DTLSv1_2,
0,
1,
b"hello",
);
for cut in 0..out.len() {
assert!(
read_record(&out[..cut]).unwrap().is_none(),
"expected None at cut={cut}",
);
}
assert!(read_record(&out).unwrap().is_some());
}
#[test]
fn fragment_length_overflow_rejected() {
let mut hdr = vec![
22, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
let bad = (MAX_FRAGMENT as u32 + 1) as u16; assert!(bad as usize > MAX_FRAGMENT);
hdr.extend_from_slice(&bad.to_be_bytes());
match read_record(&hdr) {
Err(Error::RecordOverflow) => {}
other => panic!("expected RecordOverflow, got {:?}", other.map(|_| ())),
}
}
#[test]
fn seq_cap_is_below_48_bits_and_enforced() {
const { assert!(MAX_RECORDS_PER_EPOCH < (1u64 << 48)) };
assert!(check_seq_cap(0).is_ok());
assert!(check_seq_cap(MAX_RECORDS_PER_EPOCH - 1).is_ok());
assert!(matches!(
check_seq_cap(MAX_RECORDS_PER_EPOCH),
Err(Error::TooManyRecords)
));
assert!(matches!(
check_seq_cap(MAX_RECORDS_PER_EPOCH + 1),
Err(Error::TooManyRecords)
));
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "DTLS sequence numbers are 48-bit")]
fn write_record_panics_on_oversized_seq() {
let mut out = Vec::new();
write_record(
&mut out,
ContentType::Handshake,
ProtocolVersion::DTLSv1_2,
0,
1u64 << 48,
b"",
);
}
}