Skip to main content

moq_transport/data/
header.rs

1use crate::coding::{Decode, DecodeError, Encode, EncodeError};
2use crate::data::{FetchHeader, SubgroupHeader};
3use std::fmt;
4
5/// Stream Header Types
6#[repr(u64)]
7#[derive(Copy, Debug, Clone, Eq, PartialEq)]
8pub enum StreamHeaderType {
9    SubgroupZeroId = 0x10,
10    SubgroupZeroIdExt = 0x11,
11    SubgroupFirstObjectId = 0x12,
12    SubgroupFirstObjectIdExt = 0x13,
13    SubgroupId = 0x14,
14    SubgroupIdExt = 0x15,
15    SubgroupZeroIdEndOfGroup = 0x18,
16    SubgroupZeroIdExtEndOfGroup = 0x19,
17    SubgroupFirstObjectIdEndOfGroup = 0x1a,
18    SubgroupFirstObjectIdExtEndOfGroup = 0x1b,
19    SubgroupIdEndOfGroup = 0x1c,
20    SubgroupIdExtEndOfGroup = 0x1d,
21    Fetch = 0x5,
22}
23
24impl StreamHeaderType {
25    pub fn is_subgroup(&self) -> bool {
26        let header_type = *self as u64;
27        (0x10..=0x1d).contains(&header_type)
28    }
29
30    pub fn is_fetch(&self) -> bool {
31        *self == StreamHeaderType::Fetch
32    }
33
34    pub fn has_extension_headers(&self) -> bool {
35        matches!(
36            *self,
37            StreamHeaderType::SubgroupZeroIdExt
38                | StreamHeaderType::SubgroupFirstObjectIdExt
39                | StreamHeaderType::SubgroupIdExt
40                | StreamHeaderType::SubgroupZeroIdExtEndOfGroup
41                | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup
42                | StreamHeaderType::SubgroupIdExtEndOfGroup
43                | StreamHeaderType::Fetch
44        )
45    }
46
47    pub fn has_subgroup_id(&self) -> bool {
48        matches!(
49            *self,
50            StreamHeaderType::SubgroupId
51                | StreamHeaderType::SubgroupIdExt
52                | StreamHeaderType::SubgroupIdEndOfGroup
53                | StreamHeaderType::SubgroupIdExtEndOfGroup
54        )
55    }
56}
57
58impl Encode for StreamHeaderType {
59    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
60        let val = *self as u64;
61        tracing::trace!(
62            "[ENCODE] StreamHeaderType: encoding {:?} as {:#x}",
63            self,
64            val
65        );
66        val.encode(w)?;
67        tracing::trace!("[ENCODE] StreamHeaderType: encoded successfully");
68        Ok(())
69    }
70}
71
72impl Decode for StreamHeaderType {
73    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
74        tracing::trace!(
75            "[DECODE] StreamHeaderType: starting decode, buffer_remaining={} bytes",
76            r.remaining()
77        );
78
79        let type_value = u64::decode(r)?;
80        tracing::trace!(
81            "[DECODE] StreamHeaderType: decoded type value={:#x}",
82            type_value
83        );
84
85        let header_type = match type_value {
86            0x10_u64 => Ok(Self::SubgroupZeroId),
87            0x11_u64 => Ok(Self::SubgroupZeroIdExt),
88            0x12_u64 => Ok(Self::SubgroupFirstObjectId),
89            0x13_u64 => Ok(Self::SubgroupFirstObjectIdExt),
90            0x14_u64 => Ok(Self::SubgroupId),
91            0x15_u64 => Ok(Self::SubgroupIdExt),
92            0x18_u64 => Ok(Self::SubgroupZeroIdEndOfGroup),
93            0x19_u64 => Ok(Self::SubgroupZeroIdExtEndOfGroup),
94            0x1a_u64 => Ok(Self::SubgroupFirstObjectIdEndOfGroup),
95            0x1b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroup),
96            0x1c_u64 => Ok(Self::SubgroupIdEndOfGroup),
97            0x1d_u64 => Ok(Self::SubgroupIdExtEndOfGroup),
98            0x05_u64 => Ok(Self::Fetch),
99            _ => {
100                tracing::error!(
101                    "[DECODE] StreamHeaderType: INVALID type value={:#x}",
102                    type_value
103                );
104                Err(DecodeError::InvalidHeaderType)
105            }
106        };
107
108        if let Ok(header_type_inner) = &header_type {
109            tracing::debug!(
110                "[DECODE] StreamHeaderType: {}, has_subgroup_id={}, has_extension_headers={}",
111                header_type_inner,
112                header_type_inner.has_subgroup_id(),
113                header_type_inner.has_extension_headers()
114            );
115        }
116
117        header_type
118    }
119}
120
121impl fmt::Display for StreamHeaderType {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        write!(f, "{:?} ({:#x})", self, *self as u64)
124    }
125}
126
127#[derive(Debug, Clone, Eq, PartialEq)]
128pub struct StreamHeader {
129    /// Subgroup Header Type
130    pub header_type: StreamHeaderType,
131
132    /// Subgroup Header for StreamHeaderTypes that are Subgroup header types
133    pub subgroup_header: Option<SubgroupHeader>,
134
135    /// Fetch Header for StreamHeaderTypes that are Fetch header types
136    pub fetch_header: Option<FetchHeader>,
137}
138
139impl Decode for StreamHeader {
140    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
141        tracing::trace!(
142            "[DECODE] StreamHeader: starting decode, buffer_remaining={} bytes",
143            r.remaining()
144        );
145
146        let header_type = StreamHeaderType::decode(r)?;
147        tracing::trace!(
148            "[DECODE] StreamHeader: decoded header_type={:?}",
149            header_type
150        );
151
152        let subgroup_header = match header_type.is_subgroup() {
153            true => {
154                tracing::trace!("[DECODE] StreamHeader: decoding subgroup header");
155                Some(SubgroupHeader::decode(header_type, r)?)
156            }
157            false => {
158                tracing::trace!("[DECODE] StreamHeader: no subgroup header (not a subgroup type)");
159                None
160            }
161        };
162
163        let fetch_header = match header_type.is_fetch() {
164            true => {
165                tracing::trace!("[DECODE] StreamHeader: decoding fetch header");
166                Some(FetchHeader::decode(header_type, r)?)
167            }
168            false => {
169                tracing::trace!("[DECODE] StreamHeader: no fetch header (not a fetch type)");
170                None
171            }
172        };
173
174        tracing::debug!(
175            "[DECODE] StreamHeader complete: type={:?}, has_subgroup={}, has_fetch={}, buffer_remaining={} bytes",
176            header_type,
177            subgroup_header.is_some(),
178            fetch_header.is_some(),
179            r.remaining()
180        );
181
182        Ok(Self {
183            header_type,
184            subgroup_header,
185            fetch_header,
186        })
187    }
188}
189
190impl Encode for StreamHeader {
191    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
192        tracing::trace!(
193            "[ENCODE] StreamHeader: starting encode for type={:?}, has_subgroup={}, has_fetch={}",
194            self.header_type,
195            self.subgroup_header.is_some(),
196            self.fetch_header.is_some()
197        );
198
199        // Note: we are intentionally not encoding the header_type here, it will be encoded in the
200        //       appropriate substructures.
201        //self.header_type.encode(w)?;
202        if self.header_type.is_subgroup() {
203            if let Some(subgroup_header) = &self.subgroup_header {
204                tracing::trace!("[ENCODE] StreamHeader: encoding subgroup header");
205                subgroup_header.encode(w)?;
206            } else {
207                tracing::error!(
208                    "[ENCODE] StreamHeader: MISSING subgroup header for subgroup type={:?}",
209                    self.header_type
210                );
211                return Err(EncodeError::MissingField("SubgroupHeader".to_string()));
212            }
213        } else if let Some(fetch_header) = &self.fetch_header {
214            tracing::trace!("[ENCODE] StreamHeader: encoding fetch header");
215            fetch_header.encode(w)?;
216        } else {
217            tracing::error!(
218                "[ENCODE] StreamHeader: MISSING fetch header for fetch type={:?}",
219                self.header_type
220            );
221            return Err(EncodeError::MissingField("FetchHeader".to_string()));
222        }
223
224        tracing::debug!("[ENCODE] StreamHeader complete");
225
226        Ok(())
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use bytes::Bytes;
234    use bytes::BytesMut;
235
236    #[test]
237    fn encode_decode_stream_header_type() {
238        let mut buf = BytesMut::new();
239
240        let ht = StreamHeaderType::Fetch;
241        ht.encode(&mut buf).unwrap();
242        assert_eq!(buf.to_vec(), vec![0x05]);
243        let decoded = StreamHeaderType::decode(&mut buf).unwrap();
244        assert_eq!(decoded, ht);
245        assert!(ht.is_fetch());
246        assert!(!ht.is_subgroup());
247        assert!(!ht.has_subgroup_id());
248
249        let ht = StreamHeaderType::SubgroupZeroId;
250        ht.encode(&mut buf).unwrap();
251        assert_eq!(buf.to_vec(), vec![0x10]);
252        let decoded = StreamHeaderType::decode(&mut buf).unwrap();
253        assert_eq!(decoded, ht);
254        assert!(ht.is_subgroup());
255        assert!(!ht.is_fetch());
256        assert!(!ht.has_subgroup_id());
257    }
258
259    #[test]
260    fn decode_bad_stream_header_type() {
261        let data: Vec<u8> = vec![0x00]; // Invalid filter type
262        let mut buf: Bytes = data.into();
263        let result = StreamHeaderType::decode(&mut buf);
264        assert!(matches!(result, Err(DecodeError::InvalidHeaderType)));
265    }
266
267    #[test]
268    fn encode_decode_stream_header() {
269        let mut buf = BytesMut::new();
270
271        let sh = StreamHeader {
272            header_type: StreamHeaderType::Fetch,
273            subgroup_header: None,
274            fetch_header: Some(FetchHeader {
275                header_type: StreamHeaderType::Fetch,
276                request_id: 10,
277            }),
278        };
279        sh.encode(&mut buf).unwrap();
280        let decoded = StreamHeader::decode(&mut buf).unwrap();
281        assert_eq!(decoded, sh);
282        assert!(sh.header_type.is_fetch());
283        assert!(!sh.header_type.is_subgroup());
284        assert!(!sh.header_type.has_subgroup_id());
285
286        let sh = StreamHeader {
287            header_type: StreamHeaderType::SubgroupId,
288            subgroup_header: Some(SubgroupHeader {
289                header_type: StreamHeaderType::SubgroupId,
290                track_alias: 10,
291                group_id: 0,
292                subgroup_id: Some(1),
293                publisher_priority: 100,
294            }),
295            fetch_header: None,
296        };
297        sh.encode(&mut buf).unwrap();
298        let decoded = StreamHeader::decode(&mut buf).unwrap();
299        assert_eq!(decoded, sh);
300        assert!(sh.header_type.is_subgroup());
301        assert!(!sh.header_type.is_fetch());
302        assert!(sh.header_type.has_subgroup_id());
303    }
304}