use std::num::NonZeroUsize;
use bincode::{
borrow_decode_from_slice,
config,
encode_to_vec,
error::{DecodeError, EncodeError},
};
use super::{FragmentHeader, FragmentIndex, MessageId};
use crate::byte_order::{read_network_u16, write_network_u16};
pub const FRAGMENT_MAGIC: &[u8; 4] = b"FRAG";
#[must_use]
pub fn fragment_overhead() -> NonZeroUsize {
let header = FragmentHeader::new(MessageId::new(0), FragmentIndex::zero(), false);
let header_bytes = match encode_to_vec(header, config::standard()) {
Ok(bytes) => bytes,
Err(err) => panic!("fragment header encoding must be infallible for constants: {err}"),
};
let overhead = FRAGMENT_MAGIC.len() + std::mem::size_of::<u16>() + header_bytes.len();
match NonZeroUsize::new(overhead) {
Some(non_zero) => non_zero,
None => panic!("fragment overhead must be non-zero (computed {overhead})"),
}
}
pub fn encode_fragment_payload(
header: FragmentHeader,
payload: &[u8],
) -> Result<Vec<u8>, bincode::error::EncodeError> {
let header_bytes = encode_to_vec(header, config::standard())?;
let header_len = u16::try_from(header_bytes.len())
.map_err(|_| EncodeError::Other("fragment header length must fit within u16::MAX"))?;
let mut buf = Vec::with_capacity(
FRAGMENT_MAGIC.len() + std::mem::size_of::<u16>() + header_bytes.len() + payload.len(),
);
buf.extend_from_slice(FRAGMENT_MAGIC);
let header_len_bytes = write_network_u16(header_len);
buf.extend_from_slice(&header_len_bytes);
buf.extend_from_slice(&header_bytes);
buf.extend_from_slice(payload);
Ok(buf)
}
pub fn decode_fragment_payload(
payload: &[u8],
) -> Result<Option<(FragmentHeader, &[u8])>, DecodeError> {
let minimum_len = FRAGMENT_MAGIC.len() + std::mem::size_of::<u16>();
if payload.len() < minimum_len {
return Ok(None);
}
let Some(prefix) = payload.get(..FRAGMENT_MAGIC.len()) else {
return Ok(None);
};
if prefix != FRAGMENT_MAGIC {
return Ok(None);
}
let header_len_offset = FRAGMENT_MAGIC.len();
let len_hi = payload
.get(header_len_offset)
.copied()
.ok_or(DecodeError::UnexpectedEnd { additional: 0 })?;
let len_lo = payload
.get(header_len_offset + 1)
.copied()
.ok_or(DecodeError::UnexpectedEnd { additional: 0 })?;
let len_bytes = [len_hi, len_lo];
let header_len = usize::from(read_network_u16(len_bytes));
let header_start = header_len_offset + std::mem::size_of::<u16>();
let header_end = header_start + header_len;
let Some(header_bytes) = payload.get(header_start..header_end) else {
return Err(DecodeError::UnexpectedEnd {
additional: header_end.saturating_sub(payload.len()),
});
};
let (header, consumed) =
borrow_decode_from_slice::<FragmentHeader, _>(header_bytes, config::standard())?;
if consumed != header_len {
return Err(DecodeError::OtherString(
"fragment header length mismatch".to_string(),
));
}
let remainder = payload.get(header_end..).unwrap_or_default();
Ok(Some((header, remainder)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_fragment_payload() {
let header = FragmentHeader::new(MessageId::new(9), FragmentIndex::new(2), true);
let payload = [1_u8, 2, 3, 4];
let encoded = encode_fragment_payload(header, &payload).expect("encode fragment");
let decoded = decode_fragment_payload(&encoded)
.expect("decode fragment")
.expect("fragment marker present");
assert_eq!(decoded.0, header);
assert_eq!(decoded.1, payload);
}
#[test]
fn decode_returns_none_for_non_fragment_payloads() {
let payload = [0_u8, 1, 2, 3];
assert!(
decode_fragment_payload(&payload)
.expect("decode ok")
.is_none()
);
}
#[test]
fn decode_returns_none_when_shorter_than_prefix_and_length() {
let payload = [b'F', b'R', b'A', b'G', 0];
assert!(
decode_fragment_payload(&payload)
.expect("decode ok")
.is_none()
);
}
#[test]
fn fragment_overhead_matches_encoded_header() {
let header = FragmentHeader::new(MessageId::new(1), FragmentIndex::zero(), true);
let encoded = encode_to_vec(header, config::standard()).expect("encode header");
let expected = FRAGMENT_MAGIC.len() + std::mem::size_of::<u16>() + encoded.len();
assert_eq!(fragment_overhead().get(), expected);
assert!(encoded.len() < u16::MAX as usize, "header must fit in u16");
}
fn assert_fragment_decode_error<F, E>(header: FragmentHeader, manipulate: F, assert_error: E)
where
F: FnOnce(Vec<u8>) -> (u16, Vec<u8>), E: FnOnce(DecodeError),
{
let encoded = encode_to_vec(header, config::standard()).expect("encode header");
let (advertised_len, header_bytes) = manipulate(encoded);
let mut payload = Vec::new();
payload.extend_from_slice(FRAGMENT_MAGIC);
let advertised_len_bytes = write_network_u16(advertised_len);
payload.extend_from_slice(&advertised_len_bytes);
payload.extend_from_slice(&header_bytes);
let err = decode_fragment_payload(&payload).expect_err("expected decode failure");
assert_error(err);
}
#[test]
fn decode_fragment_payload_rejects_truncated_header() {
let header = FragmentHeader::new(MessageId::new(2), FragmentIndex::new(1), false);
assert_fragment_decode_error(
header,
|encoded| {
let advertised_len: u16 = (encoded.len() + 4)
.try_into()
.expect("encoded header length must stay within u16");
(advertised_len, encoded)
},
|err| match err {
DecodeError::UnexpectedEnd { .. } => {}
other => panic!("expected UnexpectedEnd, got {other:?}"),
},
);
}
#[test]
fn decode_fragment_payload_rejects_missing_header_bytes() {
let advertised_len: u16 = 4;
let mut payload = Vec::new();
payload.extend_from_slice(FRAGMENT_MAGIC);
let advertised_len_bytes = write_network_u16(advertised_len);
payload.extend_from_slice(&advertised_len_bytes);
let err = decode_fragment_payload(&payload).expect_err("expected decode failure");
match err {
DecodeError::UnexpectedEnd { additional } => assert_eq!(additional, 4),
other => panic!("expected UnexpectedEnd, got {other:?}"),
}
}
#[test]
fn decode_fragment_payload_rejects_length_mismatch() {
let header = FragmentHeader::new(MessageId::new(3), FragmentIndex::new(5), true);
assert_fragment_decode_error(
header,
|mut encoded| {
encoded.extend_from_slice(&[0_u8, 1]);
let advertised_len: u16 = encoded
.len()
.try_into()
.expect("padded header length must fit in u16");
(advertised_len, encoded)
},
|err| match err {
DecodeError::OtherString(msg) => {
assert_eq!(msg, "fragment header length mismatch");
}
other => panic!("expected length mismatch error, got {other:?}"),
},
);
}
}