Skip to main content

mp4_edit/atom/leaf/stsd/
extension.rs

1use std::fmt;
2
3pub use audio_specific_config::AudioSpecificConfig;
4
5use crate::{
6    atom::util::{DebugList, DebugUpperHex},
7    FourCC,
8};
9
10pub mod audio_specific_config;
11
12#[derive(Clone, PartialEq)]
13pub enum StsdExtension {
14    Esds(EsdsExtension),
15    Btrt(BtrtExtension),
16    Unknown { fourcc: FourCC, data: Vec<u8> },
17}
18
19impl fmt::Debug for StsdExtension {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            StsdExtension::Btrt(btrt) => fmt::Debug::fmt(btrt, f),
23            StsdExtension::Esds(esds) => fmt::Debug::fmt(esds, f),
24            StsdExtension::Unknown { fourcc, data } => f
25                .debug_struct("Unknown")
26                .field("fourcc", &fourcc)
27                .field("data", &DebugList::new(data.iter().map(DebugUpperHex), 10))
28                .finish(),
29        }
30    }
31}
32
33trait Descriptor {
34    const TAG: u8;
35}
36
37#[derive(Default, Debug, Clone, PartialEq)]
38pub struct EsdsExtension {
39    pub version: u8,
40    pub flags: [u8; 3],
41    pub es_descriptor: EsDescriptor,
42}
43
44impl EsdsExtension {
45    const TYPE: FourCC = FourCC::new(b"esds");
46}
47
48#[derive(Default, Debug, Clone, PartialEq)]
49pub struct EsDescriptor {
50    pub es_id: u16,
51    pub depends_on_es_id: Option<u16>,
52    pub url: Option<String>,
53    pub ocr_es_id: Option<u16>,
54    pub stream_priority: u8,
55    pub decoder_config_descriptor: Option<DecoderConfigDescriptor>,
56    pub sl_config_descriptor: Option<SlConfigDescriptor>,
57}
58
59impl Descriptor for EsDescriptor {
60    const TAG: u8 = 0x03;
61}
62
63#[derive(Default, Debug, Clone, PartialEq)]
64pub struct DecoderConfigDescriptor {
65    pub object_type_indication: u8,
66    pub stream_type: u8,
67    pub upstream: bool,
68    pub buffer_size_db: u32,
69    pub max_bitrate: u32,
70    pub avg_bitrate: u32,
71    pub decoder_specific_info: Option<DecoderSpecificInfo>,
72}
73
74impl Descriptor for DecoderConfigDescriptor {
75    const TAG: u8 = 0x04;
76}
77
78#[derive(Debug, Clone, PartialEq)]
79pub enum DecoderSpecificInfo {
80    Audio(AudioSpecificConfig),
81    Unknown(Vec<u8>),
82}
83
84impl Descriptor for DecoderSpecificInfo {
85    const TAG: u8 = 0x05;
86}
87
88#[derive(Debug, Clone, PartialEq)]
89pub struct SlConfigDescriptor {
90    pub predefined: u8,
91}
92
93impl Descriptor for SlConfigDescriptor {
94    const TAG: u8 = 0x06;
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub struct BtrtExtension {
99    pub buffer_size_db: u32,
100    pub max_bitrate: u32,
101    pub avg_bitrate: u32,
102}
103
104impl BtrtExtension {
105    const TYPE: FourCC = FourCC::new(b"btrt");
106}
107
108pub(super) mod serializer {
109    use crate::{
110        atom::{
111            stsd::{
112                extension::{
113                    audio_specific_config::serializer::serialize_audio_specific_config,
114                    DecoderConfigDescriptor, Descriptor, EsDescriptor, SlConfigDescriptor,
115                },
116                BtrtExtension, DecoderSpecificInfo, EsdsExtension, StsdExtension,
117            },
118            util::serializer::{
119                be_u24, bits::Packer, pascal_string, prepend_size_exclusive,
120                prepend_size_inclusive, SizeU32, SizeU32OrU64, SizeVLQ,
121            },
122        },
123        FourCC,
124    };
125
126    pub fn serialize_stsd_extensions(extensions: Vec<StsdExtension>) -> Vec<u8> {
127        extensions
128            .into_iter()
129            .flat_map(serialize_stsd_extension)
130            .collect::<Vec<_>>()
131    }
132
133    fn serialize_stsd_extension(extension: StsdExtension) -> Vec<u8> {
134        match extension {
135            StsdExtension::Esds(esds) => {
136                serialize_box(EsdsExtension::TYPE, serialize_esds_extension(esds))
137            }
138            StsdExtension::Btrt(btrt) => {
139                serialize_box(BtrtExtension::TYPE, serialize_btrt_extension(btrt))
140            }
141            StsdExtension::Unknown { fourcc, data } => serialize_box(fourcc, data),
142        }
143    }
144
145    fn serialize_esds_extension(esds: EsdsExtension) -> Vec<u8> {
146        let mut data = Vec::new();
147        data.push(esds.version);
148        data.extend(esds.flags);
149        data.extend(serialize_es_descriptor(esds.es_descriptor));
150        data
151    }
152
153    fn serialize_es_descriptor(es_desc: EsDescriptor) -> Vec<u8> {
154        let mut data = Vec::new();
155
156        data.extend(es_desc.es_id.to_be_bytes());
157
158        let mut flags = Packer::new();
159        flags.push_bool(es_desc.depends_on_es_id.is_some());
160        flags.push_bool(es_desc.url.is_some());
161        flags.push_bool(es_desc.ocr_es_id.is_some());
162        flags.push_n::<5>(es_desc.stream_priority);
163        data.push(Vec::from(flags)[0]);
164
165        if let Some(depends_on_es_id) = es_desc.depends_on_es_id {
166            data.extend(depends_on_es_id.to_be_bytes());
167        }
168
169        if let Some(url) = es_desc.url {
170            data.extend(pascal_string(url));
171        }
172
173        if let Some(ocr_es_id) = es_desc.ocr_es_id {
174            data.extend(ocr_es_id.to_be_bytes());
175        }
176
177        if let Some(decoder_config) = es_desc.decoder_config_descriptor {
178            data.extend(serialize_decoder_config(decoder_config));
179        }
180
181        if let Some(sl_config) = es_desc.sl_config_descriptor {
182            data.extend(serialize_sl_config(sl_config));
183        }
184
185        serialize_descriptor(EsDescriptor::TAG, data)
186    }
187
188    fn serialize_decoder_config(decoder_config: DecoderConfigDescriptor) -> Vec<u8> {
189        let mut data = Vec::new();
190
191        data.push(decoder_config.object_type_indication);
192
193        let mut stream_info = Packer::new();
194        stream_info.push_n::<6>(decoder_config.stream_type);
195        stream_info.push_bool(decoder_config.upstream);
196        stream_info.push_bool(true); // reserved
197        data.push(Vec::from(stream_info)[0]);
198
199        data.extend(be_u24(decoder_config.buffer_size_db));
200
201        data.extend(decoder_config.max_bitrate.to_be_bytes());
202        data.extend(decoder_config.avg_bitrate.to_be_bytes());
203
204        if let Some(decoder_info) = decoder_config.decoder_specific_info {
205            let decoder_info_bytes = match decoder_info {
206                DecoderSpecificInfo::Audio(c) => serialize_audio_specific_config(c),
207                DecoderSpecificInfo::Unknown(c) => c,
208            };
209            data.extend(serialize_descriptor(
210                DecoderSpecificInfo::TAG,
211                decoder_info_bytes,
212            ));
213        }
214
215        serialize_descriptor(DecoderConfigDescriptor::TAG, data)
216    }
217
218    fn serialize_sl_config(sl_config: SlConfigDescriptor) -> Vec<u8> {
219        serialize_descriptor(SlConfigDescriptor::TAG, vec![sl_config.predefined])
220    }
221
222    fn serialize_btrt_extension(btrt: BtrtExtension) -> Vec<u8> {
223        let mut data = Vec::new();
224        data.extend(btrt.buffer_size_db.to_be_bytes());
225        data.extend(btrt.max_bitrate.to_be_bytes());
226        data.extend(btrt.avg_bitrate.to_be_bytes());
227        data
228    }
229
230    fn serialize_descriptor(tag: u8, descriptor_data: Vec<u8>) -> Vec<u8> {
231        let mut data = Vec::new();
232        data.push(tag);
233        data.extend(prepend_size_exclusive::<SizeVLQ<SizeU32>, _>(move || {
234            descriptor_data
235        }));
236        data
237    }
238
239    fn serialize_box(fourcc: FourCC, box_data: Vec<u8>) -> Vec<u8> {
240        prepend_size_inclusive::<SizeU32OrU64, _>(move || {
241            let mut data = Vec::new();
242            data.extend(fourcc.into_bytes());
243            data.extend(box_data);
244            data
245        })
246    }
247}
248
249pub(super) mod parser {
250    use winnow::{
251        binary::{be_u16, be_u24, be_u32, bits, length_and_then, u8},
252        combinator::{opt, repeat, seq, trace},
253        error::{ContextError, ErrMode, StrContext},
254        token::literal,
255        ModalResult, Parser,
256    };
257
258    use crate::atom::{
259        stsd::extension::audio_specific_config::parser::parse_audio_specific_config,
260        util::parser::{
261            atom_size, combinators::inclusive_length_and_then, flags3, fourcc, pascal_string,
262            rest_vec, variable_length_be_u32, version, Stream,
263        },
264    };
265
266    use super::*;
267
268    pub fn parse_stsd_extensions(input: &mut Stream<'_>) -> ModalResult<Vec<StsdExtension>> {
269        repeat(0.., parse_stsd_extension).parse_next(input)
270    }
271
272    pub fn parse_stsd_extension(input: &mut Stream<'_>) -> ModalResult<StsdExtension> {
273        inclusive_length_and_then(
274            atom_size,
275            move |input: &mut Stream<'_>| -> ModalResult<StsdExtension> {
276                let fourcc = fourcc.parse_next(input)?;
277
278                Ok(match fourcc {
279                    EsdsExtension::TYPE => {
280                        parse_esds_box.map(StsdExtension::Esds).parse_next(input)?
281                    }
282                    BtrtExtension::TYPE => {
283                        parse_btrt_box.map(StsdExtension::Btrt).parse_next(input)?
284                    }
285                    _ => StsdExtension::Unknown {
286                        fourcc,
287                        data: rest_vec.parse_next(input)?,
288                    },
289                })
290            },
291        )
292        .parse_next(input)
293    }
294
295    fn parse_esds_box(input: &mut Stream<'_>) -> ModalResult<EsdsExtension> {
296        seq!(EsdsExtension {
297            version: version,
298            flags: flags3,
299            es_descriptor: parse_es_descriptor,
300        })
301        .parse_next(input)
302    }
303
304    fn parse_es_descriptor(input: &mut Stream<'_>) -> ModalResult<EsDescriptor> {
305        parse_descriptor(move |input: &mut Stream<'_>| {
306            let es_id = be_u16.parse_next(input)?;
307
308            struct Flags {
309                stream_dependence_flag: bool,
310                url_flag: bool,
311                ocr_stream_flag: bool,
312                stream_priority: u8,
313            }
314            let Flags {
315                stream_dependence_flag,
316                url_flag,
317                ocr_stream_flag,
318                stream_priority,
319            } = bits::bits(
320                move |input: &mut (Stream<'_>, usize)| -> ModalResult<Flags> {
321                    seq!(Flags {
322                        stream_dependence_flag: bits::bool
323                            .context(StrContext::Label("stream_dependency_flag")),
324                        url_flag: bits::bool.context(StrContext::Label("url_flag")),
325                        ocr_stream_flag: bits::bool.context(StrContext::Label("ocr_stream_flag")),
326                        stream_priority: bits::take(5usize)
327                            .context(StrContext::Label("stream_priority")),
328                    })
329                    .parse_next(input)
330                },
331            )
332            .parse_next(input)?;
333
334            let depends_on_es_id = if stream_dependence_flag {
335                Some(be_u16.parse_next(input)?)
336            } else {
337                None
338            };
339
340            let url = if url_flag {
341                Some(pascal_string.parse_next(input)?)
342            } else {
343                None
344            };
345
346            let ocr_es_id = if ocr_stream_flag {
347                Some(be_u16.parse_next(input)?)
348            } else {
349                None
350            };
351
352            let decoder_config_descriptor =
353                opt(parse_decoder_config_descriptor).parse_next(input)?;
354
355            let sl_config_descriptor = opt(parse_sl_config_descriptor).parse_next(input)?;
356
357            Ok(EsDescriptor {
358                es_id,
359                depends_on_es_id,
360                url,
361                ocr_es_id,
362                stream_priority,
363                decoder_config_descriptor,
364                sl_config_descriptor,
365            })
366        })
367        .parse_next(input)
368    }
369
370    fn parse_decoder_config_descriptor(
371        input: &mut Stream<'_>,
372    ) -> ModalResult<DecoderConfigDescriptor> {
373        parse_descriptor(move |input: &mut Stream<'_>| {
374            let object_type_indication = u8.parse_next(input)?;
375
376            struct StreamInfo {
377                stream_type: u8,
378                upstream: bool,
379            }
380            let StreamInfo {
381                stream_type,
382                upstream,
383            } = bits::bits(
384                move |input: &mut (Stream<'_>, usize)| -> ModalResult<StreamInfo> {
385                    seq!(StreamInfo {
386                        stream_type: bits::take(6usize).context(StrContext::Label("stream_type")),
387                        upstream: bits::bool.context(StrContext::Label("upstream")),
388                        _: bits::bool.context(StrContext::Label("reserved")),
389                    })
390                    .parse_next(input)
391                },
392            )
393            .parse_next(input)?;
394
395            let buffer_size_db = be_u24.parse_next(input)?;
396            let max_bitrate = be_u32.parse_next(input)?;
397            let avg_bitrate = be_u32.parse_next(input)?;
398
399            // Parse DecoderSpecificInfo if present
400            let decoder_specific_info = opt(move |input: &mut Stream<'_>| {
401                parse_descriptor(match stream_type {
402                    5 => |input: &mut Stream<'_>| {
403                        parse_audio_specific_config
404                            .map(DecoderSpecificInfo::Audio)
405                            .context(StrContext::Label("audio_specific_config"))
406                            .parse_next(input)
407                    },
408                    _ => |input: &mut Stream<'_>| {
409                        rest_vec
410                            .map(DecoderSpecificInfo::Unknown)
411                            .context(StrContext::Label("unknown"))
412                            .parse_next(input)
413                    },
414                })
415                .parse_next(input)
416            })
417            .parse_next(input)?;
418
419            Ok(DecoderConfigDescriptor {
420                object_type_indication,
421                stream_type,
422                upstream,
423                buffer_size_db,
424                max_bitrate,
425                avg_bitrate,
426                decoder_specific_info,
427            })
428        })
429        .parse_next(input)
430    }
431
432    fn parse_sl_config_descriptor(input: &mut Stream<'_>) -> ModalResult<SlConfigDescriptor> {
433        trace(
434            "parse_sl_config_descriptor",
435            parse_descriptor(seq!(SlConfigDescriptor {
436                predefined: u8.context(StrContext::Label("predefined")),
437            })),
438        )
439        .parse_next(input)
440    }
441
442    fn parse_descriptor<'i, Output, ParseDescriptor>(
443        mut parser: ParseDescriptor,
444    ) -> impl Parser<Stream<'i>, Output, ErrMode<ContextError>>
445    where
446        ParseDescriptor: Parser<Stream<'i>, Output, ErrMode<ContextError>>,
447        Output: Descriptor,
448    {
449        trace("parse_descriptor", move |input: &mut Stream<'i>| {
450            literal(<Output as Descriptor>::TAG)
451                .context(StrContext::Label("tag"))
452                .parse_next(input)?;
453            length_and_then(variable_length_be_u32, parser.by_ref()).parse_next(input)
454        })
455    }
456
457    fn parse_btrt_box(input: &mut Stream<'_>) -> ModalResult<BtrtExtension> {
458        trace(
459            "parse_btrt_box",
460            seq!(BtrtExtension {
461                buffer_size_db: be_u32.context(StrContext::Label("buffer_size_db")),
462                max_bitrate: be_u32.context(StrContext::Label("max_bitrate")),
463                avg_bitrate: be_u32.context(StrContext::Label("avg_bitrate")),
464            }),
465        )
466        .parse_next(input)
467    }
468}