1use std::convert::{TryFrom, TryInto};
2
3use bytes::{BufMut, Bytes, BytesMut};
4
5use crate::MqttString;
6
7use super::{
8 len_len, length, read_mqtt_string, read_u32, read_u8, write_mqtt_string,
9 write_remaining_length, Buf, Debug, Error, FixedHeader, PacketType,
10};
11
12use super::{property, PropertyType};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(u8)]
16pub enum DisconnectReasonCode {
17 NormalDisconnection = 0x00,
19 DisconnectWithWillMessage = 0x04,
21 UnspecifiedError = 0x80,
23 MalformedPacket = 0x81,
25 ProtocolError = 0x82,
27 ImplementationSpecificError = 0x83,
29 NotAuthorized = 0x87,
31 ServerBusy = 0x89,
33 ServerShuttingDown = 0x8B,
35 KeepAliveTimeout = 0x8D,
37 SessionTakenOver = 0x8E,
39 TopicFilterInvalid = 0x8F,
41 TopicNameInvalid = 0x90,
43 ReceiveMaximumExceeded = 0x93,
45 TopicAliasInvalid = 0x94,
47 PacketTooLarge = 0x95,
49 MessageRateTooHigh = 0x96,
51 QuotaExceeded = 0x97,
53 AdministrativeAction = 0x98,
55 PayloadFormatInvalid = 0x99,
57 RetainNotSupported = 0x9A,
59 QoSNotSupported = 0x9B,
61 UseAnotherServer = 0x9C,
63 ServerMoved = 0x9D,
65 SharedSubscriptionNotSupported = 0x9E,
67 ConnectionRateExceeded = 0x9F,
69 MaximumConnectTime = 0xA0,
71 SubscriptionIdentifiersNotSupported = 0xA1,
73 WildcardSubscriptionsNotSupported = 0xA2,
75}
76
77impl TryFrom<u8> for DisconnectReasonCode {
78 type Error = Error;
79
80 fn try_from(value: u8) -> Result<Self, Self::Error> {
81 let rc = match value {
82 0x00 => Self::NormalDisconnection,
83 0x04 => Self::DisconnectWithWillMessage,
84 0x80 => Self::UnspecifiedError,
85 0x81 => Self::MalformedPacket,
86 0x82 => Self::ProtocolError,
87 0x83 => Self::ImplementationSpecificError,
88 0x87 => Self::NotAuthorized,
89 0x89 => Self::ServerBusy,
90 0x8B => Self::ServerShuttingDown,
91 0x8D => Self::KeepAliveTimeout,
92 0x8E => Self::SessionTakenOver,
93 0x8F => Self::TopicFilterInvalid,
94 0x90 => Self::TopicNameInvalid,
95 0x93 => Self::ReceiveMaximumExceeded,
96 0x94 => Self::TopicAliasInvalid,
97 0x95 => Self::PacketTooLarge,
98 0x96 => Self::MessageRateTooHigh,
99 0x97 => Self::QuotaExceeded,
100 0x98 => Self::AdministrativeAction,
101 0x99 => Self::PayloadFormatInvalid,
102 0x9A => Self::RetainNotSupported,
103 0x9B => Self::QoSNotSupported,
104 0x9C => Self::UseAnotherServer,
105 0x9D => Self::ServerMoved,
106 0x9E => Self::SharedSubscriptionNotSupported,
107 0x9F => Self::ConnectionRateExceeded,
108 0xA0 => Self::MaximumConnectTime,
109 0xA1 => Self::SubscriptionIdentifiersNotSupported,
110 0xA2 => Self::WildcardSubscriptionsNotSupported,
111 other => return Err(Error::InvalidConnectReturnCode(other)),
112 };
113
114 Ok(rc)
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct DisconnectProperties {
120 pub session_expiry_interval: Option<u32>,
122
123 pub reason_string: Option<MqttString>,
125
126 pub user_properties: Vec<(MqttString, MqttString)>,
128
129 pub server_reference: Option<MqttString>,
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Disconnect {
135 pub reason_code: DisconnectReasonCode,
137
138 pub properties: Option<DisconnectProperties>,
140}
141
142impl DisconnectProperties {
143 fn len(&self) -> usize {
144 let mut length = 0;
145
146 if self.session_expiry_interval.is_some() {
147 length += 1 + 4;
148 }
149
150 if let Some(reason) = &self.reason_string {
151 length += 1 + 2 + reason.len();
152 }
153
154 for (key, value) in &self.user_properties {
155 length += 1 + 2 + key.len() + 2 + value.len();
156 }
157
158 if let Some(server_reference) = &self.server_reference {
159 length += 1 + 2 + server_reference.len();
160 }
161
162 length
163 }
164
165 pub fn extract(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
166 let (properties_len_len, properties_len) = length(bytes.iter())?;
167
168 bytes.advance(properties_len_len);
169
170 if properties_len == 0 {
171 return Ok(None);
172 }
173
174 let mut session_expiry_interval = None;
175 let mut reason_string = None;
176 let mut user_properties = Vec::new();
177 let mut server_reference = None;
178
179 let mut cursor = 0;
180
181 while cursor < properties_len {
183 let prop = read_u8(bytes)?;
184 cursor += 1;
185
186 match property(prop)? {
187 PropertyType::SessionExpiryInterval => {
188 session_expiry_interval = Some(read_u32(bytes)?);
189 cursor += 4;
190 }
191 PropertyType::ReasonString => {
192 let reason = read_mqtt_string(bytes)?;
193 cursor += 2 + reason.len();
194 reason_string = Some(reason);
195 }
196 PropertyType::UserProperty => {
197 let key = read_mqtt_string(bytes)?;
198 let value = read_mqtt_string(bytes)?;
199 cursor += 2 + key.len() + 2 + value.len();
200 user_properties.push((key, value));
201 }
202 PropertyType::ServerReference => {
203 let reference = read_mqtt_string(bytes)?;
204 cursor += 2 + reference.len();
205 server_reference = Some(reference);
206 }
207 _ => return Err(Error::InvalidPropertyType(prop)),
208 }
209 }
210
211 let properties = Self {
212 session_expiry_interval,
213 reason_string,
214 user_properties,
215 server_reference,
216 };
217
218 Ok(Some(properties))
219 }
220
221 fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
222 let length = self.len();
223 write_remaining_length(buffer, length)?;
224
225 if let Some(session_expiry_interval) = self.session_expiry_interval {
226 buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
227 buffer.put_u32(session_expiry_interval);
228 }
229
230 if let Some(reason) = &self.reason_string {
231 buffer.put_u8(PropertyType::ReasonString as u8);
232 write_mqtt_string(buffer, reason)?;
233 }
234
235 for (key, value) in &self.user_properties {
236 buffer.put_u8(PropertyType::UserProperty as u8);
237 write_mqtt_string(buffer, key)?;
238 write_mqtt_string(buffer, value)?;
239 }
240
241 if let Some(reference) = &self.server_reference {
242 buffer.put_u8(PropertyType::ServerReference as u8);
243 write_mqtt_string(buffer, reference)?;
244 }
245
246 Ok(())
247 }
248}
249
250impl Disconnect {
251 #[must_use]
252 pub fn new(reason: DisconnectReasonCode) -> Self {
253 Self {
254 reason_code: reason,
255 properties: None,
256 }
257 }
258
259 fn len(&self) -> usize {
260 if self.reason_code == DisconnectReasonCode::NormalDisconnection
261 && self.properties.is_none()
262 {
263 return 2; }
265
266 let mut length = 0;
267
268 if let Some(properties) = &self.properties {
269 length += 1; let properties_len = properties.len();
272 let properties_len_len = len_len(properties_len);
273 length += properties_len_len + properties_len;
274 } else {
275 length += 1;
276 }
277
278 length
279 }
280
281 #[must_use]
282 pub fn size(&self) -> usize {
283 let len = self.len();
284 if len == 2 {
285 return len;
286 }
287
288 let remaining_len_size = len_len(len);
289
290 1 + remaining_len_size + len
291 }
292
293 pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
294 let packet_type = fixed_header.byte1 >> 4;
295 let flags = fixed_header.byte1 & 0b0000_1111;
296
297 bytes.advance(fixed_header.fixed_header_len);
298
299 if packet_type != PacketType::Disconnect as u8 {
300 return Err(Error::InvalidPacketType(packet_type));
301 };
302
303 if flags != 0x00 {
304 return Err(Error::MalformedPacket);
305 };
306
307 if fixed_header.remaining_len == 0 {
308 return Ok(Self::new(DisconnectReasonCode::NormalDisconnection));
309 }
310
311 let reason_code = read_u8(&mut bytes)?;
312
313 let disconnect = Self {
314 reason_code: reason_code.try_into()?,
315 properties: DisconnectProperties::extract(&mut bytes)?,
316 };
317
318 Ok(disconnect)
319 }
320
321 pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
322 buffer.put_u8(0xE0);
323
324 let length = self.len();
325
326 if length == 2 {
327 buffer.put_u8(0x00);
328 return Ok(length);
329 }
330
331 let len_len = write_remaining_length(buffer, length)?;
332
333 buffer.put_u8(self.reason_code as u8);
334
335 if let Some(properties) = &self.properties {
336 properties.write(buffer)?;
337 } else {
338 write_remaining_length(buffer, 0)?;
339 }
340
341 Ok(1 + len_len + length)
342 }
343}
344
345#[cfg(test)]
346mod test {
347 use super::{Disconnect, DisconnectProperties, DisconnectReasonCode};
348 use crate::parse_fixed_header;
349 use crate::test::read_write_packets;
350 use crate::Packet;
351 use bytes::BytesMut;
352
353 #[test]
354 fn disconnect1_parsing_works() {
355 let mut buffer = bytes::BytesMut::new();
356 let packet_bytes = [
357 0xE0, 0x00, ];
360 let expected = Disconnect::new(DisconnectReasonCode::NormalDisconnection);
361
362 buffer.extend_from_slice(&packet_bytes[..]);
363
364 let fixed_header = parse_fixed_header(buffer.iter()).unwrap();
365 let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze();
366 let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap();
367
368 assert_eq!(disconnect, expected);
369 }
370
371 #[test]
372 fn disconnect1_encoding_works() {
373 let mut buffer = BytesMut::new();
374 let disconnect = Disconnect::new(DisconnectReasonCode::NormalDisconnection);
375 let expected = [
376 0xE0, 0x00, ];
379
380 disconnect.write(&mut buffer).unwrap();
381
382 assert_eq!(&buffer[..], &expected);
383 }
384
385 fn sample2() -> Disconnect {
386 let properties = DisconnectProperties {
387 session_expiry_interval: Some(1234),
389 reason_string: Some("test".into()),
390 user_properties: vec![("test".into(), "test".into())],
391 server_reference: Some("test".into()),
392 };
393
394 Disconnect {
395 reason_code: DisconnectReasonCode::UnspecifiedError,
396 properties: Some(properties),
397 }
398 }
399
400 fn sample_bytes2() -> Vec<u8> {
401 vec![
402 0xE0, 0x22, 0x80, 0x20, 0x11, 0x00, 0x00, 0x04, 0xd2, 0x1F, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73,
409 0x74, 0x1C, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, ]
412 }
413
414 #[test]
415 fn disconnect2_parsing_works() {
416 let mut buffer = bytes::BytesMut::new();
417 let packet_bytes = sample_bytes2();
418 let expected = sample2();
419
420 buffer.extend_from_slice(&packet_bytes[..]);
421
422 let fixed_header = parse_fixed_header(buffer.iter()).unwrap();
423 let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze();
424 let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap();
425
426 assert_eq!(disconnect, expected);
427 }
428
429 #[test]
430 fn disconnect2_encoding_works() {
431 let mut buffer = BytesMut::new();
432
433 let disconnect = sample2();
434 let expected = sample_bytes2();
435
436 disconnect.write(&mut buffer).unwrap();
437
438 assert_eq!(&buffer[..], &expected);
439 }
440
441 use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
443 use pretty_assertions::assert_eq;
445
446 #[test]
447 fn length_calculation() {
448 let mut dummy_bytes = BytesMut::new();
449 let disconn_props = DisconnectProperties {
452 session_expiry_interval: None,
453 reason_string: None,
454 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
455 server_reference: None,
456 };
457
458 let mut disconn_pkt = Disconnect::new(DisconnectReasonCode::NormalDisconnection);
459 disconn_pkt.properties = Some(disconn_props);
460
461 let size_from_size = disconn_pkt.size();
462 let size_from_write = disconn_pkt.write(&mut dummy_bytes).unwrap();
463 let size_from_bytes = dummy_bytes.len();
464
465 assert_eq!(size_from_write, size_from_bytes);
466 assert_eq!(size_from_size, size_from_bytes);
467 }
468
469 #[test]
470 fn test_write_read() {
471 read_write_packets(write_read_provider());
472 }
473
474 fn write_read_provider() -> Vec<Packet> {
475 vec![
476 Packet::Disconnect(Disconnect::new(DisconnectReasonCode::NormalDisconnection)),
477 Packet::Disconnect(Disconnect {
478 reason_code: DisconnectReasonCode::UnspecifiedError,
479 properties: Some(DisconnectProperties {
480 session_expiry_interval: Some(1234),
481 reason_string: Some("test".into()),
482 user_properties: vec![("test".into(), "test".into())],
483 server_reference: Some("test".into()),
484 }),
485 }),
486 ]
487 }
488}