use bitflags::bitflags;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use super::{
error::{DeserializeError, SerializeError},
mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite},
read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType,
ProtocolVersion, QoS, WireLength,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Connect {
pub protocol_version: ProtocolVersion,
pub clean_session: bool,
pub last_will: Option<LastWill>,
pub username: Option<String>,
pub password: Option<String>,
pub keep_alive: u16,
pub connect_properties: ConnectProperties,
pub client_id: String,
}
impl Default for Connect {
fn default() -> Self {
Self {
protocol_version: ProtocolVersion::V5,
clean_session: true,
last_will: None,
username: None,
password: None,
keep_alive: 60,
connect_properties: ConnectProperties::default(),
client_id: "MQRSTT".to_string(),
}
}
}
impl VariableHeaderRead for Connect {
fn read(_: u8, _: usize, mut buf: Bytes) -> Result<Self, DeserializeError> {
if String::read(&mut buf)? != "MQTT" {
return Err(DeserializeError::MalformedPacketWithInfo(
"Protocol not MQTT".to_string(),
));
}
let protocol_version = ProtocolVersion::read(&mut buf)?;
let connect_flags_byte = buf.get_u8();
let connect_flags = ConnectFlags::from_bits(connect_flags_byte)
.ok_or("Can't read ConnectFlags".to_string())?;
let clean_session = connect_flags.contains(ConnectFlags::CLEAN_START);
let keep_alive = buf.get_u16();
let connect_properties = ConnectProperties::read(&mut buf)?;
let client_id = String::read(&mut buf)?;
let mut last_will = None;
if connect_flags.contains(ConnectFlags::WILL_FLAG) {
let retain = connect_flags.contains(ConnectFlags::WILL_RETAIN);
let qos = QoS::try_from(connect_flags)?;
last_will = Some(LastWill::read(qos, retain, &mut buf)?);
}
let username = if connect_flags.contains(ConnectFlags::USERNAME) {
Some(String::read(&mut buf)?)
}
else {
None
};
let password = if connect_flags.contains(ConnectFlags::PASSWORD) {
Some(String::read(&mut buf)?)
}
else {
None
};
let connect = Connect {
protocol_version,
clean_session,
last_will,
username,
password,
keep_alive,
connect_properties,
client_id,
};
Ok(connect)
}
}
impl VariableHeaderWrite for Connect {
fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
"MQTT".write(buf)?;
self.protocol_version.write(buf)?;
let mut connect_flags = ConnectFlags::empty();
if self.clean_session {
connect_flags |= ConnectFlags::CLEAN_START;
}
if let Some(last_will) = &self.last_will {
connect_flags |= ConnectFlags::WILL_FLAG;
if last_will.retain {
connect_flags |= ConnectFlags::WILL_RETAIN;
}
connect_flags |= last_will.qos.into();
}
if self.username.is_some() {
connect_flags |= ConnectFlags::USERNAME;
}
if self.password.is_some() {
connect_flags |= ConnectFlags::PASSWORD;
}
buf.put_u8(connect_flags.bits());
buf.put_u16(self.keep_alive);
self.connect_properties.write(buf)?;
self.client_id.write(buf)?;
if let Some(last_will) = &self.last_will {
last_will.write(buf)?;
}
if let Some(username) = &self.username {
username.write(buf)?;
}
if let Some(password) = &self.password {
password.write(buf)?;
}
Ok(())
}
}
impl WireLength for Connect {
fn wire_len(&self) -> usize {
let mut len = "MQTT".wire_len() + 1 + 1 + 2;
len += variable_integer_len(self.connect_properties.wire_len());
len += self.connect_properties.wire_len();
if let Some(last_will) = &self.last_will {
len += last_will.wire_len();
}
if let Some(username) = &self.username {
len += username.wire_len()
}
if let Some(password) = &self.password {
len += password.wire_len()
}
len += self.client_id.wire_len();
len
}
}
bitflags! {
pub struct ConnectFlags: u8 {
const CLEAN_START = 0b00000010;
const WILL_FLAG = 0b00000100;
const WILL_QOS1 = 0b00001000;
const WILL_QOS2 = 0b00010000;
const WILL_RETAIN = 0b00100000;
const PASSWORD = 0b01000000;
const USERNAME = 0b10000000;
}
}
impl From<QoS> for ConnectFlags {
fn from(q: QoS) -> Self {
match q {
QoS::AtMostOnce => ConnectFlags::empty(),
QoS::AtLeastOnce => ConnectFlags::WILL_QOS1,
QoS::ExactlyOnce => ConnectFlags::WILL_QOS2,
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ConnectProperties {
pub session_expiry_interval: Option<u32>,
pub receive_maximum: Option<u16>,
pub maximum_packet_size: Option<u32>,
pub topic_alias_maximum: Option<u16>,
pub request_response_information: Option<u8>,
pub request_problem_information: Option<u8>,
pub user_properties: Vec<(String, String)>,
pub authentication_method: Option<String>,
pub authentication_data: Bytes,
}
impl MqttWrite for ConnectProperties {
fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
write_variable_integer(buf, self.wire_len())?;
if let Some(session_expiry_interval) = self.session_expiry_interval {
PropertyType::SessionExpiryInterval.write(buf)?;
buf.put_u32(session_expiry_interval);
}
if let Some(receive_maximum) = self.receive_maximum {
PropertyType::ReceiveMaximum.write(buf)?;
buf.put_u16(receive_maximum);
}
if let Some(maximum_packet_size) = self.maximum_packet_size {
PropertyType::MaximumPacketSize.write(buf)?;
buf.put_u32(maximum_packet_size);
}
if let Some(topic_alias_maximum) = self.topic_alias_maximum {
PropertyType::TopicAliasMaximum.write(buf)?;
buf.put_u16(topic_alias_maximum);
}
if let Some(request_response_information) = self.request_response_information {
PropertyType::RequestResponseInformation.write(buf)?;
buf.put_u8(request_response_information);
}
if let Some(request_problem_information) = self.request_problem_information {
PropertyType::RequestProblemInformation.write(buf)?;
buf.put_u8(request_problem_information);
}
for (key, value) in &self.user_properties {
PropertyType::UserProperty.write(buf)?;
key.write(buf)?;
value.write(buf)?;
}
if let Some(authentication_method) = &self.authentication_method {
PropertyType::AuthenticationMethod.write(buf)?;
authentication_method.write(buf)?;
}
if !self.authentication_data.is_empty() {
if self.authentication_method.is_none() {
return Err(SerializeError::AuthDataWithoutAuthMethod);
}
PropertyType::AuthenticationData.write(buf)?;
self.authentication_data.write(buf)?;
}
Ok(())
}
}
impl MqttRead for ConnectProperties {
fn read(buf: &mut Bytes) -> Result<Self, DeserializeError> {
let (len, _) = read_variable_integer(buf)?;
let mut properties = Self::default();
if len == 0 {
return Ok(properties);
}
else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"ConnectProperties".to_string(),
buf.len(),
len,
));
}
let mut property_data = buf.split_to(len);
loop {
match PropertyType::read(&mut property_data)? {
PropertyType::SessionExpiryInterval => {
if properties.session_expiry_interval.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::SessionExpiryInterval,
));
}
properties.session_expiry_interval = Some(property_data.get_u32());
}
PropertyType::ReceiveMaximum => {
if properties.receive_maximum.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::ReceiveMaximum,
));
}
properties.receive_maximum = Some(property_data.get_u16());
}
PropertyType::MaximumPacketSize => {
if properties.maximum_packet_size.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::MaximumPacketSize,
));
}
properties.maximum_packet_size = Some(property_data.get_u32());
}
PropertyType::TopicAliasMaximum => {
if properties.topic_alias_maximum.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::TopicAliasMaximum,
));
}
properties.topic_alias_maximum = Some(property_data.get_u16());
}
PropertyType::RequestResponseInformation => {
if properties.request_response_information.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::RequestResponseInformation,
));
}
properties.request_response_information = Some(property_data.get_u8());
}
PropertyType::RequestProblemInformation => {
if properties.request_problem_information.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::RequestProblemInformation,
));
}
properties.request_problem_information = Some(property_data.get_u8());
}
PropertyType::UserProperty => properties.user_properties.push((
String::read(&mut property_data)?,
String::read(&mut property_data)?,
)),
PropertyType::AuthenticationMethod => {
if properties.authentication_method.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::AuthenticationMethod,
));
}
properties.authentication_method = Some(String::read(&mut property_data)?);
}
PropertyType::AuthenticationData => {
if properties.authentication_data.is_empty() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::AuthenticationData,
));
}
properties.authentication_data = Bytes::read(&mut property_data)?;
}
e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)),
}
if property_data.is_empty() {
break;
}
}
if !properties.authentication_data.is_empty() && properties.authentication_method.is_none()
{
return Err(DeserializeError::MalformedPacketWithInfo(
"Authentication data is not empty while authentication method is".to_string(),
));
}
Ok(properties)
}
}
impl WireLength for ConnectProperties {
fn wire_len(&self) -> usize {
let mut len: usize = 0;
if self.session_expiry_interval.is_some() {
len += 1 + 4;
}
if self.receive_maximum.is_some() {
len += 1 + 2;
}
if self.maximum_packet_size.is_some() {
len += 1 + 4;
}
if self.topic_alias_maximum.is_some() {
len += 1 + 2;
}
if self.request_response_information.is_some() {
len += 2;
}
if self.request_problem_information.is_some() {
len += 2;
}
for (key, value) in &self.user_properties {
len += 1;
len += key.wire_len();
len += value.wire_len();
}
if let Some(authentication_method) = &self.authentication_method {
len += 1 + authentication_method.wire_len();
}
if !self.authentication_data.is_empty() && self.authentication_method.is_some() {
len += 1 + self.authentication_data.wire_len();
}
len
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LastWill {
pub qos: QoS,
pub retain: bool,
pub last_will_properties: LastWillProperties,
pub topic: String,
pub payload: Bytes,
}
impl LastWill {
pub fn new<T: Into<String>, P: Into<Vec<u8>>>(
qos: QoS,
retain: bool,
topic: T,
payload: P,
) -> LastWill {
Self {
qos,
retain,
last_will_properties: LastWillProperties::default(),
topic: topic.into(),
payload: Bytes::from(payload.into()),
}
}
pub fn read(qos: QoS, retain: bool, buf: &mut Bytes) -> Result<Self, DeserializeError> {
let last_will_properties = LastWillProperties::read(buf)?;
let topic = String::read(buf)?;
let payload = Bytes::read(buf)?;
Ok(Self {
qos,
retain,
topic,
payload,
last_will_properties,
})
}
}
impl MqttWrite for LastWill {
fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
self.last_will_properties.write(buf)?;
self.topic.write(buf)?;
self.payload.write(buf)?;
Ok(())
}
}
impl WireLength for LastWill {
fn wire_len(&self) -> usize {
self.topic.wire_len() + self.payload.wire_len() + self.last_will_properties.wire_len()
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct LastWillProperties {
delay_interval: Option<u32>,
payload_format_indicator: Option<u8>,
message_expiry_interval: Option<u32>,
content_type: Option<String>,
response_topic: Option<String>,
correlation_data: Option<Bytes>,
user_properties: Vec<(String, String)>,
}
impl MqttRead for LastWillProperties {
fn read(buf: &mut Bytes) -> Result<Self, DeserializeError> {
let (len, _) = read_variable_integer(buf)?;
let mut properties = Self::default();
if len == 0 {
return Ok(properties);
}
else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"LastWillProperties".to_string(),
buf.len(),
len,
));
}
let mut property_data = buf.split_to(len);
loop {
match PropertyType::read(&mut property_data)? {
PropertyType::WillDelayInterval => {
if properties.delay_interval.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::WillDelayInterval,
));
}
properties.delay_interval = Some(u32::read(&mut property_data)?);
}
PropertyType::PayloadFormatIndicator => {
if properties.payload_format_indicator.is_none() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::PayloadFormatIndicator,
));
}
properties.payload_format_indicator = Some(u8::read(&mut property_data)?);
}
PropertyType::MessageExpiryInterval => {
if properties.message_expiry_interval.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::MessageExpiryInterval,
));
}
properties.message_expiry_interval = Some(u32::read(&mut property_data)?);
}
PropertyType::ContentType => {
if properties.content_type.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::ContentType,
));
}
properties.content_type = Some(String::read(&mut property_data)?);
}
PropertyType::ResponseTopic => {
if properties.response_topic.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::ResponseTopic,
));
}
properties.response_topic = Some(String::read(&mut property_data)?);
}
PropertyType::CorrelationData => {
if properties.correlation_data.is_some() {
return Err(DeserializeError::DuplicateProperty(
PropertyType::CorrelationData,
));
}
properties.correlation_data = Some(Bytes::read(&mut property_data)?);
}
PropertyType::UserProperty => properties.user_properties.push((
String::read(&mut property_data)?,
String::read(&mut property_data)?,
)),
e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)),
}
if property_data.is_empty() {
break;
}
}
Ok(properties)
}
}
impl MqttWrite for LastWillProperties {
fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
write_variable_integer(buf, self.wire_len())?;
if let Some(delay_interval) = self.delay_interval {
PropertyType::WillDelayInterval.write(buf)?;
buf.put_u32(delay_interval);
}
if let Some(payload_format_indicator) = self.payload_format_indicator {
PropertyType::PayloadFormatIndicator.write(buf)?;
buf.put_u8(payload_format_indicator);
}
if let Some(message_expiry_interval) = self.message_expiry_interval {
PropertyType::MessageExpiryInterval.write(buf)?;
buf.put_u32(message_expiry_interval);
}
if let Some(content_type) = &self.content_type {
PropertyType::ContentType.write(buf)?;
content_type.write(buf)?;
}
if let Some(response_topic) = &self.response_topic {
PropertyType::ResponseTopic.write(buf)?;
response_topic.write(buf)?;
}
if let Some(correlation_data) = &self.correlation_data {
PropertyType::CorrelationData.write(buf)?;
correlation_data.write(buf)?;
}
if !self.user_properties.is_empty() {
for (key, value) in &self.user_properties {
PropertyType::UserProperty.write(buf)?;
key.write(buf)?;
value.write(buf)?;
}
}
Ok(())
}
}
impl WireLength for LastWillProperties {
fn wire_len(&self) -> usize {
let mut len: usize = 0;
len += self.delay_interval.map_or(0, |_| 5);
len += self.payload_format_indicator.map_or(0, |_| 2);
len += self.message_expiry_interval.map_or(0, |_| 5);
len += self
.content_type
.as_ref()
.map_or_else(|| 0, |s| s.wire_len());
len += self
.response_topic
.as_ref()
.map_or_else(|| 0, |s| s.wire_len());
len += self
.correlation_data
.as_ref()
.map_or_else(|| 0, |b| b.wire_len());
for (key, value) in &self.user_properties {
len += key.wire_len() + value.wire_len();
}
len
}
}
#[cfg(test)]
mod tests {
use crate::packets::{
mqtt_traits::{MqttWrite, VariableHeaderRead, VariableHeaderWrite},
QoS,
};
use super::{Connect, LastWill};
#[test]
fn read_connect() {
let mut buf = bytes::BytesMut::new();
let packet = &[
0x00,
0x04,
b'M',
b'Q',
b'T',
b'T',
0x05,
0b1100_1110, 0x00, 0x0a,
0x00, 0x00, 0x04,
b't', b'e',
b's',
b't',
0x00, 0x00, 0x02,
b'/', b'a',
0x00, 0x0B,
b'h', b'e',
b'l',
b'l',
b'o',
b' ',
b'w',
b'o',
b'r',
b'l',
b'd',
0x00, 0x04,
b'u', b's',
b'e',
b'r',
0x00, 0x04,
b'p', b'a',
b's',
b's',
0xAB, 0xCD,
0xEF,
];
buf.extend_from_slice(packet);
let c = Connect::read(0, 0, buf.into()).unwrap();
dbg!(c);
}
#[test]
fn read_and_write_connect() {
let mut buf = bytes::BytesMut::new();
let packet = &[
0x00,
0x04,
b'M',
b'Q',
b'T',
b'T',
0x05, 0b1100_1110, 0x00, 0x0a,
0x00, 0x00, 0x04,
b't', b'e',
b's',
b't',
0x00, 0x00, 0x02,
b'/', b'a',
0x00, 0x0B,
b'h', b'e',
b'l',
b'l',
b'o',
b' ',
b'w',
b'o',
b'r',
b'l',
b'd',
0x00, 0x04,
b'u', b's',
b'e',
b'r',
0x00, 0x04,
b'p', b'a',
b's',
b's',
];
buf.extend_from_slice(packet);
let c = Connect::read(0, 0, buf.into()).unwrap();
let mut write_buf = bytes::BytesMut::new();
c.write(&mut write_buf).unwrap();
assert_eq!(packet.to_vec(), write_buf.to_vec());
dbg!(c);
}
#[test]
fn parsing_last_will() {
let last_will = &[
0x00, 0x00, 0x02, b'/', b'a', 0x00, 0x0B, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd',
];
let mut buf = bytes::Bytes::from_static(last_will);
assert!(LastWill::read(QoS::AtLeastOnce, false, &mut buf).is_ok());
}
#[test]
fn read_and_write_connect2() {
let _packet = [
0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11,
0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54,
0x65, 0x73, 0x74,
];
let data = [
0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff,
0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73,
0x74,
];
let mut buf = bytes::BytesMut::new();
buf.extend_from_slice(&data);
let c = Connect::read(0, 0, buf.into()).unwrap();
dbg!(c.clone());
let mut write_buf = bytes::BytesMut::new();
c.write(&mut write_buf).unwrap();
assert_eq!(data.to_vec(), write_buf.to_vec());
}
#[test]
fn parsing_and_writing_last_will() {
let last_will = &[
0x00, 0x00, 0x02, b'/', b'a', 0x00, 0x0B, b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd',
];
let mut buf = bytes::Bytes::from_static(last_will);
let lw = LastWill::read(QoS::AtLeastOnce, false, &mut buf).unwrap();
let mut write_buf = bytes::BytesMut::new();
lw.write(&mut write_buf).unwrap();
assert_eq!(last_will.to_vec(), write_buf.to_vec());
}
}