mqtt5_protocol/packet/
connack.rs

1use crate::error::{MqttError, Result};
2use crate::flags::ConnAckFlags;
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
5use crate::types::ReasonCode;
6use bytes::{Buf, BufMut};
7
8/// MQTT CONNACK packet
9#[derive(Debug, Clone)]
10pub struct ConnAckPacket {
11    /// Session present flag
12    pub session_present: bool,
13    /// Connect reason code
14    pub reason_code: ReasonCode,
15    /// CONNACK properties (v5.0 only)
16    pub properties: Properties,
17    /// Protocol version (for encoding/decoding)
18    pub protocol_version: u8,
19}
20
21impl ConnAckPacket {
22    /// Creates a new CONNACK packet
23    #[must_use]
24    pub fn new(session_present: bool, reason_code: ReasonCode) -> Self {
25        Self {
26            session_present,
27            reason_code,
28            properties: Properties::default(),
29            protocol_version: 5,
30        }
31    }
32
33    /// Creates a new v3.1.1 CONNACK packet
34    #[must_use]
35    pub fn new_v311(session_present: bool, reason_code: ReasonCode) -> Self {
36        Self {
37            session_present,
38            reason_code,
39            properties: Properties::default(),
40            protocol_version: 4,
41        }
42    }
43
44    /// Sets the session expiry interval
45    #[must_use]
46    pub fn with_session_expiry_interval(mut self, interval: u32) -> Self {
47        self.properties.set_session_expiry_interval(interval);
48        self
49    }
50
51    /// Sets the receive maximum
52    #[must_use]
53    pub fn with_receive_maximum(mut self, max: u16) -> Self {
54        self.properties.set_receive_maximum(max);
55        self
56    }
57
58    /// Sets the maximum `QoS`
59    #[must_use]
60    pub fn with_maximum_qos(mut self, qos: u8) -> Self {
61        self.properties.set_maximum_qos(qos);
62        self
63    }
64
65    /// Sets whether retain is available
66    #[must_use]
67    pub fn with_retain_available(mut self, available: bool) -> Self {
68        self.properties.set_retain_available(available);
69        self
70    }
71
72    /// Sets the maximum packet size
73    #[must_use]
74    pub fn with_maximum_packet_size(mut self, size: u32) -> Self {
75        self.properties.set_maximum_packet_size(size);
76        self
77    }
78
79    /// Sets the assigned client identifier
80    #[must_use]
81    pub fn with_assigned_client_id(mut self, id: String) -> Self {
82        self.properties.set_assigned_client_identifier(id);
83        self
84    }
85
86    /// Sets the topic alias maximum
87    #[must_use]
88    pub fn with_topic_alias_maximum(mut self, max: u16) -> Self {
89        self.properties.set_topic_alias_maximum(max);
90        self
91    }
92
93    /// Sets the reason string
94    #[must_use]
95    pub fn with_reason_string(mut self, reason: String) -> Self {
96        self.properties.set_reason_string(reason);
97        self
98    }
99
100    /// Sets whether wildcards are available
101    #[must_use]
102    pub fn with_wildcard_subscription_available(mut self, available: bool) -> Self {
103        self.properties
104            .set_wildcard_subscription_available(available);
105        self
106    }
107
108    /// Sets whether subscription identifiers are available
109    #[must_use]
110    pub fn with_subscription_identifier_available(mut self, available: bool) -> Self {
111        self.properties
112            .set_subscription_identifier_available(available);
113        self
114    }
115
116    /// Sets whether shared subscriptions are available
117    #[must_use]
118    pub fn with_shared_subscription_available(mut self, available: bool) -> Self {
119        self.properties.set_shared_subscription_available(available);
120        self
121    }
122
123    /// Sets the server keep alive
124    #[must_use]
125    pub fn with_server_keep_alive(mut self, keep_alive: u16) -> Self {
126        self.properties.set_server_keep_alive(keep_alive);
127        self
128    }
129
130    /// Sets the response information
131    #[must_use]
132    pub fn with_response_information(mut self, info: String) -> Self {
133        self.properties.set_response_information(info);
134        self
135    }
136
137    /// Sets the server reference
138    #[must_use]
139    pub fn with_server_reference(mut self, reference: String) -> Self {
140        self.properties.set_server_reference(reference);
141        self
142    }
143
144    /// Sets the authentication method
145    #[must_use]
146    pub fn with_authentication_method(mut self, method: String) -> Self {
147        self.properties.set_authentication_method(method);
148        self
149    }
150
151    /// Sets the authentication data
152    #[must_use]
153    pub fn with_authentication_data(mut self, data: Vec<u8>) -> Self {
154        self.properties.set_authentication_data(data.into());
155        self
156    }
157
158    /// Adds a user property
159    #[must_use]
160    pub fn with_user_property(mut self, key: String, value: String) -> Self {
161        self.properties.add_user_property(key, value);
162        self
163    }
164
165    #[must_use]
166    /// Gets the topic alias maximum from properties
167    pub fn topic_alias_maximum(&self) -> Option<u16> {
168        self.properties
169            .get(PropertyId::TopicAliasMaximum)
170            .and_then(|prop| {
171                if let PropertyValue::TwoByteInteger(max) = prop {
172                    Some(*max)
173                } else {
174                    None
175                }
176            })
177    }
178
179    #[must_use]
180    /// Gets the receive maximum from properties
181    pub fn receive_maximum(&self) -> Option<u16> {
182        self.properties
183            .get(PropertyId::ReceiveMaximum)
184            .and_then(|prop| {
185                if let PropertyValue::TwoByteInteger(max) = prop {
186                    Some(*max)
187                } else {
188                    None
189                }
190            })
191    }
192
193    #[must_use]
194    /// Gets the maximum packet size from properties
195    pub fn maximum_packet_size(&self) -> Option<u32> {
196        self.properties
197            .get(PropertyId::MaximumPacketSize)
198            .and_then(|prop| {
199                if let PropertyValue::FourByteInteger(max) = prop {
200                    Some(*max)
201                } else {
202                    None
203                }
204            })
205    }
206
207    /// Validates the reason code for CONNACK
208    fn is_valid_connack_reason_code(code: ReasonCode) -> bool {
209        matches!(
210            code,
211            ReasonCode::Success
212                | ReasonCode::UnspecifiedError
213                | ReasonCode::MalformedPacket
214                | ReasonCode::ProtocolError
215                | ReasonCode::ImplementationSpecificError
216                | ReasonCode::UnsupportedProtocolVersion
217                | ReasonCode::ClientIdentifierNotValid
218                | ReasonCode::BadUsernameOrPassword
219                | ReasonCode::NotAuthorized
220                | ReasonCode::ServerUnavailable
221                | ReasonCode::ServerBusy
222                | ReasonCode::Banned
223                | ReasonCode::BadAuthenticationMethod
224                | ReasonCode::TopicNameInvalid
225                | ReasonCode::PacketTooLarge
226                | ReasonCode::QuotaExceeded
227                | ReasonCode::PayloadFormatInvalid
228                | ReasonCode::RetainNotSupported
229                | ReasonCode::QoSNotSupported
230                | ReasonCode::UseAnotherServer
231                | ReasonCode::ServerMoved
232                | ReasonCode::ConnectionRateExceeded
233        )
234    }
235}
236
237impl MqttPacket for ConnAckPacket {
238    fn packet_type(&self) -> PacketType {
239        PacketType::ConnAck
240    }
241
242    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
243        // Variable header
244        let flags = if self.session_present {
245            ConnAckFlags::SessionPresent as u8
246        } else {
247            0x00
248        };
249        buf.put_u8(flags);
250
251        if self.protocol_version == 4 {
252            // v3.1.1 - Return code only
253            let return_code = match self.reason_code {
254                ReasonCode::Success => 0x00,
255                ReasonCode::UnsupportedProtocolVersion => 0x01,
256                ReasonCode::ClientIdentifierNotValid => 0x02,
257                ReasonCode::ServerUnavailable => 0x03,
258                ReasonCode::BadUsernameOrPassword => 0x04,
259                _ => u8::from(ReasonCode::NotAuthorized), // Map other codes to not authorized
260            };
261            buf.put_u8(return_code);
262        } else {
263            // v5.0 - Reason code and properties
264            buf.put_u8(u8::from(self.reason_code));
265            self.properties.encode(buf)?;
266        }
267
268        Ok(())
269    }
270
271    fn decode_body<B: Buf>(buf: &mut B, _fixed_header: &FixedHeader) -> Result<Self> {
272        // Validate reserved bits - only bit 0 (session present) is valid
273        const RESERVED_BITS_MASK: u8 = 0xFE; // All bits except bit 0
274
275        // Acknowledge flags
276        if !buf.has_remaining() {
277            return Err(MqttError::MalformedPacket(
278                "Missing acknowledge flags".to_string(),
279            ));
280        }
281        let flags = buf.get_u8();
282
283        // Use BeBytes decomposition to parse flags
284        let decomposed_flags = ConnAckFlags::decompose(flags);
285        let session_present = decomposed_flags.contains(&ConnAckFlags::SessionPresent);
286
287        if (flags & RESERVED_BITS_MASK) != 0 {
288            return Err(MqttError::MalformedPacket(
289                "Invalid acknowledge flags - reserved bits must be 0".to_string(),
290            ));
291        }
292
293        // Reason code
294        if !buf.has_remaining() {
295            return Err(MqttError::MalformedPacket(
296                "Missing reason code".to_string(),
297            ));
298        }
299        let reason_byte = buf.get_u8();
300
301        // For v3.1.1, we need to determine protocol version from the reason code
302        let (reason_code, protocol_version) = if reason_byte <= 5 && !buf.has_remaining() {
303            // Likely v3.1.1 - map return codes to reason codes
304            let code = match reason_byte {
305                0x00 => ReasonCode::Success,
306                0x01 => ReasonCode::UnsupportedProtocolVersion,
307                0x02 => ReasonCode::ClientIdentifierNotValid,
308                0x03 => ReasonCode::ServerUnavailable,
309                0x04 => ReasonCode::BadUsernameOrPassword,
310                0x05 => ReasonCode::NotAuthorized,
311                _ => unreachable!(),
312            };
313            (code, 4)
314        } else {
315            // v5.0 - decode reason code
316            let code = ReasonCode::from_u8(reason_byte).ok_or_else(|| {
317                MqttError::MalformedPacket(format!("Invalid reason code: {reason_byte}"))
318            })?;
319
320            if !Self::is_valid_connack_reason_code(code) {
321                return Err(MqttError::MalformedPacket(format!(
322                    "Invalid CONNACK reason code: {code:?}"
323                )));
324            }
325
326            (code, 5)
327        };
328
329        // Properties (v5.0 only)
330        let properties = if protocol_version == 5 && buf.has_remaining() {
331            Properties::decode(buf)?
332        } else {
333            Properties::default()
334        };
335
336        Ok(Self {
337            session_present,
338            reason_code,
339            properties,
340            protocol_version,
341        })
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use bytes::BytesMut;
349
350    #[test]
351    fn test_connack_basic() {
352        let packet = ConnAckPacket::new(true, ReasonCode::Success);
353
354        assert!(packet.session_present);
355        assert_eq!(packet.reason_code, ReasonCode::Success);
356    }
357
358    #[test]
359    fn test_connack_with_properties() {
360        let packet = ConnAckPacket::new(false, ReasonCode::Success)
361            .with_session_expiry_interval(3600)
362            .with_receive_maximum(100)
363            .with_maximum_qos(1)
364            .with_retain_available(true)
365            .with_assigned_client_id("auto-123".to_string())
366            .with_user_property("foo".to_string(), "bar".to_string());
367
368        assert!(!packet.session_present);
369        assert!(packet
370            .properties
371            .get(PropertyId::SessionExpiryInterval)
372            .is_some());
373        assert!(packet.properties.get(PropertyId::ReceiveMaximum).is_some());
374        assert!(packet.properties.get(PropertyId::MaximumQoS).is_some());
375        assert!(packet.properties.get(PropertyId::RetainAvailable).is_some());
376        assert!(packet
377            .properties
378            .get(PropertyId::AssignedClientIdentifier)
379            .is_some());
380        assert!(packet.properties.get(PropertyId::UserProperty).is_some());
381    }
382
383    #[test]
384    fn test_connack_encode_decode_v5() {
385        let packet = ConnAckPacket::new(true, ReasonCode::Success)
386            .with_session_expiry_interval(7200)
387            .with_receive_maximum(50);
388
389        let mut buf = BytesMut::new();
390        packet.encode(&mut buf).unwrap();
391
392        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
393        assert_eq!(fixed_header.packet_type, PacketType::ConnAck);
394
395        let decoded = ConnAckPacket::decode_body(&mut buf, &fixed_header).unwrap();
396        assert!(decoded.session_present);
397        assert_eq!(decoded.reason_code, ReasonCode::Success);
398        assert_eq!(decoded.protocol_version, 5);
399
400        let session_expiry = decoded.properties.get(PropertyId::SessionExpiryInterval);
401        assert!(session_expiry.is_some());
402        if let Some(PropertyValue::FourByteInteger(val)) = session_expiry {
403            assert_eq!(*val, 7200);
404        } else {
405            panic!("Wrong property type");
406        }
407    }
408
409    #[test]
410    fn test_connack_encode_decode_v311() {
411        let packet = ConnAckPacket::new_v311(false, ReasonCode::BadUsernameOrPassword);
412
413        let mut buf = BytesMut::new();
414        packet.encode(&mut buf).unwrap();
415
416        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
417        let decoded = ConnAckPacket::decode_body(&mut buf, &fixed_header).unwrap();
418
419        assert!(!decoded.session_present);
420        assert_eq!(decoded.reason_code, ReasonCode::BadUsernameOrPassword);
421        assert_eq!(decoded.protocol_version, 4);
422    }
423
424    #[test]
425    fn test_connack_error_codes() {
426        let error_codes = vec![
427            ReasonCode::UnspecifiedError,
428            ReasonCode::MalformedPacket,
429            ReasonCode::ProtocolError,
430            ReasonCode::UnsupportedProtocolVersion,
431            ReasonCode::ClientIdentifierNotValid,
432            ReasonCode::BadUsernameOrPassword,
433            ReasonCode::NotAuthorized,
434            ReasonCode::ServerUnavailable,
435            ReasonCode::ServerBusy,
436            ReasonCode::Banned,
437        ];
438
439        for code in error_codes {
440            let packet = ConnAckPacket::new(false, code);
441            let mut buf = BytesMut::new();
442            packet.encode(&mut buf).unwrap();
443
444            let fixed_header = FixedHeader::decode(&mut buf).unwrap();
445            let decoded = ConnAckPacket::decode_body(&mut buf, &fixed_header).unwrap();
446            assert_eq!(decoded.reason_code, code);
447        }
448    }
449
450    #[test]
451    fn test_connack_invalid_flags() {
452        let mut buf = BytesMut::new();
453        buf.put_u8(0xFF); // Invalid flags - reserved bits set
454        buf.put_u8(0x00); // Success code
455
456        let fixed_header = FixedHeader::new(PacketType::ConnAck, 0, 0);
457        let result = ConnAckPacket::decode_body(&mut buf, &fixed_header);
458        assert!(result.is_err());
459    }
460
461    #[test]
462    fn test_connack_valid_reason_codes() {
463        assert!(ConnAckPacket::is_valid_connack_reason_code(
464            ReasonCode::Success
465        ));
466        assert!(ConnAckPacket::is_valid_connack_reason_code(
467            ReasonCode::NotAuthorized
468        ));
469        assert!(ConnAckPacket::is_valid_connack_reason_code(
470            ReasonCode::ServerBusy
471        ));
472
473        // Invalid CONNACK reason codes
474        assert!(!ConnAckPacket::is_valid_connack_reason_code(
475            ReasonCode::NoSubscriptionExisted
476        ));
477        assert!(!ConnAckPacket::is_valid_connack_reason_code(
478            ReasonCode::SubscriptionIdentifiersNotSupported
479        ));
480    }
481}