1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::PublishFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bytes::{Buf, BufMut};
9
10#[derive(Debug, Clone)]
12pub struct PublishPacket {
13 pub topic_name: String,
15 pub packet_id: Option<u16>,
17 pub payload: Vec<u8>,
19 pub qos: QoS,
21 pub retain: bool,
23 pub dup: bool,
25 pub properties: Properties,
27 pub protocol_version: u8,
29}
30
31impl PublishPacket {
32 #[must_use]
34 pub fn new(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
35 let packet_id = if qos == QoS::AtMostOnce {
36 None
37 } else {
38 Some(0)
39 };
40
41 Self {
42 topic_name: topic_name.into(),
43 packet_id,
44 payload: payload.into(),
45 qos,
46 retain: false,
47 dup: false,
48 properties: Properties::default(),
49 protocol_version: 5,
50 }
51 }
52
53 #[must_use]
55 pub fn new_v311(topic_name: impl Into<String>, payload: impl Into<Vec<u8>>, qos: QoS) -> Self {
56 let packet_id = if qos == QoS::AtMostOnce {
57 None
58 } else {
59 Some(0)
60 };
61
62 Self {
63 topic_name: topic_name.into(),
64 packet_id,
65 payload: payload.into(),
66 qos,
67 retain: false,
68 dup: false,
69 properties: Properties::default(),
70 protocol_version: 4,
71 }
72 }
73
74 #[must_use]
76 pub fn with_packet_id(mut self, id: u16) -> Self {
77 if self.qos != QoS::AtMostOnce {
78 self.packet_id = Some(id);
79 }
80 self
81 }
82
83 #[must_use]
85 pub fn with_retain(mut self, retain: bool) -> Self {
86 self.retain = retain;
87 self
88 }
89
90 #[must_use]
92 pub fn with_dup(mut self, dup: bool) -> Self {
93 self.dup = dup;
94 self
95 }
96
97 #[must_use]
99 pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
100 self.properties.set_payload_format_indicator(is_utf8);
101 self
102 }
103
104 #[must_use]
106 pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
107 self.properties.set_message_expiry_interval(seconds);
108 self
109 }
110
111 #[must_use]
113 pub fn with_topic_alias(mut self, alias: u16) -> Self {
114 self.properties.set_topic_alias(alias);
115 self
116 }
117
118 #[must_use]
120 pub fn with_response_topic(mut self, topic: String) -> Self {
121 self.properties.set_response_topic(topic);
122 self
123 }
124
125 #[must_use]
127 pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
128 self.properties.set_correlation_data(data.into());
129 self
130 }
131
132 #[must_use]
134 pub fn with_user_property(mut self, key: String, value: String) -> Self {
135 self.properties.add_user_property(key, value);
136 self
137 }
138
139 #[must_use]
141 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
142 self.properties.set_subscription_identifier(id);
143 self
144 }
145
146 #[must_use]
148 pub fn with_content_type(mut self, content_type: String) -> Self {
149 self.properties.set_content_type(content_type);
150 self
151 }
152
153 #[must_use]
154 pub fn topic_alias(&self) -> Option<u16> {
156 self.properties
157 .get(PropertyId::TopicAlias)
158 .and_then(|prop| {
159 if let PropertyValue::TwoByteInteger(alias) = prop {
160 Some(*alias)
161 } else {
162 None
163 }
164 })
165 }
166
167 #[must_use]
168 pub fn message_expiry_interval(&self) -> Option<u32> {
170 self.properties
171 .get(PropertyId::MessageExpiryInterval)
172 .and_then(|prop| {
173 if let PropertyValue::FourByteInteger(interval) = prop {
174 Some(*interval)
175 } else {
176 None
177 }
178 })
179 }
180}
181
182impl MqttPacket for PublishPacket {
183 fn packet_type(&self) -> PacketType {
184 PacketType::Publish
185 }
186
187 fn flags(&self) -> u8 {
188 let mut flags = 0u8;
189
190 if self.dup {
191 flags |= PublishFlags::Dup as u8;
192 }
193
194 flags = PublishFlags::with_qos(flags, self.qos as u8);
195
196 if self.retain {
197 flags |= PublishFlags::Retain as u8;
198 }
199
200 flags
201 }
202
203 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
204 encode_string(buf, &self.topic_name)?;
205
206 if self.qos != QoS::AtMostOnce {
207 let packet_id = self.packet_id.ok_or_else(|| {
208 MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
209 })?;
210 buf.put_u16(packet_id);
211 }
212
213 if self.protocol_version == 5 {
214 self.properties.encode(buf)?;
215 }
216
217 buf.put_slice(&self.payload);
218
219 Ok(())
220 }
221
222 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
223 Self::decode_body_with_version(buf, fixed_header, 5)
224 }
225}
226
227impl PublishPacket {
228 pub fn decode_body_with_version<B: Buf>(
234 buf: &mut B,
235 fixed_header: &FixedHeader,
236 protocol_version: u8,
237 ) -> Result<Self> {
238 ProtocolVersion::try_from(protocol_version)
239 .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
240
241 let flags = PublishFlags::decompose(fixed_header.flags);
242 let dup = flags.contains(&PublishFlags::Dup);
243 let qos_val = PublishFlags::extract_qos(fixed_header.flags);
244 let retain = flags.contains(&PublishFlags::Retain);
245
246 let qos = match qos_val {
247 0 => QoS::AtMostOnce,
248 1 => QoS::AtLeastOnce,
249 2 => QoS::ExactlyOnce,
250 _ => {
251 return Err(MqttError::InvalidQoS(qos_val));
252 }
253 };
254
255 let topic_name = decode_string(buf)?;
256
257 let packet_id = if qos == QoS::AtMostOnce {
258 None
259 } else {
260 if buf.remaining() < 2 {
261 return Err(MqttError::MalformedPacket(
262 "Missing packet identifier".to_string(),
263 ));
264 }
265 Some(buf.get_u16())
266 };
267
268 let properties = if protocol_version == 5 {
269 Properties::decode(buf)?
270 } else {
271 Properties::default()
272 };
273
274 let payload = buf.copy_to_bytes(buf.remaining()).to_vec();
275
276 Ok(Self {
277 topic_name,
278 packet_id,
279 payload,
280 qos,
281 retain,
282 dup,
283 properties,
284 protocol_version,
285 })
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use bytes::BytesMut;
293
294 #[test]
295 fn test_publish_packet_qos0() {
296 let packet = PublishPacket::new("test/topic", b"Hello, MQTT!", QoS::AtMostOnce);
297
298 assert_eq!(packet.topic_name, "test/topic");
299 assert_eq!(packet.payload, b"Hello, MQTT!");
300 assert_eq!(packet.qos, QoS::AtMostOnce);
301 assert!(packet.packet_id.is_none());
302 assert!(!packet.retain);
303 assert!(!packet.dup);
304 }
305
306 #[test]
307 fn test_publish_packet_qos1() {
308 let packet =
309 PublishPacket::new("test/topic", b"Hello", QoS::AtLeastOnce).with_packet_id(123);
310
311 assert_eq!(packet.qos, QoS::AtLeastOnce);
312 assert_eq!(packet.packet_id, Some(123));
313 }
314
315 #[test]
316 fn test_publish_packet_with_properties() {
317 let packet = PublishPacket::new("test/topic", b"data", QoS::AtMostOnce)
318 .with_retain(true)
319 .with_payload_format_indicator(true)
320 .with_message_expiry_interval(3600)
321 .with_response_topic("response/topic".to_string())
322 .with_user_property("key".to_string(), "value".to_string());
323
324 assert!(packet.retain);
325 assert!(packet
326 .properties
327 .contains(PropertyId::PayloadFormatIndicator));
328 assert!(packet
329 .properties
330 .contains(PropertyId::MessageExpiryInterval));
331 assert!(packet.properties.contains(PropertyId::ResponseTopic));
332 assert!(packet.properties.contains(PropertyId::UserProperty));
333 }
334
335 #[test]
336 fn test_publish_flags() {
337 let packet = PublishPacket::new("topic", b"data", QoS::AtMostOnce);
338 assert_eq!(packet.flags(), 0x00);
339
340 let packet = PublishPacket::new("topic", b"data", QoS::AtLeastOnce).with_retain(true);
341 assert_eq!(packet.flags(), 0x03); let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce).with_dup(true);
344 assert_eq!(packet.flags(), 0x0C); let packet = PublishPacket::new("topic", b"data", QoS::ExactlyOnce)
347 .with_dup(true)
348 .with_retain(true);
349 assert_eq!(packet.flags(), 0x0D); }
351
352 #[test]
353 fn test_publish_encode_decode_qos0() {
354 let packet =
355 PublishPacket::new("sensor/temperature", b"23.5", QoS::AtMostOnce).with_retain(true);
356
357 let mut buf = BytesMut::new();
358 packet.encode(&mut buf).unwrap();
359
360 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
361 assert_eq!(fixed_header.packet_type, PacketType::Publish);
362 assert_eq!(
363 fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
364 crate::flags::PublishFlags::Retain as u8
365 ); let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
368 assert_eq!(decoded.topic_name, "sensor/temperature");
369 assert_eq!(decoded.payload, b"23.5");
370 assert_eq!(decoded.qos, QoS::AtMostOnce);
371 assert!(decoded.retain);
372 assert!(decoded.packet_id.is_none());
373 }
374
375 #[test]
376 fn test_publish_encode_decode_qos1() {
377 let packet =
378 PublishPacket::new("test/qos1", b"QoS 1 message", QoS::AtLeastOnce).with_packet_id(456);
379
380 let mut buf = BytesMut::new();
381 packet.encode(&mut buf).unwrap();
382
383 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
384 let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
385
386 assert_eq!(decoded.topic_name, "test/qos1");
387 assert_eq!(decoded.payload, b"QoS 1 message");
388 assert_eq!(decoded.qos, QoS::AtLeastOnce);
389 assert_eq!(decoded.packet_id, Some(456));
390 }
391
392 #[test]
393 fn test_publish_encode_decode_with_properties() {
394 let packet = PublishPacket::new("test/props", b"data", QoS::ExactlyOnce)
395 .with_packet_id(789)
396 .with_message_expiry_interval(7200)
397 .with_content_type("application/json".to_string());
398
399 let mut buf = BytesMut::new();
400 packet.encode(&mut buf).unwrap();
401
402 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
403 let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
404
405 assert_eq!(decoded.qos, QoS::ExactlyOnce);
406 assert_eq!(decoded.packet_id, Some(789));
407
408 let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
409 assert!(expiry.is_some());
410 if let Some(PropertyValue::FourByteInteger(val)) = expiry {
411 assert_eq!(*val, 7200);
412 }
413
414 let content_type = decoded.properties.get(PropertyId::ContentType);
415 assert!(content_type.is_some());
416 if let Some(PropertyValue::Utf8String(val)) = content_type {
417 assert_eq!(val, "application/json");
418 }
419 }
420
421 #[test]
422 fn test_publish_missing_packet_id() {
423 let mut buf = BytesMut::new();
424 encode_string(&mut buf, "topic").unwrap();
425 let fixed_header =
428 FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
430 assert!(result.is_err());
431 }
432
433 #[test]
434 fn test_publish_invalid_qos() {
435 let mut buf = BytesMut::new();
436 encode_string(&mut buf, "topic").unwrap();
437
438 let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
440 assert!(result.is_err());
441 }
442}