use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SplitterError {
MaxSizeTooSmall {
max_size: usize,
min_required: usize,
},
MalformedFragmentHeader {
offset: usize,
},
MissingFragments {
total: u16,
received: usize,
},
InconsistentFragments,
EmptyPacket,
TooManyFragments(u16),
}
impl fmt::Display for SplitterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MaxSizeTooSmall {
max_size,
min_required,
} => {
write!(
f,
"max packet size {max_size} is smaller than fragment header size {min_required}"
)
}
Self::MalformedFragmentHeader { offset } => {
write!(f, "malformed fragment header at offset {offset}")
}
Self::MissingFragments { total, received } => {
write!(
f,
"missing fragments: expected {total}, received {received}"
)
}
Self::InconsistentFragments => write!(f, "inconsistent fragment indices"),
Self::EmptyPacket => write!(f, "packet is empty"),
Self::TooManyFragments(n) => write!(f, "too many fragments: {n}"),
}
}
}
impl std::error::Error for SplitterError {}
pub type SplitterResult<T> = Result<T, SplitterError>;
pub const FRAGMENT_HEADER_SIZE: usize = 6;
const MAX_FRAGMENTS: u16 = 4096;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Fragment {
pub packet_id: u16,
pub fragment_index: u16,
pub total_fragments: u16,
pub payload: Vec<u8>,
}
impl Fragment {
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(FRAGMENT_HEADER_SIZE + self.payload.len());
out.extend_from_slice(&self.packet_id.to_be_bytes());
out.extend_from_slice(&self.fragment_index.to_be_bytes());
out.extend_from_slice(&self.total_fragments.to_be_bytes());
out.extend_from_slice(&self.payload);
out
}
pub fn from_bytes(data: &[u8]) -> SplitterResult<Self> {
if data.len() < FRAGMENT_HEADER_SIZE {
return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
}
let packet_id = u16::from_be_bytes([data[0], data[1]]);
let fragment_index = u16::from_be_bytes([data[2], data[3]]);
let total_fragments = u16::from_be_bytes([data[4], data[5]]);
if total_fragments == 0 {
return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
}
if fragment_index >= total_fragments {
return Err(SplitterError::InconsistentFragments);
}
Ok(Self {
packet_id,
fragment_index,
total_fragments,
payload: data[FRAGMENT_HEADER_SIZE..].to_vec(),
})
}
}
#[derive(Debug, Clone)]
pub struct SplitterConfig {
pub max_packet_size: usize,
}
impl SplitterConfig {
pub fn new(max_packet_size: usize) -> SplitterResult<Self> {
if max_packet_size <= FRAGMENT_HEADER_SIZE {
return Err(SplitterError::MaxSizeTooSmall {
max_size: max_packet_size,
min_required: FRAGMENT_HEADER_SIZE + 1,
});
}
Ok(Self { max_packet_size })
}
pub fn max_payload_per_fragment(&self) -> usize {
self.max_packet_size - FRAGMENT_HEADER_SIZE
}
}
pub fn split_packet(
packet_id: u16,
data: &[u8],
config: &SplitterConfig,
) -> SplitterResult<Vec<Fragment>> {
if data.is_empty() {
return Err(SplitterError::EmptyPacket);
}
let max_payload = config.max_payload_per_fragment();
let total_fragments = (data.len() + max_payload - 1) / max_payload;
if total_fragments > MAX_FRAGMENTS as usize {
return Err(SplitterError::TooManyFragments(total_fragments as u16));
}
let total_u16 = total_fragments as u16;
let mut fragments = Vec::with_capacity(total_fragments);
for (idx, chunk) in data.chunks(max_payload).enumerate() {
fragments.push(Fragment {
packet_id,
fragment_index: idx as u16,
total_fragments: total_u16,
payload: chunk.to_vec(),
});
}
Ok(fragments)
}
pub fn reassemble_fragments(fragments: &[Fragment]) -> SplitterResult<Vec<u8>> {
if fragments.is_empty() {
return Err(SplitterError::EmptyPacket);
}
let total = fragments[0].total_fragments;
let packet_id = fragments[0].packet_id;
if total == 0 {
return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
}
if total > MAX_FRAGMENTS {
return Err(SplitterError::TooManyFragments(total));
}
for (i, frag) in fragments.iter().enumerate() {
if frag.packet_id != packet_id || frag.total_fragments != total {
return Err(SplitterError::InconsistentFragments);
}
if frag.fragment_index >= total {
return Err(SplitterError::MalformedFragmentHeader { offset: i });
}
}
let mut slots: Vec<Option<&[u8]>> = vec![None; total as usize];
for frag in fragments {
slots[frag.fragment_index as usize] = Some(&frag.payload);
}
let received = slots.iter().filter(|s| s.is_some()).count();
if received < total as usize {
return Err(SplitterError::MissingFragments { total, received });
}
let total_bytes: usize = slots.iter().filter_map(|s| *s).map(|s| s.len()).sum();
let mut out = Vec::with_capacity(total_bytes);
for slot in slots {
if let Some(payload) = slot {
out.extend_from_slice(payload);
}
}
Ok(out)
}
pub fn split_nal_unit(nal: &[u8], max_nal_size: usize) -> SplitterResult<Vec<&[u8]>> {
if nal.is_empty() {
return Err(SplitterError::EmptyPacket);
}
if max_nal_size == 0 {
return Err(SplitterError::MaxSizeTooSmall {
max_size: 0,
min_required: 1,
});
}
Ok(nal.chunks(max_nal_size).collect())
}
pub fn enforce_max_nal_size<'a>(
nals: &[&'a [u8]],
max_size: usize,
) -> SplitterResult<Vec<&'a [u8]>> {
if max_size == 0 {
return Err(SplitterError::MaxSizeTooSmall {
max_size: 0,
min_required: 1,
});
}
let mut result = Vec::new();
for &nal in nals {
if nal.len() <= max_size {
result.push(nal);
} else {
let pieces = split_nal_unit(nal, max_size)?;
result.extend(pieces);
}
}
Ok(result)
}
pub fn encode_fragment_stream(fragments: &[Fragment]) -> Vec<u8> {
let total_bytes: usize = fragments
.iter()
.map(|f| 2 + FRAGMENT_HEADER_SIZE + f.payload.len())
.sum();
let mut out = Vec::with_capacity(total_bytes);
for frag in fragments {
let frag_bytes = frag.to_bytes();
let frag_len = frag_bytes.len() as u16;
out.extend_from_slice(&frag_len.to_be_bytes());
out.extend_from_slice(&frag_bytes);
}
out
}
pub fn decode_fragment_stream(data: &[u8]) -> SplitterResult<Vec<Fragment>> {
let mut fragments = Vec::new();
let mut offset = 0usize;
let len = data.len();
while offset < len {
if offset + 2 > len {
return Err(SplitterError::MalformedFragmentHeader { offset });
}
let frag_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
offset += 2;
if offset + frag_len > len {
return Err(SplitterError::MalformedFragmentHeader { offset });
}
let frag = Fragment::from_bytes(&data[offset..offset + frag_len])?;
fragments.push(frag);
offset += frag_len;
}
Ok(fragments)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(max: usize) -> SplitterConfig {
SplitterConfig::new(max).unwrap()
}
#[test]
fn test_split_single_fragment() {
let data = b"hello world";
let cfg = make_config(64);
let frags = split_packet(1, data, &cfg).unwrap();
assert_eq!(frags.len(), 1);
assert_eq!(frags[0].packet_id, 1);
assert_eq!(frags[0].fragment_index, 0);
assert_eq!(frags[0].total_fragments, 1);
assert_eq!(frags[0].payload, data);
}
#[test]
fn test_split_multiple_fragments() {
let cfg = make_config(FRAGMENT_HEADER_SIZE + 4);
let data: Vec<u8> = (0..10).collect();
let frags = split_packet(42, &data, &cfg).unwrap();
assert_eq!(frags.len(), 3); assert!(frags.iter().all(|f| f.packet_id == 42));
assert!(frags.iter().all(|f| f.total_fragments == 3));
for (i, f) in frags.iter().enumerate() {
assert_eq!(f.fragment_index, i as u16);
}
}
#[test]
fn test_reassemble_ordered() {
let data: Vec<u8> = (0u8..100).collect();
let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
let frags = split_packet(7, &data, &cfg).unwrap();
let reassembled = reassemble_fragments(&frags).unwrap();
assert_eq!(reassembled, data);
}
#[test]
fn test_reassemble_unordered() {
let data: Vec<u8> = (0u8..30).collect();
let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
let mut frags = split_packet(99, &data, &cfg).unwrap();
frags.reverse();
let reassembled = reassemble_fragments(&frags).unwrap();
assert_eq!(reassembled, data);
}
#[test]
fn test_reassemble_missing_fragment_error() {
let data: Vec<u8> = (0u8..20).collect();
let cfg = make_config(FRAGMENT_HEADER_SIZE + 5);
let frags = split_packet(1, &data, &cfg).unwrap();
let partial: Vec<Fragment> = frags
.into_iter()
.filter(|f| f.fragment_index != 1)
.collect();
let err = reassemble_fragments(&partial).unwrap_err();
assert!(matches!(err, SplitterError::MissingFragments { .. }));
}
#[test]
fn test_fragment_serialise_deserialise() {
let frag = Fragment {
packet_id: 5,
fragment_index: 0,
total_fragments: 1,
payload: vec![0xDE, 0xAD, 0xBE, 0xEF],
};
let bytes = frag.to_bytes();
let decoded = Fragment::from_bytes(&bytes).unwrap();
assert_eq!(decoded, frag);
}
#[test]
fn test_encode_decode_fragment_stream() {
let data: Vec<u8> = (0u8..50).collect();
let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
let frags = split_packet(3, &data, &cfg).unwrap();
let stream = encode_fragment_stream(&frags);
let decoded_frags = decode_fragment_stream(&stream).unwrap();
let reassembled = reassemble_fragments(&decoded_frags).unwrap();
assert_eq!(reassembled, data);
}
#[test]
fn test_split_nal_unit() {
let nal = [0xAAu8; 100];
let pieces = split_nal_unit(&nal, 30).unwrap();
assert_eq!(pieces.len(), 4);
assert_eq!(pieces[0].len(), 30);
assert_eq!(pieces[3].len(), 10);
}
#[test]
fn test_enforce_max_nal_size() {
let small = [0x01u8; 10];
let large = [0x02u8; 50];
let nals: Vec<&[u8]> = vec![&small, &large];
let out = enforce_max_nal_size(&nals, 20).unwrap();
assert_eq!(out.len(), 4);
assert_eq!(out[0].len(), 10);
assert!(out[1..].iter().all(|s| s.len() <= 20));
}
#[test]
fn test_config_too_small_error() {
let err = SplitterConfig::new(FRAGMENT_HEADER_SIZE).unwrap_err();
assert!(matches!(err, SplitterError::MaxSizeTooSmall { .. }));
}
#[test]
fn test_empty_packet_split_error() {
let cfg = make_config(64);
let err = split_packet(0, &[], &cfg).unwrap_err();
assert_eq!(err, SplitterError::EmptyPacket);
}
#[test]
fn test_inconsistent_fragments_error() {
let frag_a = Fragment {
packet_id: 1,
fragment_index: 0,
total_fragments: 2,
payload: vec![0x01],
};
let frag_b = Fragment {
packet_id: 2,
fragment_index: 1,
total_fragments: 2,
payload: vec![0x02],
};
let err = reassemble_fragments(&[frag_a, frag_b]).unwrap_err();
assert_eq!(err, SplitterError::InconsistentFragments);
}
}