use crate::{
codec, Authentication, ClientID, Error, PropertiesDecoder, Property, QoS, Result as SageResult,
DEFAULT_PAYLOAD_FORMAT_INDICATOR, DEFAULT_RECEIVE_MAXIMUM, DEFAULT_REQUEST_PROBLEM_INFORMATION,
DEFAULT_REQUEST_RESPONSE_INFORMATION, DEFAULT_SESSION_EXPIRY_INTERVAL,
DEFAULT_TOPIC_ALIAS_MAXIMUM, DEFAULT_WILL_DELAY_INTERVAL,
};
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use std::convert::TryInto;
use std::marker::Unpin;
#[derive(Debug, PartialEq, Clone)]
pub struct Will {
pub qos: QoS,
pub retain: bool,
pub delay_interval: u32,
pub payload_format_indicator: bool,
pub message_expiry_interval: Option<u32>,
pub content_type: String,
pub response_topic: Option<String>,
pub correlation_data: Option<Vec<u8>>,
pub user_properties: Vec<(String, String)>,
pub topic: String,
pub message: Vec<u8>,
}
impl Default for Will {
fn default() -> Self {
Will {
qos: QoS::AtMostOnce,
retain: false,
delay_interval: DEFAULT_WILL_DELAY_INTERVAL,
payload_format_indicator: DEFAULT_PAYLOAD_FORMAT_INDICATOR,
message_expiry_interval: None,
content_type: Default::default(),
response_topic: None,
correlation_data: None,
user_properties: Default::default(),
topic: Default::default(),
message: Default::default(),
}
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct Connect {
pub clean_start: bool,
pub user_name: Option<String>,
pub password: Option<Vec<u8>>,
pub keep_alive: u16,
pub session_expiry_interval: u32,
pub receive_maximum: u16,
pub maximum_packet_size: Option<u32>,
pub topic_alias_maximum: u16,
pub request_response_information: bool,
pub request_problem_information: bool,
pub user_properties: Vec<(String, String)>,
pub authentication: Option<Authentication>,
pub client_id: Option<ClientID>,
pub will: Option<Will>,
}
impl Default for Connect {
fn default() -> Self {
Connect {
clean_start: false,
user_name: None,
password: Default::default(),
keep_alive: 600,
session_expiry_interval: DEFAULT_SESSION_EXPIRY_INTERVAL,
receive_maximum: DEFAULT_RECEIVE_MAXIMUM,
maximum_packet_size: None,
topic_alias_maximum: DEFAULT_TOPIC_ALIAS_MAXIMUM,
request_response_information: DEFAULT_REQUEST_RESPONSE_INFORMATION,
request_problem_information: DEFAULT_REQUEST_PROBLEM_INFORMATION,
user_properties: Default::default(),
authentication: None,
client_id: None,
will: None,
}
}
}
struct ConnectFlags {
pub clean_start: bool,
pub will: bool,
pub will_qos: QoS,
pub will_retain: bool,
pub user_name: bool,
pub password: bool,
}
impl Connect {
pub async fn write<W: AsyncWrite + Unpin>(self, writer: &mut W) -> SageResult<usize> {
let mut n_bytes = codec::write_utf8_string("MQTT", writer).await?;
n_bytes += codec::write_byte(0x05, writer).await?;
n_bytes += ConnectFlags {
clean_start: self.clean_start,
will: self.will.is_some(),
will_qos: if let Some(w) = &self.will {
w.qos
} else {
QoS::AtMostOnce
},
will_retain: if let Some(w) = &self.will {
w.retain
} else {
false
},
user_name: self.user_name.is_some(),
password: self.password.is_some(),
}
.write(writer)
.await?;
n_bytes += codec::write_two_byte_integer(self.keep_alive, writer).await?;
let mut properties = Vec::new();
n_bytes += Property::SessionExpiryInterval(self.session_expiry_interval)
.encode(&mut properties)
.await?;
n_bytes += Property::ReceiveMaximum(self.receive_maximum)
.encode(&mut properties)
.await?;
if let Some(maximum_packet_size) = self.maximum_packet_size {
n_bytes += Property::MaximumPacketSize(maximum_packet_size)
.encode(&mut properties)
.await?;
}
n_bytes += Property::TopicAliasMaximum(self.topic_alias_maximum)
.encode(&mut properties)
.await?;
n_bytes += Property::RequestResponseInformation(self.request_response_information)
.encode(&mut properties)
.await?;
n_bytes += Property::RequestProblemInformation(self.request_problem_information)
.encode(&mut properties)
.await?;
for (k, v) in self.user_properties {
n_bytes += Property::UserProperty(k, v).encode(&mut properties).await?;
}
if let Some(authentication) = self.authentication {
n_bytes += authentication.write(writer).await?;
}
n_bytes += codec::write_variable_byte_integer(properties.len() as u32, writer).await?;
writer.write_all(&properties).await?;
if let Some(client_id) = self.client_id {
if client_id.len() > 23 || client_id.chars().any(|c| c < '0' || c > 'z') {
return Err(Error::MalformedPacket);
}
n_bytes += codec::write_utf8_string(&client_id, writer).await?;
} else {
n_bytes += codec::write_utf8_string("", writer).await?;
}
if let Some(w) = self.will {
let mut properties = Vec::new();
n_bytes += Property::WillDelayInterval(w.delay_interval)
.encode(&mut properties)
.await?;
n_bytes += Property::PayloadFormatIndicator(w.payload_format_indicator)
.encode(&mut properties)
.await?;
if let Some(v) = w.message_expiry_interval {
n_bytes += Property::MessageExpiryInterval(v)
.encode(&mut properties)
.await?;
}
n_bytes += Property::ContentType(w.content_type)
.encode(&mut properties)
.await?;
if let Some(response_topic) = w.response_topic {
n_bytes += Property::ResponseTopic(response_topic)
.encode(&mut properties)
.await?;
}
if let Some(v) = w.correlation_data {
n_bytes += Property::CorrelationData(v).encode(&mut properties).await?;
}
for (k, v) in w.user_properties {
n_bytes += Property::UserProperty(k, v).encode(&mut properties).await?;
}
n_bytes += codec::write_variable_byte_integer(properties.len() as u32, writer).await?;
writer.write_all(&properties).await?;
if w.topic.is_empty() {
return Err(Error::ProtocolError);
}
n_bytes += codec::write_utf8_string(&w.topic, writer).await?;
n_bytes += codec::write_binary_data(&w.message, writer).await?;
}
if let Some(v) = self.user_name {
n_bytes += codec::write_utf8_string(&v, writer).await?;
}
if let Some(v) = self.password {
n_bytes += codec::write_binary_data(&v, writer).await?;
}
Ok(n_bytes)
}
pub async fn read<R: AsyncRead + Unpin>(reader: &mut R) -> SageResult<Self> {
let protocol_name = codec::read_utf8_string(reader).await?;
if protocol_name != "MQTT" {
return Err(Error::MalformedPacket);
}
let protocol_version = codec::read_byte(reader).await?;
if protocol_version != 0x05 {
return Err(Error::MalformedPacket);
}
let flags = ConnectFlags::read(reader).await?;
let clean_start = flags.clean_start;
let keep_alive = codec::read_two_byte_integer(reader).await?;
let mut session_expiry_interval = DEFAULT_SESSION_EXPIRY_INTERVAL;
let mut receive_maximum = DEFAULT_RECEIVE_MAXIMUM;
let mut maximum_packet_size = None;
let mut topic_alias_maximum = DEFAULT_TOPIC_ALIAS_MAXIMUM;
let mut request_response_information = DEFAULT_REQUEST_RESPONSE_INFORMATION;
let mut request_problem_information = DEFAULT_REQUEST_PROBLEM_INFORMATION;
let mut user_properties = Vec::new();
let mut authentication_method = None;
let mut authentication_data = Default::default();
let mut decoder = PropertiesDecoder::take(reader).await?;
while decoder.has_properties() {
match decoder.read().await? {
Property::SessionExpiryInterval(v) => session_expiry_interval = v,
Property::ReceiveMaximum(v) => receive_maximum = v,
Property::MaximumPacketSize(v) => maximum_packet_size = Some(v),
Property::TopicAliasMaximum(v) => topic_alias_maximum = v,
Property::RequestResponseInformation(v) => request_response_information = v,
Property::RequestProblemInformation(v) => request_problem_information = v,
Property::AuthenticationMethod(v) => authentication_method = Some(v),
Property::AuthenticationData(v) => authentication_data = v,
Property::UserProperty(k, v) => user_properties.push((k, v)),
_ => return Err(Error::ProtocolError),
};
}
let reader = decoder.into_inner();
let authentication = if let Some(method) = authentication_method {
Some(Authentication {
method,
data: authentication_data,
})
} else {
if !authentication_data.is_empty() {
return Err(Error::ProtocolError);
}
None
};
let client_id = {
let client_id = codec::read_utf8_string(reader).await?;
if client_id.is_empty() {
None
} else {
if client_id.len() > 23 || client_id.chars().any(|c| c < '0' || c > 'z') {
return Err(Error::MalformedPacket);
}
Some(client_id)
}
};
let (reader, will) = if flags.will {
let mut decoder = PropertiesDecoder::take(reader).await?;
let mut w = Will::default();
w.qos = flags.will_qos;
while decoder.has_properties() {
match decoder.read().await? {
Property::WillDelayInterval(v) => w.delay_interval = v,
Property::PayloadFormatIndicator(v) => w.payload_format_indicator = v,
Property::MessageExpiryInterval(v) => w.message_expiry_interval = Some(v),
Property::ContentType(v) => w.content_type = v,
Property::ResponseTopic(v) => w.response_topic = Some(v),
Property::CorrelationData(v) => w.correlation_data = Some(v),
Property::UserProperty(k, v) => w.user_properties.push((k, v)),
_ => return Err(Error::ProtocolError),
}
}
let reader = decoder.into_inner();
w.topic = codec::read_utf8_string(reader).await?;
if w.topic.is_empty() {
return Err(Error::ProtocolError);
}
w.message = codec::read_binary_data(reader).await?;
(reader, Some(w))
} else {
(reader, None)
};
let user_name = if flags.user_name {
Some(codec::read_utf8_string(reader).await?)
} else {
None
};
let password = if flags.password {
Some(codec::read_binary_data(reader).await?)
} else {
None
};
Ok(Connect {
clean_start,
user_name,
password,
keep_alive,
session_expiry_interval,
receive_maximum,
maximum_packet_size,
topic_alias_maximum,
request_response_information,
request_problem_information,
authentication,
user_properties,
client_id,
will,
})
}
}
impl ConnectFlags {
pub async fn write<W: AsyncWrite + Unpin>(self, writer: &mut W) -> SageResult<usize> {
let bits = ((self.user_name as u8) << 7)
| ((self.password as u8) << 6)
| ((self.will_retain as u8) << 5)
| (self.will_qos as u8) << 3
| ((self.will as u8) << 2)
| ((self.clean_start as u8) << 1);
codec::write_byte(bits, writer).await
}
pub async fn read<R: AsyncRead + Unpin>(reader: &mut R) -> SageResult<Self> {
let bits = codec::read_byte(reader).await?;
if bits & 0x01 != 0 {
Err(Error::MalformedPacket)
} else {
Ok(ConnectFlags {
user_name: (bits & 0b1000_0000) >> 7 > 0,
password: (bits & 0b0100_0000) >> 6 > 0,
will_retain: (bits & 0b0010_0000) >> 5 > 0,
will_qos: ((bits & 0b0001_1000) >> 3).try_into()?,
will: (bits & 0b0000_00100) >> 2 > 0,
clean_start: (bits & 0b0000_00010) >> 1 > 0,
})
}
}
}
#[cfg(test)]
mod unit_connect {
use super::*;
use async_std::io::Cursor;
fn encoded() -> Vec<u8> {
vec![
0, 4, 77, 81, 84, 84, 5, 206, 0, 10, 5, 17, 0, 0, 0, 10, 0, 0, 3, 3, 0, 0, 0, 6, 67,
108, 111, 90, 101, 101, 0, 0, 0, 6, 87, 105, 108, 108, 111, 119, 0, 5, 74, 97, 100,
101, 110,
]
}
fn decoded() -> Connect {
let keep_alive = 10;
let session_expiry_interval = 10;
Connect {
keep_alive,
clean_start: true,
session_expiry_interval,
user_name: Some("Willow".into()),
password: Some("Jaden".into()),
will: Some(Will {
qos: QoS::AtLeastOnce,
topic: "CloZee".into(),
..Default::default()
}),
..Default::default()
}
}
#[async_std::test]
async fn encode() {
let test_data = decoded();
let mut tested_result = Vec::new();
let n_bytes = test_data.write(&mut tested_result).await.unwrap();
assert_eq!(tested_result, encoded());
assert_eq!(n_bytes, 47);
}
#[async_std::test]
async fn decode() {
let mut test_data = Cursor::new(encoded());
let tested_result = Connect::read(&mut test_data).await.unwrap();
assert_eq!(tested_result, decoded());
}
}