Skip to main content

moq_transport/message/
subscribe.rs

1use crate::coding::{
2    Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace,
3};
4use crate::message::FilterType;
5use crate::message::GroupOrder;
6
7/// Sent by the subscriber to request all future objects for the given track.
8///
9/// Objects will use the provided ID instead of the full track name, to save bytes.
10#[derive(Clone, Debug, Eq, PartialEq)]
11pub struct Subscribe {
12    /// The subscription request ID
13    pub id: u64,
14
15    /// Track properties
16    pub track_namespace: TrackNamespace,
17    pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096)
18
19    /// Subscriber Priority
20    pub subscriber_priority: u8,
21    pub group_order: GroupOrder,
22
23    /// Forward Flag
24    pub forward: bool,
25
26    /// Filter type
27    pub filter_type: FilterType,
28
29    /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types.
30    pub start_location: Option<Location>,
31    /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type.
32    pub end_group_id: Option<u64>,
33
34    /// Optional parameters
35    pub params: KeyValuePairs,
36}
37
38impl Decode for Subscribe {
39    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
40        let id = u64::decode(r)?;
41
42        let track_namespace = TrackNamespace::decode(r)?;
43        let track_name = String::decode(r)?;
44
45        let subscriber_priority = u8::decode(r)?;
46        let group_order = GroupOrder::decode(r)?;
47
48        let forward = bool::decode(r)?;
49
50        let filter_type = FilterType::decode(r)?;
51        let start_location: Option<Location>;
52        let end_group_id: Option<u64>;
53        match filter_type {
54            FilterType::AbsoluteStart => {
55                start_location = Some(Location::decode(r)?);
56                end_group_id = None;
57            }
58            FilterType::AbsoluteRange => {
59                start_location = Some(Location::decode(r)?);
60                end_group_id = Some(u64::decode(r)?);
61            }
62            _ => {
63                start_location = None;
64                end_group_id = None;
65            }
66        }
67
68        let params = KeyValuePairs::decode(r)?;
69
70        Ok(Self {
71            id,
72            track_namespace,
73            track_name,
74            subscriber_priority,
75            group_order,
76            forward,
77            filter_type,
78            start_location,
79            end_group_id,
80            params,
81        })
82    }
83}
84
85impl Encode for Subscribe {
86    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
87        self.id.encode(w)?;
88
89        self.track_namespace.encode(w)?;
90        self.track_name.encode(w)?;
91
92        self.subscriber_priority.encode(w)?;
93        self.group_order.encode(w)?;
94
95        self.forward.encode(w)?;
96
97        self.filter_type.encode(w)?;
98        match self.filter_type {
99            FilterType::AbsoluteStart => {
100                if let Some(start) = &self.start_location {
101                    start.encode(w)?;
102                } else {
103                    return Err(EncodeError::MissingField("StartLocation".to_string()));
104                }
105                // Just ignore end_group_id if it happens to be set
106            }
107            FilterType::AbsoluteRange => {
108                if let Some(start) = &self.start_location {
109                    start.encode(w)?;
110                } else {
111                    return Err(EncodeError::MissingField("StartLocation".to_string()));
112                }
113                if let Some(end) = self.end_group_id {
114                    end.encode(w)?;
115                } else {
116                    return Err(EncodeError::MissingField("EndGroupId".to_string()));
117                }
118            }
119            _ => {}
120        }
121
122        self.params.encode(w)?;
123
124        Ok(())
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use bytes::BytesMut;
132
133    #[test]
134    fn encode_decode() {
135        let mut buf = BytesMut::new();
136
137        // One parameter for testing
138        let mut kvps = KeyValuePairs::new();
139        kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]);
140
141        // FilterType = NextGroupStart
142        let msg = Subscribe {
143            id: 12345,
144            track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"),
145            track_name: "audiotrack".to_string(),
146            subscriber_priority: 127,
147            group_order: GroupOrder::Publisher,
148            forward: true,
149            filter_type: FilterType::NextGroupStart,
150            start_location: None,
151            end_group_id: None,
152            params: kvps.clone(),
153        };
154        msg.encode(&mut buf).unwrap();
155        let decoded = Subscribe::decode(&mut buf).unwrap();
156        assert_eq!(decoded, msg);
157
158        // FilterType = AbsoluteStart
159        let msg = Subscribe {
160            id: 12345,
161            track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"),
162            track_name: "audiotrack".to_string(),
163            subscriber_priority: 127,
164            group_order: GroupOrder::Publisher,
165            forward: true,
166            filter_type: FilterType::AbsoluteStart,
167            start_location: Some(Location::new(12345, 67890)),
168            end_group_id: None,
169            params: kvps.clone(),
170        };
171        msg.encode(&mut buf).unwrap();
172        let decoded = Subscribe::decode(&mut buf).unwrap();
173        assert_eq!(decoded, msg);
174
175        // FilterType = AbsoluteRange
176        let msg = Subscribe {
177            id: 12345,
178            track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"),
179            track_name: "audiotrack".to_string(),
180            subscriber_priority: 127,
181            group_order: GroupOrder::Publisher,
182            forward: true,
183            filter_type: FilterType::AbsoluteRange,
184            start_location: Some(Location::new(12345, 67890)),
185            end_group_id: Some(23456),
186            params: kvps.clone(),
187        };
188        msg.encode(&mut buf).unwrap();
189        let decoded = Subscribe::decode(&mut buf).unwrap();
190        assert_eq!(decoded, msg);
191    }
192
193    #[test]
194    fn encode_missing_fields() {
195        let mut buf = BytesMut::new();
196
197        // FilterType = AbsoluteStart - missing start_location
198        let msg = Subscribe {
199            id: 12345,
200            track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"),
201            track_name: "audiotrack".to_string(),
202            subscriber_priority: 127,
203            group_order: GroupOrder::Publisher,
204            forward: true,
205            filter_type: FilterType::AbsoluteStart,
206            start_location: None,
207            end_group_id: None,
208            params: Default::default(),
209        };
210        let encoded = msg.encode(&mut buf);
211        assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_)));
212
213        // FilterType = AbsoluteRange - missing start_location
214        let msg = Subscribe {
215            id: 12345,
216            track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"),
217            track_name: "audiotrack".to_string(),
218            subscriber_priority: 127,
219            group_order: GroupOrder::Publisher,
220            forward: true,
221            filter_type: FilterType::AbsoluteRange,
222            start_location: None,
223            end_group_id: None,
224            params: Default::default(),
225        };
226        let encoded = msg.encode(&mut buf);
227        assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_)));
228
229        // FilterType = AbsoluteRange - missing end_group_id
230        let msg = Subscribe {
231            id: 12345,
232            track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"),
233            track_name: "audiotrack".to_string(),
234            subscriber_priority: 127,
235            group_order: GroupOrder::Publisher,
236            forward: true,
237            filter_type: FilterType::AbsoluteRange,
238            start_location: Some(Location::new(12345, 67890)),
239            end_group_id: None,
240            params: Default::default(),
241        };
242        let encoded = msg.encode(&mut buf);
243        assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_)));
244    }
245}