Skip to main content

oximedia_codec/
packet_splitter.rs

1//! Packet splitting and fragment reassembly for codec bitstreams.
2//!
3//! Provides utilities to split oversized NAL units / codec packets into
4//! MTU-bounded fragments and to reassemble them back to the original payload.
5//!
6//! The fragmentation scheme is transport-agnostic: each fragment carries a
7//! small header so that a receiver can reconstruct the original packet from
8//! an unordered set of fragments.
9
10use std::fmt;
11
12// ---------------------------------------------------------------------------
13// Error types
14// ---------------------------------------------------------------------------
15
16/// Errors that can occur during packet splitting or reassembly.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum SplitterError {
19    /// The maximum packet size is too small to hold even the fragment header.
20    MaxSizeTooSmall {
21        /// The requested max packet size.
22        max_size: usize,
23        /// The minimum required (header size).
24        min_required: usize,
25    },
26    /// A fragment header is malformed (too short or invalid field values).
27    MalformedFragmentHeader {
28        /// Byte offset of the fragment within the reassembly buffer.
29        offset: usize,
30    },
31    /// One or more fragments are missing; reassembly cannot complete.
32    MissingFragments {
33        /// Total expected fragments.
34        total: u16,
35        /// Number received.
36        received: usize,
37    },
38    /// Fragment indices are duplicated or inconsistent.
39    InconsistentFragments,
40    /// The input packet is empty.
41    EmptyPacket,
42    /// The declared total fragment count exceeds a safety limit.
43    TooManyFragments(u16),
44}
45
46impl fmt::Display for SplitterError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            Self::MaxSizeTooSmall {
50                max_size,
51                min_required,
52            } => {
53                write!(
54                    f,
55                    "max packet size {max_size} is smaller than fragment header size {min_required}"
56                )
57            }
58            Self::MalformedFragmentHeader { offset } => {
59                write!(f, "malformed fragment header at offset {offset}")
60            }
61            Self::MissingFragments { total, received } => {
62                write!(
63                    f,
64                    "missing fragments: expected {total}, received {received}"
65                )
66            }
67            Self::InconsistentFragments => write!(f, "inconsistent fragment indices"),
68            Self::EmptyPacket => write!(f, "packet is empty"),
69            Self::TooManyFragments(n) => write!(f, "too many fragments: {n}"),
70        }
71    }
72}
73
74impl std::error::Error for SplitterError {}
75
76/// Result type for packet splitter operations.
77pub type SplitterResult<T> = Result<T, SplitterError>;
78
79// ---------------------------------------------------------------------------
80// Fragment header layout
81//
82// Each fragment is prefixed with a 6-byte header:
83//
84//   Bytes 0-1: packet_id     (u16 BE) — identifies which original packet
85//   Bytes 2-3: fragment_index (u16 BE) — zero-based index of this fragment
86//   Bytes 4-5: total_fragments (u16 BE) — total number of fragments for this packet
87//
88// Followed by the fragment payload.
89// ---------------------------------------------------------------------------
90
91/// Size of the per-fragment header in bytes.
92pub const FRAGMENT_HEADER_SIZE: usize = 6;
93
94/// Safety cap on the number of fragments a single packet may be split into.
95const MAX_FRAGMENTS: u16 = 4096;
96
97/// A single fragment of a split packet.
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct Fragment {
100    /// Identifier of the source packet (shared by all fragments of the same packet).
101    pub packet_id: u16,
102    /// Zero-based index of this fragment.
103    pub fragment_index: u16,
104    /// Total number of fragments that make up the original packet.
105    pub total_fragments: u16,
106    /// Fragment payload bytes (does NOT include the header).
107    pub payload: Vec<u8>,
108}
109
110impl Fragment {
111    /// Serialise this fragment (header + payload) into a byte vector.
112    pub fn to_bytes(&self) -> Vec<u8> {
113        let mut out = Vec::with_capacity(FRAGMENT_HEADER_SIZE + self.payload.len());
114        out.extend_from_slice(&self.packet_id.to_be_bytes());
115        out.extend_from_slice(&self.fragment_index.to_be_bytes());
116        out.extend_from_slice(&self.total_fragments.to_be_bytes());
117        out.extend_from_slice(&self.payload);
118        out
119    }
120
121    /// Deserialise a fragment from raw bytes (including header).
122    pub fn from_bytes(data: &[u8]) -> SplitterResult<Self> {
123        if data.len() < FRAGMENT_HEADER_SIZE {
124            return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
125        }
126        let packet_id = u16::from_be_bytes([data[0], data[1]]);
127        let fragment_index = u16::from_be_bytes([data[2], data[3]]);
128        let total_fragments = u16::from_be_bytes([data[4], data[5]]);
129        if total_fragments == 0 {
130            return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
131        }
132        if fragment_index >= total_fragments {
133            return Err(SplitterError::InconsistentFragments);
134        }
135        Ok(Self {
136            packet_id,
137            fragment_index,
138            total_fragments,
139            payload: data[FRAGMENT_HEADER_SIZE..].to_vec(),
140        })
141    }
142}
143
144// ---------------------------------------------------------------------------
145// Packet splitter
146// ---------------------------------------------------------------------------
147
148/// Configuration for the packet splitter.
149#[derive(Debug, Clone)]
150pub struct SplitterConfig {
151    /// Maximum number of bytes per output fragment (including the fragment header).
152    pub max_packet_size: usize,
153}
154
155impl SplitterConfig {
156    /// Create a new splitter configuration.
157    pub fn new(max_packet_size: usize) -> SplitterResult<Self> {
158        if max_packet_size <= FRAGMENT_HEADER_SIZE {
159            return Err(SplitterError::MaxSizeTooSmall {
160                max_size: max_packet_size,
161                min_required: FRAGMENT_HEADER_SIZE + 1,
162            });
163        }
164        Ok(Self { max_packet_size })
165    }
166
167    /// Maximum payload bytes per fragment (after accounting for the header).
168    pub fn max_payload_per_fragment(&self) -> usize {
169        self.max_packet_size - FRAGMENT_HEADER_SIZE
170    }
171}
172
173/// Split a packet into fragments that each fit within `config.max_packet_size`.
174///
175/// If the packet already fits in a single fragment it is still wrapped in a
176/// single-element `Vec<Fragment>` for uniform handling.
177///
178/// # Parameters
179///
180/// - `packet_id`: Caller-supplied identifier that groups all fragments of a packet.
181/// - `data`: The raw packet payload (e.g., one or more NAL units in AnnexB/AVCC).
182/// - `config`: Splitter configuration.
183pub fn split_packet(
184    packet_id: u16,
185    data: &[u8],
186    config: &SplitterConfig,
187) -> SplitterResult<Vec<Fragment>> {
188    if data.is_empty() {
189        return Err(SplitterError::EmptyPacket);
190    }
191
192    let max_payload = config.max_payload_per_fragment();
193    // Integer ceiling division.
194    let total_fragments = (data.len() + max_payload - 1) / max_payload;
195
196    if total_fragments > MAX_FRAGMENTS as usize {
197        return Err(SplitterError::TooManyFragments(total_fragments as u16));
198    }
199
200    let total_u16 = total_fragments as u16;
201    let mut fragments = Vec::with_capacity(total_fragments);
202
203    for (idx, chunk) in data.chunks(max_payload).enumerate() {
204        fragments.push(Fragment {
205            packet_id,
206            fragment_index: idx as u16,
207            total_fragments: total_u16,
208            payload: chunk.to_vec(),
209        });
210    }
211
212    Ok(fragments)
213}
214
215// ---------------------------------------------------------------------------
216// Fragment reassembly
217// ---------------------------------------------------------------------------
218
219/// Reassemble a complete packet from an unordered collection of fragments.
220///
221/// All fragments must share the same `packet_id` and agree on `total_fragments`.
222/// The function tolerates duplicates (last-write-wins by fragment index).
223pub fn reassemble_fragments(fragments: &[Fragment]) -> SplitterResult<Vec<u8>> {
224    if fragments.is_empty() {
225        return Err(SplitterError::EmptyPacket);
226    }
227
228    let total = fragments[0].total_fragments;
229    let packet_id = fragments[0].packet_id;
230
231    if total == 0 {
232        return Err(SplitterError::MalformedFragmentHeader { offset: 0 });
233    }
234    if total > MAX_FRAGMENTS {
235        return Err(SplitterError::TooManyFragments(total));
236    }
237
238    // Validate consistency across all fragments.
239    for (i, frag) in fragments.iter().enumerate() {
240        if frag.packet_id != packet_id || frag.total_fragments != total {
241            return Err(SplitterError::InconsistentFragments);
242        }
243        if frag.fragment_index >= total {
244            return Err(SplitterError::MalformedFragmentHeader { offset: i });
245        }
246    }
247
248    // Build a slot array; duplicate fragments overwrite earlier entries.
249    let mut slots: Vec<Option<&[u8]>> = vec![None; total as usize];
250    for frag in fragments {
251        slots[frag.fragment_index as usize] = Some(&frag.payload);
252    }
253
254    // Check completeness.
255    let received = slots.iter().filter(|s| s.is_some()).count();
256    if received < total as usize {
257        return Err(SplitterError::MissingFragments { total, received });
258    }
259
260    let total_bytes: usize = slots.iter().filter_map(|s| *s).map(|s| s.len()).sum();
261    let mut out = Vec::with_capacity(total_bytes);
262    for slot in slots {
263        // Safety: we checked completeness above.
264        if let Some(payload) = slot {
265            out.extend_from_slice(payload);
266        }
267    }
268
269    Ok(out)
270}
271
272// ---------------------------------------------------------------------------
273// NAL unit size enforcement
274// ---------------------------------------------------------------------------
275
276/// Split a single large NAL unit so that each piece does not exceed `max_nal_size`.
277///
278/// This is a byte-level split that does **not** attempt to find valid EBSP
279/// boundaries; use it only when the codec or transport allows arbitrary slicing
280/// (e.g., RTP packetization with FU-A style fragmentation).
281///
282/// The returned slices borrow from the input slice.
283pub fn split_nal_unit(nal: &[u8], max_nal_size: usize) -> SplitterResult<Vec<&[u8]>> {
284    if nal.is_empty() {
285        return Err(SplitterError::EmptyPacket);
286    }
287    if max_nal_size == 0 {
288        return Err(SplitterError::MaxSizeTooSmall {
289            max_size: 0,
290            min_required: 1,
291        });
292    }
293    Ok(nal.chunks(max_nal_size).collect())
294}
295
296/// Enforce a maximum payload size across a list of AnnexB NAL units.
297///
298/// Each NAL unit that fits within `max_size` is kept as-is.  Larger NAL units
299/// are split by [`split_nal_unit`].  The result is a flat list of byte slices,
300/// each guaranteed to be ≤ `max_size` bytes.
301pub fn enforce_max_nal_size<'a>(
302    nals: &[&'a [u8]],
303    max_size: usize,
304) -> SplitterResult<Vec<&'a [u8]>> {
305    if max_size == 0 {
306        return Err(SplitterError::MaxSizeTooSmall {
307            max_size: 0,
308            min_required: 1,
309        });
310    }
311    let mut result = Vec::new();
312    for &nal in nals {
313        if nal.len() <= max_size {
314            result.push(nal);
315        } else {
316            let pieces = split_nal_unit(nal, max_size)?;
317            result.extend(pieces);
318        }
319    }
320    Ok(result)
321}
322
323// ---------------------------------------------------------------------------
324// Serialise / deserialise helpers for raw fragment bytes
325// ---------------------------------------------------------------------------
326
327/// Encode a list of fragments to a single flat byte buffer.
328///
329/// Each fragment is prefixed with a 2-byte big-endian length field so that the
330/// buffer can be decoded without out-of-band information.
331pub fn encode_fragment_stream(fragments: &[Fragment]) -> Vec<u8> {
332    let total_bytes: usize = fragments
333        .iter()
334        .map(|f| 2 + FRAGMENT_HEADER_SIZE + f.payload.len())
335        .sum();
336    let mut out = Vec::with_capacity(total_bytes);
337    for frag in fragments {
338        let frag_bytes = frag.to_bytes();
339        let frag_len = frag_bytes.len() as u16;
340        out.extend_from_slice(&frag_len.to_be_bytes());
341        out.extend_from_slice(&frag_bytes);
342    }
343    out
344}
345
346/// Decode a flat fragment stream previously produced by [`encode_fragment_stream`].
347pub fn decode_fragment_stream(data: &[u8]) -> SplitterResult<Vec<Fragment>> {
348    let mut fragments = Vec::new();
349    let mut offset = 0usize;
350    let len = data.len();
351
352    while offset < len {
353        if offset + 2 > len {
354            return Err(SplitterError::MalformedFragmentHeader { offset });
355        }
356        let frag_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
357        offset += 2;
358        if offset + frag_len > len {
359            return Err(SplitterError::MalformedFragmentHeader { offset });
360        }
361        let frag = Fragment::from_bytes(&data[offset..offset + frag_len])?;
362        fragments.push(frag);
363        offset += frag_len;
364    }
365
366    Ok(fragments)
367}
368
369// ---------------------------------------------------------------------------
370// Tests
371// ---------------------------------------------------------------------------
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    fn make_config(max: usize) -> SplitterConfig {
378        SplitterConfig::new(max).unwrap()
379    }
380
381    #[test]
382    fn test_split_single_fragment() {
383        let data = b"hello world";
384        let cfg = make_config(64);
385        let frags = split_packet(1, data, &cfg).unwrap();
386        assert_eq!(frags.len(), 1);
387        assert_eq!(frags[0].packet_id, 1);
388        assert_eq!(frags[0].fragment_index, 0);
389        assert_eq!(frags[0].total_fragments, 1);
390        assert_eq!(frags[0].payload, data);
391    }
392
393    #[test]
394    fn test_split_multiple_fragments() {
395        // max_packet_size = FRAGMENT_HEADER_SIZE + 4 => payload_per_frag = 4
396        let cfg = make_config(FRAGMENT_HEADER_SIZE + 4);
397        let data: Vec<u8> = (0..10).collect();
398        let frags = split_packet(42, &data, &cfg).unwrap();
399        assert_eq!(frags.len(), 3); // 4 + 4 + 2
400        assert!(frags.iter().all(|f| f.packet_id == 42));
401        assert!(frags.iter().all(|f| f.total_fragments == 3));
402        for (i, f) in frags.iter().enumerate() {
403            assert_eq!(f.fragment_index, i as u16);
404        }
405    }
406
407    #[test]
408    fn test_reassemble_ordered() {
409        let data: Vec<u8> = (0u8..100).collect();
410        let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
411        let frags = split_packet(7, &data, &cfg).unwrap();
412        let reassembled = reassemble_fragments(&frags).unwrap();
413        assert_eq!(reassembled, data);
414    }
415
416    #[test]
417    fn test_reassemble_unordered() {
418        let data: Vec<u8> = (0u8..30).collect();
419        let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
420        let mut frags = split_packet(99, &data, &cfg).unwrap();
421        // Reverse order.
422        frags.reverse();
423        let reassembled = reassemble_fragments(&frags).unwrap();
424        assert_eq!(reassembled, data);
425    }
426
427    #[test]
428    fn test_reassemble_missing_fragment_error() {
429        let data: Vec<u8> = (0u8..20).collect();
430        let cfg = make_config(FRAGMENT_HEADER_SIZE + 5);
431        let frags = split_packet(1, &data, &cfg).unwrap();
432        // Drop second fragment.
433        let partial: Vec<Fragment> = frags
434            .into_iter()
435            .filter(|f| f.fragment_index != 1)
436            .collect();
437        let err = reassemble_fragments(&partial).unwrap_err();
438        assert!(matches!(err, SplitterError::MissingFragments { .. }));
439    }
440
441    #[test]
442    fn test_fragment_serialise_deserialise() {
443        let frag = Fragment {
444            packet_id: 5,
445            fragment_index: 0,
446            total_fragments: 1,
447            payload: vec![0xDE, 0xAD, 0xBE, 0xEF],
448        };
449        let bytes = frag.to_bytes();
450        let decoded = Fragment::from_bytes(&bytes).unwrap();
451        assert_eq!(decoded, frag);
452    }
453
454    #[test]
455    fn test_encode_decode_fragment_stream() {
456        let data: Vec<u8> = (0u8..50).collect();
457        let cfg = make_config(FRAGMENT_HEADER_SIZE + 10);
458        let frags = split_packet(3, &data, &cfg).unwrap();
459        let stream = encode_fragment_stream(&frags);
460        let decoded_frags = decode_fragment_stream(&stream).unwrap();
461        let reassembled = reassemble_fragments(&decoded_frags).unwrap();
462        assert_eq!(reassembled, data);
463    }
464
465    #[test]
466    fn test_split_nal_unit() {
467        let nal = [0xAAu8; 100];
468        let pieces = split_nal_unit(&nal, 30).unwrap();
469        // 100 / 30 = 3 full + 1 partial
470        assert_eq!(pieces.len(), 4);
471        assert_eq!(pieces[0].len(), 30);
472        assert_eq!(pieces[3].len(), 10);
473    }
474
475    #[test]
476    fn test_enforce_max_nal_size() {
477        let small = [0x01u8; 10];
478        let large = [0x02u8; 50];
479        let nals: Vec<&[u8]> = vec![&small, &large];
480        let out = enforce_max_nal_size(&nals, 20).unwrap();
481        // small fits as-is; large splits into 50/20 = 3 pieces
482        assert_eq!(out.len(), 4);
483        assert_eq!(out[0].len(), 10);
484        assert!(out[1..].iter().all(|s| s.len() <= 20));
485    }
486
487    #[test]
488    fn test_config_too_small_error() {
489        let err = SplitterConfig::new(FRAGMENT_HEADER_SIZE).unwrap_err();
490        assert!(matches!(err, SplitterError::MaxSizeTooSmall { .. }));
491    }
492
493    #[test]
494    fn test_empty_packet_split_error() {
495        let cfg = make_config(64);
496        let err = split_packet(0, &[], &cfg).unwrap_err();
497        assert_eq!(err, SplitterError::EmptyPacket);
498    }
499
500    #[test]
501    fn test_inconsistent_fragments_error() {
502        let frag_a = Fragment {
503            packet_id: 1,
504            fragment_index: 0,
505            total_fragments: 2,
506            payload: vec![0x01],
507        };
508        // Different packet_id
509        let frag_b = Fragment {
510            packet_id: 2,
511            fragment_index: 1,
512            total_fragments: 2,
513            payload: vec![0x02],
514        };
515        let err = reassemble_fragments(&[frag_a, frag_b]).unwrap_err();
516        assert_eq!(err, SplitterError::InconsistentFragments);
517    }
518}