1use crate::error::{MqttError, Result};
2use crate::packet::{AckPacketHeader, FixedHeader, MqttPacket, PacketType};
3use crate::protocol::v5::properties::Properties;
4use crate::types::ReasonCode;
5use bytes::{Buf, BufMut};
6
7#[derive(Debug, Clone)]
9pub struct PubAckPacket {
10 pub packet_id: u16,
12 pub reason_code: ReasonCode,
14 pub properties: Properties,
16}
17
18impl PubAckPacket {
19 #[must_use]
21 pub fn new(packet_id: u16) -> Self {
22 Self {
23 packet_id,
24 reason_code: ReasonCode::Success,
25 properties: Properties::default(),
26 }
27 }
28
29 #[must_use]
31 pub fn new_with_reason(packet_id: u16, reason_code: ReasonCode) -> Self {
32 Self {
33 packet_id,
34 reason_code,
35 properties: Properties::default(),
36 }
37 }
38
39 #[must_use]
41 pub fn with_reason_string(mut self, reason: String) -> Self {
42 self.properties.set_reason_string(reason);
43 self
44 }
45
46 #[must_use]
48 pub fn with_user_property(mut self, key: String, value: String) -> Self {
49 self.properties.add_user_property(key, value);
50 self
51 }
52
53 fn is_valid_puback_reason_code(code: ReasonCode) -> bool {
55 matches!(
56 code,
57 ReasonCode::Success
58 | ReasonCode::NoMatchingSubscribers
59 | ReasonCode::UnspecifiedError
60 | ReasonCode::ImplementationSpecificError
61 | ReasonCode::NotAuthorized
62 | ReasonCode::TopicNameInvalid
63 | ReasonCode::PacketIdentifierInUse
64 | ReasonCode::QuotaExceeded
65 | ReasonCode::PayloadFormatInvalid
66 )
67 }
68
69 #[must_use]
71 pub fn create_header(&self) -> AckPacketHeader {
72 AckPacketHeader::create(self.packet_id, self.reason_code)
73 }
74
75 pub fn from_header(header: AckPacketHeader, properties: Properties) -> Result<Self> {
81 let reason_code = header.get_reason_code().ok_or_else(|| {
82 MqttError::MalformedPacket(format!(
83 "Invalid PUBACK reason code: 0x{:02X}",
84 header.reason_code
85 ))
86 })?;
87
88 if !Self::is_valid_puback_reason_code(reason_code) {
89 return Err(MqttError::MalformedPacket(format!(
90 "Invalid PUBACK reason code: {reason_code:?}"
91 )));
92 }
93
94 Ok(Self {
95 packet_id: header.packet_id,
96 reason_code,
97 properties,
98 })
99 }
100}
101
102impl MqttPacket for PubAckPacket {
103 fn packet_type(&self) -> PacketType {
104 PacketType::PubAck
105 }
106
107 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
108 buf.put_u16(self.packet_id);
110
111 if self.reason_code != ReasonCode::Success || !self.properties.is_empty() {
114 buf.put_u8(u8::from(self.reason_code));
115 self.properties.encode(buf)?;
116 }
117
118 Ok(())
119 }
120
121 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
122 tracing::trace!(
123 fixed_header_remaining = fixed_header.remaining_length,
124 buf_remaining = buf.remaining(),
125 "PUBACK decode started"
126 );
127
128 if buf.remaining() < 2 {
130 return Err(MqttError::MalformedPacket(
131 "PUBACK missing packet identifier".to_string(),
132 ));
133 }
134 let packet_id = buf.get_u16();
135
136 let (reason_code, properties) = if buf.has_remaining() {
139 let reason_byte = buf.get_u8();
141 let code = ReasonCode::from_u8(reason_byte).ok_or_else(|| {
142 MqttError::MalformedPacket(format!(
143 "Invalid PUBACK reason code: {reason_byte} (0x{reason_byte:02X})"
144 ))
145 })?;
146
147 if !Self::is_valid_puback_reason_code(code) {
148 return Err(MqttError::MalformedPacket(format!(
149 "Invalid PUBACK reason code: {code:?}"
150 )));
151 }
152
153 let props = if buf.has_remaining() {
155 Properties::decode(buf)?
156 } else {
157 Properties::default()
158 };
159
160 (code, props)
161 } else {
162 (ReasonCode::Success, Properties::default())
164 };
165
166 Ok(Self {
167 packet_id,
168 reason_code,
169 properties,
170 })
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::protocol::v5::properties::PropertyId;
178 use bytes::BytesMut;
179
180 #[cfg(test)]
181 mod bebytes_tests {
182 use super::*;
183 use bebytes::BeBytes;
184 use proptest::prelude::*;
185
186 #[test]
187 fn test_ack_header_creation() {
188 let header = AckPacketHeader::create(123, ReasonCode::Success);
189 assert_eq!(header.packet_id, 123);
190 assert_eq!(header.reason_code, 0x00);
191 assert_eq!(header.get_reason_code(), Some(ReasonCode::Success));
192 }
193
194 #[test]
195 fn test_ack_header_round_trip() {
196 let header = AckPacketHeader::create(456, ReasonCode::QuotaExceeded);
197 let bytes = header.to_be_bytes();
198 assert_eq!(bytes.len(), 3); let (decoded, consumed) = AckPacketHeader::try_from_be_bytes(&bytes).unwrap();
201 assert_eq!(consumed, 3);
202 assert_eq!(decoded, header);
203 assert_eq!(decoded.packet_id, 456);
204 assert_eq!(decoded.get_reason_code(), Some(ReasonCode::QuotaExceeded));
205 }
206
207 #[test]
208 fn test_puback_from_header() {
209 let header = AckPacketHeader::create(789, ReasonCode::NoMatchingSubscribers);
210 let properties = Properties::default();
211
212 let packet = PubAckPacket::from_header(header, properties).unwrap();
213 assert_eq!(packet.packet_id, 789);
214 assert_eq!(packet.reason_code, ReasonCode::NoMatchingSubscribers);
215 }
216
217 proptest! {
218 #[test]
219 fn prop_ack_header_round_trip(
220 packet_id in any::<u16>(),
221 reason_code in 0u8..=255u8
222 ) {
223 let header = AckPacketHeader {
224 packet_id,
225 reason_code,
226 };
227
228 let bytes = header.to_be_bytes();
229 let (decoded, consumed) = AckPacketHeader::try_from_be_bytes(&bytes).unwrap();
230
231 prop_assert_eq!(consumed, 3);
232 prop_assert_eq!(decoded, header);
233 prop_assert_eq!(decoded.packet_id, packet_id);
234 prop_assert_eq!(decoded.reason_code, reason_code);
235 }
236 }
237 }
238
239 #[test]
240 fn test_puback_basic() {
241 let packet = PubAckPacket::new(123);
242
243 assert_eq!(packet.packet_id, 123);
244 assert_eq!(packet.reason_code, ReasonCode::Success);
245 assert!(packet.properties.is_empty());
246 }
247
248 #[test]
249 fn test_puback_with_reason() {
250 let packet = PubAckPacket::new_with_reason(456, ReasonCode::NoMatchingSubscribers)
251 .with_reason_string("No subscribers for topic".to_string());
252
253 assert_eq!(packet.packet_id, 456);
254 assert_eq!(packet.reason_code, ReasonCode::NoMatchingSubscribers);
255 assert!(packet.properties.contains(PropertyId::ReasonString));
256 }
257
258 #[test]
259 fn test_puback_encode_decode_minimal() {
260 let packet = PubAckPacket::new(789);
261
262 let mut buf = BytesMut::new();
263 packet.encode(&mut buf).unwrap();
264
265 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
266 assert_eq!(fixed_header.packet_type, PacketType::PubAck);
267
268 let decoded = PubAckPacket::decode_body(&mut buf, &fixed_header).unwrap();
269 assert_eq!(decoded.packet_id, 789);
270 assert_eq!(decoded.reason_code, ReasonCode::Success);
271 }
272
273 #[test]
274 fn test_puback_encode_decode_with_reason() {
275 let packet = PubAckPacket::new_with_reason(999, ReasonCode::QuotaExceeded)
276 .with_user_property("quota".to_string(), "exceeded".to_string());
277
278 let mut buf = BytesMut::new();
279 packet.encode(&mut buf).unwrap();
280
281 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
282 let decoded = PubAckPacket::decode_body(&mut buf, &fixed_header).unwrap();
283
284 assert_eq!(decoded.packet_id, 999);
285 assert_eq!(decoded.reason_code, ReasonCode::QuotaExceeded);
286 assert!(decoded.properties.contains(PropertyId::UserProperty));
287 }
288
289 #[test]
290 fn test_puback_v311_style() {
291 let mut buf = BytesMut::new();
293 buf.put_u16(1234);
294
295 let fixed_header = FixedHeader::new(PacketType::PubAck, 0, 2);
296 let decoded = PubAckPacket::decode_body(&mut buf, &fixed_header).unwrap();
297
298 assert_eq!(decoded.packet_id, 1234);
299 assert_eq!(decoded.reason_code, ReasonCode::Success);
300 assert!(decoded.properties.is_empty());
301 }
302
303 #[test]
304 fn test_puback_invalid_reason_code() {
305 let mut buf = BytesMut::new();
306 buf.put_u16(123);
307 buf.put_u8(0xFF); let fixed_header = FixedHeader::new(PacketType::PubAck, 0, 3);
310 let result = PubAckPacket::decode_body(&mut buf, &fixed_header);
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn test_puback_missing_packet_id() {
316 let mut buf = BytesMut::new();
317 buf.put_u8(0); let fixed_header = FixedHeader::new(PacketType::PubAck, 0, 1);
320 let result = PubAckPacket::decode_body(&mut buf, &fixed_header);
321 assert!(result.is_err());
322 }
323}