1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::PublishFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bytes::{Buf, BufMut};
9
10#[derive(Debug, Clone)]
12pub struct PublishPacket {
13 pub topic_name: String,
15 pub packet_id: Option<u16>,
17 pub payload: Vec<u8>,
19 pub qos: QoS,
21 pub retain: bool,
23 pub dup: bool,
25 pub properties: Properties,
27 pub protocol_version: u8,
29}
30
31impl PublishPacket {
32 #[must_use]
34 pub fn new(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
35 let packet_id = if qos == QoS::AtMostOnce {
36 None
37 } else {
38 Some(0)
39 };
40
41 Self {
42 topic_name: topic_name.into(),
43 packet_id,
44 payload: payload.into(),
45 qos,
46 retain: false,
47 dup: false,
48 properties: Properties::default(),
49 protocol_version: 5,
50 }
51 }
52
53 #[must_use]
55 pub fn new_v311(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
56 let packet_id = if qos == QoS::AtMostOnce {
57 None
58 } else {
59 Some(0)
60 };
61
62 Self {
63 topic_name: topic_name.into(),
64 packet_id,
65 payload: payload.into(),
66 qos,
67 retain: false,
68 dup: false,
69 properties: Properties::default(),
70 protocol_version: 4,
71 }
72 }
73
74 #[must_use]
76 pub fn with_packet_id(mut self, id: u16) -> Self {
77 if self.qos != QoS::AtMostOnce {
78 self.packet_id = Some(id);
79 }
80 self
81 }
82
83 #[must_use]
85 pub fn with_retain(mut self, retain: bool) -> Self {
86 self.retain = retain;
87 self
88 }
89
90 #[must_use]
92 pub fn with_dup(mut self, dup: bool) -> Self {
93 self.dup = dup;
94 self
95 }
96
97 #[must_use]
99 pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
100 self.properties.set_payload_format_indicator(is_utf8);
101 self
102 }
103
104 #[must_use]
106 pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
107 self.properties.set_message_expiry_interval(seconds);
108 self
109 }
110
111 #[must_use]
113 pub fn with_topic_alias(mut self, alias: u16) -> Self {
114 self.properties.set_topic_alias(alias);
115 self
116 }
117
118 #[must_use]
120 pub fn with_response_topic(mut self, topic: String) -> Self {
121 self.properties.set_response_topic(topic);
122 self
123 }
124
125 #[must_use]
127 pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
128 self.properties.set_correlation_data(data.into());
129 self
130 }
131
132 #[must_use]
134 pub fn with_user_property(mut self, key: String, value: String) -> Self {
135 self.properties.add_user_property(key, value);
136 self
137 }
138
139 #[must_use]
141 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
142 self.properties.set_subscription_identifier(id);
143 self
144 }
145
146 #[must_use]
148 pub fn with_content_type(mut self, content_type: String) -> Self {
149 self.properties.set_content_type(content_type);
150 self
151 }
152
153 #[must_use]
154 pub fn topic_alias(&self) -> Option<u16> {
156 self.properties
157 .get(PropertyId::TopicAlias)
158 .and_then(|prop| {
159 if let PropertyValue::TwoByteInteger(alias) = prop {
160 Some(*alias)
161 } else {
162 None
163 }
164 })
165 }
166
167 #[must_use]
168 pub fn message_expiry_interval(&self) -> Option<u32> {
170 self.properties
171 .get(PropertyId::MessageExpiryInterval)
172 .and_then(|prop| {
173 if let PropertyValue::FourByteInteger(interval) = prop {
174 Some(*interval)
175 } else {
176 None
177 }
178 })
179 }
180
181 #[must_use]
182 pub fn body_encoded_size(&self) -> usize {
183 let mut size = 2 + self.topic_name.len();
184
185 if self.qos != QoS::AtMostOnce {
186 size += 2;
187 }
188
189 if self.protocol_version == 5 {
190 size += self.properties.encoded_len();
191 }
192
193 size += self.payload.len();
194 size
195 }
196
197 pub fn encode_body_direct<B: BufMut>(&self, buf: &mut B) -> Result<()> {
200 encode_string(buf, &self.topic_name)?;
201
202 if self.qos != QoS::AtMostOnce {
203 let packet_id = self.packet_id.ok_or_else(|| {
204 MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
205 })?;
206 buf.put_u16(packet_id);
207 }
208
209 if self.protocol_version == 5 {
210 self.properties.encode_direct(buf)?;
211 }
212
213 buf.put_slice(&self.payload);
214
215 Ok(())
216 }
217}
218
219impl MqttPacket for PublishPacket {
220 fn packet_type(&self) -> PacketType {
221 PacketType::Publish
222 }
223
224 fn flags(&self) -> u8 {
225 let mut flags = 0u8;
226
227 if self.dup {
228 flags |= PublishFlags::Dup as u8;
229 }
230
231 flags = PublishFlags::with_qos(flags, self.qos as u8);
232
233 if self.retain {
234 flags |= PublishFlags::Retain as u8;
235 }
236
237 flags
238 }
239
240 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
241 encode_string(buf, &self.topic_name)?;
242
243 if self.qos != QoS::AtMostOnce {
244 let packet_id = self.packet_id.ok_or_else(|| {
245 MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
246 })?;
247 buf.put_u16(packet_id);
248 }
249
250 if self.protocol_version == 5 {
251 self.properties.encode(buf)?;
252 }
253
254 buf.put_slice(&self.payload);
255
256 Ok(())
257 }
258
259 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
260 Self::decode_body_with_version(buf, fixed_header, 5)
261 }
262}
263
264impl PublishPacket {
265 pub fn decode_body_with_version<B: Buf>(
271 buf: &mut B,
272 fixed_header: &FixedHeader,
273 protocol_version: u8,
274 ) -> Result<Self> {
275 ProtocolVersion::try_from(protocol_version)
276 .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
277
278 let flags = PublishFlags::decompose(fixed_header.flags);
279 let dup = flags.contains(&PublishFlags::Dup);
280 let qos_val = PublishFlags::extract_qos(fixed_header.flags);
281 let retain = flags.contains(&PublishFlags::Retain);
282
283 let qos = match qos_val {
284 0 => QoS::AtMostOnce,
285 1 => QoS::AtLeastOnce,
286 2 => QoS::ExactlyOnce,
287 _ => {
288 return Err(MqttError::InvalidQoS(qos_val));
289 }
290 };
291
292 let topic_name = decode_string(buf)?;
293
294 let packet_id = if qos == QoS::AtMostOnce {
295 None
296 } else {
297 if buf.remaining() < 2 {
298 return Err(MqttError::MalformedPacket(
299 "Missing packet identifier".to_string(),
300 ));
301 }
302 Some(buf.get_u16())
303 };
304
305 let properties = if protocol_version == 5 {
306 Properties::decode(buf)?
307 } else {
308 Properties::default()
309 };
310
311 let payload = buf.copy_to_bytes(buf.remaining()).to_vec();
312
313 Ok(Self {
314 topic_name,
315 packet_id,
316 payload,
317 qos,
318 retain,
319 dup,
320 properties,
321 protocol_version,
322 })
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use bytes::BytesMut;
330
331 #[test]
332 fn test_publish_packet_qos0() {
333 let packet = PublishPacket::new("test/topic", b"Hello, MQTT!", QoS::AtMostOnce);
334
335 assert_eq!(packet.topic_name, "test/topic");
336 assert_eq!(packet.payload, b"Hello, MQTT!");
337 assert_eq!(packet.qos, QoS::AtMostOnce);
338 assert!(packet.packet_id.is_none());
339 assert!(!packet.retain);
340 assert!(!packet.dup);
341 }
342
343 #[test]
344 fn test_publish_packet_qos1() {
345 let packet =
346 PublishPacket::new("test/topic", b"Hello", QoS::AtLeastOnce).with_packet_id(123);
347
348 assert_eq!(packet.qos, QoS::AtLeastOnce);
349 assert_eq!(packet.packet_id, Some(123));
350 }
351
352 #[test]
353 fn test_publish_packet_with_properties() {
354 let packet = PublishPacket::new("test/topic", b"data", QoS::AtMostOnce)
355 .with_retain(true)
356 .with_payload_format_indicator(true)
357 .with_message_expiry_interval(3600)
358 .with_response_topic("response/topic".to_string())
359 .with_user_property("key".to_string(), "value".to_string());
360
361 assert!(packet.retain);
362 assert!(packet
363 .properties
364 .contains(PropertyId::PayloadFormatIndicator));
365 assert!(packet
366 .properties
367 .contains(PropertyId::MessageExpiryInterval));
368 assert!(packet.properties.contains(PropertyId::ResponseTopic));
369 assert!(packet.properties.contains(PropertyId::UserProperty));
370 }
371
372 #[test]
373 fn test_publish_flags() {
374 let packet = PublishPacket::new("topic", b"data", QoS::AtMostOnce);
375 assert_eq!(packet.flags(), 0x00);
376
377 let packet = PublishPacket::new("topic", b"data", QoS::AtLeastOnce).with_retain(true);
378 assert_eq!(packet.flags(), 0x03); let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce).with_dup(true);
381 assert_eq!(packet.flags(), 0x0C); let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce)
384 .with_dup(true)
385 .with_retain(true);
386 assert_eq!(packet.flags(), 0x0D); }
388
389 #[test]
390 fn test_publish_encode_decode_qos0() {
391 let packet =
392 PublishPacket::new("sensor/temperature", b"23.5", QoS::AtMostOnce).with_retain(true);
393
394 let mut buf = BytesMut::new();
395 packet.encode(&mut buf).unwrap();
396
397 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
398 assert_eq!(fixed_header.packet_type, PacketType::Publish);
399 assert_eq!(
400 fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
401 crate::flags::PublishFlags::Retain as u8
402 ); let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
405 assert_eq!(decoded.topic_name, "sensor/temperature");
406 assert_eq!(decoded.payload, b"23.5");
407 assert_eq!(decoded.qos, QoS::AtMostOnce);
408 assert!(decoded.retain);
409 assert!(decoded.packet_id.is_none());
410 }
411
412 #[test]
413 fn test_publish_encode_decode_qos1() {
414 let packet =
415 PublishPacket::new("test/qos1", b"QoS 1 message", QoS::AtLeastOnce).with_packet_id(456);
416
417 let mut buf = BytesMut::new();
418 packet.encode(&mut buf).unwrap();
419
420 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
421 let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
422
423 assert_eq!(decoded.topic_name, "test/qos1");
424 assert_eq!(decoded.payload, b"QoS 1 message");
425 assert_eq!(decoded.qos, QoS::AtLeastOnce);
426 assert_eq!(decoded.packet_id, Some(456));
427 }
428
429 #[test]
430 fn test_publish_encode_decode_with_properties() {
431 let packet = PublishPacket::new("test/props", b"data", QoS::ExactlyOnce)
432 .with_packet_id(789)
433 .with_message_expiry_interval(7200)
434 .with_content_type("application/json".to_string());
435
436 let mut buf = BytesMut::new();
437 packet.encode(&mut buf).unwrap();
438
439 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
440 let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
441
442 assert_eq!(decoded.qos, QoS::ExactlyOnce);
443 assert_eq!(decoded.packet_id, Some(789));
444
445 let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
446 assert!(expiry.is_some());
447 if let Some(PropertyValue::FourByteInteger(val)) = expiry {
448 assert_eq!(*val, 7200);
449 }
450
451 let content_type = decoded.properties.get(PropertyId::ContentType);
452 assert!(content_type.is_some());
453 if let Some(PropertyValue::Utf8String(val)) = content_type {
454 assert_eq!(val, "application/json");
455 }
456 }
457
458 #[test]
459 fn test_publish_missing_packet_id() {
460 let mut buf = BytesMut::new();
461 encode_string(&mut buf, "topic").unwrap();
462 let fixed_header =
465 FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
467 assert!(result.is_err());
468 }
469
470 #[test]
471 fn test_publish_invalid_qos() {
472 let mut buf = BytesMut::new();
473 encode_string(&mut buf, "topic").unwrap();
474
475 let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
477 assert!(result.is_err());
478 }
479}