#![allow(dead_code)]
use crate::tls::Error;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
pub(crate) const HEADER_LEN: usize = 12;
pub(crate) struct HandshakeFragment<'a> {
pub(crate) msg_type: u8,
pub(crate) total_length: u32,
pub(crate) message_seq: u16,
pub(crate) fragment_offset: u32,
pub(crate) fragment: &'a [u8],
pub(crate) len: usize,
}
pub(crate) fn read_fragment(buf: &[u8]) -> Result<HandshakeFragment<'_>, Error> {
if buf.len() < HEADER_LEN {
return Err(Error::Decode);
}
let msg_type = buf[0];
let total_length = u24_be(&buf[1..4]);
let message_seq = u16::from_be_bytes([buf[4], buf[5]]);
let fragment_offset = u24_be(&buf[6..9]);
let fragment_length = u24_be(&buf[9..12]);
let body_len = fragment_length as usize;
let end_in_buf = HEADER_LEN.checked_add(body_len).ok_or(Error::Decode)?;
if end_in_buf > buf.len() {
return Err(Error::Decode);
}
let end = (fragment_offset as u64) + (fragment_length as u64);
if end > total_length as u64 {
return Err(Error::Decode);
}
Ok(HandshakeFragment {
msg_type,
total_length,
message_seq,
fragment_offset,
fragment: &buf[HEADER_LEN..end_in_buf],
len: end_in_buf,
})
}
pub(crate) fn write_message(
out: &mut Vec<u8>,
msg_type: u8,
message_seq: u16,
full_message_body: &[u8],
max_fragment_size: usize,
) {
let total = full_message_body.len();
debug_assert!(
total <= 0xFF_FFFF,
"handshake message exceeds the 24-bit length field",
);
let chunk = if max_fragment_size == 0 || max_fragment_size >= total {
total.max(1)
} else {
max_fragment_size
};
if total == 0 {
write_fragment_header(out, msg_type, total as u32, message_seq, 0, 0);
return;
}
let mut offset = 0usize;
while offset < total {
let n = core::cmp::min(chunk, total - offset);
write_fragment_header(
out,
msg_type,
total as u32,
message_seq,
offset as u32,
n as u32,
);
out.extend_from_slice(&full_message_body[offset..offset + n]);
offset += n;
}
}
struct PartialMessage {
msg_type: u8,
total_length: u32,
buf: Vec<u8>,
received: Vec<bool>,
received_count: u32,
}
const MAX_MESSAGE_LEN: u32 = 256 * 1024;
const MAX_IN_PROGRESS: usize = 8;
pub(crate) struct Reassembler {
expected_msg_seq: u16,
in_progress: BTreeMap<u16, PartialMessage>,
}
impl Reassembler {
pub(crate) fn new() -> Self {
Self {
expected_msg_seq: 0,
in_progress: BTreeMap::new(),
}
}
#[allow(dead_code)] pub(crate) fn expected_msg_seq(&self) -> u16 {
self.expected_msg_seq
}
pub(crate) fn feed(&mut self, frag: HandshakeFragment<'_>) -> Option<(u8, Vec<u8>)> {
if frag.message_seq < self.expected_msg_seq {
return None;
}
if frag.total_length > MAX_MESSAGE_LEN {
return None;
}
if !self.in_progress.contains_key(&frag.message_seq)
&& self.in_progress.len() >= MAX_IN_PROGRESS
{
return None;
}
let total_length = frag.total_length;
let entry = self
.in_progress
.entry(frag.message_seq)
.or_insert_with(|| PartialMessage {
msg_type: frag.msg_type,
total_length,
buf: vec_zeroed(total_length as usize),
received: vec_false(total_length as usize),
received_count: 0,
});
if entry.msg_type != frag.msg_type || entry.total_length != total_length {
return None;
}
let off = frag.fragment_offset as usize;
for (i, &b) in frag.fragment.iter().enumerate() {
let idx = off + i;
if idx >= entry.received.len() {
return None;
}
if !entry.received[idx] {
entry.received[idx] = true;
entry.buf[idx] = b;
entry.received_count += 1;
}
}
if entry.received_count == entry.total_length {
if frag.message_seq == self.expected_msg_seq {
let done = self.in_progress.remove(&frag.message_seq).unwrap();
self.expected_msg_seq = self.expected_msg_seq.wrapping_add(1);
return Some((done.msg_type, done.buf));
}
}
None
}
pub(crate) fn pop_ready(&mut self) -> Option<(u8, Vec<u8>)> {
let entry = self.in_progress.get(&self.expected_msg_seq)?;
if entry.received_count != entry.total_length {
return None;
}
let done = self.in_progress.remove(&self.expected_msg_seq)?;
self.expected_msg_seq = self.expected_msg_seq.wrapping_add(1);
Some((done.msg_type, done.buf))
}
}
#[inline]
fn u24_be(bytes: &[u8]) -> u32 {
((bytes[0] as u32) << 16) | ((bytes[1] as u32) << 8) | (bytes[2] as u32)
}
fn put_u24(out: &mut Vec<u8>, v: u32) {
out.push(((v >> 16) & 0xff) as u8);
out.push(((v >> 8) & 0xff) as u8);
out.push((v & 0xff) as u8);
}
fn write_fragment_header(
out: &mut Vec<u8>,
msg_type: u8,
total_length: u32,
message_seq: u16,
fragment_offset: u32,
fragment_length: u32,
) {
out.push(msg_type);
put_u24(out, total_length);
out.extend_from_slice(&message_seq.to_be_bytes());
put_u24(out, fragment_offset);
put_u24(out, fragment_length);
}
fn vec_zeroed(n: usize) -> Vec<u8> {
alloc::vec![0u8; n]
}
fn vec_false(n: usize) -> Vec<bool> {
alloc::vec![false; n]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_unfragmented_message_completes_immediately() {
let mut buf = Vec::new();
write_message(&mut buf, 1, 0, b"hello world", 0);
let frag = read_fragment(&buf).unwrap();
assert_eq!(frag.msg_type, 1);
assert_eq!(frag.total_length, 11);
assert_eq!(frag.message_seq, 0);
assert_eq!(frag.fragment_offset, 0);
assert_eq!(frag.fragment, b"hello world");
let mut r = Reassembler::new();
let out = r.feed(frag).unwrap();
assert_eq!(out.0, 1);
assert_eq!(out.1, b"hello world");
assert_eq!(r.expected_msg_seq(), 1);
}
#[test]
fn empty_body_message_completes() {
let mut buf = Vec::new();
write_message(&mut buf, 14, 3, b"", 0);
assert_eq!(buf.len(), HEADER_LEN);
let frag = read_fragment(&buf).unwrap();
assert_eq!(frag.fragment.len(), 0);
assert_eq!(frag.total_length, 0);
let mut r = Reassembler::new();
for s in 0..3 {
let mut tmp = Vec::new();
write_message(&mut tmp, 14, s as u16, b"", 0);
let f = read_fragment(&tmp).unwrap();
let res = r.feed(f).unwrap();
assert_eq!(res.0, 14);
assert!(res.1.is_empty());
}
let f3 = read_fragment(&buf).unwrap();
let res = r.feed(f3).unwrap();
assert_eq!(res.0, 14);
assert!(res.1.is_empty());
}
#[test]
fn out_of_order_fragments_reassemble() {
let body: Vec<u8> = (b'A'..=b'T').collect();
assert_eq!(body.len(), 20);
let total = body.len() as u32;
let mut frag_late = Vec::new();
write_fragment_header(&mut frag_late, 2, total, 0, 10, 10);
frag_late.extend_from_slice(&body[10..]);
let mut frag_early = Vec::new();
write_fragment_header(&mut frag_early, 2, total, 0, 0, 10);
frag_early.extend_from_slice(&body[..10]);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&frag_late).unwrap()).is_none());
let out = r.feed(read_fragment(&frag_early).unwrap()).unwrap();
assert_eq!(out.0, 2);
assert_eq!(out.1, body);
}
#[test]
fn duplicate_fragment_is_idempotent() {
let body = b"some message bytes".to_vec();
let total = body.len() as u32;
let mut buf = Vec::new();
write_fragment_header(&mut buf, 11, total, 0, 0, total);
buf.extend_from_slice(&body);
let mut r = Reassembler::new();
let out1 = r.feed(read_fragment(&buf).unwrap()).unwrap();
assert_eq!(out1.1, body);
assert_eq!(r.expected_msg_seq(), 1);
assert!(r.feed(read_fragment(&buf).unwrap()).is_none());
}
#[test]
fn duplicate_partial_fragment_no_corruption() {
let body: Vec<u8> = (0u8..20).collect();
let total = body.len() as u32;
let mut early = Vec::new();
write_fragment_header(&mut early, 5, total, 0, 0, 10);
early.extend_from_slice(&body[..10]);
let mut late = Vec::new();
write_fragment_header(&mut late, 5, total, 0, 10, 10);
late.extend_from_slice(&body[10..]);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
let out = r.feed(read_fragment(&late).unwrap()).unwrap();
assert_eq!(out.0, 5);
assert_eq!(out.1, body);
}
#[test]
fn fragment_out_of_bounds_rejected() {
let mut buf = Vec::new();
write_fragment_header(&mut buf, 1, 10, 0, 8, 5);
buf.extend_from_slice(&[0; 5]);
match read_fragment(&buf) {
Err(Error::Decode) => {}
Ok(_) => panic!("expected Decode error, got Ok"),
Err(e) => panic!("expected Decode, got {e:?}"),
}
}
#[test]
fn fragment_length_truncated_rejected() {
let mut buf = Vec::new();
write_fragment_header(&mut buf, 1, 10, 0, 0, 5);
buf.extend_from_slice(&[0; 3]);
match read_fragment(&buf) {
Err(Error::Decode) => {}
Ok(_) => panic!("expected Decode error, got Ok"),
Err(e) => panic!("expected Decode, got {e:?}"),
}
}
#[test]
fn fragment_with_trailing_bytes_consumes_only_its_own() {
let mut buf = Vec::new();
write_fragment_header(&mut buf, 1, 4, 0, 0, 4);
buf.extend_from_slice(b"AAAA");
let tail_start = buf.len();
write_fragment_header(&mut buf, 2, 6, 7, 0, 6);
buf.extend_from_slice(b"BBBBBB");
let f1 = read_fragment(&buf).unwrap();
assert_eq!(f1.msg_type, 1);
assert_eq!(f1.fragment, b"AAAA");
assert_eq!(f1.len, tail_start);
let f2 = read_fragment(&buf[f1.len..]).unwrap();
assert_eq!(f2.msg_type, 2);
assert_eq!(f2.message_seq, 7);
assert_eq!(f2.fragment, b"BBBBBB");
}
#[test]
fn write_message_chunks_at_max_fragment_size() {
let body: Vec<u8> = (0u8..25).collect();
let mut out = Vec::new();
write_message(&mut out, 7, 9, &body, 10);
let f1 = read_fragment(&out).unwrap();
assert_eq!(f1.message_seq, 9);
assert_eq!(f1.fragment_offset, 0);
assert_eq!(f1.fragment.len(), 10);
let f2 = read_fragment(&out[f1.len..]).unwrap();
assert_eq!(f2.fragment_offset, 10);
assert_eq!(f2.fragment.len(), 10);
let f3 = read_fragment(&out[f1.len + f2.len..]).unwrap();
assert_eq!(f3.fragment_offset, 20);
assert_eq!(f3.fragment.len(), 5);
assert_eq!(f1.len + f2.len + f3.len, out.len());
}
#[test]
fn out_of_order_message_sequence_buffered() {
let mut b0 = Vec::new();
write_message(&mut b0, 1, 0, b"zero", 0);
let mut b1 = Vec::new();
write_message(&mut b1, 1, 1, b"one!", 0);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&b1).unwrap()).is_none());
let out0 = r.feed(read_fragment(&b0).unwrap()).unwrap();
assert_eq!(out0.0, 1);
assert_eq!(out0.1, b"zero");
let out1 = r.feed(read_fragment(&b1).unwrap()).unwrap();
assert_eq!(out1.0, 1);
assert_eq!(out1.1, b"one!");
}
}