1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::PublishFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::QoS;
7use bytes::{Buf, BufMut};
8
9#[derive(Debug, Clone)]
11pub struct PublishPacket {
12 pub topic_name: String,
14 pub packet_id: Option<u16>,
16 pub payload: Vec<u8>,
18 pub qos: QoS,
20 pub retain: bool,
22 pub dup: bool,
24 pub properties: Properties,
26}
27
28impl PublishPacket {
29 #[must_use]
31 pub fn new(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
32 let packet_id = if qos == QoS::AtMostOnce {
33 None
34 } else {
35 Some(0) };
37
38 Self {
39 topic_name: topic_name.into(),
40 packet_id,
41 payload: payload.into(),
42 qos,
43 retain: false,
44 dup: false,
45 properties: Properties::default(),
46 }
47 }
48
49 #[must_use]
51 pub fn with_packet_id(mut self, id: u16) -> Self {
52 if self.qos != QoS::AtMostOnce {
53 self.packet_id = Some(id);
54 }
55 self
56 }
57
58 #[must_use]
60 pub fn with_retain(mut self, retain: bool) -> Self {
61 self.retain = retain;
62 self
63 }
64
65 #[must_use]
67 pub fn with_dup(mut self, dup: bool) -> Self {
68 self.dup = dup;
69 self
70 }
71
72 #[must_use]
74 pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
75 self.properties.set_payload_format_indicator(is_utf8);
76 self
77 }
78
79 #[must_use]
81 pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
82 self.properties.set_message_expiry_interval(seconds);
83 self
84 }
85
86 #[must_use]
88 pub fn with_topic_alias(mut self, alias: u16) -> Self {
89 self.properties.set_topic_alias(alias);
90 self
91 }
92
93 #[must_use]
95 pub fn with_response_topic(mut self, topic: String) -> Self {
96 self.properties.set_response_topic(topic);
97 self
98 }
99
100 #[must_use]
102 pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
103 self.properties.set_correlation_data(data.into());
104 self
105 }
106
107 #[must_use]
109 pub fn with_user_property(mut self, key: String, value: String) -> Self {
110 self.properties.add_user_property(key, value);
111 self
112 }
113
114 #[must_use]
116 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
117 self.properties.set_subscription_identifier(id);
118 self
119 }
120
121 #[must_use]
123 pub fn with_content_type(mut self, content_type: String) -> Self {
124 self.properties.set_content_type(content_type);
125 self
126 }
127
128 #[must_use]
129 pub fn topic_alias(&self) -> Option<u16> {
131 self.properties
132 .get(PropertyId::TopicAlias)
133 .and_then(|prop| {
134 if let PropertyValue::TwoByteInteger(alias) = prop {
135 Some(*alias)
136 } else {
137 None
138 }
139 })
140 }
141
142 #[must_use]
143 pub fn message_expiry_interval(&self) -> Option<u32> {
145 self.properties
146 .get(PropertyId::MessageExpiryInterval)
147 .and_then(|prop| {
148 if let PropertyValue::FourByteInteger(interval) = prop {
149 Some(*interval)
150 } else {
151 None
152 }
153 })
154 }
155}
156
157impl MqttPacket for PublishPacket {
158 fn packet_type(&self) -> PacketType {
159 PacketType::Publish
160 }
161
162 fn flags(&self) -> u8 {
163 let mut flags = 0u8;
164
165 if self.dup {
166 flags |= PublishFlags::Dup as u8;
167 }
168
169 flags = PublishFlags::with_qos(flags, self.qos as u8);
170
171 if self.retain {
172 flags |= PublishFlags::Retain as u8;
173 }
174
175 flags
176 }
177
178 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
179 encode_string(buf, &self.topic_name)?;
181
182 if self.qos != QoS::AtMostOnce {
184 let packet_id = self.packet_id.ok_or_else(|| {
185 MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
186 })?;
187 buf.put_u16(packet_id);
188 }
189
190 self.properties.encode(buf)?;
193
194 buf.put_slice(&self.payload);
196
197 Ok(())
198 }
199
200 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
201 let flags = PublishFlags::decompose(fixed_header.flags);
203 let dup = flags.contains(&PublishFlags::Dup);
204 let qos_val = PublishFlags::extract_qos(fixed_header.flags);
205 let retain = flags.contains(&PublishFlags::Retain);
206
207 let qos = match qos_val {
208 0 => QoS::AtMostOnce,
209 1 => QoS::AtLeastOnce,
210 2 => QoS::ExactlyOnce,
211 _ => {
212 return Err(MqttError::InvalidQoS(qos_val));
213 }
214 };
215
216 let topic_name = decode_string(buf)?;
218
219 let packet_id = if qos == QoS::AtMostOnce {
221 None
222 } else {
223 if buf.remaining() < 2 {
224 return Err(MqttError::MalformedPacket(
225 "Missing packet identifier".to_string(),
226 ));
227 }
228 Some(buf.get_u16())
229 };
230
231 let properties = if buf.has_remaining() {
234 match Properties::decode(buf) {
235 Ok(props) => props,
236 Err(_) => {
237 return Err(MqttError::MalformedPacket(
240 "Failed to decode PUBLISH properties".to_string(),
241 ));
242 }
243 }
244 } else {
245 Properties::default()
247 };
248
249 let payload = buf.copy_to_bytes(buf.remaining()).to_vec();
251
252 Ok(Self {
253 topic_name,
254 packet_id,
255 payload,
256 qos,
257 retain,
258 dup,
259 properties,
260 })
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use bytes::BytesMut;
268
269 #[test]
270 fn test_publish_packet_qos0() {
271 let packet = PublishPacket::new("test/topic", b"Hello, MQTT!", QoS::AtMostOnce);
272
273 assert_eq!(packet.topic_name, "test/topic");
274 assert_eq!(packet.payload, b"Hello, MQTT!");
275 assert_eq!(packet.qos, QoS::AtMostOnce);
276 assert!(packet.packet_id.is_none());
277 assert!(!packet.retain);
278 assert!(!packet.dup);
279 }
280
281 #[test]
282 fn test_publish_packet_qos1() {
283 let packet =
284 PublishPacket::new("test/topic", b"Hello", QoS::AtLeastOnce).with_packet_id(123);
285
286 assert_eq!(packet.qos, QoS::AtLeastOnce);
287 assert_eq!(packet.packet_id, Some(123));
288 }
289
290 #[test]
291 fn test_publish_packet_with_properties() {
292 let packet = PublishPacket::new("test/topic", b"data", QoS::AtMostOnce)
293 .with_retain(true)
294 .with_payload_format_indicator(true)
295 .with_message_expiry_interval(3600)
296 .with_response_topic("response/topic".to_string())
297 .with_user_property("key".to_string(), "value".to_string());
298
299 assert!(packet.retain);
300 assert!(packet
301 .properties
302 .contains(PropertyId::PayloadFormatIndicator));
303 assert!(packet
304 .properties
305 .contains(PropertyId::MessageExpiryInterval));
306 assert!(packet.properties.contains(PropertyId::ResponseTopic));
307 assert!(packet.properties.contains(PropertyId::UserProperty));
308 }
309
310 #[test]
311 fn test_publish_flags() {
312 let packet = PublishPacket::new("topic", b"data", QoS::AtMostOnce);
313 assert_eq!(packet.flags(), 0x00);
314
315 let packet = PublishPacket::new("topic", b"data", QoS::AtLeastOnce).with_retain(true);
316 assert_eq!(packet.flags(), 0x03); let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce).with_dup(true);
319 assert_eq!(packet.flags(), 0x0C); let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce)
322 .with_dup(true)
323 .with_retain(true);
324 assert_eq!(packet.flags(), 0x0D); }
326
327 #[test]
328 fn test_publish_encode_decode_qos0() {
329 let packet =
330 PublishPacket::new("sensor/temperature", b"23.5", QoS::AtMostOnce).with_retain(true);
331
332 let mut buf = BytesMut::new();
333 packet.encode(&mut buf).unwrap();
334
335 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
336 assert_eq!(fixed_header.packet_type, PacketType::Publish);
337 assert_eq!(
338 fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
339 crate::flags::PublishFlags::Retain as u8
340 ); let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
343 assert_eq!(decoded.topic_name, "sensor/temperature");
344 assert_eq!(decoded.payload, b"23.5");
345 assert_eq!(decoded.qos, QoS::AtMostOnce);
346 assert!(decoded.retain);
347 assert!(decoded.packet_id.is_none());
348 }
349
350 #[test]
351 fn test_publish_encode_decode_qos1() {
352 let packet =
353 PublishPacket::new("test/qos1", b"QoS 1 message", QoS::AtLeastOnce).with_packet_id(456);
354
355 let mut buf = BytesMut::new();
356 packet.encode(&mut buf).unwrap();
357
358 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
359 let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
360
361 assert_eq!(decoded.topic_name, "test/qos1");
362 assert_eq!(decoded.payload, b"QoS 1 message");
363 assert_eq!(decoded.qos, QoS::AtLeastOnce);
364 assert_eq!(decoded.packet_id, Some(456));
365 }
366
367 #[test]
368 fn test_publish_encode_decode_with_properties() {
369 let packet = PublishPacket::new("test/props", b"data", QoS::ExactlyOnce)
370 .with_packet_id(789)
371 .with_message_expiry_interval(7200)
372 .with_content_type("application/json".to_string());
373
374 let mut buf = BytesMut::new();
375 packet.encode(&mut buf).unwrap();
376
377 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
378 let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
379
380 assert_eq!(decoded.qos, QoS::ExactlyOnce);
381 assert_eq!(decoded.packet_id, Some(789));
382
383 let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
384 assert!(expiry.is_some());
385 if let Some(PropertyValue::FourByteInteger(val)) = expiry {
386 assert_eq!(*val, 7200);
387 }
388
389 let content_type = decoded.properties.get(PropertyId::ContentType);
390 assert!(content_type.is_some());
391 if let Some(PropertyValue::Utf8String(val)) = content_type {
392 assert_eq!(val, "application/json");
393 }
394 }
395
396 #[test]
397 fn test_publish_missing_packet_id() {
398 let mut buf = BytesMut::new();
399 encode_string(&mut buf, "topic").unwrap();
400 let fixed_header =
403 FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
405 assert!(result.is_err());
406 }
407
408 #[test]
409 fn test_publish_invalid_qos() {
410 let mut buf = BytesMut::new();
411 encode_string(&mut buf, "topic").unwrap();
412
413 let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
415 assert!(result.is_err());
416 }
417}