use std::convert::TryFrom;
use std::fmt;
use crate::{
ByteArray, DecodeError, DecodePacket, EncodeError, EncodePacket, ProtocolLevel, QoS, VarInt,
VarIntError,
};
pub trait Packet: Send + fmt::Debug {
fn packet_type(&self) -> PacketType;
fn bytes(&self) -> Result<usize, VarIntError>;
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum PacketType {
#[default]
Connect,
ConnectAck,
Publish { dup: bool, qos: QoS, retain: bool },
PublishAck,
PublishReceived,
PublishRelease,
PublishComplete,
Subscribe,
SubscribeAck,
Unsubscribe,
UnsubscribeAck,
PingRequest,
PingResponse,
Disconnect,
Auth,
}
impl PacketType {
#[must_use]
pub const fn bytes() -> usize {
1
}
}
impl From<PacketType> for u8 {
#[allow(clippy::bool_to_int_with_if)]
fn from(packet_type: PacketType) -> Self {
let type_bits = match packet_type {
PacketType::Connect => 1,
PacketType::ConnectAck => 2,
PacketType::Publish { .. } => 3,
PacketType::PublishAck => 4,
PacketType::PublishReceived => 5,
PacketType::PublishRelease => 6,
PacketType::PublishComplete => 7,
PacketType::Subscribe => 8,
PacketType::SubscribeAck => 9,
PacketType::Unsubscribe => 10,
PacketType::UnsubscribeAck => 11,
PacketType::PingRequest => 12,
PacketType::PingResponse => 13,
PacketType::Disconnect => 14,
PacketType::Auth => 15,
};
let flags_bits = match packet_type {
PacketType::Publish { dup, qos, retain } => {
let dup = if dup { 0b0000_1000 } else { 0b0000_0000 };
let qos = match qos {
QoS::AtMostOnce => 0b0000_0000,
QoS::AtLeastOnce => 0b0000_0010,
QoS::ExactOnce => 0b0000_0100,
};
let retain = if retain { 0b0000_0001 } else { 0b0000_0000 };
dup | qos | retain
}
PacketType::PublishRelease | PacketType::Subscribe | PacketType::Unsubscribe => {
0b0000_0010
}
_ => 0b0000_0000,
};
(type_bits << 4) | flags_bits
}
}
impl TryFrom<u8> for PacketType {
type Error = DecodeError;
#[allow(clippy::too_many_lines)]
#[allow(clippy::cognitive_complexity)]
fn try_from(v: u8) -> Result<Self, Self::Error> {
let type_bits = (v & 0b1111_0000) >> 4;
let flag = v & 0b0000_1111;
match type_bits {
1 => {
if flag == 0b0000_0000 {
Ok(Self::Connect)
} else {
log::error!("header: Got packet flag in Connect: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
2 => {
if flag == 0b0000_0000 {
Ok(Self::ConnectAck)
} else {
log::error!("header: Got packet flag in ConnectAck: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
3 => {
let dup = (flag & 0b0000_1000) == 0b0000_1000;
let retain = (flag & 0b0000_0001) == 0b0000_0001;
let qos = match flag & 0b0000_0110 {
0b0000_0000 => QoS::AtMostOnce,
0b0000_0010 => QoS::AtLeastOnce,
0b0000_0100 => QoS::ExactOnce,
_ => return Err(DecodeError::InvalidPacketFlags),
};
Ok(Self::Publish { dup, retain, qos })
}
4 => {
if flag == 0b0000_0000 {
Ok(Self::PublishAck)
} else {
log::error!("header: Got packet flag in PublishAck: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
5 => {
if flag == 0b0000_0000 {
Ok(Self::PublishReceived)
} else {
log::error!("header: Got packet flag in PublishReceived: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
6 => {
if flag == 0b0000_0010 {
Ok(Self::PublishRelease)
} else {
log::error!("header: Got packet flag in PublishRelease: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
7 => {
if flag == 0b0000_0000 {
Ok(Self::PublishComplete)
} else {
log::error!("header: Got packet flag in PublishComplete: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
8 => {
if flag == 0b0000_0010 {
Ok(Self::Subscribe)
} else {
log::error!("header: Got packet flag in Subscribe: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
9 => {
if flag == 0b0000_0000 {
Ok(Self::SubscribeAck)
} else {
log::error!("header: Got packet flag in Subscribe: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
10 => {
if flag == 0b0000_0010 {
Ok(Self::Unsubscribe)
} else {
log::error!("header: Got packet flag in Unsubscribe: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
11 => {
if flag == 0b0000_0000 {
Ok(Self::UnsubscribeAck)
} else {
log::error!("header: Got packet flag in UnsubscribeAck: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
12 => {
if flag == 0b0000_0000 {
Ok(Self::PingRequest)
} else {
log::error!("header: Got packet flag in PingRequest: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
13 => {
if flag == 0b0000_0000 {
Ok(Self::PingResponse)
} else {
log::error!("header: Got packet flag in PingResponse: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
14 => {
if flag == 0b0000_0000 {
Ok(Self::Disconnect)
} else {
log::error!("header: Got packet flag in Disconnect: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
15 => {
if flag == 0b0000_0000 {
Ok(Self::Auth)
} else {
log::error!("header: Got packet flag in Auth: {:#b}", flag);
Err(DecodeError::InvalidPacketFlags)
}
}
t => {
log::error!("Invlaid type_bits: {:#b}", t);
Err(DecodeError::InvalidPacketType)
}
}
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct FixedHeader {
packet_type: PacketType,
remaining_length: VarInt,
}
impl FixedHeader {
pub fn new(packet_type: PacketType, remaining_length: usize) -> Result<Self, VarIntError> {
let remaining_length = VarInt::from(remaining_length)?;
Ok(Self {
packet_type,
remaining_length,
})
}
#[must_use]
pub const fn packet_type(&self) -> PacketType {
self.packet_type
}
#[must_use]
pub const fn remaining_length(&self) -> usize {
self.remaining_length.value()
}
#[must_use]
pub const fn bytes(&self) -> usize {
PacketType::bytes() + self.remaining_length.bytes()
}
#[must_use]
pub fn is_valid_header(&self, protocol_level: ProtocolLevel) -> bool {
!(self.packet_type == PacketType::Auth && protocol_level != ProtocolLevel::V5)
}
}
impl DecodePacket for FixedHeader {
fn decode(ba: &mut ByteArray) -> Result<Self, DecodeError> {
let flag = ba.read_byte()?;
let packet_type = PacketType::try_from(flag)?;
let remaining_length = VarInt::decode(ba)?;
Ok(Self {
packet_type,
remaining_length,
})
}
}
impl EncodePacket for FixedHeader {
fn encode(&self, v: &mut Vec<u8>) -> Result<usize, EncodeError> {
let packet_type: u8 = self.packet_type.into();
v.push(packet_type);
self.remaining_length.encode(v)?;
Ok(PacketType::bytes() + self.remaining_length.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode() {
let mut buf = Vec::new();
println!("buf size: {}", buf.len());
let fixed_header = FixedHeader::new(PacketType::PingResponse, 0);
assert!(fixed_header.is_ok());
let fixed_header = fixed_header.unwrap();
let ret = fixed_header.encode(&mut buf);
assert!(ret.is_ok());
println!("buf size: {}", buf.len());
assert_eq!(ret.unwrap(), 2);
}
#[test]
fn test_decode() {
let buf = vec![
0x30, 0x13, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x48, 0x65, 0x6c, 0x6c, 0x6f,
0x2c, 0x20, 0x77, 0x6f, 0x72,
];
let mut ba = ByteArray::new(&buf);
let fixed_header = FixedHeader::decode(&mut ba);
assert!(fixed_header.is_ok());
let fixed_header = fixed_header.unwrap();
assert_eq!(
fixed_header.packet_type(),
PacketType::Publish {
dup: false,
qos: QoS::AtMostOnce,
retain: false
}
);
assert_eq!(fixed_header.remaining_length(), 19);
}
}