use std::convert::TryFrom;
use crate::base::{PROTOCOL_NAME, PROTOCOL_NAME_V3};
use crate::connect_flags::ConnectFlags;
use crate::utils::validate_client_id;
use crate::{
validate_keep_alive, BinaryData, ByteArray, DecodeError, DecodePacket, EncodeError,
EncodePacket, FixedHeader, KeepAlive, Packet, PacketType, ProtocolLevel, PubTopic, QoS,
StringData, VarIntError,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectPacket {
protocol_name: StringData,
protocol_level: ProtocolLevel,
connect_flags: ConnectFlags,
keep_alive: KeepAlive,
client_id: StringData,
will_topic: Option<PubTopic>,
will_message: BinaryData,
username: StringData,
password: BinaryData,
}
impl ConnectPacket {
pub fn new(client_id: &str) -> Result<Self, EncodeError> {
let protocol_name = StringData::from(PROTOCOL_NAME)?;
validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
let client_id = StringData::from(client_id)?;
Ok(Self {
protocol_name,
keep_alive: KeepAlive::new(60),
client_id,
..Self::default()
})
}
pub fn new_v3(client_id: &str) -> Result<Self, EncodeError> {
let protocol_name = StringData::from(PROTOCOL_NAME_V3)?;
let protocol_level = ProtocolLevel::V3;
validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
let client_id = StringData::from(client_id)?;
Ok(Self {
protocol_name,
protocol_level,
keep_alive: KeepAlive::new(60),
client_id,
..Self::default()
})
}
pub fn set_protcol_level(&mut self, level: ProtocolLevel) -> Result<(), EncodeError> {
match level {
ProtocolLevel::V3 => {
self.protocol_name = StringData::from(PROTOCOL_NAME_V3)?;
}
ProtocolLevel::V4 => {
self.protocol_name = StringData::from(PROTOCOL_NAME)?;
}
ProtocolLevel::V5 => {
return Err(EncodeError::InvalidPacketLevel);
}
}
self.protocol_level = level;
Ok(())
}
#[must_use]
#[inline]
pub const fn protocol_level(&self) -> ProtocolLevel {
self.protocol_level
}
pub fn set_connect_flags(&mut self, flags: ConnectFlags) -> &Self {
self.connect_flags = flags;
self
}
#[must_use]
#[inline]
pub const fn connect_flags(&self) -> &ConnectFlags {
&self.connect_flags
}
pub fn set_keep_alive(&mut self, keep_alive: u16) -> &mut Self {
self.keep_alive = KeepAlive::new(keep_alive);
self
}
#[must_use]
#[inline]
pub const fn keep_alive(&self) -> u16 {
self.keep_alive.value()
}
pub fn set_client_id(&mut self, client_id: &str) -> Result<&mut Self, EncodeError> {
validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
self.client_id = StringData::from(client_id)?;
Ok(self)
}
#[must_use]
pub fn client_id(&self) -> &str {
self.client_id.as_ref()
}
pub fn set_username(&mut self, username: &str) -> Result<&mut Self, EncodeError> {
self.username = StringData::from(username)?;
Ok(self)
}
#[must_use]
pub fn username(&self) -> &str {
self.username.as_ref()
}
pub fn set_password(&mut self, password: &[u8]) -> Result<&mut Self, EncodeError> {
self.password = BinaryData::from_slice(password)?;
Ok(self)
}
#[must_use]
pub fn password(&self) -> &[u8] {
self.password.as_ref()
}
pub fn set_will_topic(&mut self, topic: &str) -> Result<&mut Self, EncodeError> {
if topic.is_empty() {
self.will_topic = None;
} else {
self.will_topic = Some(PubTopic::new(topic)?);
}
Ok(self)
}
#[must_use]
pub fn will_topic(&self) -> Option<&str> {
self.will_topic.as_ref().map(AsRef::as_ref)
}
pub fn set_will_message(&mut self, message: &[u8]) -> Result<&mut Self, EncodeError> {
self.will_message = BinaryData::from_slice(message)?;
Ok(self)
}
#[must_use]
pub fn will_message(&self) -> &[u8] {
self.will_message.as_ref()
}
fn get_fixed_header(&self) -> Result<FixedHeader, VarIntError> {
let mut remaining_length = self.protocol_name.bytes()
+ ProtocolLevel::bytes()
+ ConnectFlags::bytes()
+ KeepAlive::bytes()
+ self.client_id.bytes();
if self.connect_flags.will() {
assert!(self.will_topic.is_some());
if let Some(will_topic) = &self.will_topic {
remaining_length += will_topic.bytes();
}
remaining_length += self.will_message.bytes();
}
if self.connect_flags.has_username() {
remaining_length += self.username.bytes();
}
if self.connect_flags.has_password() {
remaining_length += self.password.bytes();
}
FixedHeader::new(PacketType::Connect, remaining_length)
}
}
impl EncodePacket for ConnectPacket {
fn encode(&self, v: &mut Vec<u8>) -> Result<usize, EncodeError> {
let old_len = v.len();
let fixed_header = self.get_fixed_header()?;
fixed_header.encode(v)?;
self.protocol_name.encode(v)?;
self.protocol_level.encode(v)?;
self.connect_flags.encode(v)?;
self.keep_alive.encode(v)?;
self.client_id.encode(v)?;
if self.connect_flags.will() {
assert!(self.will_topic.is_some());
if let Some(will_topic) = &self.will_topic {
will_topic.encode(v)?;
}
self.will_message.encode(v)?;
}
if self.connect_flags.has_username() {
self.username.encode(v)?;
}
if self.connect_flags.has_password() {
self.password.encode(v)?;
}
Ok(v.len() - old_len)
}
}
impl DecodePacket for ConnectPacket {
fn decode(ba: &mut ByteArray) -> Result<Self, DecodeError> {
let fixed_header = FixedHeader::decode(ba)?;
if fixed_header.packet_type() != PacketType::Connect {
return Err(DecodeError::InvalidPacketType);
}
let protocol_name = StringData::decode(ba)?;
let protocol_level = ProtocolLevel::try_from(ba.read_byte()?)?;
match protocol_level {
ProtocolLevel::V3 => {
if protocol_name.as_ref() != PROTOCOL_NAME_V3 {
return Err(DecodeError::InvalidProtocolName);
}
}
ProtocolLevel::V4 => {
if protocol_name.as_ref() != PROTOCOL_NAME {
return Err(DecodeError::InvalidProtocolName);
}
}
ProtocolLevel::V5 => {
return Err(DecodeError::InvalidProtocolLevel);
}
}
let connect_flags = ConnectFlags::decode(ba)?;
if !connect_flags.will()
&& (connect_flags.will_qos() != QoS::AtMostOnce || connect_flags.will_retain())
{
return Err(DecodeError::InvalidConnectFlags);
}
if !connect_flags.has_username() && connect_flags.has_password() {
return Err(DecodeError::InvalidConnectFlags);
}
let keep_alive = KeepAlive::decode(ba)?;
validate_keep_alive(keep_alive)?;
let client_id = StringData::decode(ba).map_err(|_err| DecodeError::InvalidClientId)?;
if client_id.is_empty() && !connect_flags.clean_session() {
return Err(DecodeError::InvalidClientId);
}
validate_client_id(client_id.as_ref())?;
let will_topic = if connect_flags.will() {
Some(PubTopic::decode(ba)?)
} else {
None
};
let will_message = if connect_flags.will() {
BinaryData::decode(ba)?
} else {
BinaryData::new()
};
let username = if connect_flags.has_username() {
StringData::decode(ba)?
} else {
StringData::new()
};
let password = if connect_flags.has_password() {
BinaryData::decode(ba)?
} else {
BinaryData::new()
};
Ok(Self {
protocol_name,
protocol_level,
connect_flags,
keep_alive,
client_id,
will_topic,
will_message,
username,
password,
})
}
}
impl Packet for ConnectPacket {
fn packet_type(&self) -> PacketType {
PacketType::Connect
}
fn bytes(&self) -> Result<usize, VarIntError> {
let fixed_header = self.get_fixed_header()?;
Ok(fixed_header.bytes() + fixed_header.remaining_length())
}
}
#[cfg(test)]
mod tests {
use super::{ByteArray, ConnectPacket, DecodePacket};
#[test]
fn test_decode() {
let buf: Vec<u8> = vec![
16, 20, 0, 4, 77, 81, 84, 84, 4, 2, 0, 60, 0, 8, 119, 118, 80, 84, 88, 99, 67, 119,
];
let mut ba = ByteArray::new(&buf);
let packet = ConnectPacket::decode(&mut ba);
assert!(packet.is_ok());
let packet = packet.unwrap();
assert_eq!(packet.client_id(), "wvPTXcCw");
}
}