#![allow(dead_code)]
use crate::tls::Error;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
pub(crate) const HEADER_LEN: usize = 12;
pub(crate) const MAX_HS_MSG_SEQ: u16 = 8;
pub(crate) struct HandshakeFragment<'a> {
pub(crate) msg_type: u8,
pub(crate) total_length: u32,
pub(crate) message_seq: u16,
pub(crate) fragment_offset: u32,
pub(crate) fragment: &'a [u8],
pub(crate) len: usize,
}
pub(crate) fn read_fragment(buf: &[u8]) -> Result<HandshakeFragment<'_>, Error> {
if buf.len() < HEADER_LEN {
return Err(Error::Decode);
}
let msg_type = buf[0];
let total_length = u24_be(&buf[1..4]);
let message_seq = u16::from_be_bytes([buf[4], buf[5]]);
let fragment_offset = u24_be(&buf[6..9]);
let fragment_length = u24_be(&buf[9..12]);
let body_len = fragment_length as usize;
let end_in_buf = HEADER_LEN.checked_add(body_len).ok_or(Error::Decode)?;
if end_in_buf > buf.len() {
return Err(Error::Decode);
}
let end = (fragment_offset as u64) + (fragment_length as u64);
if end > total_length as u64 {
return Err(Error::Decode);
}
Ok(HandshakeFragment {
msg_type,
total_length,
message_seq,
fragment_offset,
fragment: &buf[HEADER_LEN..end_in_buf],
len: end_in_buf,
})
}
pub(crate) fn write_message(
out: &mut Vec<u8>,
msg_type: u8,
message_seq: u16,
full_message_body: &[u8],
max_fragment_size: usize,
) {
let total = full_message_body.len();
debug_assert!(
total <= 0xFF_FFFF,
"handshake message exceeds the 24-bit length field",
);
let chunk = if max_fragment_size == 0 || max_fragment_size >= total {
total.max(1)
} else {
max_fragment_size
};
if total == 0 {
write_fragment_header(out, msg_type, total as u32, message_seq, 0, 0);
return;
}
let mut offset = 0usize;
while offset < total {
let n = core::cmp::min(chunk, total - offset);
write_fragment_header(
out,
msg_type,
total as u32,
message_seq,
offset as u32,
n as u32,
);
out.extend_from_slice(&full_message_body[offset..offset + n]);
offset += n;
}
}
struct PartialMessage {
msg_type: u8,
total_length: u32,
buf: Vec<u8>,
received: Vec<u64>,
received_count: u32,
}
impl PartialMessage {
#[inline]
fn is_received(&self, byte_idx: usize) -> bool {
let word = byte_idx >> 6;
let bit = byte_idx & 0x3f;
match self.received.get(word) {
Some(w) => (*w >> bit) & 1 != 0,
None => false,
}
}
#[inline]
fn set_received(&mut self, byte_idx: usize) -> bool {
let word = byte_idx >> 6;
let bit = byte_idx & 0x3f;
let Some(w) = self.received.get_mut(word) else {
return false;
};
let mask = 1u64 << bit;
if *w & mask != 0 {
false
} else {
*w |= mask;
true
}
}
}
const MAX_MESSAGE_LEN: u32 = 256 * 1024;
const MAX_IN_PROGRESS: usize = 8;
pub(crate) struct Reassembler {
expected_msg_seq: u16,
in_progress: BTreeMap<u16, PartialMessage>,
max_message_len: u32,
max_in_progress: usize,
}
impl Reassembler {
pub(crate) fn new() -> Self {
Self::with_limits(MAX_MESSAGE_LEN, MAX_IN_PROGRESS)
}
pub(crate) fn with_limits(max_message_len: u32, max_in_progress: usize) -> Self {
Self {
expected_msg_seq: 0,
in_progress: BTreeMap::new(),
max_message_len,
max_in_progress,
}
}
#[allow(dead_code)] pub(crate) fn expected_msg_seq(&self) -> u16 {
self.expected_msg_seq
}
pub(crate) fn rewind_expected_msg_seq(&mut self, to: u16) {
if to <= self.expected_msg_seq {
self.expected_msg_seq = to;
}
}
pub(crate) fn feed(&mut self, frag: HandshakeFragment<'_>) -> Option<(u8, Vec<u8>)> {
if frag.message_seq < self.expected_msg_seq {
return None;
}
if frag.total_length > self.max_message_len {
return None;
}
if !self.in_progress.contains_key(&frag.message_seq)
&& self.in_progress.len() >= self.max_in_progress
{
return None;
}
let total_length = frag.total_length;
let entry = self
.in_progress
.entry(frag.message_seq)
.or_insert_with(|| PartialMessage {
msg_type: frag.msg_type,
total_length,
buf: vec_zeroed(total_length as usize),
received: vec_bitmap_words(total_length as usize),
received_count: 0,
});
if entry.msg_type != frag.msg_type || entry.total_length != total_length {
return None;
}
let off = frag.fragment_offset as usize;
if off
.checked_add(frag.fragment.len())
.is_none_or(|end| end > entry.buf.len())
{
return None;
}
for (i, &b) in frag.fragment.iter().enumerate() {
let idx = off + i;
if entry.is_received(idx) && entry.buf[idx] != b {
return None;
}
}
for (i, &b) in frag.fragment.iter().enumerate() {
let idx = off + i;
if entry.set_received(idx) {
entry.buf[idx] = b;
entry.received_count += 1;
}
}
if entry.received_count == entry.total_length {
if frag.message_seq == self.expected_msg_seq {
let done = self.in_progress.remove(&frag.message_seq).unwrap();
self.expected_msg_seq = self.expected_msg_seq.wrapping_add(1);
return Some((done.msg_type, done.buf));
}
}
None
}
pub(crate) fn pop_ready(&mut self) -> Option<(u8, Vec<u8>)> {
let entry = self.in_progress.get(&self.expected_msg_seq)?;
if entry.received_count != entry.total_length {
return None;
}
let done = self.in_progress.remove(&self.expected_msg_seq)?;
self.expected_msg_seq = self.expected_msg_seq.wrapping_add(1);
Some((done.msg_type, done.buf))
}
}
#[inline]
fn u24_be(bytes: &[u8]) -> u32 {
((bytes[0] as u32) << 16) | ((bytes[1] as u32) << 8) | (bytes[2] as u32)
}
fn put_u24(out: &mut Vec<u8>, v: u32) {
out.push(((v >> 16) & 0xff) as u8);
out.push(((v >> 8) & 0xff) as u8);
out.push((v & 0xff) as u8);
}
fn write_fragment_header(
out: &mut Vec<u8>,
msg_type: u8,
total_length: u32,
message_seq: u16,
fragment_offset: u32,
fragment_length: u32,
) {
out.push(msg_type);
put_u24(out, total_length);
out.extend_from_slice(&message_seq.to_be_bytes());
put_u24(out, fragment_offset);
put_u24(out, fragment_length);
}
fn vec_zeroed(n: usize) -> Vec<u8> {
alloc::vec![0u8; n]
}
fn vec_bitmap_words(bits: usize) -> Vec<u64> {
alloc::vec![0u64; bits.div_ceil(64)]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_unfragmented_message_completes_immediately() {
let mut buf = Vec::new();
write_message(&mut buf, 1, 0, b"hello world", 0);
let frag = read_fragment(&buf).unwrap();
assert_eq!(frag.msg_type, 1);
assert_eq!(frag.total_length, 11);
assert_eq!(frag.message_seq, 0);
assert_eq!(frag.fragment_offset, 0);
assert_eq!(frag.fragment, b"hello world");
let mut r = Reassembler::new();
let out = r.feed(frag).unwrap();
assert_eq!(out.0, 1);
assert_eq!(out.1, b"hello world");
assert_eq!(r.expected_msg_seq(), 1);
}
#[test]
fn empty_body_message_completes() {
let mut buf = Vec::new();
write_message(&mut buf, 14, 3, b"", 0);
assert_eq!(buf.len(), HEADER_LEN);
let frag = read_fragment(&buf).unwrap();
assert_eq!(frag.fragment.len(), 0);
assert_eq!(frag.total_length, 0);
let mut r = Reassembler::new();
for s in 0..3 {
let mut tmp = Vec::new();
write_message(&mut tmp, 14, s as u16, b"", 0);
let f = read_fragment(&tmp).unwrap();
let res = r.feed(f).unwrap();
assert_eq!(res.0, 14);
assert!(res.1.is_empty());
}
let f3 = read_fragment(&buf).unwrap();
let res = r.feed(f3).unwrap();
assert_eq!(res.0, 14);
assert!(res.1.is_empty());
}
#[test]
fn out_of_order_fragments_reassemble() {
let body: Vec<u8> = (b'A'..=b'T').collect();
assert_eq!(body.len(), 20);
let total = body.len() as u32;
let mut frag_late = Vec::new();
write_fragment_header(&mut frag_late, 2, total, 0, 10, 10);
frag_late.extend_from_slice(&body[10..]);
let mut frag_early = Vec::new();
write_fragment_header(&mut frag_early, 2, total, 0, 0, 10);
frag_early.extend_from_slice(&body[..10]);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&frag_late).unwrap()).is_none());
let out = r.feed(read_fragment(&frag_early).unwrap()).unwrap();
assert_eq!(out.0, 2);
assert_eq!(out.1, body);
}
#[test]
fn duplicate_fragment_is_idempotent() {
let body = b"some message bytes".to_vec();
let total = body.len() as u32;
let mut buf = Vec::new();
write_fragment_header(&mut buf, 11, total, 0, 0, total);
buf.extend_from_slice(&body);
let mut r = Reassembler::new();
let out1 = r.feed(read_fragment(&buf).unwrap()).unwrap();
assert_eq!(out1.1, body);
assert_eq!(r.expected_msg_seq(), 1);
assert!(r.feed(read_fragment(&buf).unwrap()).is_none());
}
#[test]
fn duplicate_partial_fragment_no_corruption() {
let body: Vec<u8> = (0u8..20).collect();
let total = body.len() as u32;
let mut early = Vec::new();
write_fragment_header(&mut early, 5, total, 0, 0, 10);
early.extend_from_slice(&body[..10]);
let mut late = Vec::new();
write_fragment_header(&mut late, 5, total, 0, 10, 10);
late.extend_from_slice(&body[10..]);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
let out = r.feed(read_fragment(&late).unwrap()).unwrap();
assert_eq!(out.0, 5);
assert_eq!(out.1, body);
}
#[test]
fn fragment_out_of_bounds_rejected() {
let mut buf = Vec::new();
write_fragment_header(&mut buf, 1, 10, 0, 8, 5);
buf.extend_from_slice(&[0; 5]);
match read_fragment(&buf) {
Err(Error::Decode) => {}
Ok(_) => panic!("expected Decode error, got Ok"),
Err(e) => panic!("expected Decode, got {e:?}"),
}
}
#[test]
fn fragment_length_truncated_rejected() {
let mut buf = Vec::new();
write_fragment_header(&mut buf, 1, 10, 0, 0, 5);
buf.extend_from_slice(&[0; 3]);
match read_fragment(&buf) {
Err(Error::Decode) => {}
Ok(_) => panic!("expected Decode error, got Ok"),
Err(e) => panic!("expected Decode, got {e:?}"),
}
}
#[test]
fn fragment_with_trailing_bytes_consumes_only_its_own() {
let mut buf = Vec::new();
write_fragment_header(&mut buf, 1, 4, 0, 0, 4);
buf.extend_from_slice(b"AAAA");
let tail_start = buf.len();
write_fragment_header(&mut buf, 2, 6, 7, 0, 6);
buf.extend_from_slice(b"BBBBBB");
let f1 = read_fragment(&buf).unwrap();
assert_eq!(f1.msg_type, 1);
assert_eq!(f1.fragment, b"AAAA");
assert_eq!(f1.len, tail_start);
let f2 = read_fragment(&buf[f1.len..]).unwrap();
assert_eq!(f2.msg_type, 2);
assert_eq!(f2.message_seq, 7);
assert_eq!(f2.fragment, b"BBBBBB");
}
#[test]
fn write_message_chunks_at_max_fragment_size() {
let body: Vec<u8> = (0u8..25).collect();
let mut out = Vec::new();
write_message(&mut out, 7, 9, &body, 10);
let f1 = read_fragment(&out).unwrap();
assert_eq!(f1.message_seq, 9);
assert_eq!(f1.fragment_offset, 0);
assert_eq!(f1.fragment.len(), 10);
let f2 = read_fragment(&out[f1.len..]).unwrap();
assert_eq!(f2.fragment_offset, 10);
assert_eq!(f2.fragment.len(), 10);
let f3 = read_fragment(&out[f1.len + f2.len..]).unwrap();
assert_eq!(f3.fragment_offset, 20);
assert_eq!(f3.fragment.len(), 5);
assert_eq!(f1.len + f2.len + f3.len, out.len());
}
#[test]
fn fragmented_reassembly_with_bitmap_transitions() {
let body: Vec<u8> = (0..1024).map(|i| (i & 0xff) as u8).collect();
let total = body.len() as u32;
let mut frags: Vec<Vec<u8>> = (0..256)
.map(|i| {
let off = i * 4;
let mut f = Vec::new();
write_fragment_header(&mut f, 11, total, 0, off as u32, 4);
f.extend_from_slice(&body[off..off + 4]);
f
})
.collect();
let mid = frags.len() / 2;
let mut shuffled = Vec::new();
for i in 0..mid {
shuffled.push(frags.swap_remove(0));
if !frags.is_empty() {
shuffled.push(frags.remove(mid - i - 1));
}
}
shuffled.extend(frags);
let mut r = Reassembler::new();
let mut completion = None;
for f in &shuffled {
let frag = read_fragment(f).unwrap();
if let Some(out) = r.feed(frag) {
assert!(completion.is_none(), "completed twice!");
completion = Some(out);
}
}
let (ty, got) = completion.expect("message must complete");
assert_eq!(ty, 11);
assert_eq!(got, body);
assert_eq!(r.expected_msg_seq(), 1);
}
#[test]
fn with_limits_rejects_oversized_total_length() {
let mut r = Reassembler::with_limits(1024, 1);
let mut big = Vec::new();
write_fragment_header(&mut big, 1, 64 * 1024, 0, 0, 1);
big.push(0xab);
assert!(r.feed(read_fragment(&big).unwrap()).is_none());
assert_eq!(r.expected_msg_seq(), 0, "oversized claim must not advance");
let mut ok = Vec::new();
write_message(&mut ok, 1, 0, b"legit client hello", 0);
let out = r.feed(read_fragment(&ok).unwrap()).unwrap();
assert_eq!(out.1, b"legit client hello");
}
#[test]
fn with_limits_caps_in_progress_messages() {
let mut r = Reassembler::with_limits(1024, 1);
let mut half = Vec::new();
write_fragment_header(&mut half, 1, 20, 0, 0, 10);
half.extend_from_slice(&[0x11; 10]);
assert!(r.feed(read_fragment(&half).unwrap()).is_none());
let mut other = Vec::new();
write_message(&mut other, 1, 1, b"future", 0);
assert!(r.feed(read_fragment(&other).unwrap()).is_none());
let mut rest = Vec::new();
write_fragment_header(&mut rest, 1, 20, 0, 10, 10);
rest.extend_from_slice(&[0x22; 10]);
let out = r.feed(read_fragment(&rest).unwrap()).unwrap();
assert_eq!(out.1.len(), 20);
}
#[test]
fn out_of_order_message_sequence_buffered() {
let mut b0 = Vec::new();
write_message(&mut b0, 1, 0, b"zero", 0);
let mut b1 = Vec::new();
write_message(&mut b1, 1, 1, b"one!", 0);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&b1).unwrap()).is_none());
let out0 = r.feed(read_fragment(&b0).unwrap()).unwrap();
assert_eq!(out0.0, 1);
assert_eq!(out0.1, b"zero");
let out1 = r.feed(read_fragment(&b1).unwrap()).unwrap();
assert_eq!(out1.0, 1);
assert_eq!(out1.1, b"one!");
}
#[test]
fn conflicting_overlap_dropped_genuine_bytes_win() {
let body: Vec<u8> = (0u8..20).collect();
let total = body.len() as u32;
let mut early = Vec::new();
write_fragment_header(&mut early, 5, total, 0, 0, 10);
early.extend_from_slice(&body[..10]);
let mut late = Vec::new();
write_fragment_header(&mut late, 5, total, 0, 10, 10);
late.extend_from_slice(&body[10..]);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&early).unwrap()).is_none());
let mut spoof_early = Vec::new();
write_fragment_header(&mut spoof_early, 5, total, 0, 0, 10);
spoof_early.extend_from_slice(&[0xff; 10]);
assert!(r.feed(read_fragment(&spoof_early).unwrap()).is_none());
let out = r.feed(read_fragment(&late).unwrap()).unwrap();
assert_eq!(
out.1, body,
"genuine bytes must survive a conflicting spoof"
);
}
#[test]
fn idempotent_duplicate_overlap_still_accepted() {
let body: Vec<u8> = (0u8..16).collect();
let total = body.len() as u32;
let mut first = Vec::new();
write_fragment_header(&mut first, 6, total, 0, 0, 8);
first.extend_from_slice(&body[..8]);
let mut overlap = Vec::new();
write_fragment_header(&mut overlap, 6, total, 0, 4, 8);
overlap.extend_from_slice(&body[4..12]);
let mut tail = Vec::new();
write_fragment_header(&mut tail, 6, total, 0, 12, 4);
tail.extend_from_slice(&body[12..]);
let mut r = Reassembler::new();
assert!(r.feed(read_fragment(&first).unwrap()).is_none());
assert!(r.feed(read_fragment(&overlap).unwrap()).is_none());
let out = r.feed(read_fragment(&tail).unwrap()).unwrap();
assert_eq!(out.1, body);
}
}