mp4_atom/moov/trak/mdia/minf/stbl/
sgpd.rs

1use crate::*;
2
3/// SampleGroupDescriptionBox, ISO/IEC 14496-12:2024 Sect 8.9.3
4#[derive(Debug, Clone, PartialEq, Eq)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct Sgpd {
7    pub grouping_type: FourCC,
8    pub default_length: Option<u32>,
9    pub default_group_description_index: Option<u32>,
10    pub static_group_description: bool,
11    pub static_mapping: bool,
12    pub essential: bool,
13    pub entries: Vec<SgpdEntry>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct SgpdEntry {
19    pub description_length: Option<u32>,
20    pub entry: AnySampleGroupEntry,
21}
22
23ext!(
24    name: Sgpd,
25    versions: [0, 1, 2, 3],
26    flags: {
27        static_group_description = 0,
28        static_mapping = 1,
29    }
30);
31
32impl PartialOrd for SgpdVersion {
33    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34        match (self, other) {
35            (SgpdVersion::V0, SgpdVersion::V0) => Some(std::cmp::Ordering::Equal),
36            (SgpdVersion::V0, _) => Some(std::cmp::Ordering::Less),
37            (SgpdVersion::V1, SgpdVersion::V0) => Some(std::cmp::Ordering::Greater),
38            (SgpdVersion::V1, SgpdVersion::V1) => Some(std::cmp::Ordering::Equal),
39            (SgpdVersion::V1, _) => Some(std::cmp::Ordering::Less),
40            (SgpdVersion::V2, SgpdVersion::V2) => Some(std::cmp::Ordering::Equal),
41            (SgpdVersion::V2, SgpdVersion::V3) => Some(std::cmp::Ordering::Less),
42            (SgpdVersion::V2, _) => Some(std::cmp::Ordering::Greater),
43            (SgpdVersion::V3, SgpdVersion::V3) => Some(std::cmp::Ordering::Equal),
44            (SgpdVersion::V3, _) => Some(std::cmp::Ordering::Greater),
45        }
46    }
47}
48
49impl AtomExt for Sgpd {
50    type Ext = SgpdExt;
51
52    const KIND_EXT: FourCC = FourCC::new(b"sgpd");
53
54    fn decode_body_ext<B: Buf>(buf: &mut B, ext: Self::Ext) -> Result<Self> {
55        let grouping_type = FourCC::decode(buf)?;
56        let default_length = if ext.version >= SgpdVersion::V1 {
57            Some(u32::decode(buf)?)
58        } else {
59            None
60        };
61        let default_group_description_index = if ext.version >= SgpdVersion::V2 {
62            Some(u32::decode(buf)?)
63        } else {
64            None
65        };
66        let entry_count = u32::decode(buf)?;
67        let mut entries = Vec::with_capacity((entry_count as usize).min(1024));
68        for _ in 0..entry_count {
69            // Spec states: if version>=1 && default_length==0
70            // But, default_length.is_some(), if and only if version>=1, so fine to just check for
71            // `Some(0)`.
72            let description_length = if default_length == Some(0) {
73                Some(u32::decode(buf)?)
74            } else {
75                default_length
76            };
77            let entry = AnySampleGroupEntry::decode(grouping_type, buf)?;
78            entries.push(SgpdEntry {
79                description_length,
80                entry,
81            });
82        }
83        let static_group_description = ext.static_group_description;
84        let static_mapping = ext.static_mapping;
85        let essential = ext.version == SgpdVersion::V3;
86        Ok(Self {
87            grouping_type,
88            default_length,
89            default_group_description_index,
90            static_group_description,
91            static_mapping,
92            essential,
93            entries,
94        })
95    }
96
97    fn encode_body_ext<B: BufMut>(&self, buf: &mut B) -> Result<Self::Ext> {
98        let version = if self.essential {
99            SgpdVersion::V3
100        } else if self.default_group_description_index.is_some() {
101            SgpdVersion::V2
102        } else if self.default_length.is_some() {
103            SgpdVersion::V1
104        } else {
105            SgpdVersion::V0
106        };
107        let ext = SgpdExt {
108            version,
109            static_group_description: self.static_group_description,
110            static_mapping: self.static_mapping,
111        };
112        self.grouping_type.encode(buf)?;
113        if let Some(default_length) = self.default_length {
114            default_length.encode(buf)?;
115        }
116        if let Some(default_group_description_index) = self.default_group_description_index {
117            default_group_description_index.encode(buf)?;
118        }
119        (self.entries.len() as u32).encode(buf)?;
120        for entry in &self.entries {
121            if self.default_length == Some(0) {
122                if let Some(description_length) = entry.description_length {
123                    description_length.encode(buf)?
124                }
125            }
126            entry.entry.encode(buf)?
127        }
128        Ok(ext)
129    }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
134pub enum AnySampleGroupEntry {
135    UnknownGroupingType(FourCC, Vec<u8>),
136}
137
138impl AnySampleGroupEntry {
139    fn decode<B: Buf>(grouping_type: FourCC, buf: &mut B) -> Result<Self> {
140        Ok(Self::UnknownGroupingType(grouping_type, Vec::decode(buf)?))
141    }
142
143    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
144        match self {
145            Self::UnknownGroupingType(_, bytes) => bytes.encode(buf),
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use std::io::Cursor;
154
155    // This example was taken from:
156    // https://mpeggroup.github.io/FileFormatConformance/files/published/isobmff/a9-aac-samplegroups-edit.mp4
157    //
158    // I just extracted the bytes for the sgpd atom location.
159    const SIMPLE_SGPD: &[u8] = &[
160        0x00, 0x00, 0x00, 0x1A, 0x73, 0x67, 0x70, 0x64, 0x01, 0x00, 0x00, 0x00, 0x72, 0x6F, 0x6C,
161        0x6C, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0xFF, 0xFF,
162    ];
163
164    #[test]
165    fn sgpd_decodes_from_bytes_correctly() {
166        let mut buf = Cursor::new(SIMPLE_SGPD);
167        let sgpd = Sgpd::decode(&mut buf).expect("sgpd should decode successfully");
168        assert_eq!(
169            sgpd,
170            Sgpd {
171                grouping_type: FourCC::from(b"roll"),
172                default_length: Some(2),
173                default_group_description_index: None,
174                static_group_description: false,
175                static_mapping: false,
176                essential: false,
177                entries: vec![SgpdEntry {
178                    description_length: Some(2),
179                    entry: AnySampleGroupEntry::UnknownGroupingType(
180                        FourCC::from(b"roll"),
181                        SIMPLE_SGPD[24..].to_vec()
182                    )
183                }],
184            }
185        )
186    }
187
188    #[test]
189    fn sgpd_encodes_from_type_correctly() {
190        let sgpd = Sgpd {
191            grouping_type: FourCC::from(b"roll"),
192            default_length: Some(2),
193            default_group_description_index: None,
194            static_group_description: false,
195            static_mapping: false,
196            essential: false,
197            entries: vec![SgpdEntry {
198                description_length: Some(2),
199                entry: AnySampleGroupEntry::UnknownGroupingType(
200                    FourCC::from(b"roll"),
201                    SIMPLE_SGPD[24..].to_vec(),
202                ),
203            }],
204        };
205        let mut buf = Vec::new();
206        sgpd.encode(&mut buf).expect("encode should be successful");
207        assert_eq!(SIMPLE_SGPD, &buf);
208    }
209}