rumqttc/v5/mqttbytes/v5/
publish.rs

1use super::*;
2use bytes::{Buf, Bytes};
3
4/// Publish packet
5#[derive(Clone, Debug, PartialEq, Eq, Default)]
6pub struct Publish {
7    pub dup: bool,
8    pub qos: QoS,
9    pub retain: bool,
10    pub topic: Bytes,
11    pub pkid: u16,
12    pub payload: Bytes,
13    pub properties: Option<PublishProperties>,
14}
15
16impl Publish {
17    pub fn new<T: Into<String>, P: Into<Bytes>>(
18        topic: T,
19        qos: QoS,
20        payload: P,
21        properties: Option<PublishProperties>,
22    ) -> Self {
23        let topic = Bytes::copy_from_slice(topic.into().as_bytes());
24        Self {
25            qos,
26            topic,
27            payload: payload.into(),
28            properties,
29            ..Default::default()
30        }
31    }
32
33    pub fn size(&self) -> usize {
34        let len = self.len();
35        let remaining_len_size = len_len(len);
36
37        1 + remaining_len_size + len
38    }
39
40    fn len(&self) -> usize {
41        let mut len = 2 + self.topic.len();
42        if self.qos != QoS::AtMostOnce && self.pkid != 0 {
43            len += 2;
44        }
45
46        if let Some(p) = &self.properties {
47            let properties_len = p.len();
48            let properties_len_len = len_len(properties_len);
49            len += properties_len_len + properties_len;
50        } else {
51            // just 1 byte representing 0 len
52            len += 1;
53        }
54
55        len += self.payload.len();
56        len
57    }
58
59    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Publish, Error> {
60        let qos_num = (fixed_header.byte1 & 0b0110) >> 1;
61        let qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
62        let dup = (fixed_header.byte1 & 0b1000) != 0;
63        let retain = (fixed_header.byte1 & 0b0001) != 0;
64
65        let variable_header_index = fixed_header.fixed_header_len;
66        bytes.advance(variable_header_index);
67        let topic = read_mqtt_bytes(&mut bytes)?;
68
69        // Packet identifier exists where QoS > 0
70        let pkid = match qos {
71            QoS::AtMostOnce => 0,
72            QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?,
73        };
74
75        if qos != QoS::AtMostOnce && pkid == 0 {
76            return Err(Error::PacketIdZero);
77        }
78
79        let properties = PublishProperties::read(&mut bytes)?;
80        let publish = Publish {
81            dup,
82            retain,
83            qos,
84            pkid,
85            topic,
86            payload: bytes,
87            properties,
88        };
89
90        Ok(publish)
91    }
92
93    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
94        let len = self.len();
95
96        let dup = self.dup as u8;
97        let qos = self.qos as u8;
98        let retain = self.retain as u8;
99        buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3);
100
101        let count = write_remaining_length(buffer, len)?;
102        write_mqtt_bytes(buffer, &self.topic);
103
104        if self.qos != QoS::AtMostOnce {
105            let pkid = self.pkid;
106            if pkid == 0 {
107                return Err(Error::PacketIdZero);
108            }
109
110            buffer.put_u16(pkid);
111        }
112
113        if let Some(p) = &self.properties {
114            p.write(buffer)?;
115        } else {
116            write_remaining_length(buffer, 0)?;
117        }
118
119        buffer.extend_from_slice(&self.payload);
120
121        Ok(1 + count + len)
122    }
123}
124
125#[derive(Debug, Clone, PartialEq, Eq, Default)]
126pub struct PublishProperties {
127    pub payload_format_indicator: Option<u8>,
128    pub message_expiry_interval: Option<u32>,
129    pub topic_alias: Option<u16>,
130    pub response_topic: Option<String>,
131    pub correlation_data: Option<Bytes>,
132    pub user_properties: Vec<(String, String)>,
133    pub subscription_identifiers: Vec<usize>,
134    pub content_type: Option<String>,
135}
136
137impl PublishProperties {
138    fn len(&self) -> usize {
139        let mut len = 0;
140
141        if self.payload_format_indicator.is_some() {
142            len += 1 + 1;
143        }
144
145        if self.message_expiry_interval.is_some() {
146            len += 1 + 4;
147        }
148
149        if self.topic_alias.is_some() {
150            len += 1 + 2;
151        }
152
153        if let Some(topic) = &self.response_topic {
154            len += 1 + 2 + topic.len()
155        }
156
157        if let Some(data) = &self.correlation_data {
158            len += 1 + 2 + data.len()
159        }
160
161        for (key, value) in self.user_properties.iter() {
162            len += 1 + 2 + key.len() + 2 + value.len();
163        }
164
165        for id in self.subscription_identifiers.iter() {
166            len += 1 + len_len(*id);
167        }
168
169        if let Some(typ) = &self.content_type {
170            len += 1 + 2 + typ.len()
171        }
172
173        len
174    }
175
176    pub fn read(bytes: &mut Bytes) -> Result<Option<PublishProperties>, Error> {
177        let mut payload_format_indicator = None;
178        let mut message_expiry_interval = None;
179        let mut topic_alias = None;
180        let mut response_topic = None;
181        let mut correlation_data = None;
182        let mut user_properties = Vec::new();
183        let mut subscription_identifiers = Vec::new();
184        let mut content_type = None;
185
186        let (properties_len_len, properties_len) = length(bytes.iter())?;
187        bytes.advance(properties_len_len);
188        if properties_len == 0 {
189            return Ok(None);
190        }
191
192        let mut cursor = 0;
193        // read until cursor reaches property length. properties_len = 0 will skip this loop
194        while cursor < properties_len {
195            let prop = read_u8(bytes)?;
196            cursor += 1;
197
198            match property(prop)? {
199                PropertyType::PayloadFormatIndicator => {
200                    payload_format_indicator = Some(read_u8(bytes)?);
201                    cursor += 1;
202                }
203                PropertyType::MessageExpiryInterval => {
204                    message_expiry_interval = Some(read_u32(bytes)?);
205                    cursor += 4;
206                }
207                PropertyType::TopicAlias => {
208                    topic_alias = Some(read_u16(bytes)?);
209                    cursor += 2;
210                }
211                PropertyType::ResponseTopic => {
212                    let topic = read_mqtt_string(bytes)?;
213                    cursor += 2 + topic.len();
214                    response_topic = Some(topic);
215                }
216                PropertyType::CorrelationData => {
217                    let data = read_mqtt_bytes(bytes)?;
218                    cursor += 2 + data.len();
219                    correlation_data = Some(data);
220                }
221                PropertyType::UserProperty => {
222                    let key = read_mqtt_string(bytes)?;
223                    let value = read_mqtt_string(bytes)?;
224                    cursor += 2 + key.len() + 2 + value.len();
225                    user_properties.push((key, value));
226                }
227                PropertyType::SubscriptionIdentifier => {
228                    let (id_len, id) = length(bytes.iter())?;
229                    cursor += 1 + id_len;
230                    bytes.advance(id_len);
231                    subscription_identifiers.push(id);
232                }
233                PropertyType::ContentType => {
234                    let typ = read_mqtt_string(bytes)?;
235                    cursor += 2 + typ.len();
236                    content_type = Some(typ);
237                }
238                _ => return Err(Error::InvalidPropertyType(prop)),
239            }
240        }
241
242        Ok(Some(PublishProperties {
243            payload_format_indicator,
244            message_expiry_interval,
245            topic_alias,
246            response_topic,
247            correlation_data,
248            user_properties,
249            subscription_identifiers,
250            content_type,
251        }))
252    }
253
254    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
255        let len = self.len();
256        write_remaining_length(buffer, len)?;
257
258        if let Some(payload_format_indicator) = self.payload_format_indicator {
259            buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
260            buffer.put_u8(payload_format_indicator);
261        }
262
263        if let Some(message_expiry_interval) = self.message_expiry_interval {
264            buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
265            buffer.put_u32(message_expiry_interval);
266        }
267
268        if let Some(topic_alias) = self.topic_alias {
269            buffer.put_u8(PropertyType::TopicAlias as u8);
270            buffer.put_u16(topic_alias);
271        }
272
273        if let Some(topic) = &self.response_topic {
274            buffer.put_u8(PropertyType::ResponseTopic as u8);
275            write_mqtt_string(buffer, topic);
276        }
277
278        if let Some(data) = &self.correlation_data {
279            buffer.put_u8(PropertyType::CorrelationData as u8);
280            write_mqtt_bytes(buffer, data);
281        }
282
283        for (key, value) in self.user_properties.iter() {
284            buffer.put_u8(PropertyType::UserProperty as u8);
285            write_mqtt_string(buffer, key);
286            write_mqtt_string(buffer, value);
287        }
288
289        for id in self.subscription_identifiers.iter() {
290            buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
291            write_remaining_length(buffer, *id)?;
292        }
293
294        if let Some(typ) = &self.content_type {
295            buffer.put_u8(PropertyType::ContentType as u8);
296            write_mqtt_string(buffer, typ);
297        }
298
299        Ok(())
300    }
301}
302
303#[cfg(test)]
304mod test {
305    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
306    use super::*;
307    use bytes::BytesMut;
308    use pretty_assertions::assert_eq;
309
310    #[test]
311    fn length_calculation() {
312        let mut dummy_bytes = BytesMut::new();
313        // Use user_properties to pad the size to exceed ~128 bytes to make the
314        // remaining_length field in the packet be 2 bytes long.
315        let publish_props = PublishProperties {
316            payload_format_indicator: None,
317            message_expiry_interval: None,
318            topic_alias: None,
319            response_topic: None,
320            correlation_data: None,
321            user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
322            subscription_identifiers: vec![1],
323            content_type: None,
324        };
325
326        let publish_pkt = Publish::new(
327            "hello/world",
328            QoS::AtMostOnce,
329            vec![1; 10],
330            Some(publish_props),
331        );
332
333        let size_from_size = publish_pkt.size();
334        let size_from_write = publish_pkt.write(&mut dummy_bytes).unwrap();
335        let size_from_bytes = dummy_bytes.len();
336
337        assert_eq!(size_from_write, size_from_bytes);
338        assert_eq!(size_from_size, size_from_bytes);
339    }
340}