mqtt_bytes_v5/
subscribe.rs

1use crate::MqttString;
2
3use super::{
4    len_len, length, property, qos, read_mqtt_string, read_u16, read_u8, vec, write_mqtt_string,
5    write_remaining_length, BufMut, BytesMut, Debug, Error, FixedHeader, PropertyType, QoS,
6};
7use bytes::{Buf, Bytes};
8
9/// Subscription packet
10#[derive(Clone, Debug, PartialEq, Eq, Default)]
11pub struct Subscribe {
12    pub pkid: u16,
13    pub filters: Vec<Filter>,
14    pub properties: Option<SubscribeProperties>,
15}
16
17impl Subscribe {
18    #[must_use]
19    pub fn new(filter: Filter, properties: Option<SubscribeProperties>) -> Self {
20        Self {
21            filters: vec![filter],
22            properties,
23            ..Default::default()
24        }
25    }
26
27    pub fn new_many<F>(filters: F, properties: Option<SubscribeProperties>) -> Self
28    where
29        F: IntoIterator<Item = Filter>,
30    {
31        Self {
32            filters: filters.into_iter().collect(),
33            properties,
34            ..Default::default()
35        }
36    }
37
38    #[must_use]
39    pub fn size(&self) -> usize {
40        let len = self.len();
41        let remaining_len_size = len_len(len);
42
43        1 + remaining_len_size + len
44    }
45
46    fn len(&self) -> usize {
47        let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len());
48
49        if let Some(p) = &self.properties {
50            let properties_len = p.len();
51            let properties_len_len = len_len(properties_len);
52            len += properties_len_len + properties_len;
53        } else {
54            // just 1 byte representing 0 len
55            len += 1;
56        }
57
58        len
59    }
60
61    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Subscribe, Error> {
62        let variable_header_index = fixed_header.fixed_header_len;
63        bytes.advance(variable_header_index);
64
65        let pkid = read_u16(&mut bytes)?;
66        let properties = SubscribeProperties::read(&mut bytes)?;
67
68        // variable header size = 2 (packet identifier)
69        let filters = Filter::read(&mut bytes)?;
70
71        match filters.len() {
72            0 => Err(Error::EmptySubscription),
73            _ => Ok(Subscribe {
74                pkid,
75                filters,
76                properties,
77            }),
78        }
79    }
80
81    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
82        // write packet type
83        buffer.put_u8(0x82);
84
85        // write remaining length
86        let remaining_len = self.len();
87        let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?;
88
89        // write packet id
90        buffer.put_u16(self.pkid);
91
92        if let Some(p) = &self.properties {
93            p.write(buffer)?;
94        } else {
95            write_remaining_length(buffer, 0)?;
96        }
97
98        // write filters
99        for f in &self.filters {
100            f.write(buffer)?;
101        }
102
103        Ok(1 + remaining_len_bytes + remaining_len)
104    }
105}
106
107///  Subscription filter
108#[derive(Clone, Debug, PartialEq, Eq, Default)]
109pub struct Filter {
110    pub path: MqttString,
111    pub qos: QoS,
112    pub nolocal: bool,
113    pub preserve_retain: bool,
114    pub retain_forward_rule: RetainForwardRule,
115}
116
117impl Filter {
118    pub fn new<T: Into<MqttString>>(topic: T, qos: QoS) -> Self {
119        Self {
120            path: topic.into(),
121            qos,
122            ..Default::default()
123        }
124    }
125
126    fn len(&self) -> usize {
127        // filter len + filter + options
128        2 + self.path.len() + 1
129    }
130
131    pub fn read(bytes: &mut Bytes) -> Result<Vec<Filter>, Error> {
132        // variable header size = 2 (packet identifier)
133        let mut filters = Vec::new();
134
135        while bytes.has_remaining() {
136            let path = read_mqtt_string(bytes)?;
137            let options = read_u8(bytes)?;
138            let requested_qos = options & 0b0000_0011;
139
140            let nolocal = options >> 2 & 0b0000_0001;
141            let nolocal = nolocal != 0;
142
143            let preserve_retain = options >> 3 & 0b0000_0001;
144            let preserve_retain = preserve_retain != 0;
145
146            let retain_forward_rule = (options >> 4) & 0b0000_0011;
147            let retain_forward_rule = match retain_forward_rule {
148                0 => RetainForwardRule::OnEverySubscribe,
149                1 => RetainForwardRule::OnNewSubscribe,
150                2 => RetainForwardRule::Never,
151                r => return Err(Error::InvalidRetainForwardRule(r)),
152            };
153
154            filters.push(Filter {
155                path,
156                qos: qos(requested_qos).ok_or(Error::InvalidQoS(requested_qos))?,
157                nolocal,
158                preserve_retain,
159                retain_forward_rule,
160            });
161        }
162
163        Ok(filters)
164    }
165
166    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
167        let mut options = 0;
168        options |= self.qos as u8;
169
170        if self.nolocal {
171            options |= 0b0000_0100;
172        }
173
174        if self.preserve_retain {
175            options |= 0b0000_1000;
176        }
177
178        options |= match self.retain_forward_rule {
179            RetainForwardRule::OnEverySubscribe => 0b0000_0000,
180            RetainForwardRule::OnNewSubscribe => 0b0001_0000,
181            RetainForwardRule::Never => 0b0010_0000,
182        };
183
184        write_mqtt_string(buffer, &self.path)?;
185        buffer.put_u8(options);
186        Ok(())
187    }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
191pub enum RetainForwardRule {
192    OnEverySubscribe,
193    OnNewSubscribe,
194    Never,
195}
196
197impl Default for RetainForwardRule {
198    fn default() -> Self {
199        Self::OnEverySubscribe
200    }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct SubscribeProperties {
205    pub id: Option<usize>,
206    pub user_properties: Vec<(MqttString, MqttString)>,
207}
208
209impl SubscribeProperties {
210    fn len(&self) -> usize {
211        let mut len = 0;
212
213        if let Some(id) = &self.id {
214            len += 1 + len_len(*id);
215        }
216
217        for (key, value) in &self.user_properties {
218            len += 1 + 2 + key.len() + 2 + value.len();
219        }
220
221        len
222    }
223
224    pub fn read(bytes: &mut Bytes) -> Result<Option<SubscribeProperties>, Error> {
225        let mut id = None;
226        let mut user_properties = Vec::new();
227
228        let (properties_len_len, properties_len) = length(bytes.iter())?;
229        bytes.advance(properties_len_len);
230
231        if properties_len == 0 {
232            return Ok(None);
233        }
234
235        let mut cursor = 0;
236        // read until cursor reaches property length. properties_len = 0 will skip this loop
237        while cursor < properties_len {
238            let prop = read_u8(bytes)?;
239            cursor += 1;
240
241            match property(prop)? {
242                PropertyType::SubscriptionIdentifier => {
243                    let (id_len, sub_id) = length(bytes.iter())?;
244                    cursor += id_len;
245                    bytes.advance(id_len);
246                    id = Some(sub_id);
247                }
248                PropertyType::UserProperty => {
249                    let key = read_mqtt_string(bytes)?;
250                    let value = read_mqtt_string(bytes)?;
251                    cursor += 2 + key.len() + 2 + value.len();
252                    user_properties.push((key, value));
253                }
254                _ => return Err(Error::InvalidPropertyType(prop)),
255            }
256        }
257
258        if cursor > properties_len {
259            return Err(Error::MalformedPacket);
260        }
261
262        Ok(Some(SubscribeProperties {
263            id,
264            user_properties,
265        }))
266    }
267
268    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
269        let len = self.len();
270        write_remaining_length(buffer, len)?;
271
272        if let Some(id) = &self.id {
273            buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
274            write_remaining_length(buffer, *id)?;
275        }
276
277        for (key, value) in &self.user_properties {
278            buffer.put_u8(PropertyType::UserProperty as u8);
279            write_mqtt_string(buffer, key)?;
280            write_mqtt_string(buffer, value)?;
281        }
282
283        Ok(())
284    }
285}
286
287#[cfg(test)]
288mod test {
289    use crate::test::read_write_packets;
290    use crate::Packet;
291
292    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
293    use super::*;
294    use bytes::BytesMut;
295    use pretty_assertions::assert_eq;
296
297    #[test]
298    fn length_calculation() {
299        let mut dummy_bytes = BytesMut::new();
300        // Use user_properties to pad the size to exceed ~128 bytes to make the
301        // remaining_length field in the packet be 2 bytes long.
302        let subscribe_props = SubscribeProperties {
303            id: None,
304            user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
305        };
306
307        let subscribe_pkt = Subscribe::new(
308            Filter::new("hello/world", QoS::AtMostOnce),
309            Some(subscribe_props),
310        );
311
312        let size_from_size = subscribe_pkt.size();
313        let size_from_write = subscribe_pkt.write(&mut dummy_bytes).unwrap();
314        let size_from_bytes = dummy_bytes.len();
315
316        assert_eq!(size_from_write, size_from_bytes);
317        assert_eq!(size_from_size, size_from_bytes);
318    }
319
320    #[test]
321    fn test_write_read() {
322        read_write_packets(write_read_provider());
323    }
324
325    fn write_read_provider() -> Vec<Packet> {
326        vec![
327            Packet::Subscribe(Subscribe {
328                pkid: 0,
329                filters: vec![Filter {
330                    path: "hello/world".into(),
331                    qos: QoS::AtLeastOnce,
332                    nolocal: false,
333                    preserve_retain: false,
334                    retain_forward_rule: RetainForwardRule::OnEverySubscribe,
335                }],
336                properties: None,
337            }),
338            Packet::Subscribe(Subscribe {
339                pkid: 0,
340                filters: vec![Filter {
341                    path: "hello/world".into(),
342                    qos: QoS::ExactlyOnce,
343                    nolocal: false,
344                    preserve_retain: false,
345                    retain_forward_rule: RetainForwardRule::OnEverySubscribe,
346                }],
347                properties: None,
348            }),
349            Packet::Subscribe(Subscribe {
350                pkid: 42,
351                filters: vec![Filter {
352                    path: "hello/world".into(),
353                    qos: QoS::AtMostOnce,
354                    nolocal: false,
355                    preserve_retain: false,
356                    retain_forward_rule: RetainForwardRule::OnEverySubscribe,
357                }],
358                properties: None,
359            }),
360            Packet::Subscribe(Subscribe {
361                pkid: 42,
362                filters: vec![Filter {
363                    path: "hello/world".into(),
364                    qos: QoS::AtMostOnce,
365                    nolocal: false,
366                    preserve_retain: false,
367                    retain_forward_rule: RetainForwardRule::OnEverySubscribe,
368                }],
369                properties: Some(SubscribeProperties {
370                    id: None,
371                    user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
372                }),
373            }),
374            Packet::Subscribe(Subscribe {
375                pkid: 42,
376                filters: vec![
377                    Filter {
378                        path: "hello/world".into(),
379                        qos: QoS::AtMostOnce,
380                        nolocal: true,
381                        preserve_retain: false,
382                        retain_forward_rule: RetainForwardRule::OnEverySubscribe,
383                    },
384                    Filter {
385                        path: "hello/world".into(),
386                        qos: QoS::AtMostOnce,
387                        nolocal: false,
388                        preserve_retain: true,
389                        retain_forward_rule: RetainForwardRule::OnEverySubscribe,
390                    },
391                ],
392                properties: Some(SubscribeProperties {
393                    id: Some(1),
394                    user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
395                }),
396            }),
397            Packet::Subscribe(Subscribe {
398                pkid: 42,
399                filters: vec![Filter {
400                    path: "hello/world".into(),
401                    qos: QoS::AtMostOnce,
402                    nolocal: true,
403                    preserve_retain: false,
404                    retain_forward_rule: RetainForwardRule::OnEverySubscribe,
405                }],
406                properties: Some(SubscribeProperties {
407                    id: Some(100_000_000),
408                    user_properties: vec![("f".into(), String::new())],
409                }),
410            }),
411            Packet::Subscribe(Subscribe {
412                pkid: 42,
413                filters: vec![Filter {
414                    path: "hello/world".into(),
415                    qos: QoS::AtMostOnce,
416                    nolocal: true,
417                    preserve_retain: false,
418                    retain_forward_rule: RetainForwardRule::OnEverySubscribe,
419                }],
420                properties: Some(SubscribeProperties {
421                    id: Some(100_000_000),
422                    user_properties: vec![],
423                }),
424            }),
425        ]
426    }
427}