pub(crate) mod connack;
pub(crate) mod connect;
mod disconnect;
mod ping;
mod pubcommon;
mod publish;
mod suback;
mod subscribe;
mod unsuback;
mod unsubscribe;
pub(crate) use crate::protocol::packet::subscribe::RetainForwardRule;
pub use crate::protocol::packet::{
pubcommon::{PubAck, PubComp, PubRec, PubRel},
publish::Publish,
suback::{SubAck, SubscribeReasonCode},
subscribe::Subscribe,
unsuback::{UnsubAck, UnsubAckReason},
unsubscribe::Unsubscribe
};
use crate::protocol::{
FixedHeader, PacketParseError, PacketType, Protocol
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
pub use connack::*;
pub use connect::*;
pub use disconnect::*;
pub use ping::*;
use log::{debug, error};
use std::slice::Iter;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FixedHeaderError {
InsufficientBytes(usize),
MalformedRemainingLength
}
impl FixedHeaderError {
pub fn to_discard(&self) -> bool {
*self == Self::MalformedRemainingLength
}
}
#[derive(Clone, Debug)]
pub enum Packet {
Connect(Connect),
ConnAck(ConnAck),
Publish(Publish),
PubAck(PubAck),
PingResp,
Subscribe(Subscribe),
SubAck(SubAck),
PubRec(PubRec),
PubRel(PubRel),
PubComp(PubComp),
Unsubscribe(Unsubscribe),
UnsubAck(UnsubAck),
Disconnect(Disconnect)
}
impl Packet {
pub fn packet_ty(&self) -> PacketType {
match self {
Packet::ConnAck(_) => PacketType::ConnAck,
Packet::Publish(_) => PacketType::Publish,
Packet::PubAck(_) => PacketType::PubAck,
Packet::PubRec(_) => PacketType::PubRec,
Packet::PubRel(_) => PacketType::PubRel,
Packet::PubComp(_) => PacketType::PubComp,
Packet::SubAck(_) => PacketType::SubAck,
Packet::UnsubAck(_) => PacketType::UnsubAck,
Packet::PingResp => PacketType::PingResp,
Packet::Connect(_) => PacketType::Connect,
Packet::Subscribe(_) => PacketType::Subscribe,
Packet::Unsubscribe(_) => PacketType::Unsubscribe,
Packet::Disconnect(_) => PacketType::Disconnect
}
}
}
pub fn read_from_network(
stream: &mut BytesMut,
version: Protocol
) -> Result<Option<Packet>, PacketParseError> {
let fixed_header =
match parse_fixed_header_by_slice(stream.as_ref()) {
Ok(fixed_header) => {
if fixed_header.frame_length() > stream.len() {
debug!(
"{} > {}",
fixed_header.frame_length(),
stream.len()
);
return Ok(None);
}
fixed_header
},
Err(err) => {
if err.to_discard() {
let _ = stream.split_to(stream.len());
}
return Err(err.into());
}
};
let packet = stream.split_to(fixed_header.frame_length());
let packet_type = fixed_header.packet_type()?;
let packet = packet.freeze();
let packet = match packet_type {
PacketType::ConnAck => Packet::ConnAck(ConnAck::read(
fixed_header,
packet,
version
)?),
PacketType::Publish => Packet::Publish(Publish::read(
fixed_header,
packet,
version
)?),
PacketType::PubAck => Packet::PubAck(PubAck::read(
fixed_header,
packet,
version
)?),
PacketType::PubRec => Packet::PubRec(PubRec::read(
fixed_header,
packet,
version
)?),
PacketType::PubRel => Packet::PubRel(PubRel::read(
fixed_header,
packet,
version
)?),
PacketType::PubComp => Packet::PubComp(PubComp::read(
fixed_header,
packet,
version
)?),
PacketType::SubAck => Packet::SubAck(SubAck::read(
fixed_header,
packet,
version
)?),
PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(
fixed_header,
packet,
version
)?),
PacketType::PingResp => Packet::PingResp,
PacketType::Disconnect => match version {
Protocol::V4 => {
return Err(PacketParseError::InvalidPacketType(
packet_type as u8
));
},
Protocol::V5 => Packet::Disconnect(Disconnect::read(
fixed_header,
packet,
version
)?)
},
ty => {
error!("{:?}", ty);
return Err(PacketParseError::InvalidPacketType(
packet_type as u8
));
}
};
Ok(Some(packet))
}
pub fn parse_fixed_header_by_slice(
stream: &[u8]
) -> Result<FixedHeader, FixedHeaderError> {
let stream_len = stream.len();
if stream_len < 2 {
return Err(FixedHeaderError::InsufficientBytes(
2 - stream_len
));
}
let byte1 = stream[0];
let (len_len, len) = length(stream[1..].iter())?;
Ok(FixedHeader::new(byte1, len_len, len))
}
pub fn length(
stream: Iter<u8>
) -> Result<(usize, usize), FixedHeaderError> {
let mut len: usize = 0;
let mut len_len = 0;
let mut done = false;
let mut shift = 0;
for byte in stream {
len_len += 1;
let byte = *byte as usize;
len += (byte & 0x7F) << shift;
done = (byte & 0x80) == 0;
if done {
break;
}
shift += 7;
if shift > 21 {
return Err(FixedHeaderError::MalformedRemainingLength);
}
}
if !done {
return Err(FixedHeaderError::InsufficientBytes(1));
}
Ok((len_len, len))
}
pub fn read_mqtt_bytes(
stream: &mut Bytes
) -> Result<Bytes, PacketParseError> {
let len = read_u16(stream)? as usize;
if len > stream.len() {
return Err(PacketParseError::BoundaryCrossed(len));
}
Ok(stream.split_to(len))
}
pub fn read_mqtt_string(
stream: &mut Bytes
) -> Result<String, PacketParseError> {
let s = read_mqtt_bytes(stream)?;
match String::from_utf8(s.to_vec()) {
Ok(v) => Ok(v),
Err(_e) => Err(PacketParseError::TopicNotUtf8)
}
}
pub fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
stream.put_u16(bytes.len() as u16);
stream.extend_from_slice(bytes);
}
pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
write_mqtt_bytes(stream, string.as_bytes());
}
pub fn write_remaining_length(
stream: &mut BytesMut,
len: usize
) -> usize {
let mut done = false;
let mut x = len;
let mut count = 0;
while !done {
let mut byte = (x % 128) as u8;
x /= 128;
if x > 0 {
byte |= 128;
}
stream.put_u8(byte);
count += 1;
done = x == 0;
}
count
}
pub fn read_u16(stream: &mut Bytes) -> Result<u16, PacketParseError> {
if stream.len() < 2 {
return Err(PacketParseError::MalformedPacket);
}
Ok(stream.get_u16())
}
pub fn read_u8(stream: &mut Bytes) -> Result<u8, PacketParseError> {
if stream.is_empty() {
return Err(PacketParseError::MalformedPacket);
}
Ok(stream.get_u8())
}
pub fn read_u32(stream: &mut Bytes) -> Result<u32, PacketParseError> {
if stream.len() < 4 {
return Err(PacketParseError::MalformedPacket);
}
Ok(stream.get_u32())
}