mqtt5_protocol/packet/
publish.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::PublishFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bytes::{Buf, BufMut};
9
10/// MQTT PUBLISH packet
11#[derive(Debug, Clone)]
12pub struct PublishPacket {
13    /// Topic name
14    pub topic_name: String,
15    /// Packet identifier (required for `QoS` > 0)
16    pub packet_id: Option<u16>,
17    /// Message payload
18    pub payload: Vec<u8>,
19    /// Quality of Service level
20    pub qos: QoS,
21    /// Retain flag
22    pub retain: bool,
23    /// Duplicate delivery flag
24    pub dup: bool,
25    /// PUBLISH properties (v5.0 only)
26    pub properties: Properties,
27    /// Protocol version (4 = v3.1.1, 5 = v5.0)
28    pub protocol_version: u8,
29}
30
31impl PublishPacket {
32    /// Creates a new PUBLISH packet (v5.0)
33    #[must_use]
34    pub fn new(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
35        let packet_id = if qos == QoS::AtMostOnce {
36            None
37        } else {
38            Some(0)
39        };
40
41        Self {
42            topic_name: topic_name.into(),
43            packet_id,
44            payload: payload.into(),
45            qos,
46            retain: false,
47            dup: false,
48            properties: Properties::default(),
49            protocol_version: 5,
50        }
51    }
52
53    /// Creates a new PUBLISH packet for v3.1.1
54    #[must_use]
55    pub fn new_v311(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
56        let packet_id = if qos == QoS::AtMostOnce {
57            None
58        } else {
59            Some(0)
60        };
61
62        Self {
63            topic_name: topic_name.into(),
64            packet_id,
65            payload: payload.into(),
66            qos,
67            retain: false,
68            dup: false,
69            properties: Properties::default(),
70            protocol_version: 4,
71        }
72    }
73
74    /// Sets the packet identifier
75    #[must_use]
76    pub fn with_packet_id(mut self, id: u16) -> Self {
77        if self.qos != QoS::AtMostOnce {
78            self.packet_id = Some(id);
79        }
80        self
81    }
82
83    /// Sets the retain flag
84    #[must_use]
85    pub fn with_retain(mut self, retain: bool) -> Self {
86        self.retain = retain;
87        self
88    }
89
90    /// Sets the duplicate flag
91    #[must_use]
92    pub fn with_dup(mut self, dup: bool) -> Self {
93        self.dup = dup;
94        self
95    }
96
97    /// Sets the payload format indicator
98    #[must_use]
99    pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
100        self.properties.set_payload_format_indicator(is_utf8);
101        self
102    }
103
104    /// Sets the message expiry interval
105    #[must_use]
106    pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
107        self.properties.set_message_expiry_interval(seconds);
108        self
109    }
110
111    /// Sets the topic alias
112    #[must_use]
113    pub fn with_topic_alias(mut self, alias: u16) -> Self {
114        self.properties.set_topic_alias(alias);
115        self
116    }
117
118    /// Sets the response topic
119    #[must_use]
120    pub fn with_response_topic(mut self, topic: String) -> Self {
121        self.properties.set_response_topic(topic);
122        self
123    }
124
125    /// Sets the correlation data
126    #[must_use]
127    pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
128        self.properties.set_correlation_data(data.into());
129        self
130    }
131
132    /// Adds a user property
133    #[must_use]
134    pub fn with_user_property(mut self, key: String, value: String) -> Self {
135        self.properties.add_user_property(key, value);
136        self
137    }
138
139    /// Adds a subscription identifier
140    #[must_use]
141    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
142        self.properties.set_subscription_identifier(id);
143        self
144    }
145
146    /// Sets the content type
147    #[must_use]
148    pub fn with_content_type(mut self, content_type: String) -> Self {
149        self.properties.set_content_type(content_type);
150        self
151    }
152
153    #[must_use]
154    /// Gets the topic alias from properties
155    pub fn topic_alias(&self) -> Option<u16> {
156        self.properties
157            .get(PropertyId::TopicAlias)
158            .and_then(|prop| {
159                if let PropertyValue::TwoByteInteger(alias) = prop {
160                    Some(*alias)
161                } else {
162                    None
163                }
164            })
165    }
166
167    #[must_use]
168    /// Gets the message expiry interval from properties
169    pub fn message_expiry_interval(&self) -> Option<u32> {
170        self.properties
171            .get(PropertyId::MessageExpiryInterval)
172            .and_then(|prop| {
173                if let PropertyValue::FourByteInteger(interval) = prop {
174                    Some(*interval)
175                } else {
176                    None
177                }
178            })
179    }
180}
181
182impl MqttPacket for PublishPacket {
183    fn packet_type(&self) -> PacketType {
184        PacketType::Publish
185    }
186
187    fn flags(&self) -> u8 {
188        let mut flags = 0u8;
189
190        if self.dup {
191            flags |= PublishFlags::Dup as u8;
192        }
193
194        flags = PublishFlags::with_qos(flags, self.qos as u8);
195
196        if self.retain {
197            flags |= PublishFlags::Retain as u8;
198        }
199
200        flags
201    }
202
203    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
204        encode_string(buf, &self.topic_name)?;
205
206        if self.qos != QoS::AtMostOnce {
207            let packet_id = self.packet_id.ok_or_else(|| {
208                MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
209            })?;
210            buf.put_u16(packet_id);
211        }
212
213        if self.protocol_version == 5 {
214            self.properties.encode(buf)?;
215        }
216
217        buf.put_slice(&self.payload);
218
219        Ok(())
220    }
221
222    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
223        Self::decode_body_with_version(buf, fixed_header, 5)
224    }
225}
226
227impl PublishPacket {
228    /// Decodes the packet body with a specific protocol version
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if decoding fails
233    pub fn decode_body_with_version<B: Buf>(
234        buf: &mut B,
235        fixed_header: &FixedHeader,
236        protocol_version: u8,
237    ) -> Result<Self> {
238        ProtocolVersion::try_from(protocol_version)
239            .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
240
241        let flags = PublishFlags::decompose(fixed_header.flags);
242        let dup = flags.contains(&PublishFlags::Dup);
243        let qos_val = PublishFlags::extract_qos(fixed_header.flags);
244        let retain = flags.contains(&PublishFlags::Retain);
245
246        let qos = match qos_val {
247            0 => QoS::AtMostOnce,
248            1 => QoS::AtLeastOnce,
249            2 => QoS::ExactlyOnce,
250            _ => {
251                return Err(MqttError::InvalidQoS(qos_val));
252            }
253        };
254
255        let topic_name = decode_string(buf)?;
256
257        let packet_id = if qos == QoS::AtMostOnce {
258            None
259        } else {
260            if buf.remaining() < 2 {
261                return Err(MqttError::MalformedPacket(
262                    "Missing packet identifier".to_string(),
263                ));
264            }
265            Some(buf.get_u16())
266        };
267
268        let properties = if protocol_version == 5 {
269            Properties::decode(buf)?
270        } else {
271            Properties::default()
272        };
273
274        let payload = buf.copy_to_bytes(buf.remaining()).to_vec();
275
276        Ok(Self {
277            topic_name,
278            packet_id,
279            payload,
280            qos,
281            retain,
282            dup,
283            properties,
284            protocol_version,
285        })
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use bytes::BytesMut;
293
294    #[test]
295    fn test_publish_packet_qos0() {
296        let packet = PublishPacket::new("test/topic", b"Hello, MQTT!", QoS::AtMostOnce);
297
298        assert_eq!(packet.topic_name, "test/topic");
299        assert_eq!(packet.payload, b"Hello, MQTT!");
300        assert_eq!(packet.qos, QoS::AtMostOnce);
301        assert!(packet.packet_id.is_none());
302        assert!(!packet.retain);
303        assert!(!packet.dup);
304    }
305
306    #[test]
307    fn test_publish_packet_qos1() {
308        let packet =
309            PublishPacket::new("test/topic", b"Hello", QoS::AtLeastOnce).with_packet_id(123);
310
311        assert_eq!(packet.qos, QoS::AtLeastOnce);
312        assert_eq!(packet.packet_id, Some(123));
313    }
314
315    #[test]
316    fn test_publish_packet_with_properties() {
317        let packet = PublishPacket::new("test/topic", b"data", QoS::AtMostOnce)
318            .with_retain(true)
319            .with_payload_format_indicator(true)
320            .with_message_expiry_interval(3600)
321            .with_response_topic("response/topic".to_string())
322            .with_user_property("key".to_string(), "value".to_string());
323
324        assert!(packet.retain);
325        assert!(packet
326            .properties
327            .contains(PropertyId::PayloadFormatIndicator));
328        assert!(packet
329            .properties
330            .contains(PropertyId::MessageExpiryInterval));
331        assert!(packet.properties.contains(PropertyId::ResponseTopic));
332        assert!(packet.properties.contains(PropertyId::UserProperty));
333    }
334
335    #[test]
336    fn test_publish_flags() {
337        let packet = PublishPacket::new("topic", b"data", QoS::AtMostOnce);
338        assert_eq!(packet.flags(), 0x00);
339
340        let packet = PublishPacket::new("topic", b"data", QoS::AtLeastOnce).with_retain(true);
341        assert_eq!(packet.flags(), 0x03); // QoS 1 + Retain
342
343        let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce).with_dup(true);
344        assert_eq!(packet.flags(), 0x0C); // DUP + QoS 2
345
346        let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce)
347            .with_dup(true)
348            .with_retain(true);
349        assert_eq!(packet.flags(), 0x0D); // DUP + QoS 2 + Retain
350    }
351
352    #[test]
353    fn test_publish_encode_decode_qos0() {
354        let packet =
355            PublishPacket::new("sensor/temperature", b"23.5", QoS::AtMostOnce).with_retain(true);
356
357        let mut buf = BytesMut::new();
358        packet.encode(&mut buf).unwrap();
359
360        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
361        assert_eq!(fixed_header.packet_type, PacketType::Publish);
362        assert_eq!(
363            fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
364            crate::flags::PublishFlags::Retain as u8
365        ); // Retain flag
366
367        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
368        assert_eq!(decoded.topic_name, "sensor/temperature");
369        assert_eq!(decoded.payload, b"23.5");
370        assert_eq!(decoded.qos, QoS::AtMostOnce);
371        assert!(decoded.retain);
372        assert!(decoded.packet_id.is_none());
373    }
374
375    #[test]
376    fn test_publish_encode_decode_qos1() {
377        let packet =
378            PublishPacket::new("test/qos1", b"QoS 1 message", QoS::AtLeastOnce).with_packet_id(456);
379
380        let mut buf = BytesMut::new();
381        packet.encode(&mut buf).unwrap();
382
383        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
384        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
385
386        assert_eq!(decoded.topic_name, "test/qos1");
387        assert_eq!(decoded.payload, b"QoS 1 message");
388        assert_eq!(decoded.qos, QoS::AtLeastOnce);
389        assert_eq!(decoded.packet_id, Some(456));
390    }
391
392    #[test]
393    fn test_publish_encode_decode_with_properties() {
394        let packet = PublishPacket::new("test/props", b"data", QoS::ExactlyOnce)
395            .with_packet_id(789)
396            .with_message_expiry_interval(7200)
397            .with_content_type("application/json".to_string());
398
399        let mut buf = BytesMut::new();
400        packet.encode(&mut buf).unwrap();
401
402        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
403        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
404
405        assert_eq!(decoded.qos, QoS::ExactlyOnce);
406        assert_eq!(decoded.packet_id, Some(789));
407
408        let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
409        assert!(expiry.is_some());
410        if let Some(PropertyValue::FourByteInteger(val)) = expiry {
411            assert_eq!(*val, 7200);
412        }
413
414        let content_type = decoded.properties.get(PropertyId::ContentType);
415        assert!(content_type.is_some());
416        if let Some(PropertyValue::Utf8String(val)) = content_type {
417            assert_eq!(val, "application/json");
418        }
419    }
420
421    #[test]
422    fn test_publish_missing_packet_id() {
423        let mut buf = BytesMut::new();
424        encode_string(&mut buf, "topic").unwrap();
425        // No packet ID for QoS > 0 - buffer ends here
426
427        let fixed_header =
428            FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); // QoS 1
429        let result = PublishPacket::decode_body(&mut buf, &fixed_header);
430        assert!(result.is_err());
431    }
432
433    #[test]
434    fn test_publish_invalid_qos() {
435        let mut buf = BytesMut::new();
436        encode_string(&mut buf, "topic").unwrap();
437
438        let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); // Invalid QoS 3
439        let result = PublishPacket::decode_body(&mut buf, &fixed_header);
440        assert!(result.is_err());
441    }
442}