mqtt5_protocol/packet/
publish.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::PublishFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bytes::{Buf, BufMut};
9
10/// MQTT PUBLISH packet
11#[derive(Debug, Clone)]
12pub struct PublishPacket {
13    /// Topic name
14    pub topic_name: String,
15    /// Packet identifier (required for `QoS` > 0)
16    pub packet_id: Option<u16>,
17    /// Message payload
18    pub payload: Vec<u8>,
19    /// Quality of Service level
20    pub qos: QoS,
21    /// Retain flag
22    pub retain: bool,
23    /// Duplicate delivery flag
24    pub dup: bool,
25    /// PUBLISH properties (v5.0 only)
26    pub properties: Properties,
27    /// Protocol version (4 = v3.1.1, 5 = v5.0)
28    pub protocol_version: u8,
29}
30
31impl PublishPacket {
32    /// Creates a new PUBLISH packet (v5.0)
33    #[must_use]
34    pub fn new(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
35        let packet_id = if qos == QoS::AtMostOnce {
36            None
37        } else {
38            Some(0)
39        };
40
41        Self {
42            topic_name: topic_name.into(),
43            packet_id,
44            payload: payload.into(),
45            qos,
46            retain: false,
47            dup: false,
48            properties: Properties::default(),
49            protocol_version: 5,
50        }
51    }
52
53    /// Creates a new PUBLISH packet for v3.1.1
54    #[must_use]
55    pub fn new_v311(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
56        let packet_id = if qos == QoS::AtMostOnce {
57            None
58        } else {
59            Some(0)
60        };
61
62        Self {
63            topic_name: topic_name.into(),
64            packet_id,
65            payload: payload.into(),
66            qos,
67            retain: false,
68            dup: false,
69            properties: Properties::default(),
70            protocol_version: 4,
71        }
72    }
73
74    /// Sets the packet identifier
75    #[must_use]
76    pub fn with_packet_id(mut self, id: u16) -> Self {
77        if self.qos != QoS::AtMostOnce {
78            self.packet_id = Some(id);
79        }
80        self
81    }
82
83    /// Sets the retain flag
84    #[must_use]
85    pub fn with_retain(mut self, retain: bool) -> Self {
86        self.retain = retain;
87        self
88    }
89
90    /// Sets the duplicate flag
91    #[must_use]
92    pub fn with_dup(mut self, dup: bool) -> Self {
93        self.dup = dup;
94        self
95    }
96
97    /// Sets the payload format indicator
98    #[must_use]
99    pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
100        self.properties.set_payload_format_indicator(is_utf8);
101        self
102    }
103
104    /// Sets the message expiry interval
105    #[must_use]
106    pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
107        self.properties.set_message_expiry_interval(seconds);
108        self
109    }
110
111    /// Sets the topic alias
112    #[must_use]
113    pub fn with_topic_alias(mut self, alias: u16) -> Self {
114        self.properties.set_topic_alias(alias);
115        self
116    }
117
118    /// Sets the response topic
119    #[must_use]
120    pub fn with_response_topic(mut self, topic: String) -> Self {
121        self.properties.set_response_topic(topic);
122        self
123    }
124
125    /// Sets the correlation data
126    #[must_use]
127    pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
128        self.properties.set_correlation_data(data.into());
129        self
130    }
131
132    /// Adds a user property
133    #[must_use]
134    pub fn with_user_property(mut self, key: String, value: String) -> Self {
135        self.properties.add_user_property(key, value);
136        self
137    }
138
139    /// Adds a subscription identifier
140    #[must_use]
141    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
142        self.properties.set_subscription_identifier(id);
143        self
144    }
145
146    /// Sets the content type
147    #[must_use]
148    pub fn with_content_type(mut self, content_type: String) -> Self {
149        self.properties.set_content_type(content_type);
150        self
151    }
152
153    #[must_use]
154    /// Gets the topic alias from properties
155    pub fn topic_alias(&self) -> Option<u16> {
156        self.properties
157            .get(PropertyId::TopicAlias)
158            .and_then(|prop| {
159                if let PropertyValue::TwoByteInteger(alias) = prop {
160                    Some(*alias)
161                } else {
162                    None
163                }
164            })
165    }
166
167    #[must_use]
168    /// Gets the message expiry interval from properties
169    pub fn message_expiry_interval(&self) -> Option<u32> {
170        self.properties
171            .get(PropertyId::MessageExpiryInterval)
172            .and_then(|prop| {
173                if let PropertyValue::FourByteInteger(interval) = prop {
174                    Some(*interval)
175                } else {
176                    None
177                }
178            })
179    }
180
181    #[must_use]
182    pub fn body_encoded_size(&self) -> usize {
183        let mut size = 2 + self.topic_name.len();
184
185        if self.qos != QoS::AtMostOnce {
186            size += 2;
187        }
188
189        if self.protocol_version == 5 {
190            size += self.properties.encoded_len();
191        }
192
193        size += self.payload.len();
194        size
195    }
196
197    /// # Errors
198    /// Returns error if encoding fails
199    pub fn encode_body_direct<B: BufMut>(&self, buf: &mut B) -> Result<()> {
200        encode_string(buf, &self.topic_name)?;
201
202        if self.qos != QoS::AtMostOnce {
203            let packet_id = self.packet_id.ok_or_else(|| {
204                MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
205            })?;
206            buf.put_u16(packet_id);
207        }
208
209        if self.protocol_version == 5 {
210            self.properties.encode_direct(buf)?;
211        }
212
213        buf.put_slice(&self.payload);
214
215        Ok(())
216    }
217}
218
219impl MqttPacket for PublishPacket {
220    fn packet_type(&self) -> PacketType {
221        PacketType::Publish
222    }
223
224    fn flags(&self) -> u8 {
225        let mut flags = 0u8;
226
227        if self.dup {
228            flags |= PublishFlags::Dup as u8;
229        }
230
231        flags = PublishFlags::with_qos(flags, self.qos as u8);
232
233        if self.retain {
234            flags |= PublishFlags::Retain as u8;
235        }
236
237        flags
238    }
239
240    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
241        encode_string(buf, &self.topic_name)?;
242
243        if self.qos != QoS::AtMostOnce {
244            let packet_id = self.packet_id.ok_or_else(|| {
245                MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
246            })?;
247            buf.put_u16(packet_id);
248        }
249
250        if self.protocol_version == 5 {
251            self.properties.encode(buf)?;
252        }
253
254        buf.put_slice(&self.payload);
255
256        Ok(())
257    }
258
259    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
260        Self::decode_body_with_version(buf, fixed_header, 5)
261    }
262}
263
264impl PublishPacket {
265    /// Decodes the packet body with a specific protocol version
266    ///
267    /// # Errors
268    ///
269    /// Returns an error if decoding fails
270    pub fn decode_body_with_version<B: Buf>(
271        buf: &mut B,
272        fixed_header: &FixedHeader,
273        protocol_version: u8,
274    ) -> Result<Self> {
275        ProtocolVersion::try_from(protocol_version)
276            .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
277
278        let flags = PublishFlags::decompose(fixed_header.flags);
279        let dup = flags.contains(&PublishFlags::Dup);
280        let qos_val = PublishFlags::extract_qos(fixed_header.flags);
281        let retain = flags.contains(&PublishFlags::Retain);
282
283        let qos = match qos_val {
284            0 => QoS::AtMostOnce,
285            1 => QoS::AtLeastOnce,
286            2 => QoS::ExactlyOnce,
287            _ => {
288                return Err(MqttError::InvalidQoS(qos_val));
289            }
290        };
291
292        let topic_name = decode_string(buf)?;
293
294        let packet_id = if qos == QoS::AtMostOnce {
295            None
296        } else {
297            if buf.remaining() < 2 {
298                return Err(MqttError::MalformedPacket(
299                    "Missing packet identifier".to_string(),
300                ));
301            }
302            Some(buf.get_u16())
303        };
304
305        let properties = if protocol_version == 5 {
306            Properties::decode(buf)?
307        } else {
308            Properties::default()
309        };
310
311        let payload = buf.copy_to_bytes(buf.remaining()).to_vec();
312
313        Ok(Self {
314            topic_name,
315            packet_id,
316            payload,
317            qos,
318            retain,
319            dup,
320            properties,
321            protocol_version,
322        })
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use bytes::BytesMut;
330
331    #[test]
332    fn test_publish_packet_qos0() {
333        let packet = PublishPacket::new("test/topic", b"Hello, MQTT!", QoS::AtMostOnce);
334
335        assert_eq!(packet.topic_name, "test/topic");
336        assert_eq!(packet.payload, b"Hello, MQTT!");
337        assert_eq!(packet.qos, QoS::AtMostOnce);
338        assert!(packet.packet_id.is_none());
339        assert!(!packet.retain);
340        assert!(!packet.dup);
341    }
342
343    #[test]
344    fn test_publish_packet_qos1() {
345        let packet =
346            PublishPacket::new("test/topic", b"Hello", QoS::AtLeastOnce).with_packet_id(123);
347
348        assert_eq!(packet.qos, QoS::AtLeastOnce);
349        assert_eq!(packet.packet_id, Some(123));
350    }
351
352    #[test]
353    fn test_publish_packet_with_properties() {
354        let packet = PublishPacket::new("test/topic", b"data", QoS::AtMostOnce)
355            .with_retain(true)
356            .with_payload_format_indicator(true)
357            .with_message_expiry_interval(3600)
358            .with_response_topic("response/topic".to_string())
359            .with_user_property("key".to_string(), "value".to_string());
360
361        assert!(packet.retain);
362        assert!(packet
363            .properties
364            .contains(PropertyId::PayloadFormatIndicator));
365        assert!(packet
366            .properties
367            .contains(PropertyId::MessageExpiryInterval));
368        assert!(packet.properties.contains(PropertyId::ResponseTopic));
369        assert!(packet.properties.contains(PropertyId::UserProperty));
370    }
371
372    #[test]
373    fn test_publish_flags() {
374        let packet = PublishPacket::new("topic", b"data", QoS::AtMostOnce);
375        assert_eq!(packet.flags(), 0x00);
376
377        let packet = PublishPacket::new("topic", b"data", QoS::AtLeastOnce).with_retain(true);
378        assert_eq!(packet.flags(), 0x03); // QoS 1 + Retain
379
380        let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce).with_dup(true);
381        assert_eq!(packet.flags(), 0x0C); // DUP + QoS 2
382
383        let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce)
384            .with_dup(true)
385            .with_retain(true);
386        assert_eq!(packet.flags(), 0x0D); // DUP + QoS 2 + Retain
387    }
388
389    #[test]
390    fn test_publish_encode_decode_qos0() {
391        let packet =
392            PublishPacket::new("sensor/temperature", b"23.5", QoS::AtMostOnce).with_retain(true);
393
394        let mut buf = BytesMut::new();
395        packet.encode(&mut buf).unwrap();
396
397        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
398        assert_eq!(fixed_header.packet_type, PacketType::Publish);
399        assert_eq!(
400            fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
401            crate::flags::PublishFlags::Retain as u8
402        ); // Retain flag
403
404        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
405        assert_eq!(decoded.topic_name, "sensor/temperature");
406        assert_eq!(decoded.payload, b"23.5");
407        assert_eq!(decoded.qos, QoS::AtMostOnce);
408        assert!(decoded.retain);
409        assert!(decoded.packet_id.is_none());
410    }
411
412    #[test]
413    fn test_publish_encode_decode_qos1() {
414        let packet =
415            PublishPacket::new("test/qos1", b"QoS 1 message", QoS::AtLeastOnce).with_packet_id(456);
416
417        let mut buf = BytesMut::new();
418        packet.encode(&mut buf).unwrap();
419
420        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
421        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
422
423        assert_eq!(decoded.topic_name, "test/qos1");
424        assert_eq!(decoded.payload, b"QoS 1 message");
425        assert_eq!(decoded.qos, QoS::AtLeastOnce);
426        assert_eq!(decoded.packet_id, Some(456));
427    }
428
429    #[test]
430    fn test_publish_encode_decode_with_properties() {
431        let packet = PublishPacket::new("test/props", b"data", QoS::ExactlyOnce)
432            .with_packet_id(789)
433            .with_message_expiry_interval(7200)
434            .with_content_type("application/json".to_string());
435
436        let mut buf = BytesMut::new();
437        packet.encode(&mut buf).unwrap();
438
439        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
440        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
441
442        assert_eq!(decoded.qos, QoS::ExactlyOnce);
443        assert_eq!(decoded.packet_id, Some(789));
444
445        let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
446        assert!(expiry.is_some());
447        if let Some(PropertyValue::FourByteInteger(val)) = expiry {
448            assert_eq!(*val, 7200);
449        }
450
451        let content_type = decoded.properties.get(PropertyId::ContentType);
452        assert!(content_type.is_some());
453        if let Some(PropertyValue::Utf8String(val)) = content_type {
454            assert_eq!(val, "application/json");
455        }
456    }
457
458    #[test]
459    fn test_publish_missing_packet_id() {
460        let mut buf = BytesMut::new();
461        encode_string(&mut buf, "topic").unwrap();
462        // No packet ID for QoS > 0 - buffer ends here
463
464        let fixed_header =
465            FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); // QoS 1
466        let result = PublishPacket::decode_body(&mut buf, &fixed_header);
467        assert!(result.is_err());
468    }
469
470    #[test]
471    fn test_publish_invalid_qos() {
472        let mut buf = BytesMut::new();
473        encode_string(&mut buf, "topic").unwrap();
474
475        let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); // Invalid QoS 3
476        let result = PublishPacket::decode_body(&mut buf, &fixed_header);
477        assert!(result.is_err());
478    }
479}