Skip to main content

rumqttc/mqttbytes/v4/
connack.rs

1use super::{Error, FixedHeader, len_len, read_u8, write_remaining_length};
2use bytes::{Buf, BufMut, Bytes, BytesMut};
3
4/// Return code in connack
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6#[repr(u8)]
7pub enum ConnectReturnCode {
8    Success = 0,
9    RefusedProtocolVersion,
10    BadClientId,
11    ServiceUnavailable,
12    BadUserNamePassword,
13    NotAuthorized,
14}
15
16/// Acknowledgement to connect packet
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ConnAck {
19    pub session_present: bool,
20    pub code: ConnectReturnCode,
21}
22
23impl ConnAck {
24    #[must_use]
25    pub const fn new(code: ConnectReturnCode, session_present: bool) -> Self {
26        Self {
27            session_present,
28            code,
29        }
30    }
31
32    const fn len() -> usize {
33        // sesssion present + code
34
35        1 + 1
36    }
37
38    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
39        if fixed_header.remaining_len != Self::len() {
40            return Err(Error::PayloadSizeIncorrect);
41        }
42
43        let variable_header_index = fixed_header.header_len;
44        bytes.advance(variable_header_index);
45
46        let flags = read_u8(&mut bytes)?;
47        let return_code = read_u8(&mut bytes)?;
48
49        if (flags & 0xFE) != 0 {
50            return Err(Error::IncorrectPacketFormat);
51        }
52
53        let session_present = (flags & 0x01) == 1;
54        let code = connect_return(return_code)?;
55        if code != ConnectReturnCode::Success && session_present {
56            return Err(Error::IncorrectPacketFormat);
57        }
58
59        let connack = Self {
60            session_present,
61            code,
62        };
63
64        Ok(connack)
65    }
66
67    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
68        let len = Self::len();
69        buffer.put_u8(0x20);
70
71        let count = write_remaining_length(buffer, len)?;
72        buffer.put_u8(u8::from(self.session_present));
73        buffer.put_u8(self.code as u8);
74
75        Ok(1 + count + len)
76    }
77
78    #[must_use]
79    pub const fn size(&self) -> usize {
80        let len = Self::len();
81        let remaining_len_size = len_len(len);
82
83        1 + remaining_len_size + len
84    }
85}
86
87/// Connection return code type
88const fn connect_return(num: u8) -> Result<ConnectReturnCode, Error> {
89    match num {
90        0 => Ok(ConnectReturnCode::Success),
91        1 => Ok(ConnectReturnCode::RefusedProtocolVersion),
92        2 => Ok(ConnectReturnCode::BadClientId),
93        3 => Ok(ConnectReturnCode::ServiceUnavailable),
94        4 => Ok(ConnectReturnCode::BadUserNamePassword),
95        5 => Ok(ConnectReturnCode::NotAuthorized),
96        num => Err(Error::InvalidConnectReturnCode(num)),
97    }
98}
99
100#[cfg(test)]
101mod test {
102    use super::*;
103    use crate::mqttbytes::parse_fixed_header;
104    use bytes::BytesMut;
105    use pretty_assertions::assert_eq;
106
107    #[test]
108    fn connack_parsing_works() {
109        let mut stream = bytes::BytesMut::new();
110        let packetstream = &[
111            0b0010_0000,
112            0x02, // packet type, flags and remaining len
113            0x01,
114            0x00, // variable header. connack flags, connect return code
115            0xDE,
116            0xAD,
117            0xBE,
118            0xEF, // extra packets in the stream
119        ];
120
121        stream.extend_from_slice(&packetstream[..]);
122        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
123        let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
124        let connack = ConnAck::read(fixed_header, connack_bytes).unwrap();
125
126        assert_eq!(
127            connack,
128            ConnAck {
129                session_present: true,
130                code: ConnectReturnCode::Success,
131            }
132        );
133    }
134
135    #[test]
136    fn connack_encoding_works() {
137        let connack = ConnAck {
138            session_present: true,
139            code: ConnectReturnCode::Success,
140        };
141
142        let mut buf = BytesMut::new();
143        connack.write(&mut buf).unwrap();
144        assert_eq!(buf, vec![0b0010_0000, 0x02, 0x01, 0x00]);
145    }
146
147    #[test]
148    fn connack_parsing_rejects_invalid_remaining_len() {
149        let mut stream = bytes::BytesMut::new();
150        let packetstream = &[0b0010_0000, 0x03, 0x00, 0x00, 0x00];
151        stream.extend_from_slice(packetstream);
152        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
153        let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
154        let connack = ConnAck::read(fixed_header, connack_bytes);
155        assert!(matches!(connack, Err(Error::PayloadSizeIncorrect)));
156    }
157
158    #[test]
159    fn connack_parsing_rejects_reserved_flag_bits() {
160        let mut stream = bytes::BytesMut::new();
161        let packetstream = &[0b0010_0000, 0x02, 0x02, 0x00];
162        stream.extend_from_slice(packetstream);
163        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
164        let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
165        let connack = ConnAck::read(fixed_header, connack_bytes);
166        assert!(matches!(connack, Err(Error::IncorrectPacketFormat)));
167    }
168
169    #[test]
170    fn connack_parsing_rejects_session_present_on_error() {
171        let mut stream = bytes::BytesMut::new();
172        let packetstream = &[0b0010_0000, 0x02, 0x01, 0x01];
173        stream.extend_from_slice(packetstream);
174        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
175        let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
176        let connack = ConnAck::read(fixed_header, connack_bytes);
177        assert!(matches!(connack, Err(Error::IncorrectPacketFormat)));
178    }
179}