Skip to main content

rumqttc/mqttbytes/v5/
subscribe.rs

1use super::{
2    BufMut, BytesMut, Error, FixedHeader, PropertyType, QoS, len_len, length, property, qos,
3    read_mqtt_string, read_u8, read_u16, write_mqtt_string, write_remaining_length,
4};
5use bytes::{Buf, Bytes};
6
7/// Subscription packet
8#[derive(Clone, Debug, PartialEq, Eq, Default)]
9pub struct Subscribe {
10    pub pkid: u16,
11    pub filters: Vec<Filter>,
12    pub properties: Option<SubscribeProperties>,
13}
14
15impl Subscribe {
16    #[must_use]
17    pub fn new(filter: Filter, properties: Option<SubscribeProperties>) -> Self {
18        Self {
19            filters: vec![filter],
20            properties,
21            ..Default::default()
22        }
23    }
24
25    pub fn new_many<F>(filters: F, properties: Option<SubscribeProperties>) -> Self
26    where
27        F: IntoIterator<Item = Filter>,
28    {
29        Self {
30            filters: filters.into_iter().collect(),
31            properties,
32            ..Default::default()
33        }
34    }
35
36    #[must_use]
37    pub fn size(&self) -> usize {
38        let len = self.len();
39        let remaining_len_size = len_len(len);
40
41        1 + remaining_len_size + len
42    }
43
44    fn len(&self) -> usize {
45        let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len());
46
47        if let Some(p) = &self.properties {
48            let properties_len = p.len();
49            let properties_len_len = len_len(properties_len);
50            len += properties_len_len + properties_len;
51        } else {
52            // just 1 byte representing 0 len
53            len += 1;
54        }
55
56        len
57    }
58
59    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
60        let variable_header_index = fixed_header.header_len;
61        bytes.advance(variable_header_index);
62
63        let pkid = read_u16(&mut bytes)?;
64        let properties = SubscribeProperties::read(&mut bytes)?;
65
66        // variable header size = 2 (packet identifier)
67        let filters = Filter::read(&mut bytes)?;
68
69        match filters.len() {
70            0 => Err(Error::EmptySubscription),
71            _ => Ok(Self {
72                pkid,
73                filters,
74                properties,
75            }),
76        }
77    }
78
79    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
80        // write packet type
81        buffer.put_u8(0x82);
82
83        // write remaining length
84        let remaining_len = self.len();
85        let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?;
86
87        // write packet id
88        buffer.put_u16(self.pkid);
89
90        if let Some(p) = &self.properties {
91            p.write(buffer)?;
92        } else {
93            write_remaining_length(buffer, 0)?;
94        }
95
96        // write filters
97        for f in &self.filters {
98            f.write(buffer);
99        }
100
101        Ok(1 + remaining_len_bytes + remaining_len)
102    }
103}
104
105///  Subscription filter
106#[derive(Clone, Debug, PartialEq, Eq, Default)]
107pub struct Filter {
108    pub path: String,
109    pub qos: QoS,
110    pub nolocal: bool,
111    pub preserve_retain: bool,
112    pub retain_forward_rule: RetainForwardRule,
113}
114
115impl Filter {
116    pub fn new<T: Into<String>>(topic: T, qos: QoS) -> Self {
117        Self {
118            path: topic.into(),
119            qos,
120            ..Default::default()
121        }
122    }
123
124    const fn len(&self) -> usize {
125        // filter len + filter + options
126        2 + self.path.len() + 1
127    }
128
129    pub fn read(bytes: &mut Bytes) -> Result<Vec<Self>, Error> {
130        // variable header size = 2 (packet identifier)
131        let mut filters = Vec::new();
132
133        while bytes.has_remaining() {
134            let path = read_mqtt_string(bytes)?;
135            let options = read_u8(bytes)?;
136            let requested_qos = options & 0b0000_0011;
137
138            let nolocal = (options >> 2) & 0b0000_0001;
139            let nolocal = nolocal != 0;
140
141            let preserve_retain = (options >> 3) & 0b0000_0001;
142            let preserve_retain = preserve_retain != 0;
143
144            let retain_forward_rule = (options >> 4) & 0b0000_0011;
145            let retain_forward_rule = match retain_forward_rule {
146                0 => RetainForwardRule::OnEverySubscribe,
147                1 => RetainForwardRule::OnNewSubscribe,
148                2 => RetainForwardRule::Never,
149                r => return Err(Error::InvalidRetainForwardRule(r)),
150            };
151
152            filters.push(Self {
153                path,
154                qos: qos(requested_qos).ok_or(Error::InvalidQoS(requested_qos))?,
155                nolocal,
156                preserve_retain,
157                retain_forward_rule,
158            });
159        }
160
161        Ok(filters)
162    }
163
164    pub fn write(&self, buffer: &mut BytesMut) {
165        let mut options = 0;
166        options |= self.qos as u8;
167
168        if self.nolocal {
169            options |= 0b0000_0100;
170        }
171
172        if self.preserve_retain {
173            options |= 0b0000_1000;
174        }
175
176        options |= match self.retain_forward_rule {
177            RetainForwardRule::OnEverySubscribe => 0b0000_0000,
178            RetainForwardRule::OnNewSubscribe => 0b0001_0000,
179            RetainForwardRule::Never => 0b0010_0000,
180        };
181
182        write_mqtt_string(buffer, self.path.as_str());
183        buffer.put_u8(options);
184    }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, Default)]
188pub enum RetainForwardRule {
189    #[default]
190    OnEverySubscribe,
191    OnNewSubscribe,
192    Never,
193}
194
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct SubscribeProperties {
197    pub id: Option<usize>,
198    pub user_properties: Vec<(String, String)>,
199}
200
201impl SubscribeProperties {
202    fn len(&self) -> usize {
203        let mut len = 0;
204
205        if let Some(id) = &self.id {
206            len += 1 + len_len(*id);
207        }
208
209        for (key, value) in &self.user_properties {
210            len += 1 + 2 + key.len() + 2 + value.len();
211        }
212
213        len
214    }
215
216    pub fn read(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
217        let mut id = None;
218        let mut user_properties = Vec::new();
219
220        let (properties_len_len, properties_len) = length(bytes.iter())?;
221        bytes.advance(properties_len_len);
222
223        if properties_len == 0 {
224            return Ok(None);
225        }
226
227        let mut cursor = 0;
228        // read until cursor reaches property length. properties_len = 0 will skip this loop
229        while cursor < properties_len {
230            let prop = read_u8(bytes)?;
231            cursor += 1;
232
233            match property(prop)? {
234                PropertyType::SubscriptionIdentifier => {
235                    let (id_len, sub_id) = length(bytes.iter())?;
236                    if sub_id == 0 {
237                        return Err(Error::ProtocolError);
238                    }
239                    cursor += id_len;
240                    bytes.advance(id_len);
241                    id = Some(sub_id);
242                }
243                PropertyType::UserProperty => {
244                    let key = read_mqtt_string(bytes)?;
245                    let value = read_mqtt_string(bytes)?;
246                    cursor += 2 + key.len() + 2 + value.len();
247                    user_properties.push((key, value));
248                }
249                _ => return Err(Error::InvalidPropertyType(prop)),
250            }
251        }
252
253        Ok(Some(Self {
254            id,
255            user_properties,
256        }))
257    }
258
259    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
260        let len = self.len();
261        write_remaining_length(buffer, len)?;
262
263        if let Some(id) = &self.id {
264            if *id == 0 {
265                return Err(Error::ProtocolError);
266            }
267            buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
268            write_remaining_length(buffer, *id)?;
269        }
270
271        for (key, value) in &self.user_properties {
272            buffer.put_u8(PropertyType::UserProperty as u8);
273            write_mqtt_string(buffer, key);
274            write_mqtt_string(buffer, value);
275        }
276
277        Ok(())
278    }
279}
280
281#[cfg(test)]
282mod test {
283    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
284    use super::*;
285    use bytes::{Bytes, BytesMut};
286    use pretty_assertions::assert_eq;
287
288    #[test]
289    fn length_calculation() {
290        let mut dummy_bytes = BytesMut::new();
291        // Use user_properties to pad the size to exceed ~128 bytes to make the
292        // remaining_length field in the packet be 2 bytes long.
293        let subscribe_props = SubscribeProperties {
294            id: None,
295            user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
296        };
297
298        let subscribe_pkt = Subscribe::new(
299            Filter::new("hello/world", QoS::AtMostOnce),
300            Some(subscribe_props),
301        );
302
303        let size_from_size = subscribe_pkt.size();
304        let size_from_write = subscribe_pkt.write(&mut dummy_bytes).unwrap();
305        let size_from_bytes = dummy_bytes.len();
306
307        assert_eq!(size_from_write, size_from_bytes);
308        assert_eq!(size_from_size, size_from_bytes);
309    }
310
311    #[test]
312    fn read_rejects_subscription_identifier_zero() {
313        let mut bytes = Bytes::from_static(&[0x02, 0x0B, 0x00]);
314        let result = SubscribeProperties::read(&mut bytes);
315
316        assert!(matches!(result, Err(Error::ProtocolError)));
317    }
318
319    #[test]
320    fn write_rejects_subscription_identifier_zero() {
321        let props = SubscribeProperties {
322            id: Some(0),
323            user_properties: vec![],
324        };
325
326        let mut bytes = BytesMut::new();
327        let result = props.write(&mut bytes);
328
329        assert!(matches!(result, Err(Error::ProtocolError)));
330    }
331
332    #[test]
333    fn read_subscription_identifier_and_user_property_parses_both() {
334        let mut bytes = Bytes::from_static(&[
335            0x09, // properties length
336            0x0B, // SubscriptionIdentifier property
337            0x01, // varint value = 1
338            0x26, // UserProperty property
339            0x00, 0x01, b'k', // key
340            0x00, 0x01, b'v', // value
341        ]);
342
343        let properties = SubscribeProperties::read(&mut bytes)
344            .unwrap()
345            .expect("properties should be present");
346
347        assert_eq!(properties.id, Some(1));
348        assert_eq!(
349            properties.user_properties,
350            vec![("k".to_owned(), "v".to_owned())]
351        );
352    }
353}