mqtt5_protocol/packet/
publish.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::PublishFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::QoS;
7use bytes::{Buf, BufMut};
8
9/// MQTT PUBLISH packet
10#[derive(Debug, Clone)]
11pub struct PublishPacket {
12    /// Topic name
13    pub topic_name: String,
14    /// Packet identifier (required for `QoS` > 0)
15    pub packet_id: Option<u16>,
16    /// Message payload
17    pub payload: Vec<u8>,
18    /// Quality of Service level
19    pub qos: QoS,
20    /// Retain flag
21    pub retain: bool,
22    /// Duplicate delivery flag
23    pub dup: bool,
24    /// PUBLISH properties (v5.0 only)
25    pub properties: Properties,
26}
27
28impl PublishPacket {
29    /// Creates a new PUBLISH packet
30    #[must_use]
31    pub fn new(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
32        let packet_id = if qos == QoS::AtMostOnce {
33            None
34        } else {
35            Some(0) // Will be assigned by the client
36        };
37
38        Self {
39            topic_name: topic_name.into(),
40            packet_id,
41            payload: payload.into(),
42            qos,
43            retain: false,
44            dup: false,
45            properties: Properties::default(),
46        }
47    }
48
49    /// Sets the packet identifier
50    #[must_use]
51    pub fn with_packet_id(mut self, id: u16) -> Self {
52        if self.qos != QoS::AtMostOnce {
53            self.packet_id = Some(id);
54        }
55        self
56    }
57
58    /// Sets the retain flag
59    #[must_use]
60    pub fn with_retain(mut self, retain: bool) -> Self {
61        self.retain = retain;
62        self
63    }
64
65    /// Sets the duplicate flag
66    #[must_use]
67    pub fn with_dup(mut self, dup: bool) -> Self {
68        self.dup = dup;
69        self
70    }
71
72    /// Sets the payload format indicator
73    #[must_use]
74    pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
75        self.properties.set_payload_format_indicator(is_utf8);
76        self
77    }
78
79    /// Sets the message expiry interval
80    #[must_use]
81    pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
82        self.properties.set_message_expiry_interval(seconds);
83        self
84    }
85
86    /// Sets the topic alias
87    #[must_use]
88    pub fn with_topic_alias(mut self, alias: u16) -> Self {
89        self.properties.set_topic_alias(alias);
90        self
91    }
92
93    /// Sets the response topic
94    #[must_use]
95    pub fn with_response_topic(mut self, topic: String) -> Self {
96        self.properties.set_response_topic(topic);
97        self
98    }
99
100    /// Sets the correlation data
101    #[must_use]
102    pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
103        self.properties.set_correlation_data(data.into());
104        self
105    }
106
107    /// Adds a user property
108    #[must_use]
109    pub fn with_user_property(mut self, key: String, value: String) -> Self {
110        self.properties.add_user_property(key, value);
111        self
112    }
113
114    /// Adds a subscription identifier
115    #[must_use]
116    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
117        self.properties.set_subscription_identifier(id);
118        self
119    }
120
121    /// Sets the content type
122    #[must_use]
123    pub fn with_content_type(mut self, content_type: String) -> Self {
124        self.properties.set_content_type(content_type);
125        self
126    }
127
128    #[must_use]
129    /// Gets the topic alias from properties
130    pub fn topic_alias(&self) -> Option<u16> {
131        self.properties
132            .get(PropertyId::TopicAlias)
133            .and_then(|prop| {
134                if let PropertyValue::TwoByteInteger(alias) = prop {
135                    Some(*alias)
136                } else {
137                    None
138                }
139            })
140    }
141
142    #[must_use]
143    /// Gets the message expiry interval from properties
144    pub fn message_expiry_interval(&self) -> Option<u32> {
145        self.properties
146            .get(PropertyId::MessageExpiryInterval)
147            .and_then(|prop| {
148                if let PropertyValue::FourByteInteger(interval) = prop {
149                    Some(*interval)
150                } else {
151                    None
152                }
153            })
154    }
155}
156
157impl MqttPacket for PublishPacket {
158    fn packet_type(&self) -> PacketType {
159        PacketType::Publish
160    }
161
162    fn flags(&self) -> u8 {
163        let mut flags = 0u8;
164
165        if self.dup {
166            flags |= PublishFlags::Dup as u8;
167        }
168
169        flags = PublishFlags::with_qos(flags, self.qos as u8);
170
171        if self.retain {
172            flags |= PublishFlags::Retain as u8;
173        }
174
175        flags
176    }
177
178    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
179        // Variable header
180        encode_string(buf, &self.topic_name)?;
181
182        // Packet identifier (only for QoS > 0)
183        if self.qos != QoS::AtMostOnce {
184            let packet_id = self.packet_id.ok_or_else(|| {
185                MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
186            })?;
187            buf.put_u16(packet_id);
188        }
189
190        // Properties (v5.0 - always encode for v5.0)
191        // For v3.1.1 compatibility, we'd skip this
192        self.properties.encode(buf)?;
193
194        // Payload
195        buf.put_slice(&self.payload);
196
197        Ok(())
198    }
199
200    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
201        // Parse flags using BeBytes decomposition
202        let flags = PublishFlags::decompose(fixed_header.flags);
203        let dup = flags.contains(&PublishFlags::Dup);
204        let qos_val = PublishFlags::extract_qos(fixed_header.flags);
205        let retain = flags.contains(&PublishFlags::Retain);
206
207        let qos = match qos_val {
208            0 => QoS::AtMostOnce,
209            1 => QoS::AtLeastOnce,
210            2 => QoS::ExactlyOnce,
211            _ => {
212                return Err(MqttError::InvalidQoS(qos_val));
213            }
214        };
215
216        // Variable header
217        let topic_name = decode_string(buf)?;
218
219        // Packet identifier (only for QoS > 0)
220        let packet_id = if qos == QoS::AtMostOnce {
221            None
222        } else {
223            if buf.remaining() < 2 {
224                return Err(MqttError::MalformedPacket(
225                    "Missing packet identifier".to_string(),
226                ));
227            }
228            Some(buf.get_u16())
229        };
230
231        // Properties (v5.0)
232        // Always try to decode properties first
233        let properties = if buf.has_remaining() {
234            match Properties::decode(buf) {
235                Ok(props) => props,
236                Err(_) => {
237                    // If properties decode fails, it might be v3.1.1 format
238                    // In v3.1.1, there are no properties
239                    return Err(MqttError::MalformedPacket(
240                        "Failed to decode PUBLISH properties".to_string(),
241                    ));
242                }
243            }
244        } else {
245            // No properties means empty properties in v5.0
246            Properties::default()
247        };
248
249        // Payload - all remaining bytes
250        let payload = buf.copy_to_bytes(buf.remaining()).to_vec();
251
252        Ok(Self {
253            topic_name,
254            packet_id,
255            payload,
256            qos,
257            retain,
258            dup,
259            properties,
260        })
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use bytes::BytesMut;
268
269    #[test]
270    fn test_publish_packet_qos0() {
271        let packet = PublishPacket::new("test/topic", b"Hello, MQTT!", QoS::AtMostOnce);
272
273        assert_eq!(packet.topic_name, "test/topic");
274        assert_eq!(packet.payload, b"Hello, MQTT!");
275        assert_eq!(packet.qos, QoS::AtMostOnce);
276        assert!(packet.packet_id.is_none());
277        assert!(!packet.retain);
278        assert!(!packet.dup);
279    }
280
281    #[test]
282    fn test_publish_packet_qos1() {
283        let packet =
284            PublishPacket::new("test/topic", b"Hello", QoS::AtLeastOnce).with_packet_id(123);
285
286        assert_eq!(packet.qos, QoS::AtLeastOnce);
287        assert_eq!(packet.packet_id, Some(123));
288    }
289
290    #[test]
291    fn test_publish_packet_with_properties() {
292        let packet = PublishPacket::new("test/topic", b"data", QoS::AtMostOnce)
293            .with_retain(true)
294            .with_payload_format_indicator(true)
295            .with_message_expiry_interval(3600)
296            .with_response_topic("response/topic".to_string())
297            .with_user_property("key".to_string(), "value".to_string());
298
299        assert!(packet.retain);
300        assert!(packet
301            .properties
302            .contains(PropertyId::PayloadFormatIndicator));
303        assert!(packet
304            .properties
305            .contains(PropertyId::MessageExpiryInterval));
306        assert!(packet.properties.contains(PropertyId::ResponseTopic));
307        assert!(packet.properties.contains(PropertyId::UserProperty));
308    }
309
310    #[test]
311    fn test_publish_flags() {
312        let packet = PublishPacket::new("topic", b"data", QoS::AtMostOnce);
313        assert_eq!(packet.flags(), 0x00);
314
315        let packet = PublishPacket::new("topic", b"data", QoS::AtLeastOnce).with_retain(true);
316        assert_eq!(packet.flags(), 0x03); // QoS 1 + Retain
317
318        let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce).with_dup(true);
319        assert_eq!(packet.flags(), 0x0C); // DUP + QoS 2
320
321        let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce)
322            .with_dup(true)
323            .with_retain(true);
324        assert_eq!(packet.flags(), 0x0D); // DUP + QoS 2 + Retain
325    }
326
327    #[test]
328    fn test_publish_encode_decode_qos0() {
329        let packet =
330            PublishPacket::new("sensor/temperature", b"23.5", QoS::AtMostOnce).with_retain(true);
331
332        let mut buf = BytesMut::new();
333        packet.encode(&mut buf).unwrap();
334
335        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
336        assert_eq!(fixed_header.packet_type, PacketType::Publish);
337        assert_eq!(
338            fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
339            crate::flags::PublishFlags::Retain as u8
340        ); // Retain flag
341
342        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
343        assert_eq!(decoded.topic_name, "sensor/temperature");
344        assert_eq!(decoded.payload, b"23.5");
345        assert_eq!(decoded.qos, QoS::AtMostOnce);
346        assert!(decoded.retain);
347        assert!(decoded.packet_id.is_none());
348    }
349
350    #[test]
351    fn test_publish_encode_decode_qos1() {
352        let packet =
353            PublishPacket::new("test/qos1", b"QoS 1 message", QoS::AtLeastOnce).with_packet_id(456);
354
355        let mut buf = BytesMut::new();
356        packet.encode(&mut buf).unwrap();
357
358        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
359        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
360
361        assert_eq!(decoded.topic_name, "test/qos1");
362        assert_eq!(decoded.payload, b"QoS 1 message");
363        assert_eq!(decoded.qos, QoS::AtLeastOnce);
364        assert_eq!(decoded.packet_id, Some(456));
365    }
366
367    #[test]
368    fn test_publish_encode_decode_with_properties() {
369        let packet = PublishPacket::new("test/props", b"data", QoS::ExactlyOnce)
370            .with_packet_id(789)
371            .with_message_expiry_interval(7200)
372            .with_content_type("application/json".to_string());
373
374        let mut buf = BytesMut::new();
375        packet.encode(&mut buf).unwrap();
376
377        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
378        let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
379
380        assert_eq!(decoded.qos, QoS::ExactlyOnce);
381        assert_eq!(decoded.packet_id, Some(789));
382
383        let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
384        assert!(expiry.is_some());
385        if let Some(PropertyValue::FourByteInteger(val)) = expiry {
386            assert_eq!(*val, 7200);
387        }
388
389        let content_type = decoded.properties.get(PropertyId::ContentType);
390        assert!(content_type.is_some());
391        if let Some(PropertyValue::Utf8String(val)) = content_type {
392            assert_eq!(val, "application/json");
393        }
394    }
395
396    #[test]
397    fn test_publish_missing_packet_id() {
398        let mut buf = BytesMut::new();
399        encode_string(&mut buf, "topic").unwrap();
400        // No packet ID for QoS > 0 - buffer ends here
401
402        let fixed_header =
403            FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); // QoS 1
404        let result = PublishPacket::decode_body(&mut buf, &fixed_header);
405        assert!(result.is_err());
406    }
407
408    #[test]
409    fn test_publish_invalid_qos() {
410        let mut buf = BytesMut::new();
411        encode_string(&mut buf, "topic").unwrap();
412
413        let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); // Invalid QoS 3
414        let result = PublishPacket::decode_body(&mut buf, &fixed_header);
415        assert!(result.is_err());
416    }
417}