mqtt_bytes_v5/
publish.rs

1use crate::MqttString;
2
3use super::{
4    len_len, length, property, qos, read_mqtt_bytes, read_mqtt_string, read_u16, read_u32, read_u8,
5    write_mqtt_bytes, write_mqtt_string, write_remaining_length, BufMut, BytesMut, Debug, Error,
6    FixedHeader, PropertyType, QoS,
7};
8use bytes::{Buf, Bytes};
9
10/// Publish packet
11#[derive(Clone, Debug, PartialEq, Eq, Default)]
12pub struct Publish {
13    pub dup: bool,
14    pub qos: QoS,
15    pub retain: bool,
16    pub topic: MqttString,
17    pub pkid: u16,
18    pub payload: Bytes,
19    pub properties: Option<PublishProperties>,
20}
21
22impl Publish {
23    pub fn new<T: Into<MqttString>, P: Into<Bytes>>(
24        topic: T,
25        qos: QoS,
26        payload: P,
27        properties: Option<PublishProperties>,
28    ) -> Self {
29        Self {
30            qos,
31            topic: topic.into(),
32            payload: payload.into(),
33            properties,
34            ..Default::default()
35        }
36    }
37
38    pub fn size(&self) -> usize {
39        let len = self.len();
40        let remaining_len_size = len_len(len);
41
42        1 + remaining_len_size + len
43    }
44
45    fn len(&self) -> usize {
46        let mut len = 2 + self.topic.len();
47        if self.qos != QoS::AtMostOnce && self.pkid != 0 {
48            len += 2;
49        }
50
51        if let Some(p) = &self.properties {
52            let properties_len = p.len();
53            let properties_len_len = len_len(properties_len);
54            len += properties_len_len + properties_len;
55        } else {
56            // just 1 byte representing 0 len
57            len += 1;
58        }
59
60        len += self.payload.len();
61        len
62    }
63
64    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Publish, Error> {
65        let qos_num = (fixed_header.byte1 & 0b0110) >> 1;
66        let qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
67        let dup = (fixed_header.byte1 & 0b1000) != 0;
68        let retain = (fixed_header.byte1 & 0b0001) != 0;
69
70        let variable_header_index = fixed_header.fixed_header_len;
71        bytes.advance(variable_header_index);
72        let topic = read_mqtt_string(&mut bytes)?;
73
74        // Packet identifier exists where QoS > 0
75        let pkid = match qos {
76            QoS::AtMostOnce => 0,
77            QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?,
78        };
79
80        if qos != QoS::AtMostOnce && pkid == 0 {
81            return Err(Error::PacketIdZero);
82        }
83
84        let properties = PublishProperties::read(&mut bytes)?;
85        let publish = Publish {
86            dup,
87            retain,
88            qos,
89            pkid,
90            topic,
91            payload: bytes,
92            properties,
93        };
94
95        Ok(publish)
96    }
97
98    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
99        let len = self.len();
100
101        let dup = u8::from(self.dup);
102        let qos = self.qos as u8;
103        let retain = u8::from(self.retain);
104        buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3);
105
106        let count = write_remaining_length(buffer, len)?;
107        write_mqtt_string(buffer, &self.topic)?;
108
109        if self.qos != QoS::AtMostOnce {
110            let pkid = self.pkid;
111            if pkid == 0 {
112                return Err(Error::PacketIdZero);
113            }
114
115            buffer.put_u16(pkid);
116        }
117
118        if let Some(p) = &self.properties {
119            p.write(buffer)?;
120        } else {
121            write_remaining_length(buffer, 0)?;
122        }
123
124        buffer.extend_from_slice(&self.payload);
125
126        Ok(1 + count + len)
127    }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Default)]
131pub struct PublishProperties {
132    pub payload_format_indicator: Option<u8>,
133    pub message_expiry_interval: Option<u32>,
134    pub topic_alias: Option<u16>,
135    pub response_topic: Option<MqttString>,
136    pub correlation_data: Option<Bytes>,
137    pub user_properties: Vec<(MqttString, MqttString)>,
138    pub subscription_identifiers: Vec<usize>,
139    pub content_type: Option<MqttString>,
140}
141
142impl PublishProperties {
143    fn len(&self) -> usize {
144        let mut len = 0;
145
146        if self.payload_format_indicator.is_some() {
147            len += 1 + 1;
148        }
149
150        if self.message_expiry_interval.is_some() {
151            len += 1 + 4;
152        }
153
154        if self.topic_alias.is_some() {
155            len += 1 + 2;
156        }
157
158        if let Some(topic) = &self.response_topic {
159            len += 1 + 2 + topic.len();
160        }
161
162        if let Some(data) = &self.correlation_data {
163            len += 1 + 2 + data.len();
164        }
165
166        for (key, value) in &self.user_properties {
167            len += 1 + 2 + key.len() + 2 + value.len();
168        }
169
170        for id in &self.subscription_identifiers {
171            len += 1 + len_len(*id);
172        }
173
174        if let Some(typ) = &self.content_type {
175            len += 1 + 2 + typ.len();
176        }
177
178        len
179    }
180
181    pub fn read(bytes: &mut Bytes) -> Result<Option<PublishProperties>, Error> {
182        let mut payload_format_indicator = None;
183        let mut message_expiry_interval = None;
184        let mut topic_alias = None;
185        let mut response_topic = None;
186        let mut correlation_data = None;
187        let mut user_properties = Vec::new();
188        let mut subscription_identifiers = Vec::new();
189        let mut content_type = None;
190
191        let (properties_len_len, properties_len) = length(bytes.iter())?;
192        bytes.advance(properties_len_len);
193        if properties_len == 0 {
194            return Ok(None);
195        }
196
197        let mut cursor = 0;
198        // read until cursor reaches property length. properties_len = 0 will skip this loop
199        while cursor < properties_len {
200            let prop = read_u8(bytes)?;
201            cursor += 1;
202
203            match property(prop)? {
204                PropertyType::PayloadFormatIndicator => {
205                    payload_format_indicator = Some(read_u8(bytes)?);
206                    cursor += 1;
207                }
208                PropertyType::MessageExpiryInterval => {
209                    message_expiry_interval = Some(read_u32(bytes)?);
210                    cursor += 4;
211                }
212                PropertyType::TopicAlias => {
213                    topic_alias = Some(read_u16(bytes)?);
214                    cursor += 2;
215                }
216                PropertyType::ResponseTopic => {
217                    let topic = read_mqtt_string(bytes)?;
218                    cursor += 2 + topic.len();
219                    response_topic = Some(topic);
220                }
221                PropertyType::CorrelationData => {
222                    let data = read_mqtt_bytes(bytes)?;
223                    cursor += 2 + data.len();
224                    correlation_data = Some(data);
225                }
226                PropertyType::UserProperty => {
227                    let key = read_mqtt_string(bytes)?;
228                    let value = read_mqtt_string(bytes)?;
229                    cursor += 2 + key.len() + 2 + value.len();
230                    user_properties.push((key, value));
231                }
232                PropertyType::SubscriptionIdentifier => {
233                    let (id_len, id) = length(bytes.iter())?;
234                    cursor += 1 + id_len;
235                    bytes.advance(id_len);
236                    subscription_identifiers.push(id);
237                }
238                PropertyType::ContentType => {
239                    let typ = read_mqtt_string(bytes)?;
240                    cursor += 2 + typ.len();
241                    content_type = Some(typ);
242                }
243                _ => return Err(Error::InvalidPropertyType(prop)),
244            }
245        }
246
247        Ok(Some(PublishProperties {
248            payload_format_indicator,
249            message_expiry_interval,
250            topic_alias,
251            response_topic,
252            correlation_data,
253            user_properties,
254            subscription_identifiers,
255            content_type,
256        }))
257    }
258
259    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
260        let len = self.len();
261        write_remaining_length(buffer, len)?;
262
263        if let Some(payload_format_indicator) = self.payload_format_indicator {
264            buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
265            buffer.put_u8(payload_format_indicator);
266        }
267
268        if let Some(message_expiry_interval) = self.message_expiry_interval {
269            buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
270            buffer.put_u32(message_expiry_interval);
271        }
272
273        if let Some(topic_alias) = self.topic_alias {
274            buffer.put_u8(PropertyType::TopicAlias as u8);
275            buffer.put_u16(topic_alias);
276        }
277
278        if let Some(topic) = &self.response_topic {
279            buffer.put_u8(PropertyType::ResponseTopic as u8);
280            write_mqtt_string(buffer, topic)?;
281        }
282
283        if let Some(data) = &self.correlation_data {
284            buffer.put_u8(PropertyType::CorrelationData as u8);
285            write_mqtt_bytes(buffer, data)?;
286        }
287
288        for (key, value) in &self.user_properties {
289            buffer.put_u8(PropertyType::UserProperty as u8);
290            write_mqtt_string(buffer, key)?;
291            write_mqtt_string(buffer, value)?;
292        }
293
294        for id in &self.subscription_identifiers {
295            buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
296            write_remaining_length(buffer, *id)?;
297        }
298
299        if let Some(typ) = &self.content_type {
300            buffer.put_u8(PropertyType::ContentType as u8);
301            write_mqtt_string(buffer, typ)?;
302        }
303
304        Ok(())
305    }
306}
307
308#[cfg(test)]
309mod test {
310    use crate::test::read_write_packets;
311    use crate::Packet;
312
313    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
314    use super::*;
315    use bytes::BytesMut;
316    use pretty_assertions::assert_eq;
317
318    #[test]
319    fn length_calculation() {
320        let mut dummy_bytes = BytesMut::new();
321        // Use user_properties to pad the size to exceed ~128 bytes to make the
322        // remaining_length field in the packet be 2 bytes long.
323        let publish_props = PublishProperties {
324            payload_format_indicator: None,
325            message_expiry_interval: None,
326            topic_alias: None,
327            response_topic: None,
328            correlation_data: None,
329            user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
330            subscription_identifiers: vec![1],
331            content_type: None,
332        };
333
334        let publish_pkt = Publish::new(
335            "hello/world",
336            QoS::AtMostOnce,
337            vec![1; 10],
338            Some(publish_props),
339        );
340
341        let size_from_size = publish_pkt.size();
342        let size_from_write = publish_pkt.write(&mut dummy_bytes).unwrap();
343        let size_from_bytes = dummy_bytes.len();
344
345        assert_eq!(size_from_write, size_from_bytes);
346        assert_eq!(size_from_size, size_from_bytes);
347    }
348
349    #[test]
350    fn test_write_read() {
351        read_write_packets(write_read_provider());
352    }
353
354    fn write_read_provider() -> Vec<Packet> {
355        vec![
356            Packet::Publish(Publish::new(
357                "hello/world",
358                QoS::AtMostOnce,
359                vec![1; 10],
360                None,
361            )),
362            Packet::Publish(Publish::new(
363                "hello/world",
364                QoS::AtMostOnce,
365                vec![1; 10],
366                None,
367            )),
368            Packet::Publish(Publish {
369                dup: true,
370                qos: QoS::ExactlyOnce,
371                retain: true,
372                topic: "hello/world".into(),
373                pkid: 12,
374                payload: vec![1; 10].into(),
375                properties: None,
376            }),
377            Packet::Publish(Publish {
378                dup: true,
379                qos: QoS::AtLeastOnce,
380                retain: true,
381                topic: "hello/world".into(),
382                pkid: 12,
383                payload: vec![1; 10].into(),
384                properties: None,
385            }),
386            Packet::Publish(Publish::new(
387                "hello/world",
388                QoS::AtMostOnce,
389                vec![1; 10],
390                Some(PublishProperties {
391                    payload_format_indicator: Some(1),
392                    message_expiry_interval: Some(100),
393                    topic_alias: Some(10),
394                    response_topic: Some("response/topic".into()),
395                    correlation_data: Some(vec![1, 2, 3].into()),
396                    user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
397                    subscription_identifiers: vec![1],
398                    content_type: Some("content/type".into()),
399                }),
400            )),
401        ]
402    }
403}