use std::convert::TryFrom;
use super::property::check_property_type_list;
use super::{Properties, PropertyType};
use crate::base::PROTOCOL_NAME;
use crate::connect_flags::ConnectFlags;
use crate::utils::validate_client_id;
use crate::{
validate_keep_alive, BinaryData, ByteArray, DecodeError, DecodePacket, EncodeError,
EncodePacket, FixedHeader, KeepAlive, Packet, PacketType, ProtocolLevel, PubTopic, QoS,
StringData, VarIntError,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectPacket {
protocol_name: StringData,
protocol_level: ProtocolLevel,
connect_flags: ConnectFlags,
keep_alive: KeepAlive,
properties: Properties,
client_id: StringData,
will_properties: Properties,
will_topic: Option<PubTopic>,
will_message: BinaryData,
username: StringData,
password: BinaryData,
}
pub const CONNECT_PROPERTIES: &[PropertyType] = &[
PropertyType::SessionExpiryInterval,
PropertyType::ReceiveMaximum,
PropertyType::MaximumPacketSize,
PropertyType::TopicAliasMaximum,
PropertyType::RequestProblemInformation,
PropertyType::UserProperty,
PropertyType::AuthenticationMethod,
PropertyType::AuthenticationData,
];
pub const CONNECT_WILL_PROPERTIES: &[PropertyType] = &[
PropertyType::WillDelayInterval,
PropertyType::PayloadFormatIndicator,
PropertyType::MessageExpiryInterval,
PropertyType::ContentType,
PropertyType::ResponseTopic,
PropertyType::CorrelationData,
PropertyType::UserProperty,
];
impl ConnectPacket {
pub fn new(client_id: &str) -> Result<Self, EncodeError> {
let protocol_name = StringData::from(PROTOCOL_NAME)?;
validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
let client_id = StringData::from(client_id)?;
Ok(Self {
protocol_name,
keep_alive: KeepAlive::new(60),
client_id,
..Self::default()
})
}
pub fn set_protcol_level(&mut self, level: ProtocolLevel) -> &mut Self {
self.protocol_level = level;
self
}
#[must_use]
pub const fn protocol_level(&self) -> ProtocolLevel {
self.protocol_level
}
pub fn set_connect_flags(&mut self, flags: ConnectFlags) -> &Self {
self.connect_flags = flags;
self
}
#[must_use]
#[inline]
pub const fn connect_flags(&self) -> &ConnectFlags {
&self.connect_flags
}
pub fn set_keep_alive(&mut self, keep_alive: u16) -> &mut Self {
self.keep_alive = KeepAlive::new(keep_alive);
self
}
#[must_use]
pub const fn keep_alive(&self) -> u16 {
self.keep_alive.value()
}
pub fn set_will_retain(&mut self, will_retain: bool) -> &mut Self {
self.connect_flags.set_will_retain(will_retain);
self
}
#[must_use]
pub const fn will_retain(&self) -> bool {
self.connect_flags.will_retain()
}
pub fn set_will_qos(&mut self, qos: QoS) -> &mut Self {
self.connect_flags.set_will_qos(qos);
self
}
#[must_use]
pub const fn will_qos(&self) -> QoS {
self.connect_flags.will_qos()
}
pub fn set_will(&mut self, will: bool) -> &mut Self {
self.connect_flags.set_will(will);
self
}
#[must_use]
pub const fn will(&self) -> bool {
self.connect_flags.will()
}
pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self {
self.connect_flags.set_clean_session(clean_session);
self
}
#[must_use]
pub const fn clean_session(&self) -> bool {
self.connect_flags.clean_session()
}
pub fn properties_mut(&mut self) -> &mut Properties {
&mut self.properties
}
#[must_use]
pub const fn properties(&self) -> &Properties {
&self.properties
}
pub fn set_client_id(&mut self, client_id: &str) -> Result<&mut Self, EncodeError> {
validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
self.client_id = StringData::from(client_id)?;
Ok(self)
}
#[must_use]
pub fn client_id(&self) -> &str {
self.client_id.as_ref()
}
pub fn set_username(&mut self, username: Option<&str>) -> Result<&mut Self, DecodeError> {
if let Some(username) = username {
self.username = StringData::from(username)?;
self.connect_flags.set_has_username(true);
} else {
self.connect_flags.set_has_username(false);
self.username = StringData::new();
}
Ok(self)
}
#[must_use]
pub fn username(&self) -> &str {
self.username.as_ref()
}
pub fn set_password(&mut self, password: Option<&[u8]>) -> Result<&mut Self, EncodeError> {
if let Some(password) = password {
self.connect_flags.set_has_password(true);
self.password = BinaryData::from_slice(password)?;
} else {
self.connect_flags.set_has_password(false);
self.password.clear();
}
Ok(self)
}
#[must_use]
pub fn password(&self) -> &[u8] {
self.password.as_ref()
}
pub fn will_properties_mut(&mut self) -> &mut Properties {
&mut self.will_properties
}
#[must_use]
pub const fn will_properties(&self) -> &Properties {
&self.will_properties
}
pub fn set_will_topic(&mut self, topic: &str) -> Result<&mut Self, EncodeError> {
if topic.is_empty() {
self.will_topic = None;
} else {
self.will_topic = Some(PubTopic::new(topic)?);
}
Ok(self)
}
pub fn will_topic(&self) -> Option<&str> {
self.will_topic.as_ref().map(AsRef::as_ref)
}
pub fn set_will_message(&mut self, message: &[u8]) -> Result<&mut Self, EncodeError> {
self.will_message = BinaryData::from_slice(message)?;
Ok(self)
}
#[must_use]
pub fn will_message(&self) -> &[u8] {
self.will_message.as_ref()
}
fn get_fixed_header(&self) -> Result<FixedHeader, VarIntError> {
let mut remaining_length = self.protocol_name.bytes()
+ ProtocolLevel::bytes()
+ ConnectFlags::bytes()
+ KeepAlive::bytes()
+ self.client_id.bytes();
if self.connect_flags.will() {
assert!(self.will_topic.is_some());
if let Some(will_topic) = &self.will_topic {
remaining_length += will_topic.bytes();
}
remaining_length += self.will_message.bytes();
}
if self.connect_flags.has_username() {
remaining_length += self.username.bytes();
}
if self.connect_flags.has_password() {
remaining_length += self.password.bytes();
}
FixedHeader::new(PacketType::Connect, remaining_length)
}
}
impl EncodePacket for ConnectPacket {
fn encode(&self, v: &mut Vec<u8>) -> Result<usize, EncodeError> {
let old_len = v.len();
let fixed_header = self.get_fixed_header()?;
fixed_header.encode(v)?;
self.protocol_name.encode(v)?;
self.protocol_level.encode(v)?;
self.connect_flags.encode(v)?;
self.keep_alive.encode(v)?;
self.client_id.encode(v)?;
if self.connect_flags.will() {
assert!(self.will_topic.is_some());
if let Some(will_topic) = &self.will_topic {
will_topic.encode(v)?;
}
self.will_message.encode(v)?;
}
if self.connect_flags.has_username() {
self.username.encode(v)?;
}
if self.connect_flags.has_password() {
self.password.encode(v)?;
}
Ok(v.len() - old_len)
}
}
impl DecodePacket for ConnectPacket {
fn decode(ba: &mut ByteArray) -> Result<Self, DecodeError> {
let fixed_header = FixedHeader::decode(ba)?;
if fixed_header.packet_type() != PacketType::Connect {
return Err(DecodeError::InvalidPacketType);
}
let protocol_name = StringData::decode(ba)?;
if protocol_name.as_ref() != PROTOCOL_NAME {
return Err(DecodeError::InvalidProtocolName);
}
let protocol_level = ProtocolLevel::try_from(ba.read_byte()?)?;
let connect_flags = ConnectFlags::decode(ba)?;
if !connect_flags.will()
&& (connect_flags.will_qos() != QoS::AtMostOnce || connect_flags.will_retain())
{
return Err(DecodeError::InvalidConnectFlags);
}
if !connect_flags.has_username() && connect_flags.has_password() {
return Err(DecodeError::InvalidConnectFlags);
}
let keep_alive = KeepAlive::decode(ba)?;
validate_keep_alive(keep_alive)?;
let properties = Properties::decode(ba);
let properties = match properties {
Ok(properties) => properties,
Err(err) => {
log::error!("err: {:?}", err);
return Err(DecodeError::InvalidPropertyType);
}
};
if let Err(property_type) = check_property_type_list(properties.props(), CONNECT_PROPERTIES)
{
log::error!(
"v5/ConnectPacket: property type {:?} cannot be used in properties!",
property_type
);
return Err(DecodeError::InvalidPropertyType);
}
let client_id = StringData::decode(ba).map_err(|_err| DecodeError::InvalidClientId)?;
if client_id.is_empty() && !connect_flags.clean_session() {
return Err(DecodeError::InvalidClientId);
}
validate_client_id(client_id.as_ref())?;
let will_properties = if connect_flags.will() {
Properties::decode(ba)?
} else {
Properties::new()
};
if let Err(property_type) =
check_property_type_list(will_properties.props(), CONNECT_WILL_PROPERTIES)
{
log::error!(
"v5/ConnectPacket: property type {:?} cannot be used in will properties!",
property_type
);
return Err(DecodeError::InvalidPropertyType);
}
let will_topic = if connect_flags.will() {
Some(PubTopic::decode(ba)?)
} else {
None
};
let will_message = if connect_flags.will() {
BinaryData::decode(ba)?
} else {
BinaryData::new()
};
let username = if connect_flags.has_username() {
StringData::decode(ba)?
} else {
StringData::new()
};
let password = if connect_flags.has_password() {
BinaryData::decode(ba)?
} else {
BinaryData::new()
};
Ok(Self {
protocol_name,
protocol_level,
connect_flags,
keep_alive,
properties,
client_id,
will_properties,
will_topic,
will_message,
username,
password,
})
}
}
impl Packet for ConnectPacket {
fn packet_type(&self) -> PacketType {
PacketType::Connect
}
fn bytes(&self) -> Result<usize, VarIntError> {
let fixed_header = self.get_fixed_header()?;
Ok(fixed_header.bytes() + fixed_header.remaining_length())
}
}
#[cfg(test)]
mod tests {
use super::{ByteArray, ConnectPacket, DecodePacket};
#[test]
fn test_decode() {
let buf: Vec<u8> = vec![
0x10, 0x15, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x02, 0x00, 0x3c, 0x00, 0x00,
0x08, 0x77, 0x76, 0x50, 0x54, 0x58, 0x63, 0x43, 0x77,
];
let mut ba = ByteArray::new(&buf);
let packet = ConnectPacket::decode(&mut ba);
assert!(packet.is_ok());
let packet = packet.unwrap();
assert_eq!(packet.client_id(), "wvPTXcCw");
}
}