1use std::convert::{TryFrom, TryInto};
2
3use bytes::{BufMut, Bytes, BytesMut};
4
5use crate::MqttString;
6
7use super::{
8    len_len, length, read_mqtt_string, read_u32, read_u8, write_mqtt_string,
9    write_remaining_length, Buf, Debug, Error, FixedHeader, PacketType,
10};
11
12use super::{property, PropertyType};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(u8)]
16pub enum DisconnectReasonCode {
17    NormalDisconnection = 0x00,
19    DisconnectWithWillMessage = 0x04,
21    UnspecifiedError = 0x80,
23    MalformedPacket = 0x81,
25    ProtocolError = 0x82,
27    ImplementationSpecificError = 0x83,
29    NotAuthorized = 0x87,
31    ServerBusy = 0x89,
33    ServerShuttingDown = 0x8B,
35    KeepAliveTimeout = 0x8D,
37    SessionTakenOver = 0x8E,
39    TopicFilterInvalid = 0x8F,
41    TopicNameInvalid = 0x90,
43    ReceiveMaximumExceeded = 0x93,
45    TopicAliasInvalid = 0x94,
47    PacketTooLarge = 0x95,
49    MessageRateTooHigh = 0x96,
51    QuotaExceeded = 0x97,
53    AdministrativeAction = 0x98,
55    PayloadFormatInvalid = 0x99,
57    RetainNotSupported = 0x9A,
59    QoSNotSupported = 0x9B,
61    UseAnotherServer = 0x9C,
63    ServerMoved = 0x9D,
65    SharedSubscriptionNotSupported = 0x9E,
67    ConnectionRateExceeded = 0x9F,
69    MaximumConnectTime = 0xA0,
71    SubscriptionIdentifiersNotSupported = 0xA1,
73    WildcardSubscriptionsNotSupported = 0xA2,
75}
76
77impl TryFrom<u8> for DisconnectReasonCode {
78    type Error = Error;
79
80    fn try_from(value: u8) -> Result<Self, Self::Error> {
81        let rc = match value {
82            0x00 => Self::NormalDisconnection,
83            0x04 => Self::DisconnectWithWillMessage,
84            0x80 => Self::UnspecifiedError,
85            0x81 => Self::MalformedPacket,
86            0x82 => Self::ProtocolError,
87            0x83 => Self::ImplementationSpecificError,
88            0x87 => Self::NotAuthorized,
89            0x89 => Self::ServerBusy,
90            0x8B => Self::ServerShuttingDown,
91            0x8D => Self::KeepAliveTimeout,
92            0x8E => Self::SessionTakenOver,
93            0x8F => Self::TopicFilterInvalid,
94            0x90 => Self::TopicNameInvalid,
95            0x93 => Self::ReceiveMaximumExceeded,
96            0x94 => Self::TopicAliasInvalid,
97            0x95 => Self::PacketTooLarge,
98            0x96 => Self::MessageRateTooHigh,
99            0x97 => Self::QuotaExceeded,
100            0x98 => Self::AdministrativeAction,
101            0x99 => Self::PayloadFormatInvalid,
102            0x9A => Self::RetainNotSupported,
103            0x9B => Self::QoSNotSupported,
104            0x9C => Self::UseAnotherServer,
105            0x9D => Self::ServerMoved,
106            0x9E => Self::SharedSubscriptionNotSupported,
107            0x9F => Self::ConnectionRateExceeded,
108            0xA0 => Self::MaximumConnectTime,
109            0xA1 => Self::SubscriptionIdentifiersNotSupported,
110            0xA2 => Self::WildcardSubscriptionsNotSupported,
111            other => return Err(Error::InvalidConnectReturnCode(other)),
112        };
113
114        Ok(rc)
115    }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct DisconnectProperties {
120    pub session_expiry_interval: Option<u32>,
122
123    pub reason_string: Option<MqttString>,
125
126    pub user_properties: Vec<(MqttString, MqttString)>,
128
129    pub server_reference: Option<MqttString>,
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Disconnect {
135    pub reason_code: DisconnectReasonCode,
137
138    pub properties: Option<DisconnectProperties>,
140}
141
142impl DisconnectProperties {
143    fn len(&self) -> usize {
144        let mut length = 0;
145
146        if self.session_expiry_interval.is_some() {
147            length += 1 + 4;
148        }
149
150        if let Some(reason) = &self.reason_string {
151            length += 1 + 2 + reason.len();
152        }
153
154        for (key, value) in &self.user_properties {
155            length += 1 + 2 + key.len() + 2 + value.len();
156        }
157
158        if let Some(server_reference) = &self.server_reference {
159            length += 1 + 2 + server_reference.len();
160        }
161
162        length
163    }
164
165    pub fn extract(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
166        let (properties_len_len, properties_len) = length(bytes.iter())?;
167
168        bytes.advance(properties_len_len);
169
170        if properties_len == 0 {
171            return Ok(None);
172        }
173
174        let mut session_expiry_interval = None;
175        let mut reason_string = None;
176        let mut user_properties = Vec::new();
177        let mut server_reference = None;
178
179        let mut cursor = 0;
180
181        while cursor < properties_len {
183            let prop = read_u8(bytes)?;
184            cursor += 1;
185
186            match property(prop)? {
187                PropertyType::SessionExpiryInterval => {
188                    session_expiry_interval = Some(read_u32(bytes)?);
189                    cursor += 4;
190                }
191                PropertyType::ReasonString => {
192                    let reason = read_mqtt_string(bytes)?;
193                    cursor += 2 + reason.len();
194                    reason_string = Some(reason);
195                }
196                PropertyType::UserProperty => {
197                    let key = read_mqtt_string(bytes)?;
198                    let value = read_mqtt_string(bytes)?;
199                    cursor += 2 + key.len() + 2 + value.len();
200                    user_properties.push((key, value));
201                }
202                PropertyType::ServerReference => {
203                    let reference = read_mqtt_string(bytes)?;
204                    cursor += 2 + reference.len();
205                    server_reference = Some(reference);
206                }
207                _ => return Err(Error::InvalidPropertyType(prop)),
208            }
209        }
210
211        let properties = Self {
212            session_expiry_interval,
213            reason_string,
214            user_properties,
215            server_reference,
216        };
217
218        Ok(Some(properties))
219    }
220
221    fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
222        let length = self.len();
223        write_remaining_length(buffer, length)?;
224
225        if let Some(session_expiry_interval) = self.session_expiry_interval {
226            buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
227            buffer.put_u32(session_expiry_interval);
228        }
229
230        if let Some(reason) = &self.reason_string {
231            buffer.put_u8(PropertyType::ReasonString as u8);
232            write_mqtt_string(buffer, reason)?;
233        }
234
235        for (key, value) in &self.user_properties {
236            buffer.put_u8(PropertyType::UserProperty as u8);
237            write_mqtt_string(buffer, key)?;
238            write_mqtt_string(buffer, value)?;
239        }
240
241        if let Some(reference) = &self.server_reference {
242            buffer.put_u8(PropertyType::ServerReference as u8);
243            write_mqtt_string(buffer, reference)?;
244        }
245
246        Ok(())
247    }
248}
249
250impl Disconnect {
251    #[must_use]
252    pub fn new(reason: DisconnectReasonCode) -> Self {
253        Self {
254            reason_code: reason,
255            properties: None,
256        }
257    }
258
259    fn len(&self) -> usize {
260        if self.reason_code == DisconnectReasonCode::NormalDisconnection
261            && self.properties.is_none()
262        {
263            return 2; }
265
266        let mut length = 0;
267
268        if let Some(properties) = &self.properties {
269            length += 1; let properties_len = properties.len();
272            let properties_len_len = len_len(properties_len);
273            length += properties_len_len + properties_len;
274        } else {
275            length += 1;
276        }
277
278        length
279    }
280
281    #[must_use]
282    pub fn size(&self) -> usize {
283        let len = self.len();
284        if len == 2 {
285            return len;
286        }
287
288        let remaining_len_size = len_len(len);
289
290        1 + remaining_len_size + len
291    }
292
293    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
294        let packet_type = fixed_header.byte1 >> 4;
295        let flags = fixed_header.byte1 & 0b0000_1111;
296
297        bytes.advance(fixed_header.fixed_header_len);
298
299        if packet_type != PacketType::Disconnect as u8 {
300            return Err(Error::InvalidPacketType(packet_type));
301        };
302
303        if flags != 0x00 {
304            return Err(Error::MalformedPacket);
305        };
306
307        if fixed_header.remaining_len == 0 {
308            return Ok(Self::new(DisconnectReasonCode::NormalDisconnection));
309        }
310
311        let reason_code = read_u8(&mut bytes)?;
312
313        let disconnect = Self {
314            reason_code: reason_code.try_into()?,
315            properties: DisconnectProperties::extract(&mut bytes)?,
316        };
317
318        Ok(disconnect)
319    }
320
321    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
322        buffer.put_u8(0xE0);
323
324        let length = self.len();
325
326        if length == 2 {
327            buffer.put_u8(0x00);
328            return Ok(length);
329        }
330
331        let len_len = write_remaining_length(buffer, length)?;
332
333        buffer.put_u8(self.reason_code as u8);
334
335        if let Some(properties) = &self.properties {
336            properties.write(buffer)?;
337        } else {
338            write_remaining_length(buffer, 0)?;
339        }
340
341        Ok(1 + len_len + length)
342    }
343}
344
345#[cfg(test)]
346mod test {
347    use super::{Disconnect, DisconnectProperties, DisconnectReasonCode};
348    use crate::parse_fixed_header;
349    use crate::test::read_write_packets;
350    use crate::Packet;
351    use bytes::BytesMut;
352
353    #[test]
354    fn disconnect1_parsing_works() {
355        let mut buffer = bytes::BytesMut::new();
356        let packet_bytes = [
357            0xE0, 0x00, ];
360        let expected = Disconnect::new(DisconnectReasonCode::NormalDisconnection);
361
362        buffer.extend_from_slice(&packet_bytes[..]);
363
364        let fixed_header = parse_fixed_header(buffer.iter()).unwrap();
365        let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze();
366        let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap();
367
368        assert_eq!(disconnect, expected);
369    }
370
371    #[test]
372    fn disconnect1_encoding_works() {
373        let mut buffer = BytesMut::new();
374        let disconnect = Disconnect::new(DisconnectReasonCode::NormalDisconnection);
375        let expected = [
376            0xE0, 0x00, ];
379
380        disconnect.write(&mut buffer).unwrap();
381
382        assert_eq!(&buffer[..], &expected);
383    }
384
385    fn sample2() -> Disconnect {
386        let properties = DisconnectProperties {
387            session_expiry_interval: Some(1234),
389            reason_string: Some("test".into()),
390            user_properties: vec![("test".into(), "test".into())],
391            server_reference: Some("test".into()),
392        };
393
394        Disconnect {
395            reason_code: DisconnectReasonCode::UnspecifiedError,
396            properties: Some(properties),
397        }
398    }
399
400    fn sample_bytes2() -> Vec<u8> {
401        vec![
402            0xE0, 0x22, 0x80, 0x20, 0x11, 0x00, 0x00, 0x04, 0xd2, 0x1F, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73,
409            0x74, 0x1C, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, ]
412    }
413
414    #[test]
415    fn disconnect2_parsing_works() {
416        let mut buffer = bytes::BytesMut::new();
417        let packet_bytes = sample_bytes2();
418        let expected = sample2();
419
420        buffer.extend_from_slice(&packet_bytes[..]);
421
422        let fixed_header = parse_fixed_header(buffer.iter()).unwrap();
423        let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze();
424        let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap();
425
426        assert_eq!(disconnect, expected);
427    }
428
429    #[test]
430    fn disconnect2_encoding_works() {
431        let mut buffer = BytesMut::new();
432
433        let disconnect = sample2();
434        let expected = sample_bytes2();
435
436        disconnect.write(&mut buffer).unwrap();
437
438        assert_eq!(&buffer[..], &expected);
439    }
440
441    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
443    use pretty_assertions::assert_eq;
445
446    #[test]
447    fn length_calculation() {
448        let mut dummy_bytes = BytesMut::new();
449        let disconn_props = DisconnectProperties {
452            session_expiry_interval: None,
453            reason_string: None,
454            user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
455            server_reference: None,
456        };
457
458        let mut disconn_pkt = Disconnect::new(DisconnectReasonCode::NormalDisconnection);
459        disconn_pkt.properties = Some(disconn_props);
460
461        let size_from_size = disconn_pkt.size();
462        let size_from_write = disconn_pkt.write(&mut dummy_bytes).unwrap();
463        let size_from_bytes = dummy_bytes.len();
464
465        assert_eq!(size_from_write, size_from_bytes);
466        assert_eq!(size_from_size, size_from_bytes);
467    }
468
469    #[test]
470    fn test_write_read() {
471        read_write_packets(write_read_provider());
472    }
473
474    fn write_read_provider() -> Vec<Packet> {
475        vec![
476            Packet::Disconnect(Disconnect::new(DisconnectReasonCode::NormalDisconnection)),
477            Packet::Disconnect(Disconnect {
478                reason_code: DisconnectReasonCode::UnspecifiedError,
479                properties: Some(DisconnectProperties {
480                    session_expiry_interval: Some(1234),
481                    reason_string: Some("test".into()),
482                    user_properties: vec![("test".into(), "test".into())],
483                    server_reference: Some("test".into()),
484                }),
485            }),
486        ]
487    }
488}