1use crate::MqttString;
2
3use super::{
4 len_len, length, property, qos, read_mqtt_string, read_u16, read_u8, vec, write_mqtt_string,
5 write_remaining_length, BufMut, BytesMut, Debug, Error, FixedHeader, PropertyType, QoS,
6};
7use bytes::{Buf, Bytes};
8
9#[derive(Clone, Debug, PartialEq, Eq, Default)]
11pub struct Subscribe {
12 pub pkid: u16,
13 pub filters: Vec<Filter>,
14 pub properties: Option<SubscribeProperties>,
15}
16
17impl Subscribe {
18 #[must_use]
19 pub fn new(filter: Filter, properties: Option<SubscribeProperties>) -> Self {
20 Self {
21 filters: vec![filter],
22 properties,
23 ..Default::default()
24 }
25 }
26
27 pub fn new_many<F>(filters: F, properties: Option<SubscribeProperties>) -> Self
28 where
29 F: IntoIterator<Item = Filter>,
30 {
31 Self {
32 filters: filters.into_iter().collect(),
33 properties,
34 ..Default::default()
35 }
36 }
37
38 #[must_use]
39 pub fn size(&self) -> usize {
40 let len = self.len();
41 let remaining_len_size = len_len(len);
42
43 1 + remaining_len_size + len
44 }
45
46 fn len(&self) -> usize {
47 let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len());
48
49 if let Some(p) = &self.properties {
50 let properties_len = p.len();
51 let properties_len_len = len_len(properties_len);
52 len += properties_len_len + properties_len;
53 } else {
54 len += 1;
56 }
57
58 len
59 }
60
61 pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Subscribe, Error> {
62 let variable_header_index = fixed_header.fixed_header_len;
63 bytes.advance(variable_header_index);
64
65 let pkid = read_u16(&mut bytes)?;
66 let properties = SubscribeProperties::read(&mut bytes)?;
67
68 let filters = Filter::read(&mut bytes)?;
70
71 match filters.len() {
72 0 => Err(Error::EmptySubscription),
73 _ => Ok(Subscribe {
74 pkid,
75 filters,
76 properties,
77 }),
78 }
79 }
80
81 pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
82 buffer.put_u8(0x82);
84
85 let remaining_len = self.len();
87 let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?;
88
89 buffer.put_u16(self.pkid);
91
92 if let Some(p) = &self.properties {
93 p.write(buffer)?;
94 } else {
95 write_remaining_length(buffer, 0)?;
96 }
97
98 for f in &self.filters {
100 f.write(buffer)?;
101 }
102
103 Ok(1 + remaining_len_bytes + remaining_len)
104 }
105}
106
107#[derive(Clone, Debug, PartialEq, Eq, Default)]
109pub struct Filter {
110 pub path: MqttString,
111 pub qos: QoS,
112 pub nolocal: bool,
113 pub preserve_retain: bool,
114 pub retain_forward_rule: RetainForwardRule,
115}
116
117impl Filter {
118 pub fn new<T: Into<MqttString>>(topic: T, qos: QoS) -> Self {
119 Self {
120 path: topic.into(),
121 qos,
122 ..Default::default()
123 }
124 }
125
126 fn len(&self) -> usize {
127 2 + self.path.len() + 1
129 }
130
131 pub fn read(bytes: &mut Bytes) -> Result<Vec<Filter>, Error> {
132 let mut filters = Vec::new();
134
135 while bytes.has_remaining() {
136 let path = read_mqtt_string(bytes)?;
137 let options = read_u8(bytes)?;
138 let requested_qos = options & 0b0000_0011;
139
140 let nolocal = options >> 2 & 0b0000_0001;
141 let nolocal = nolocal != 0;
142
143 let preserve_retain = options >> 3 & 0b0000_0001;
144 let preserve_retain = preserve_retain != 0;
145
146 let retain_forward_rule = (options >> 4) & 0b0000_0011;
147 let retain_forward_rule = match retain_forward_rule {
148 0 => RetainForwardRule::OnEverySubscribe,
149 1 => RetainForwardRule::OnNewSubscribe,
150 2 => RetainForwardRule::Never,
151 r => return Err(Error::InvalidRetainForwardRule(r)),
152 };
153
154 filters.push(Filter {
155 path,
156 qos: qos(requested_qos).ok_or(Error::InvalidQoS(requested_qos))?,
157 nolocal,
158 preserve_retain,
159 retain_forward_rule,
160 });
161 }
162
163 Ok(filters)
164 }
165
166 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
167 let mut options = 0;
168 options |= self.qos as u8;
169
170 if self.nolocal {
171 options |= 0b0000_0100;
172 }
173
174 if self.preserve_retain {
175 options |= 0b0000_1000;
176 }
177
178 options |= match self.retain_forward_rule {
179 RetainForwardRule::OnEverySubscribe => 0b0000_0000,
180 RetainForwardRule::OnNewSubscribe => 0b0001_0000,
181 RetainForwardRule::Never => 0b0010_0000,
182 };
183
184 write_mqtt_string(buffer, &self.path)?;
185 buffer.put_u8(options);
186 Ok(())
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
191pub enum RetainForwardRule {
192 OnEverySubscribe,
193 OnNewSubscribe,
194 Never,
195}
196
197impl Default for RetainForwardRule {
198 fn default() -> Self {
199 Self::OnEverySubscribe
200 }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct SubscribeProperties {
205 pub id: Option<usize>,
206 pub user_properties: Vec<(MqttString, MqttString)>,
207}
208
209impl SubscribeProperties {
210 fn len(&self) -> usize {
211 let mut len = 0;
212
213 if let Some(id) = &self.id {
214 len += 1 + len_len(*id);
215 }
216
217 for (key, value) in &self.user_properties {
218 len += 1 + 2 + key.len() + 2 + value.len();
219 }
220
221 len
222 }
223
224 pub fn read(bytes: &mut Bytes) -> Result<Option<SubscribeProperties>, Error> {
225 let mut id = None;
226 let mut user_properties = Vec::new();
227
228 let (properties_len_len, properties_len) = length(bytes.iter())?;
229 bytes.advance(properties_len_len);
230
231 if properties_len == 0 {
232 return Ok(None);
233 }
234
235 let mut cursor = 0;
236 while cursor < properties_len {
238 let prop = read_u8(bytes)?;
239 cursor += 1;
240
241 match property(prop)? {
242 PropertyType::SubscriptionIdentifier => {
243 let (id_len, sub_id) = length(bytes.iter())?;
244 cursor += id_len;
245 bytes.advance(id_len);
246 id = Some(sub_id);
247 }
248 PropertyType::UserProperty => {
249 let key = read_mqtt_string(bytes)?;
250 let value = read_mqtt_string(bytes)?;
251 cursor += 2 + key.len() + 2 + value.len();
252 user_properties.push((key, value));
253 }
254 _ => return Err(Error::InvalidPropertyType(prop)),
255 }
256 }
257
258 if cursor > properties_len {
259 return Err(Error::MalformedPacket);
260 }
261
262 Ok(Some(SubscribeProperties {
263 id,
264 user_properties,
265 }))
266 }
267
268 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
269 let len = self.len();
270 write_remaining_length(buffer, len)?;
271
272 if let Some(id) = &self.id {
273 buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
274 write_remaining_length(buffer, *id)?;
275 }
276
277 for (key, value) in &self.user_properties {
278 buffer.put_u8(PropertyType::UserProperty as u8);
279 write_mqtt_string(buffer, key)?;
280 write_mqtt_string(buffer, value)?;
281 }
282
283 Ok(())
284 }
285}
286
287#[cfg(test)]
288mod test {
289 use crate::test::read_write_packets;
290 use crate::Packet;
291
292 use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
293 use super::*;
294 use bytes::BytesMut;
295 use pretty_assertions::assert_eq;
296
297 #[test]
298 fn length_calculation() {
299 let mut dummy_bytes = BytesMut::new();
300 let subscribe_props = SubscribeProperties {
303 id: None,
304 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
305 };
306
307 let subscribe_pkt = Subscribe::new(
308 Filter::new("hello/world", QoS::AtMostOnce),
309 Some(subscribe_props),
310 );
311
312 let size_from_size = subscribe_pkt.size();
313 let size_from_write = subscribe_pkt.write(&mut dummy_bytes).unwrap();
314 let size_from_bytes = dummy_bytes.len();
315
316 assert_eq!(size_from_write, size_from_bytes);
317 assert_eq!(size_from_size, size_from_bytes);
318 }
319
320 #[test]
321 fn test_write_read() {
322 read_write_packets(write_read_provider());
323 }
324
325 fn write_read_provider() -> Vec<Packet> {
326 vec![
327 Packet::Subscribe(Subscribe {
328 pkid: 0,
329 filters: vec![Filter {
330 path: "hello/world".into(),
331 qos: QoS::AtLeastOnce,
332 nolocal: false,
333 preserve_retain: false,
334 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
335 }],
336 properties: None,
337 }),
338 Packet::Subscribe(Subscribe {
339 pkid: 0,
340 filters: vec![Filter {
341 path: "hello/world".into(),
342 qos: QoS::ExactlyOnce,
343 nolocal: false,
344 preserve_retain: false,
345 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
346 }],
347 properties: None,
348 }),
349 Packet::Subscribe(Subscribe {
350 pkid: 42,
351 filters: vec![Filter {
352 path: "hello/world".into(),
353 qos: QoS::AtMostOnce,
354 nolocal: false,
355 preserve_retain: false,
356 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
357 }],
358 properties: None,
359 }),
360 Packet::Subscribe(Subscribe {
361 pkid: 42,
362 filters: vec![Filter {
363 path: "hello/world".into(),
364 qos: QoS::AtMostOnce,
365 nolocal: false,
366 preserve_retain: false,
367 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
368 }],
369 properties: Some(SubscribeProperties {
370 id: None,
371 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
372 }),
373 }),
374 Packet::Subscribe(Subscribe {
375 pkid: 42,
376 filters: vec![
377 Filter {
378 path: "hello/world".into(),
379 qos: QoS::AtMostOnce,
380 nolocal: true,
381 preserve_retain: false,
382 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
383 },
384 Filter {
385 path: "hello/world".into(),
386 qos: QoS::AtMostOnce,
387 nolocal: false,
388 preserve_retain: true,
389 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
390 },
391 ],
392 properties: Some(SubscribeProperties {
393 id: Some(1),
394 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
395 }),
396 }),
397 Packet::Subscribe(Subscribe {
398 pkid: 42,
399 filters: vec![Filter {
400 path: "hello/world".into(),
401 qos: QoS::AtMostOnce,
402 nolocal: true,
403 preserve_retain: false,
404 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
405 }],
406 properties: Some(SubscribeProperties {
407 id: Some(100_000_000),
408 user_properties: vec![("f".into(), String::new())],
409 }),
410 }),
411 Packet::Subscribe(Subscribe {
412 pkid: 42,
413 filters: vec![Filter {
414 path: "hello/world".into(),
415 qos: QoS::AtMostOnce,
416 nolocal: true,
417 preserve_retain: false,
418 retain_forward_rule: RetainForwardRule::OnEverySubscribe,
419 }],
420 properties: Some(SubscribeProperties {
421 id: Some(100_000_000),
422 user_properties: vec![],
423 }),
424 }),
425 ]
426 }
427}