Skip to main content

mp4_atom/moof/traf/
trun.rs

1use crate::*;
2
3ext! {
4    name: Trun,
5    versions: [0, 1],
6    flags: {
7        data_offset = 0,
8        first_sample_flags = 2,
9        sample_duration = 8,
10        sample_size = 9,
11        sample_flags = 10,
12        sample_cts = 11,
13    }
14}
15
16#[derive(Debug, Clone, PartialEq, Eq, Default)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct Trun {
19    pub data_offset: Option<i32>,
20    pub entries: Vec<TrunEntry>,
21}
22
23/// A single sample entry in a trun box.
24///
25/// `None` fields mean the value was not present in the per-sample trun data.
26/// After decode, callers should resolve `None` against tfhd defaults
27/// (`default_sample_duration`, `default_sample_size`, `default_sample_flags`)
28/// before using the values. The encoder backfills unresolved `None` with `0`
29/// as a last resort to avoid silently dropping fields that other entries set.
30#[derive(Debug, Clone, PartialEq, Eq, Default)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct TrunEntry {
33    pub duration: Option<u32>,
34    pub size: Option<u32>,
35    pub flags: Option<u32>,
36    pub cts: Option<i32>,
37}
38
39impl AtomExt for Trun {
40    const KIND_EXT: FourCC = FourCC::new(b"trun");
41
42    type Ext = TrunExt;
43
44    fn decode_body_ext<B: Buf>(buf: &mut B, ext: TrunExt) -> Result<Self> {
45        let sample_count = u32::decode(buf)?;
46        let data_offset = match ext.data_offset {
47            true => i32::decode(buf)?.into(),
48            false => None,
49        };
50
51        let mut first_sample_flags = match ext.first_sample_flags {
52            true => u32::decode(buf)?.into(),
53            false => None,
54        };
55
56        // Avoid a memory exhaustion attack.
57        // If none of the flags are set, then the trun entry has zero size, then we'll allocate `sample_count` entries.
58        // Rather than make the API worse, we just limit the number of (useless?) identical entries to 4096.
59        if !(ext.sample_duration
60            || ext.sample_size
61            || ext.sample_flags
62            || ext.sample_cts
63            || sample_count <= 4096)
64        {
65            return Err(Error::OutOfMemory);
66        }
67
68        let mut entries = Vec::with_capacity(sample_count.min(4096) as _);
69
70        for _ in 0..sample_count {
71            let duration = match ext.sample_duration {
72                true => u32::decode(buf)?.into(),
73                false => None,
74            };
75            let size = match ext.sample_size {
76                true => u32::decode(buf)?.into(),
77                false => None,
78            };
79            let sample_flags = match first_sample_flags.take() {
80                Some(flags) => Some(flags),
81                None => match ext.sample_flags {
82                    true => u32::decode(buf)?.into(),
83                    false => None,
84                },
85            };
86            let cts = match ext.sample_cts {
87                true => i32::decode(buf)?.into(),
88                false => None,
89            };
90
91            entries.push(TrunEntry {
92                duration,
93                size,
94                flags: sample_flags,
95                cts,
96            });
97        }
98
99        Ok(Trun {
100            data_offset,
101            entries,
102        })
103    }
104
105    fn encode_body_ext<B: BufMut>(&self, buf: &mut B) -> Result<TrunExt> {
106        let any_flags = self.entries.iter().any(|s| s.flags.is_some());
107        let first_only_flags = any_flags
108            && self.entries.first().is_some_and(|s| s.flags.is_some())
109            && self.entries.iter().skip(1).all(|s| s.flags.is_none());
110
111        // Use per-sample flags when any entry has flags and it's not the first-only pattern.
112        // None entries are backfilled with 0 to avoid silently dropping flags.
113        let sample_flags = any_flags && !first_only_flags;
114
115        let ext = TrunExt {
116            version: TrunVersion::V1,
117            data_offset: self.data_offset.is_some(),
118            first_sample_flags: first_only_flags,
119
120            // Use the field if any entry has it set. None entries are backfilled
121            // with 0 during encoding to avoid silently dropping fields on
122            // decode→encode roundtrips (entries that inherited defaults from tfhd
123            // have None after decode).
124            sample_duration: self.entries.iter().any(|s| s.duration.is_some()),
125            sample_size: self.entries.iter().any(|s| s.size.is_some()),
126            sample_flags,
127            sample_cts: self.entries.iter().any(|s| s.cts.is_some()),
128        };
129
130        (self.entries.len() as u32).encode(buf)?;
131
132        self.data_offset.encode(buf)?;
133        if ext.first_sample_flags {
134            self.entries[0].flags.unwrap().encode(buf)?;
135        }
136
137        for entry in &self.entries {
138            if ext.sample_duration {
139                Some(Some(entry.duration.unwrap_or(0))).encode(buf)?;
140            }
141            if ext.sample_size {
142                Some(Some(entry.size.unwrap_or(0))).encode(buf)?;
143            }
144            if ext.sample_flags {
145                Some(Some(entry.flags.unwrap_or(0))).encode(buf)?;
146            }
147            if ext.sample_cts {
148                Some(Some(entry.cts.unwrap_or(0))).encode(buf)?;
149            }
150        }
151
152        Ok(ext)
153    }
154}
155
156#[cfg(test)]
157mod test {
158    use super::*;
159
160    /// Verify that first_sample_flags survives encode→decode roundtrip.
161    ///
162    /// ffmpeg commonly writes trun boxes where only the first entry has flags
163    /// (via first_sample_flags) and the rest inherit default_sample_flags from
164    /// tfhd. After decode, entry[0].flags = Some(keyframe), entries[1..N].flags = None.
165    /// The encoder must preserve this by emitting first_sample_flags.
166    #[test]
167    fn first_sample_flags_roundtrip() {
168        let trun = Trun {
169            data_offset: Some(100),
170            entries: vec![
171                TrunEntry {
172                    duration: Some(512),
173                    size: Some(1000),
174                    flags: Some(0x02000000), // keyframe (sample_depends_on=2)
175                    cts: None,
176                },
177                TrunEntry {
178                    duration: Some(512),
179                    size: Some(200),
180                    flags: None, // inherits default_sample_flags from tfhd
181                    cts: None,
182                },
183                TrunEntry {
184                    duration: Some(512),
185                    size: Some(200),
186                    flags: None,
187                    cts: None,
188                },
189            ],
190        };
191
192        let mut buf = Vec::new();
193        trun.encode(&mut buf).expect("encode");
194
195        let decoded = Trun::decode(&mut &buf[..]).expect("decode");
196
197        // entry[0] must have the keyframe flags from first_sample_flags
198        assert_eq!(decoded.entries[0].flags, Some(0x02000000));
199        // entries[1..N] must have None (they use default_sample_flags from tfhd)
200        assert_eq!(decoded.entries[1].flags, None);
201        assert_eq!(decoded.entries[2].flags, None);
202        assert_eq!(decoded.data_offset, Some(100));
203        assert_eq!(decoded.entries.len(), 3);
204    }
205
206    /// When multiple entries have explicit flags (not just the first),
207    /// the encoder must use per-sample flags, not first_sample_flags.
208    #[test]
209    fn mixed_flags_uses_per_sample() {
210        let trun = Trun {
211            data_offset: Some(100),
212            entries: vec![
213                TrunEntry {
214                    duration: Some(512),
215                    size: Some(1000),
216                    flags: Some(0x02000000), // keyframe
217                    cts: None,
218                },
219                TrunEntry {
220                    duration: Some(512),
221                    size: Some(200),
222                    flags: Some(0x01010000), // non-keyframe (explicit)
223                    cts: None,
224                },
225                TrunEntry {
226                    duration: Some(512),
227                    size: Some(200),
228                    flags: None, // no flags
229                    cts: None,
230                },
231            ],
232        };
233
234        let mut buf = Vec::new();
235        trun.encode(&mut buf).expect("encode");
236
237        let decoded = Trun::decode(&mut &buf[..]).expect("decode");
238
239        // Mixed Some/None: encoder backfills None with 0 and emits per-sample flags.
240        assert_eq!(decoded.entries[0].flags, Some(0x02000000));
241        assert_eq!(decoded.entries[1].flags, Some(0x01010000));
242        assert_eq!(decoded.entries[2].flags, Some(0)); // was None, backfilled to 0
243    }
244
245    /// When all entries have explicit flags, per-sample flags are used.
246    #[test]
247    fn all_flags_roundtrip() {
248        let trun = Trun {
249            data_offset: Some(100),
250            entries: vec![
251                TrunEntry {
252                    duration: Some(512),
253                    size: Some(1000),
254                    flags: Some(0x02000000),
255                    cts: None,
256                },
257                TrunEntry {
258                    duration: Some(512),
259                    size: Some(200),
260                    flags: Some(0x01010000),
261                    cts: None,
262                },
263            ],
264        };
265
266        let mut buf = Vec::new();
267        trun.encode(&mut buf).expect("encode");
268
269        let decoded = Trun::decode(&mut &buf[..]).expect("decode");
270
271        assert_eq!(decoded.entries[0].flags, Some(0x02000000));
272        assert_eq!(decoded.entries[1].flags, Some(0x01010000));
273    }
274
275    /// Entries with None duration (inherited from tfhd default_sample_duration)
276    /// must not cause the duration field to be dropped entirely on re-encode.
277    #[test]
278    fn duration_backfill_roundtrip() {
279        let trun = Trun {
280            data_offset: Some(100),
281            entries: vec![
282                TrunEntry {
283                    duration: Some(512),
284                    size: Some(1000),
285                    flags: Some(0x02000000),
286                    cts: None,
287                },
288                TrunEntry {
289                    duration: None, // inherited from tfhd
290                    size: Some(200),
291                    flags: None,
292                    cts: None,
293                },
294            ],
295        };
296
297        let mut buf = Vec::new();
298        trun.encode(&mut buf).expect("encode");
299
300        let decoded = Trun::decode(&mut &buf[..]).expect("decode");
301
302        assert_eq!(decoded.entries[0].duration, Some(512));
303        // None backfilled to 0 during encode
304        assert_eq!(decoded.entries[1].duration, Some(0));
305    }
306
307    /// Entries with None size (inherited from tfhd default_sample_size)
308    /// must not cause the size field to be dropped entirely on re-encode.
309    #[test]
310    fn size_backfill_roundtrip() {
311        let trun = Trun {
312            data_offset: Some(100),
313            entries: vec![
314                TrunEntry {
315                    duration: Some(512),
316                    size: Some(1000),
317                    flags: None,
318                    cts: None,
319                },
320                TrunEntry {
321                    duration: Some(512),
322                    size: None, // inherited from tfhd
323                    flags: None,
324                    cts: None,
325                },
326            ],
327        };
328
329        let mut buf = Vec::new();
330        trun.encode(&mut buf).expect("encode");
331
332        let decoded = Trun::decode(&mut &buf[..]).expect("decode");
333
334        assert_eq!(decoded.entries[0].size, Some(1000));
335        // None backfilled to 0 during encode
336        assert_eq!(decoded.entries[1].size, Some(0));
337    }
338
339    /// When all entries have None for a field, the flag is not set and
340    /// the field is omitted entirely (no unnecessary bytes written).
341    #[test]
342    fn all_none_fields_omitted() {
343        let trun = Trun {
344            data_offset: Some(100),
345            entries: vec![
346                TrunEntry {
347                    duration: None,
348                    size: None,
349                    flags: None,
350                    cts: None,
351                },
352                TrunEntry {
353                    duration: None,
354                    size: None,
355                    flags: None,
356                    cts: None,
357                },
358            ],
359        };
360
361        let mut buf = Vec::new();
362        trun.encode(&mut buf).expect("encode");
363
364        let decoded = Trun::decode(&mut &buf[..]).expect("decode");
365
366        assert_eq!(decoded.entries[0].duration, None);
367        assert_eq!(decoded.entries[0].size, None);
368        assert_eq!(decoded.entries[0].flags, None);
369        assert_eq!(decoded.entries[0].cts, None);
370    }
371}