1use std::borrow::Cow;
4
5use num_enum::{IntoPrimitive, TryFromPrimitive};
6
7use crate::{
8 coding::*,
9 ietf::{GroupOrder, Location, Message, Parameters, RequestId, Version},
10 Path,
11};
12
13use super::namespace::{decode_namespace, encode_namespace};
14
15#[derive(Clone, Copy, Debug, TryFromPrimitive, IntoPrimitive)]
16#[repr(u64)]
17pub enum FilterType {
18 NextGroup = 0x01,
19 LargestObject = 0x2,
20 AbsoluteStart = 0x3,
21 AbsoluteRange = 0x4,
22}
23
24impl<V> Encode<V> for FilterType {
25 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) {
26 u64::from(*self).encode(w, version);
27 }
28}
29
30impl<V> Decode<V> for FilterType {
31 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
32 Self::try_from(u64::decode(r, version)?).map_err(|_| DecodeError::InvalidValue)
33 }
34}
35
36#[derive(Clone, Debug)]
39pub struct Subscribe<'a> {
40 pub request_id: RequestId,
41 pub track_namespace: Path<'a>,
42 pub track_name: Cow<'a, str>,
43 pub subscriber_priority: u8,
44 pub group_order: GroupOrder,
45 pub filter_type: FilterType,
46}
47
48impl<'a> Message for Subscribe<'a> {
49 const ID: u64 = 0x03;
50
51 fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
52 let request_id = RequestId::decode(r, version)?;
53
54 let track_namespace = decode_namespace(r, version)?;
56
57 let track_name = Cow::<str>::decode(r, version)?;
58 let subscriber_priority = u8::decode(r, version)?;
59
60 let group_order = GroupOrder::decode(r, version)?;
61
62 let forward = bool::decode(r, version)?;
63 if !forward {
64 return Err(DecodeError::Unsupported);
65 }
66
67 let filter_type = FilterType::decode(r, version)?;
68 match filter_type {
69 FilterType::AbsoluteStart => {
70 let _start = Location::decode(r, version)?;
71 }
72 FilterType::AbsoluteRange => {
73 let _start = Location::decode(r, version)?;
74 let _end_group = u64::decode(r, version)?;
75 }
76 FilterType::NextGroup | FilterType::LargestObject => {}
77 };
78
79 let _params = Parameters::decode(r, version)?;
81
82 Ok(Self {
83 request_id,
84 track_namespace,
85 track_name,
86 subscriber_priority,
87 group_order,
88 filter_type,
89 })
90 }
91
92 fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
93 self.request_id.encode(w, version);
94 encode_namespace(w, &self.track_namespace, version);
95 self.track_name.encode(w, version);
96 self.subscriber_priority.encode(w, version);
97 GroupOrder::Descending.encode(w, version);
98 true.encode(w, version); assert!(
101 !matches!(self.filter_type, FilterType::AbsoluteStart | FilterType::AbsoluteRange),
102 "Absolute subscribe not supported"
103 );
104
105 self.filter_type.encode(w, version);
106 0u8.encode(w, version); }
108}
109
110#[derive(Clone, Debug)]
112pub struct SubscribeOk {
113 pub request_id: RequestId,
114 pub track_alias: u64,
115}
116
117impl Message for SubscribeOk {
118 const ID: u64 = 0x04;
119
120 fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
121 self.request_id.encode(w, version);
122 self.track_alias.encode(w, version);
123 0u64.encode(w, version); GroupOrder::Descending.encode(w, version);
125 false.encode(w, version); 0u8.encode(w, version); }
128
129 fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
130 let request_id = RequestId::decode(r, version)?;
131 let track_alias = u64::decode(r, version)?;
132
133 let expires = u64::decode(r, version)?;
134 if expires != 0 {
135 return Err(DecodeError::Unsupported);
136 }
137
138 let _group_order = u8::decode(r, version)?;
140
141 if bool::decode(r, version)? {
143 let _group = u64::decode(r, version)?;
144 let _object = u64::decode(r, version)?;
145 }
146
147 let _params = Parameters::decode(r, version)?;
149
150 Ok(Self {
151 request_id,
152 track_alias,
153 })
154 }
155}
156
157#[derive(Clone, Debug)]
159pub struct SubscribeError<'a> {
160 pub request_id: RequestId,
161 pub error_code: u64,
162 pub reason_phrase: Cow<'a, str>,
163}
164
165impl<'a> Message for SubscribeError<'a> {
166 const ID: u64 = 0x05;
167
168 fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
169 self.request_id.encode(w, version);
170 self.error_code.encode(w, version);
171 self.reason_phrase.encode(w, version);
172 }
173 fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
174 let request_id = RequestId::decode(r, version)?;
175 let error_code = u64::decode(r, version)?;
176 let reason_phrase = Cow::<str>::decode(r, version)?;
177
178 Ok(Self {
179 request_id,
180 error_code,
181 reason_phrase,
182 })
183 }
184}
185
186#[derive(Clone, Debug)]
188pub struct Unsubscribe {
189 pub request_id: RequestId,
190}
191
192impl Message for Unsubscribe {
193 const ID: u64 = 0x0a;
194
195 fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
196 self.request_id.encode(w, version);
197 }
198
199 fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
200 let request_id = RequestId::decode(r, version)?;
201 Ok(Self { request_id })
202 }
203}
204
205#[derive(Debug)]
218pub struct SubscribeUpdate {
219 pub request_id: RequestId,
220 pub subscription_request_id: RequestId,
221 pub start_location: Location,
222 pub end_group: u64,
223 pub subscriber_priority: u8,
224 pub forward: bool,
225 }
227
228impl Message for SubscribeUpdate {
229 const ID: u64 = 0x02;
230
231 fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
232 self.request_id.encode(w, version);
233 self.subscription_request_id.encode(w, version);
234 self.start_location.encode(w, version);
235 self.end_group.encode(w, version);
236 self.subscriber_priority.encode(w, version);
237 self.forward.encode(w, version);
238 0u8.encode(w, version); }
240
241 fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
242 let request_id = RequestId::decode(r, version)?;
243 let subscription_request_id = RequestId::decode(r, version)?;
244 let start_location = Location::decode(r, version)?;
245 let end_group = u64::decode(r, version)?;
246 let subscriber_priority = u8::decode(r, version)?;
247 let forward = bool::decode(r, version)?;
248 let _parameters = Parameters::decode(r, version)?;
249
250 Ok(Self {
251 request_id,
252 subscription_request_id,
253 start_location,
254 end_group,
255 subscriber_priority,
256 forward,
257 })
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use bytes::BytesMut;
265
266 fn encode_message<M: Message>(msg: &M) -> Vec<u8> {
267 let mut buf = BytesMut::new();
268 msg.encode_msg(&mut buf, Version::Draft14);
269 buf.to_vec()
270 }
271
272 fn decode_message<M: Message>(bytes: &[u8]) -> Result<M, DecodeError> {
273 let mut buf = bytes::Bytes::from(bytes.to_vec());
274 M::decode_msg(&mut buf, Version::Draft14)
275 }
276
277 #[test]
278 fn test_subscribe_round_trip() {
279 let msg = Subscribe {
280 request_id: RequestId(1),
281 track_namespace: Path::new("test"),
282 track_name: "video".into(),
283 subscriber_priority: 128,
284 group_order: GroupOrder::Descending,
285 filter_type: FilterType::LargestObject,
286 };
287
288 let encoded = encode_message(&msg);
289 let decoded: Subscribe = decode_message(&encoded).unwrap();
290
291 assert_eq!(decoded.request_id, RequestId(1));
292 assert_eq!(decoded.track_namespace.as_str(), "test");
293 assert_eq!(decoded.track_name, "video");
294 assert_eq!(decoded.subscriber_priority, 128);
295 }
296
297 #[test]
298 fn test_subscribe_nested_namespace() {
299 let msg = Subscribe {
300 request_id: RequestId(100),
301 track_namespace: Path::new("conference/room123"),
302 track_name: "audio".into(),
303 subscriber_priority: 255,
304 group_order: GroupOrder::Descending,
305 filter_type: FilterType::LargestObject,
306 };
307
308 let encoded = encode_message(&msg);
309 let decoded: Subscribe = decode_message(&encoded).unwrap();
310
311 assert_eq!(decoded.track_namespace.as_str(), "conference/room123");
312 }
313
314 #[test]
315 fn test_subscribe_ok() {
316 let msg = SubscribeOk {
317 request_id: RequestId(42),
318 track_alias: 42,
319 };
320
321 let encoded = encode_message(&msg);
322 let decoded: SubscribeOk = decode_message(&encoded).unwrap();
323
324 assert_eq!(decoded.request_id, RequestId(42));
325 }
326
327 #[test]
328 fn test_subscribe_error() {
329 let msg = SubscribeError {
330 request_id: RequestId(123),
331 error_code: 500,
332 reason_phrase: "Not found".into(),
333 };
334
335 let encoded = encode_message(&msg);
336 let decoded: SubscribeError = decode_message(&encoded).unwrap();
337
338 assert_eq!(decoded.request_id, RequestId(123));
339 assert_eq!(decoded.error_code, 500);
340 assert_eq!(decoded.reason_phrase, "Not found");
341 }
342
343 #[test]
344 fn test_unsubscribe() {
345 let msg = Unsubscribe {
346 request_id: RequestId(999),
347 };
348
349 let encoded = encode_message(&msg);
350 let decoded: Unsubscribe = decode_message(&encoded).unwrap();
351
352 assert_eq!(decoded.request_id, RequestId(999));
353 }
354
355 #[test]
356 fn test_subscribe_rejects_invalid_filter_type() {
357 #[rustfmt::skip]
358 let invalid_bytes = vec![
359 0x01, 0x02, 0x01, 0x04, 0x74, 0x65, 0x73, 0x74, 0x05, 0x76, 0x69, 0x64, 0x65, 0x6f, 0x80, 0x02, 0x99, 0x00, ];
369
370 let result: Result<Subscribe, _> = decode_message(&invalid_bytes);
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn test_subscribe_ok_rejects_non_zero_expires() {
376 #[rustfmt::skip]
377 let invalid_bytes = vec![
378 0x01, 0x05, 0x02, 0x00, 0x00, ];
384
385 let result: Result<SubscribeOk, _> = decode_message(&invalid_bytes);
386 assert!(result.is_err());
387 }
388}