use alloc::vec::Vec;
use core::fmt;
use derive_builder::Builder;
#[cfg(feature = "std")]
use std::io::IoSlice;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use getset::{CopyGetters, Getters};
use crate::mqtt::packet::json_bin_encode::escape_binary_json_string;
use crate::mqtt::packet::mqtt_binary::MqttBinary;
use crate::mqtt::packet::mqtt_string::MqttString;
use crate::mqtt::packet::packet_type::{FixedHeader, PacketType};
use crate::mqtt::packet::qos::Qos;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
use crate::mqtt::result_code::MqttError;
use core::convert::TryInto;
#[derive(PartialEq, Eq, Builder, Clone, Getters, CopyGetters)]
#[builder(no_std, derive(Debug), pattern = "owned", setter(into), build_fn(skip))]
pub struct Connect {
#[builder(private)]
fixed_header: [u8; 1],
#[builder(private)]
remaining_length: VariableByteInteger,
#[builder(private)]
protocol_name: [u8; 6],
#[builder(private)]
protocol_version_buf: [u8; 1],
#[builder(private)]
connect_flags_buf: [u8; 1],
#[builder(private)]
keep_alive_buf: [u8; 2],
#[builder(private)]
client_id_buf: MqttString,
#[builder(private)]
will_topic_buf: MqttString,
#[builder(private)]
will_payload_buf: MqttBinary,
#[builder(private)]
user_name_buf: MqttString,
#[builder(private)]
password_buf: MqttBinary,
}
impl Connect {
pub fn builder() -> ConnectBuilder {
ConnectBuilder::default()
}
pub fn packet_type() -> PacketType {
PacketType::Connect
}
pub fn protocol_name(&self) -> &str {
"MQTT"
}
pub fn protocol_version(&self) -> u8 {
self.protocol_version_buf[0]
}
pub fn clean_session(&self) -> bool {
(self.connect_flags_buf[0] & 0b0000_0010) != 0
}
pub fn clean_start(&self) -> bool {
self.clean_session()
}
pub fn will_flag(&self) -> bool {
(self.connect_flags_buf[0] & 0b0000_0100) != 0
}
pub fn will_qos(&self) -> Qos {
let qos_bits = (self.connect_flags_buf[0] >> 3) & 0x03;
Qos::try_from(qos_bits).unwrap_or(Qos::AtMostOnce)
}
pub fn will_retain(&self) -> bool {
(self.connect_flags_buf[0] & 0b0010_0000) != 0
}
pub fn password_flag(&self) -> bool {
(self.connect_flags_buf[0] & 0b0100_0000) != 0
}
pub fn user_name_flag(&self) -> bool {
(self.connect_flags_buf[0] & 0b1000_0000) != 0
}
pub fn keep_alive(&self) -> u16 {
u16::from_be_bytes(self.keep_alive_buf)
}
pub fn client_id(&self) -> &str {
self.client_id_buf.as_str()
}
pub fn will_topic(&self) -> Option<&str> {
if self.will_flag() {
Some(self.will_topic_buf.as_str())
} else {
None
}
}
pub fn will_payload(&self) -> Option<&[u8]> {
if self.will_flag() {
Some(self.will_payload_buf.as_slice())
} else {
None
}
}
pub fn user_name(&self) -> Option<&str> {
if self.user_name_flag() {
Some(self.user_name_buf.as_str())
} else {
None
}
}
pub fn password(&self) -> Option<&[u8]> {
if self.password_flag() {
Some(self.password_buf.as_slice())
} else {
None
}
}
pub fn size(&self) -> usize {
1 + self.remaining_length.size() + self.remaining_length.to_u32() as usize
}
#[cfg(feature = "std")]
pub fn to_buffers(&self) -> Vec<IoSlice<'_>> {
let mut bufs = Vec::new();
bufs.push(IoSlice::new(&self.fixed_header));
bufs.push(IoSlice::new(self.remaining_length.as_bytes()));
bufs.push(IoSlice::new(&self.protocol_name));
bufs.push(IoSlice::new(&self.protocol_version_buf));
bufs.push(IoSlice::new(&self.connect_flags_buf));
bufs.push(IoSlice::new(&self.keep_alive_buf));
bufs.extend(self.client_id_buf.to_buffers());
if self.will_flag() {
bufs.extend(self.will_topic_buf.to_buffers());
bufs.extend(self.will_payload_buf.to_buffers());
}
if self.user_name_flag() {
bufs.extend(self.user_name_buf.to_buffers());
}
if self.password_flag() {
bufs.extend(self.password_buf.to_buffers());
}
bufs
}
pub fn to_continuous_buffer(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&self.fixed_header);
buf.extend_from_slice(self.remaining_length.as_bytes());
buf.extend_from_slice(&self.protocol_name);
buf.extend_from_slice(&self.protocol_version_buf);
buf.extend_from_slice(&self.connect_flags_buf);
buf.extend_from_slice(&self.keep_alive_buf);
buf.append(&mut self.client_id_buf.to_continuous_buffer());
if self.will_flag() {
buf.append(&mut self.will_topic_buf.to_continuous_buffer());
buf.append(&mut self.will_payload_buf.to_continuous_buffer());
}
if self.user_name_flag() {
buf.append(&mut self.user_name_buf.to_continuous_buffer());
}
if self.password_flag() {
buf.append(&mut self.password_buf.to_continuous_buffer());
}
buf
}
pub fn parse(data: &[u8]) -> Result<(Self, usize), MqttError> {
let mut cursor = 0;
if data.len() < cursor + 6 {
return Err(MqttError::MalformedPacket);
}
let protocol_name = [
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
data[cursor + 4],
data[cursor + 5],
];
cursor += 6;
if &protocol_name[2..] != b"MQTT" || protocol_name[0] != 0x00 || protocol_name[1] != 0x04 {
return Err(MqttError::ProtocolError);
}
if data.len() < cursor + 1 {
return Err(MqttError::MalformedPacket);
}
let protocol_version = data[cursor];
let protocol_version_buf = [protocol_version];
cursor += 1;
if protocol_version != 0x04 {
return Err(MqttError::UnsupportedProtocolVersion);
}
if data.len() < cursor + 1 {
return Err(MqttError::MalformedPacket);
}
let connect_flags = data[cursor];
let connect_flags_buf = [connect_flags];
cursor += 1;
if data.len() < cursor + 2 {
return Err(MqttError::MalformedPacket);
}
let keep_alive_buf = [data[cursor], data[cursor + 1]];
cursor += 2;
let (client_id_buf, consumed) =
MqttString::decode(&data[cursor..]).map_err(|_| MqttError::ClientIdentifierNotValid)?;
cursor += consumed;
let will_flag = (connect_flags & 0b0000_0100) != 0;
let mut will_topic_buf = MqttString::default();
let mut will_payload_buf = MqttBinary::default();
if will_flag {
let (w_topic, consumed) = MqttString::decode(&data[cursor..])?;
cursor += consumed;
will_topic_buf = w_topic;
let (w_payload, consumed) = MqttBinary::decode(&data[cursor..])?;
cursor += consumed;
will_payload_buf = w_payload;
}
let user_name_flag = (connect_flags & 0b1000_0000) != 0;
let mut user_name_buf = MqttString::default();
if user_name_flag {
let (uname, consumed) = MqttString::decode(&data[cursor..])
.map_err(|_| MqttError::BadUserNameOrPassword)?;
cursor += consumed;
user_name_buf = uname;
}
let password_flag = (connect_flags & 0b0100_0000) != 0;
let mut password_buf = MqttBinary::default();
if password_flag {
let (pwd, consumed) = MqttBinary::decode(&data[cursor..])
.map_err(|_| MqttError::BadUserNameOrPassword)?;
cursor += consumed;
password_buf = pwd;
}
if password_flag && !user_name_flag {
return Err(MqttError::ProtocolError);
}
let connect = Connect {
fixed_header: [FixedHeader::Connect as u8],
remaining_length: VariableByteInteger::from_u32(cursor as u32).unwrap(),
protocol_name,
protocol_version_buf,
connect_flags_buf,
keep_alive_buf,
client_id_buf,
will_topic_buf,
will_payload_buf,
user_name_buf,
password_buf,
};
Ok((connect, cursor))
}
}
impl ConnectBuilder {
pub fn client_id<T>(mut self, id: T) -> Result<Self, MqttError>
where
T: TryInto<MqttString, Error = MqttError>,
{
let mqtt_str = id.try_into()?;
self.client_id_buf = Some(mqtt_str);
Ok(self)
}
pub fn clean_session(mut self, clean: bool) -> Self {
let mut flags = self.connect_flags_buf.unwrap_or([0])[0];
if clean {
flags |= 0b0000_0010;
} else {
flags &= !0b0000_0010;
}
self.connect_flags_buf = Some([flags]);
self
}
pub fn clean_start(self, clean: bool) -> Self {
self.clean_session(clean)
}
pub fn will_message<T, B>(
mut self,
topic: T,
payload: B,
qos: Qos,
retain: bool,
) -> Result<Self, MqttError>
where
T: TryInto<MqttString, Error = MqttError>,
B: TryInto<MqttBinary, Error = MqttError>,
{
let will_topic = topic.try_into()?;
let will_payload = payload.try_into()?;
self.will_topic_buf = Some(will_topic);
self.will_payload_buf = Some(will_payload);
let mut flags = self.connect_flags_buf.unwrap_or([0b0000_0010])[0];
flags |= 0b0000_0100; flags |= (qos as u8) << 3; if retain {
flags |= 0b0010_0000; }
self.connect_flags_buf = Some([flags]);
Ok(self)
}
pub fn user_name<T>(mut self, name: T) -> Result<Self, MqttError>
where
T: TryInto<MqttString, Error = MqttError>,
{
let mqtt_str = name.try_into()?;
self.user_name_buf = Some(mqtt_str);
let mut flags = self.connect_flags_buf.unwrap_or([0b0000_0010])[0];
flags |= 0b1000_0000; self.connect_flags_buf = Some([flags]);
Ok(self)
}
pub fn password<B>(mut self, pwd: B) -> Result<Self, MqttError>
where
B: TryInto<MqttBinary, Error = MqttError>,
{
let mqtt_bin = pwd.try_into()?;
self.password_buf = Some(mqtt_bin);
let mut flags = self.connect_flags_buf.unwrap_or([0b0000_0010])[0];
flags |= 0b0100_0000; self.connect_flags_buf = Some([flags]);
Ok(self)
}
pub fn keep_alive(mut self, seconds: u16) -> Self {
self.keep_alive_buf = Some(seconds.to_be_bytes());
self
}
fn validate(&self) -> Result<(), MqttError> {
let flags = self.connect_flags_buf.unwrap_or([0b0000_0010])[0];
if (flags & 0b0100_0000) != 0 && (flags & 0b1000_0000) == 0 {
return Err(MqttError::ProtocolError);
}
let will_flag = (flags & 0b0000_0100) != 0;
if will_flag {
if self.will_topic_buf.is_none() || self.will_payload_buf.is_none() {
return Err(MqttError::MalformedPacket);
}
}
Ok(())
}
pub fn build(self) -> Result<Connect, MqttError> {
self.validate()?;
let protocol_name = [0x00, 0x04, b'M', b'Q', b'T', b'T'];
let protocol_version_buf = [0x04];
let connect_flags_buf = self.connect_flags_buf.unwrap_or([0b0000_0010]); let connect_flags = connect_flags_buf[0];
let keep_alive_buf = self.keep_alive_buf.unwrap_or([0, 0]);
let client_id_buf = self.client_id_buf.unwrap_or_default();
let will_flag = (connect_flags & 0b0000_0100) != 0;
let will_topic_buf = self.will_topic_buf.unwrap_or_default();
let will_payload_buf = self.will_payload_buf.unwrap_or_default();
let user_name_buf = self.user_name_buf.unwrap_or_default();
let password_buf = self.password_buf.unwrap_or_default();
let mut remaining = 0;
remaining += 6; remaining += 1; remaining += 1; remaining += 2; remaining += client_id_buf.size();
if will_flag {
remaining += will_topic_buf.size(); remaining += will_payload_buf.size(); }
if (connect_flags & 0b1000_0000) != 0 {
remaining += user_name_buf.size(); }
if (connect_flags & 0b0100_0000) != 0 {
remaining += password_buf.size(); }
let remaining_length = VariableByteInteger::from_u32(remaining as u32).unwrap();
Ok(Connect {
fixed_header: [FixedHeader::Connect as u8],
remaining_length,
protocol_name,
protocol_version_buf,
connect_flags_buf,
keep_alive_buf,
client_id_buf,
will_topic_buf,
will_payload_buf,
user_name_buf,
password_buf,
})
}
}
impl Serialize for Connect {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut field_count = 5;
if self.user_name_flag() {
field_count += 1;
}
if self.password_flag() {
field_count += 1;
}
if self.will_flag() {
field_count += 4; }
let mut state = serializer.serialize_struct("Connect", field_count)?;
state.serialize_field("type", "connect")?;
state.serialize_field("client_id", &self.client_id())?;
state.serialize_field("clean_start", &self.clean_start())?;
state.serialize_field("keep_alive", &self.keep_alive())?;
if self.user_name_flag() {
state.serialize_field("user_name", &self.user_name())?;
}
if self.password_flag() {
state.serialize_field("password", "*****")?;
}
if self.will_flag() {
state.serialize_field("will_qos", &self.will_qos())?;
state.serialize_field("will_retain", &self.will_retain())?;
state.serialize_field("will_topic", &self.will_topic())?;
if let Some(payload) = self.will_payload() {
match escape_binary_json_string(payload) {
Some(escaped) => state.serialize_field("will_payload", &escaped)?,
None => state.serialize_field("will_payload", &payload)?,
}
}
}
state.end()
}
}
impl fmt::Display for Connect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{json}"),
Err(e) => write!(f, "{{\"error\": \"{e}\"}}"),
}
}
}
impl fmt::Debug for Connect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl GenericPacketTrait for Connect {
fn size(&self) -> usize {
self.size()
}
#[cfg(feature = "std")]
fn to_buffers(&self) -> Vec<IoSlice<'_>> {
self.to_buffers()
}
fn to_continuous_buffer(&self) -> Vec<u8> {
self.to_continuous_buffer()
}
}
impl GenericPacketDisplay for Connect {
fn fmt_debug(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(self, f)
}
fn fmt_display(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
}