moq_lite/ietf/
subscribe.rs

1//! IETF moq-transport-14 subscribe messages
2
3use std::borrow::Cow;
4
5use num_enum::{IntoPrimitive, TryFromPrimitive};
6
7use crate::{
8	coding::*,
9	ietf::{GroupOrder, Location, Message, Parameters, RequestId, Version},
10	Path,
11};
12
13use super::namespace::{decode_namespace, encode_namespace};
14
15#[derive(Clone, Copy, Debug, TryFromPrimitive, IntoPrimitive)]
16#[repr(u64)]
17pub enum FilterType {
18	NextGroup = 0x01,
19	LargestObject = 0x2,
20	AbsoluteStart = 0x3,
21	AbsoluteRange = 0x4,
22}
23
24impl<V> Encode<V> for FilterType {
25	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) {
26		u64::from(*self).encode(w, version);
27	}
28}
29
30impl<V> Decode<V> for FilterType {
31	fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
32		Self::try_from(u64::decode(r, version)?).map_err(|_| DecodeError::InvalidValue)
33	}
34}
35
36/// Subscribe message (0x03)
37/// Sent by the subscriber to request all future objects for the given track.
38#[derive(Clone, Debug)]
39pub struct Subscribe<'a> {
40	pub request_id: RequestId,
41	pub track_namespace: Path<'a>,
42	pub track_name: Cow<'a, str>,
43	pub subscriber_priority: u8,
44	pub group_order: GroupOrder,
45	pub filter_type: FilterType,
46}
47
48impl<'a> Message for Subscribe<'a> {
49	const ID: u64 = 0x03;
50
51	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
52		let request_id = RequestId::decode(r, version)?;
53
54		// Decode namespace (tuple of strings)
55		let track_namespace = decode_namespace(r, version)?;
56
57		let track_name = Cow::<str>::decode(r, version)?;
58		let subscriber_priority = u8::decode(r, version)?;
59
60		let group_order = GroupOrder::decode(r, version)?;
61
62		let forward = bool::decode(r, version)?;
63		if !forward {
64			return Err(DecodeError::Unsupported);
65		}
66
67		let filter_type = FilterType::decode(r, version)?;
68		match filter_type {
69			FilterType::AbsoluteStart => {
70				let _start = Location::decode(r, version)?;
71			}
72			FilterType::AbsoluteRange => {
73				let _start = Location::decode(r, version)?;
74				let _end_group = u64::decode(r, version)?;
75			}
76			FilterType::NextGroup | FilterType::LargestObject => {}
77		};
78
79		// Ignore parameters, who cares.
80		let _params = Parameters::decode(r, version)?;
81
82		Ok(Self {
83			request_id,
84			track_namespace,
85			track_name,
86			subscriber_priority,
87			group_order,
88			filter_type,
89		})
90	}
91
92	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
93		self.request_id.encode(w, version);
94		encode_namespace(w, &self.track_namespace, version);
95		self.track_name.encode(w, version);
96		self.subscriber_priority.encode(w, version);
97		GroupOrder::Descending.encode(w, version);
98		true.encode(w, version); // forward
99
100		assert!(
101			!matches!(self.filter_type, FilterType::AbsoluteStart | FilterType::AbsoluteRange),
102			"Absolute subscribe not supported"
103		);
104
105		self.filter_type.encode(w, version);
106		0u8.encode(w, version); // no parameters
107	}
108}
109
110/// SubscribeOk message (0x04)
111#[derive(Clone, Debug)]
112pub struct SubscribeOk {
113	pub request_id: RequestId,
114	pub track_alias: u64,
115}
116
117impl Message for SubscribeOk {
118	const ID: u64 = 0x04;
119
120	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
121		self.request_id.encode(w, version);
122		self.track_alias.encode(w, version);
123		0u64.encode(w, version); // expires = 0
124		GroupOrder::Descending.encode(w, version);
125		false.encode(w, version); // no content
126		0u8.encode(w, version); // no parameters
127	}
128
129	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
130		let request_id = RequestId::decode(r, version)?;
131		let track_alias = u64::decode(r, version)?;
132
133		let expires = u64::decode(r, version)?;
134		if expires != 0 {
135			return Err(DecodeError::Unsupported);
136		}
137
138		// Ignore group order, who cares.
139		let _group_order = u8::decode(r, version)?;
140
141		// TODO: We don't support largest group/object yet
142		if bool::decode(r, version)? {
143			let _group = u64::decode(r, version)?;
144			let _object = u64::decode(r, version)?;
145		}
146
147		// Ignore parameters, who cares.
148		let _params = Parameters::decode(r, version)?;
149
150		Ok(Self {
151			request_id,
152			track_alias,
153		})
154	}
155}
156
157/// SubscribeError message (0x05)
158#[derive(Clone, Debug)]
159pub struct SubscribeError<'a> {
160	pub request_id: RequestId,
161	pub error_code: u64,
162	pub reason_phrase: Cow<'a, str>,
163}
164
165impl<'a> Message for SubscribeError<'a> {
166	const ID: u64 = 0x05;
167
168	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
169		self.request_id.encode(w, version);
170		self.error_code.encode(w, version);
171		self.reason_phrase.encode(w, version);
172	}
173	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
174		let request_id = RequestId::decode(r, version)?;
175		let error_code = u64::decode(r, version)?;
176		let reason_phrase = Cow::<str>::decode(r, version)?;
177
178		Ok(Self {
179			request_id,
180			error_code,
181			reason_phrase,
182		})
183	}
184}
185
186/// Unsubscribe message (0x0a)
187#[derive(Clone, Debug)]
188pub struct Unsubscribe {
189	pub request_id: RequestId,
190}
191
192impl Message for Unsubscribe {
193	const ID: u64 = 0x0a;
194
195	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
196		self.request_id.encode(w, version);
197	}
198
199	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
200		let request_id = RequestId::decode(r, version)?;
201		Ok(Self { request_id })
202	}
203}
204
205/*
206  Type (i) = 0x2,
207  Length (16),
208  Request ID (i),
209  Subscription Request ID (i),
210  Start Location (Location),
211  End Group (i),
212  Subscriber Priority (8),
213  Forward (8),
214  Number of Parameters (i),
215  Parameters (..) ...
216*/
217#[derive(Debug)]
218pub struct SubscribeUpdate {
219	pub request_id: RequestId,
220	pub subscription_request_id: RequestId,
221	pub start_location: Location,
222	pub end_group: u64,
223	pub subscriber_priority: u8,
224	pub forward: bool,
225	// pub parameters: Parameters,
226}
227
228impl Message for SubscribeUpdate {
229	const ID: u64 = 0x02;
230
231	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) {
232		self.request_id.encode(w, version);
233		self.subscription_request_id.encode(w, version);
234		self.start_location.encode(w, version);
235		self.end_group.encode(w, version);
236		self.subscriber_priority.encode(w, version);
237		self.forward.encode(w, version);
238		0u8.encode(w, version); // no parameters
239	}
240
241	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
242		let request_id = RequestId::decode(r, version)?;
243		let subscription_request_id = RequestId::decode(r, version)?;
244		let start_location = Location::decode(r, version)?;
245		let end_group = u64::decode(r, version)?;
246		let subscriber_priority = u8::decode(r, version)?;
247		let forward = bool::decode(r, version)?;
248		let _parameters = Parameters::decode(r, version)?;
249
250		Ok(Self {
251			request_id,
252			subscription_request_id,
253			start_location,
254			end_group,
255			subscriber_priority,
256			forward,
257		})
258	}
259}
260
261#[cfg(test)]
262mod tests {
263	use super::*;
264	use bytes::BytesMut;
265
266	fn encode_message<M: Message>(msg: &M) -> Vec<u8> {
267		let mut buf = BytesMut::new();
268		msg.encode_msg(&mut buf, Version::Draft14);
269		buf.to_vec()
270	}
271
272	fn decode_message<M: Message>(bytes: &[u8]) -> Result<M, DecodeError> {
273		let mut buf = bytes::Bytes::from(bytes.to_vec());
274		M::decode_msg(&mut buf, Version::Draft14)
275	}
276
277	#[test]
278	fn test_subscribe_round_trip() {
279		let msg = Subscribe {
280			request_id: RequestId(1),
281			track_namespace: Path::new("test"),
282			track_name: "video".into(),
283			subscriber_priority: 128,
284			group_order: GroupOrder::Descending,
285			filter_type: FilterType::LargestObject,
286		};
287
288		let encoded = encode_message(&msg);
289		let decoded: Subscribe = decode_message(&encoded).unwrap();
290
291		assert_eq!(decoded.request_id, RequestId(1));
292		assert_eq!(decoded.track_namespace.as_str(), "test");
293		assert_eq!(decoded.track_name, "video");
294		assert_eq!(decoded.subscriber_priority, 128);
295	}
296
297	#[test]
298	fn test_subscribe_nested_namespace() {
299		let msg = Subscribe {
300			request_id: RequestId(100),
301			track_namespace: Path::new("conference/room123"),
302			track_name: "audio".into(),
303			subscriber_priority: 255,
304			group_order: GroupOrder::Descending,
305			filter_type: FilterType::LargestObject,
306		};
307
308		let encoded = encode_message(&msg);
309		let decoded: Subscribe = decode_message(&encoded).unwrap();
310
311		assert_eq!(decoded.track_namespace.as_str(), "conference/room123");
312	}
313
314	#[test]
315	fn test_subscribe_ok() {
316		let msg = SubscribeOk {
317			request_id: RequestId(42),
318			track_alias: 42,
319		};
320
321		let encoded = encode_message(&msg);
322		let decoded: SubscribeOk = decode_message(&encoded).unwrap();
323
324		assert_eq!(decoded.request_id, RequestId(42));
325	}
326
327	#[test]
328	fn test_subscribe_error() {
329		let msg = SubscribeError {
330			request_id: RequestId(123),
331			error_code: 500,
332			reason_phrase: "Not found".into(),
333		};
334
335		let encoded = encode_message(&msg);
336		let decoded: SubscribeError = decode_message(&encoded).unwrap();
337
338		assert_eq!(decoded.request_id, RequestId(123));
339		assert_eq!(decoded.error_code, 500);
340		assert_eq!(decoded.reason_phrase, "Not found");
341	}
342
343	#[test]
344	fn test_unsubscribe() {
345		let msg = Unsubscribe {
346			request_id: RequestId(999),
347		};
348
349		let encoded = encode_message(&msg);
350		let decoded: Unsubscribe = decode_message(&encoded).unwrap();
351
352		assert_eq!(decoded.request_id, RequestId(999));
353	}
354
355	#[test]
356	fn test_subscribe_rejects_invalid_filter_type() {
357		#[rustfmt::skip]
358		let invalid_bytes = vec![
359			0x01, // subscribe_id
360			0x02, // track_alias
361			0x01, // namespace length
362			0x04, 0x74, 0x65, 0x73, 0x74, // "test"
363			0x05, 0x76, 0x69, 0x64, 0x65, 0x6f, // "video"
364			0x80, // subscriber_priority
365			0x02, // group_order
366			0x99, // INVALID filter_type
367			0x00, // num_params
368		];
369
370		let result: Result<Subscribe, _> = decode_message(&invalid_bytes);
371		assert!(result.is_err());
372	}
373
374	#[test]
375	fn test_subscribe_ok_rejects_non_zero_expires() {
376		#[rustfmt::skip]
377		let invalid_bytes = vec![
378			0x01, // subscribe_id
379			0x05, // INVALID: expires = 5
380			0x02, // group_order
381			0x00, // content_exists
382			0x00, // num_params
383		];
384
385		let result: Result<SubscribeOk, _> = decode_message(&invalid_bytes);
386		assert!(result.is_err());
387	}
388}