use crate::error::{MqttError, Result};
use crate::packet::{FixedHeader, MqttPacket, PacketType};
use crate::prelude::{format, String};
use crate::protocol::v5::properties::Properties;
use crate::protocol::v5::reason_codes::NORMAL_DISCONNECTION;
use crate::types::ReasonCode;
use bytes::{Buf, BufMut};
#[derive(Debug, Clone)]
pub struct DisconnectPacket {
pub reason_code: ReasonCode,
pub properties: Properties,
}
impl DisconnectPacket {
#[must_use]
pub fn new(reason_code: ReasonCode) -> Self {
Self {
reason_code,
properties: Properties::default(),
}
}
#[must_use]
pub fn normal() -> Self {
Self::new(NORMAL_DISCONNECTION)
}
#[must_use]
pub fn with_session_expiry_interval(mut self, seconds: u32) -> Self {
self.properties.set_session_expiry_interval(seconds);
self
}
#[must_use]
pub fn with_reason_string(mut self, reason: String) -> Self {
self.properties.set_reason_string(reason);
self
}
#[must_use]
pub fn with_server_reference(mut self, reference: String) -> Self {
self.properties.set_server_reference(reference);
self
}
#[must_use]
pub fn with_user_property(mut self, key: String, value: String) -> Self {
self.properties.add_user_property(key, value);
self
}
fn is_valid_disconnect_reason_code(code: ReasonCode) -> bool {
matches!(
code,
NORMAL_DISCONNECTION
| ReasonCode::DisconnectWithWillMessage
| ReasonCode::UnspecifiedError
| ReasonCode::MalformedPacket
| ReasonCode::ProtocolError
| ReasonCode::ImplementationSpecificError
| ReasonCode::NotAuthorized
| ReasonCode::ServerBusy
| ReasonCode::ServerShuttingDown
| ReasonCode::KeepAliveTimeout
| ReasonCode::SessionTakenOver
| ReasonCode::TopicFilterInvalid
| ReasonCode::TopicNameInvalid
| ReasonCode::ReceiveMaximumExceeded
| ReasonCode::TopicAliasInvalid
| ReasonCode::PacketTooLarge
| ReasonCode::MessageRateTooHigh
| ReasonCode::QuotaExceeded
| ReasonCode::AdministrativeAction
| ReasonCode::PayloadFormatInvalid
| ReasonCode::RetainNotSupported
| ReasonCode::QoSNotSupported
| ReasonCode::UseAnotherServer
| ReasonCode::ServerMoved
| ReasonCode::SharedSubscriptionsNotSupported
| ReasonCode::ConnectionRateExceeded
| ReasonCode::MaximumConnectTime
| ReasonCode::SubscriptionIdentifiersNotSupported
| ReasonCode::WildcardSubscriptionsNotSupported
)
}
}
impl MqttPacket for DisconnectPacket {
fn packet_type(&self) -> PacketType {
PacketType::Disconnect
}
fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
if self.reason_code != NORMAL_DISCONNECTION || !self.properties.is_empty() {
buf.put_u8(u8::from(self.reason_code));
if !self.properties.is_empty() {
self.properties.encode(buf)?;
}
}
Ok(())
}
fn decode_body<B: Buf>(buf: &mut B, _fixed_header: &FixedHeader) -> Result<Self> {
if !buf.has_remaining() {
return Ok(Self::normal());
}
let reason_byte = buf.get_u8();
let reason_code = ReasonCode::from_u8(reason_byte).ok_or_else(|| {
MqttError::MalformedPacket(format!("Invalid DISCONNECT reason code: {reason_byte}"))
})?;
if !Self::is_valid_disconnect_reason_code(reason_code) {
return Err(MqttError::MalformedPacket(format!(
"Invalid DISCONNECT reason code: {reason_code:?}"
)));
}
let properties = if buf.has_remaining() {
Properties::decode(buf)?
} else {
Properties::default()
};
Ok(Self {
reason_code,
properties,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::v5::properties::PropertyId;
use bytes::BytesMut;
#[test]
fn test_disconnect_normal() {
let packet = DisconnectPacket::normal();
assert_eq!(packet.reason_code, NORMAL_DISCONNECTION);
assert!(packet.properties.is_empty());
}
#[test]
fn test_disconnect_with_reason() {
let packet = DisconnectPacket::new(ReasonCode::ServerShuttingDown)
.with_reason_string("Maintenance mode".to_string());
assert_eq!(packet.reason_code, ReasonCode::ServerShuttingDown);
assert!(packet.properties.contains(PropertyId::ReasonString));
}
#[test]
fn test_disconnect_encode_decode_normal() {
let packet = DisconnectPacket::normal();
let mut buf = BytesMut::new();
packet.encode(&mut buf).unwrap();
let fixed_header = FixedHeader::decode(&mut buf).unwrap();
assert_eq!(fixed_header.packet_type, PacketType::Disconnect);
let decoded = DisconnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded.reason_code, NORMAL_DISCONNECTION);
}
#[test]
fn test_disconnect_encode_decode_with_properties() {
let packet = DisconnectPacket::new(ReasonCode::SessionTakenOver)
.with_session_expiry_interval(0)
.with_reason_string("Another client connected".to_string());
let mut buf = BytesMut::new();
packet.encode(&mut buf).unwrap();
let fixed_header = FixedHeader::decode(&mut buf).unwrap();
let decoded = DisconnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded.reason_code, ReasonCode::SessionTakenOver);
assert!(decoded
.properties
.contains(PropertyId::SessionExpiryInterval));
assert!(decoded.properties.contains(PropertyId::ReasonString));
}
#[test]
fn test_disconnect_v311_style() {
let mut buf = BytesMut::new();
let fixed_header = FixedHeader::new(PacketType::Disconnect, 0, 0);
let decoded = DisconnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded.reason_code, NORMAL_DISCONNECTION);
assert!(decoded.properties.is_empty());
}
}