mp4_atom/moov/trak/mdia/minf/stbl/
sgpd.rs

1use crate::*;
2
3/// SampleGroupDescriptionBox, ISO/IEC 14496-12:2024 Sect 8.9.3
4#[derive(Debug, Clone, PartialEq, Eq)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct Sgpd {
7    pub grouping_type: FourCC,
8    pub default_length: Option<u32>,
9    pub default_group_description_index: Option<u32>,
10    pub static_group_description: bool,
11    pub static_mapping: bool,
12    pub essential: bool,
13    pub entries: Vec<SgpdEntry>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct SgpdEntry {
19    pub description_length: Option<u32>,
20    pub entry: AnySampleGroupEntry,
21}
22
23ext!(
24    name: Sgpd,
25    versions: [0, 1, 2, 3],
26    flags: {
27        static_group_description = 0,
28        static_mapping = 1,
29    }
30);
31
32impl PartialOrd for SgpdVersion {
33    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34        match (self, other) {
35            (SgpdVersion::V0, SgpdVersion::V0) => Some(std::cmp::Ordering::Equal),
36            (SgpdVersion::V0, _) => Some(std::cmp::Ordering::Less),
37            (SgpdVersion::V1, SgpdVersion::V0) => Some(std::cmp::Ordering::Greater),
38            (SgpdVersion::V1, SgpdVersion::V1) => Some(std::cmp::Ordering::Equal),
39            (SgpdVersion::V1, _) => Some(std::cmp::Ordering::Less),
40            (SgpdVersion::V2, SgpdVersion::V2) => Some(std::cmp::Ordering::Equal),
41            (SgpdVersion::V2, SgpdVersion::V3) => Some(std::cmp::Ordering::Less),
42            (SgpdVersion::V2, _) => Some(std::cmp::Ordering::Greater),
43            (SgpdVersion::V3, SgpdVersion::V3) => Some(std::cmp::Ordering::Equal),
44            (SgpdVersion::V3, _) => Some(std::cmp::Ordering::Greater),
45        }
46    }
47}
48
49impl AtomExt for Sgpd {
50    type Ext = SgpdExt;
51
52    const KIND_EXT: FourCC = FourCC::new(b"sgpd");
53
54    fn decode_body_ext<B: Buf>(buf: &mut B, ext: Self::Ext) -> Result<Self> {
55        let grouping_type = FourCC::decode(buf)?;
56        let default_length = if ext.version >= SgpdVersion::V1 {
57            Some(u32::decode(buf)?)
58        } else {
59            None
60        };
61        let default_group_description_index = if ext.version >= SgpdVersion::V2 {
62            Some(u32::decode(buf)?)
63        } else {
64            None
65        };
66        let entry_count = u32::decode(buf)?;
67        let mut entries = Vec::with_capacity((entry_count as usize).min(1024));
68        for _ in 0..entry_count {
69            // Spec states: if version>=1 && default_length==0
70            // But, default_length.is_some(), if and only if version>=1, so fine to just check for
71            // `Some(0)`.
72            let description_length = if default_length == Some(0) {
73                Some(u32::decode(buf)?)
74            } else {
75                default_length
76            };
77            let entry = AnySampleGroupEntry::decode(grouping_type, buf)?;
78            entries.push(SgpdEntry {
79                description_length,
80                entry,
81            });
82        }
83        let static_group_description = ext.static_group_description;
84        let static_mapping = ext.static_mapping;
85        let essential = ext.version == SgpdVersion::V3;
86        Ok(Self {
87            grouping_type,
88            default_length,
89            default_group_description_index,
90            static_group_description,
91            static_mapping,
92            essential,
93            entries,
94        })
95    }
96
97    fn encode_body_ext<B: BufMut>(&self, buf: &mut B) -> Result<Self::Ext> {
98        let version = if self.essential {
99            SgpdVersion::V3
100        } else if self.default_group_description_index.is_some() {
101            SgpdVersion::V2
102        } else if self.default_length.is_some() {
103            SgpdVersion::V1
104        } else {
105            SgpdVersion::V0
106        };
107        let ext = SgpdExt {
108            version,
109            static_group_description: self.static_group_description,
110            static_mapping: self.static_mapping,
111        };
112        self.grouping_type.encode(buf)?;
113        if let Some(default_length) = self.default_length {
114            default_length.encode(buf)?;
115        }
116        if let Some(default_group_description_index) = self.default_group_description_index {
117            default_group_description_index.encode(buf)?;
118        }
119        (self.entries.len() as u32).encode(buf)?;
120        for entry in &self.entries {
121            if self.default_length == Some(0) {
122                if let Some(description_length) = entry.description_length {
123                    description_length.encode(buf)?
124                }
125            }
126            entry.entry.encode(buf)?
127        }
128        Ok(ext)
129    }
130}
131
132const REFS_4CC: FourCC = FourCC::new(b"refs");
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136pub enum AnySampleGroupEntry {
137    DirectReferenceSampleList(u32, Vec<u32>),
138    UnknownGroupingType(FourCC, Vec<u8>),
139}
140
141impl AnySampleGroupEntry {
142    fn decode<B: Buf>(grouping_type: FourCC, buf: &mut B) -> Result<Self> {
143        match grouping_type {
144            REFS_4CC => {
145                let sample_id = u32::decode(buf)?;
146                let num_direct_reference_samples = u8::decode(buf)? as usize;
147                let mut direct_reference_samples =
148                    Vec::with_capacity(std::cmp::min(num_direct_reference_samples, 16));
149                for _ in 0..num_direct_reference_samples {
150                    direct_reference_samples.push(u32::decode(buf)?);
151                }
152                Ok(Self::DirectReferenceSampleList(
153                    sample_id,
154                    direct_reference_samples,
155                ))
156            }
157            _ => Ok(Self::UnknownGroupingType(grouping_type, Vec::decode(buf)?)),
158        }
159    }
160
161    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
162        match self {
163            Self::DirectReferenceSampleList(sample_id, direct_reference_samples) => {
164                sample_id.encode(buf)?;
165                let num_direct_reference_samples: u8 = direct_reference_samples
166                    .len()
167                    .try_into()
168                    .map_err(|_| Error::TooLarge(REFS_4CC))?;
169                num_direct_reference_samples.encode(buf)?;
170                for direct_reference_sample in direct_reference_samples {
171                    direct_reference_sample.encode(buf)?;
172                }
173                Ok(())
174            }
175            Self::UnknownGroupingType(_, bytes) => bytes.encode(buf),
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use std::io::Cursor;
184
185    // This example was taken from:
186    // https://mpeggroup.github.io/FileFormatConformance/files/published/isobmff/a9-aac-samplegroups-edit.mp4
187    //
188    // I just extracted the bytes for the sgpd atom location.
189    const SIMPLE_SGPD: &[u8] = &[
190        0x00, 0x00, 0x00, 0x1A, 0x73, 0x67, 0x70, 0x64, 0x01, 0x00, 0x00, 0x00, 0x72, 0x6F, 0x6C,
191        0x6C, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0xFF, 0xFF,
192    ];
193
194    #[test]
195    fn sgpd_decodes_from_bytes_correctly() {
196        let mut buf = Cursor::new(SIMPLE_SGPD);
197        let sgpd = Sgpd::decode(&mut buf).expect("sgpd should decode successfully");
198        assert_eq!(
199            sgpd,
200            Sgpd {
201                grouping_type: FourCC::from(b"roll"),
202                default_length: Some(2),
203                default_group_description_index: None,
204                static_group_description: false,
205                static_mapping: false,
206                essential: false,
207                entries: vec![SgpdEntry {
208                    description_length: Some(2),
209                    entry: AnySampleGroupEntry::UnknownGroupingType(
210                        FourCC::from(b"roll"),
211                        SIMPLE_SGPD[24..].to_vec()
212                    )
213                }],
214            }
215        )
216    }
217
218    #[test]
219    fn sgpd_encodes_from_type_correctly() {
220        let sgpd = Sgpd {
221            grouping_type: FourCC::from(b"roll"),
222            default_length: Some(2),
223            default_group_description_index: None,
224            static_group_description: false,
225            static_mapping: false,
226            essential: false,
227            entries: vec![SgpdEntry {
228                description_length: Some(2),
229                entry: AnySampleGroupEntry::UnknownGroupingType(
230                    FourCC::from(b"roll"),
231                    SIMPLE_SGPD[24..].to_vec(),
232                ),
233            }],
234        };
235        let mut buf = Vec::new();
236        sgpd.encode(&mut buf).expect("encode should be successful");
237        assert_eq!(SIMPLE_SGPD, &buf);
238    }
239
240    // From the MPEG File Format Conformance suite, heif/C041.heic
241    const SGPD_ENCODED_C041: &[u8] = &[
242        0x00, 0x00, 0x00, 0x2e, 0x73, 0x67, 0x70, 0x64, 0x01, 0x00, 0x00, 0x00, 0x72, 0x65, 0x66,
243        0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00,
244        0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01,
245        0x00,
246    ];
247
248    #[test]
249    fn sgpd_c041_decode() {
250        let mut buf = Cursor::new(SGPD_ENCODED_C041);
251        let sgpd = Sgpd::decode(&mut buf).expect("sgpd should decode successfully");
252        assert_eq!(
253            sgpd,
254            Sgpd {
255                grouping_type: FourCC::from(b"refs"),
256                default_length: Some(0),
257                default_group_description_index: None,
258                static_group_description: false,
259                static_mapping: false,
260                essential: false,
261                entries: vec![
262                    SgpdEntry {
263                        description_length: Some(9),
264                        entry: AnySampleGroupEntry::DirectReferenceSampleList(0, vec![1])
265                    },
266                    SgpdEntry {
267                        description_length: Some(5),
268                        entry: AnySampleGroupEntry::DirectReferenceSampleList(1, vec![])
269                    }
270                ],
271            }
272        );
273
274        let mut encoded = Vec::new();
275        sgpd.encode(&mut encoded).unwrap();
276
277        assert_eq!(encoded, SGPD_ENCODED_C041);
278    }
279}