1use std::fmt;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum SplitterError {
19 MaxSizeTooSmall {
21 max_size: usize,
23 min_required: usize,
25 },
26 MalformedFragmentHeader {
28 offset: usize,
30 },
31 MissingFragments {
33 total: u16,
35 received: usize,
37 },
38 InconsistentFragments,
40 EmptyPacket,
42 TooManyFragments(u16),
44}
45
46impl fmt::Display for SplitterError {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 Self::MaxSizeTooSmall {
50 max_size,
51 min_required,
52 } => {
53 write!(
54 f,
55 "max packet size {max_size} is smaller than fragment header size {min_required}"
56 )
57 }
58 Self::MalformedFragmentHeader { offset } => {
59 write!(f, "malformed fragment header at offset {offset}")
60 }
61 Self::MissingFragments { total, received } => {
62 write!(
63 f,
64 "missing fragments: expected {total}, received {received}"
65 )
66 }
67 Self::InconsistentFragments => write!(f, "inconsistent fragment indices"),
68 Self::EmptyPacket => write!(f, "packet is empty"),
69 Self::TooManyFragments(n) => write!(f, "too many fragments: {n}"),
70 }
71 }
72}
73
74impl std::error::Error for SplitterError {}
75
76pub type SplitterResult<T> = Result<T, SplitterError>;
78
79pub const FRAGMENT_HEADER_SIZE: usize = 6;
93
94const MAX_FRAGMENTS: u16 = 4096;
96
97#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct Fragment {
100 pub packet_id: u16,
102 pub fragment_index: u16,
104 pub total_fragments: u16,
106 pub payload: Vec<u8>,
108}
109
110impl Fragment {
111 pub fn to_bytes(&self) -> Vec<u8> {
113 let mut out = Vec::with_capacity(FRAGMENT_HEADER_SIZE + self.payload.len());
114 out.extend_from_slice(&self.packet_id.to_be_bytes());
115 out.extend_from_slice(&self.fragment_index.to_be_bytes());
116 out.extend_from_slice(&self.total_fragments.to_be_bytes());
117 out.extend_from_slice(&self.payload);
118 out
119 }
120
121 pub fn from_bytes(data: &[u8]) -> SplitterResult<Self> {
123 if data.len() < FRAGMENT_HEADER_SIZE {
124 return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
125 }
126 let packet_id = u16::from_be_bytes([data[0], data[1]]);
127 let fragment_index = u16::from_be_bytes([data[2], data[3]]);
128 let total_fragments = u16::from_be_bytes([data[4], data[5]]);
129 if total_fragments == 0 {
130 return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
131 }
132 if fragment_index >= total_fragments {
133 return Err(SplitterError::InconsistentFragments);
134 }
135 Ok(Self {
136 packet_id,
137 fragment_index,
138 total_fragments,
139 payload: data[FRAGMENT_HEADER_SIZE..].to_vec(),
140 })
141 }
142}
143
144#[derive(Debug, Clone)]
150pub struct SplitterConfig {
151 pub max_packet_size: usize,
153}
154
155impl SplitterConfig {
156 pub fn new(max_packet_size: usize) -> SplitterResult<Self> {
158 if max_packet_size <= FRAGMENT_HEADER_SIZE {
159 return Err(SplitterError::MaxSizeTooSmall {
160 max_size: max_packet_size,
161 min_required: FRAGMENT_HEADER_SIZE + 1,
162 });
163 }
164 Ok(Self { max_packet_size })
165 }
166
167 pub fn max_payload_per_fragment(&self) -> usize {
169 self.max_packet_size - FRAGMENT_HEADER_SIZE
170 }
171}
172
173pub fn split_packet(
184 packet_id: u16,
185 data: &[u8],
186 config: &SplitterConfig,
187) -> SplitterResult<Vec<Fragment>> {
188 if data.is_empty() {
189 return Err(SplitterError::EmptyPacket);
190 }
191
192 let max_payload = config.max_payload_per_fragment();
193 let total_fragments = (data.len() + max_payload - 1) / max_payload;
195
196 if total_fragments > MAX_FRAGMENTS as usize {
197 return Err(SplitterError::TooManyFragments(total_fragments as u16));
198 }
199
200 let total_u16 = total_fragments as u16;
201 let mut fragments = Vec::with_capacity(total_fragments);
202
203 for (idx, chunk) in data.chunks(max_payload).enumerate() {
204 fragments.push(Fragment {
205 packet_id,
206 fragment_index: idx as u16,
207 total_fragments: total_u16,
208 payload: chunk.to_vec(),
209 });
210 }
211
212 Ok(fragments)
213}
214
215pub fn reassemble_fragments(fragments: &[Fragment]) -> SplitterResult<Vec<u8>> {
224 if fragments.is_empty() {
225 return Err(SplitterError::EmptyPacket);
226 }
227
228 let total = fragments[0].total_fragments;
229 let packet_id = fragments[0].packet_id;
230
231 if total == 0 {
232 return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
233 }
234 if total > MAX_FRAGMENTS {
235 return Err(SplitterError::TooManyFragments(total));
236 }
237
238 for (i, frag) in fragments.iter().enumerate() {
240 if frag.packet_id != packet_id || frag.total_fragments != total {
241 return Err(SplitterError::InconsistentFragments);
242 }
243 if frag.fragment_index >= total {
244 return Err(SplitterError::MalformedFragmentHeader { offset: i });
245 }
246 }
247
248 let mut slots: Vec<Option<&[u8]>> = vec![None; total as usize];
250 for frag in fragments {
251 slots[frag.fragment_index as usize] = Some(&frag.payload);
252 }
253
254 let received = slots.iter().filter(|s| s.is_some()).count();
256 if received < total as usize {
257 return Err(SplitterError::MissingFragments { total, received });
258 }
259
260 let total_bytes: usize = slots.iter().filter_map(|s| *s).map(|s| s.len()).sum();
261 let mut out = Vec::with_capacity(total_bytes);
262 for slot in slots {
263 if let Some(payload) = slot {
265 out.extend_from_slice(payload);
266 }
267 }
268
269 Ok(out)
270}
271
272pub fn split_nal_unit(nal: &[u8], max_nal_size: usize) -> SplitterResult<Vec<&[u8]>> {
284 if nal.is_empty() {
285 return Err(SplitterError::EmptyPacket);
286 }
287 if max_nal_size == 0 {
288 return Err(SplitterError::MaxSizeTooSmall {
289 max_size: 0,
290 min_required: 1,
291 });
292 }
293 Ok(nal.chunks(max_nal_size).collect())
294}
295
296pub fn enforce_max_nal_size<'a>(
302 nals: &[&'a [u8]],
303 max_size: usize,
304) -> SplitterResult<Vec<&'a [u8]>> {
305 if max_size == 0 {
306 return Err(SplitterError::MaxSizeTooSmall {
307 max_size: 0,
308 min_required: 1,
309 });
310 }
311 let mut result = Vec::new();
312 for &nal in nals {
313 if nal.len() <= max_size {
314 result.push(nal);
315 } else {
316 let pieces = split_nal_unit(nal, max_size)?;
317 result.extend(pieces);
318 }
319 }
320 Ok(result)
321}
322
323pub fn encode_fragment_stream(fragments: &[Fragment]) -> Vec<u8> {
332 let total_bytes: usize = fragments
333 .iter()
334 .map(|f| 2 + FRAGMENT_HEADER_SIZE + f.payload.len())
335 .sum();
336 let mut out = Vec::with_capacity(total_bytes);
337 for frag in fragments {
338 let frag_bytes = frag.to_bytes();
339 let frag_len = frag_bytes.len() as u16;
340 out.extend_from_slice(&frag_len.to_be_bytes());
341 out.extend_from_slice(&frag_bytes);
342 }
343 out
344}
345
346pub fn decode_fragment_stream(data: &[u8]) -> SplitterResult<Vec<Fragment>> {
348 let mut fragments = Vec::new();
349 let mut offset = 0usize;
350 let len = data.len();
351
352 while offset < len {
353 if offset + 2 > len {
354 return Err(SplitterError::MalformedFragmentHeader { offset });
355 }
356 let frag_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
357 offset += 2;
358 if offset + frag_len > len {
359 return Err(SplitterError::MalformedFragmentHeader { offset });
360 }
361 let frag = Fragment::from_bytes(&data[offset..offset + frag_len])?;
362 fragments.push(frag);
363 offset += frag_len;
364 }
365
366 Ok(fragments)
367}
368
369#[cfg(test)]
374mod tests {
375 use super::*;
376
377 fn make_config(max: usize) -> SplitterConfig {
378 SplitterConfig::new(max).unwrap()
379 }
380
381 #[test]
382 fn test_split_single_fragment() {
383 let data = b"hello world";
384 let cfg = make_config(64);
385 let frags = split_packet(1, data, &cfg).unwrap();
386 assert_eq!(frags.len(), 1);
387 assert_eq!(frags[0].packet_id, 1);
388 assert_eq!(frags[0].fragment_index, 0);
389 assert_eq!(frags[0].total_fragments, 1);
390 assert_eq!(frags[0].payload, data);
391 }
392
393 #[test]
394 fn test_split_multiple_fragments() {
395 let cfg = make_config(FRAGMENT_HEADER_SIZE + 4);
397 let data: Vec<u8> = (0..10).collect();
398 let frags = split_packet(42, &data, &cfg).unwrap();
399 assert_eq!(frags.len(), 3); assert!(frags.iter().all(|f| f.packet_id == 42));
401 assert!(frags.iter().all(|f| f.total_fragments == 3));
402 for (i, f) in frags.iter().enumerate() {
403 assert_eq!(f.fragment_index, i as u16);
404 }
405 }
406
407 #[test]
408 fn test_reassemble_ordered() {
409 let data: Vec<u8> = (0u8..100).collect();
410 let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
411 let frags = split_packet(7, &data, &cfg).unwrap();
412 let reassembled = reassemble_fragments(&frags).unwrap();
413 assert_eq!(reassembled, data);
414 }
415
416 #[test]
417 fn test_reassemble_unordered() {
418 let data: Vec<u8> = (0u8..30).collect();
419 let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
420 let mut frags = split_packet(99, &data, &cfg).unwrap();
421 frags.reverse();
423 let reassembled = reassemble_fragments(&frags).unwrap();
424 assert_eq!(reassembled, data);
425 }
426
427 #[test]
428 fn test_reassemble_missing_fragment_error() {
429 let data: Vec<u8> = (0u8..20).collect();
430 let cfg = make_config(FRAGMENT_HEADER_SIZE + 5);
431 let frags = split_packet(1, &data, &cfg).unwrap();
432 let partial: Vec<Fragment> = frags
434 .into_iter()
435 .filter(|f| f.fragment_index != 1)
436 .collect();
437 let err = reassemble_fragments(&partial).unwrap_err();
438 assert!(matches!(err, SplitterError::MissingFragments { .. }));
439 }
440
441 #[test]
442 fn test_fragment_serialise_deserialise() {
443 let frag = Fragment {
444 packet_id: 5,
445 fragment_index: 0,
446 total_fragments: 1,
447 payload: vec![0xDE, 0xAD, 0xBE, 0xEF],
448 };
449 let bytes = frag.to_bytes();
450 let decoded = Fragment::from_bytes(&bytes).unwrap();
451 assert_eq!(decoded, frag);
452 }
453
454 #[test]
455 fn test_encode_decode_fragment_stream() {
456 let data: Vec<u8> = (0u8..50).collect();
457 let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
458 let frags = split_packet(3, &data, &cfg).unwrap();
459 let stream = encode_fragment_stream(&frags);
460 let decoded_frags = decode_fragment_stream(&stream).unwrap();
461 let reassembled = reassemble_fragments(&decoded_frags).unwrap();
462 assert_eq!(reassembled, data);
463 }
464
465 #[test]
466 fn test_split_nal_unit() {
467 let nal = [0xAAu8; 100];
468 let pieces = split_nal_unit(&nal, 30).unwrap();
469 assert_eq!(pieces.len(), 4);
471 assert_eq!(pieces[0].len(), 30);
472 assert_eq!(pieces[3].len(), 10);
473 }
474
475 #[test]
476 fn test_enforce_max_nal_size() {
477 let small = [0x01u8; 10];
478 let large = [0x02u8; 50];
479 let nals: Vec<&[u8]> = vec![&small, &large];
480 let out = enforce_max_nal_size(&nals, 20).unwrap();
481 assert_eq!(out.len(), 4);
483 assert_eq!(out[0].len(), 10);
484 assert!(out[1..].iter().all(|s| s.len() <= 20));
485 }
486
487 #[test]
488 fn test_config_too_small_error() {
489 let err = SplitterConfig::new(FRAGMENT_HEADER_SIZE).unwrap_err();
490 assert!(matches!(err, SplitterError::MaxSizeTooSmall { .. }));
491 }
492
493 #[test]
494 fn test_empty_packet_split_error() {
495 let cfg = make_config(64);
496 let err = split_packet(0, &[], &cfg).unwrap_err();
497 assert_eq!(err, SplitterError::EmptyPacket);
498 }
499
500 #[test]
501 fn test_inconsistent_fragments_error() {
502 let frag_a = Fragment {
503 packet_id: 1,
504 fragment_index: 0,
505 total_fragments: 2,
506 payload: vec![0x01],
507 };
508 let frag_b = Fragment {
510 packet_id: 2,
511 fragment_index: 1,
512 total_fragments: 2,
513 payload: vec![0x02],
514 };
515 let err = reassemble_fragments(&[frag_a, frag_b]).unwrap_err();
516 assert_eq!(err, SplitterError::InconsistentFragments);
517 }
518}