use std::slice::Iter;
use crate::router::Ack;
use super::*;
use bytes::{Buf, BufMut, Bytes, BytesMut};
mod connack;
mod connect;
mod disconnect;
mod ping;
mod puback;
mod pubcomp;
mod publish;
mod pubrec;
mod pubrel;
mod suback;
mod subscribe;
mod unsuback;
mod unsubscribe;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketType {
Connect = 1,
ConnAck,
Publish,
PubAck,
PubRec,
PubRel,
PubComp,
Subscribe,
SubAck,
Unsubscribe,
UnsubAck,
PingReq,
PingResp,
Disconnect,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PropertyType {
PayloadFormatIndicator = 1,
MessageExpiryInterval = 2,
ContentType = 3,
ResponseTopic = 8,
CorrelationData = 9,
SubscriptionIdentifier = 11,
SessionExpiryInterval = 17,
AssignedClientIdentifier = 18,
ServerKeepAlive = 19,
AuthenticationMethod = 21,
AuthenticationData = 22,
RequestProblemInformation = 23,
WillDelayInterval = 24,
RequestResponseInformation = 25,
ResponseInformation = 26,
ServerReference = 28,
ReasonString = 31,
ReceiveMaximum = 33,
TopicAliasMaximum = 34,
TopicAlias = 35,
MaximumQos = 36,
RetainAvailable = 37,
UserProperty = 38,
MaximumPacketSize = 39,
WildcardSubscriptionAvailable = 40,
SubscriptionIdentifierAvailable = 41,
SharedSubscriptionAvailable = 42,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
pub struct FixedHeader {
byte1: u8,
fixed_header_len: usize,
remaining_len: usize,
}
impl FixedHeader {
pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
FixedHeader {
byte1,
fixed_header_len: remaining_len_len + 1,
remaining_len,
}
}
pub fn packet_type(&self) -> Result<PacketType, Error> {
let num = self.byte1 >> 4;
match num {
1 => Ok(PacketType::Connect),
2 => Ok(PacketType::ConnAck),
3 => Ok(PacketType::Publish),
4 => Ok(PacketType::PubAck),
5 => Ok(PacketType::PubRec),
6 => Ok(PacketType::PubRel),
7 => Ok(PacketType::PubComp),
8 => Ok(PacketType::Subscribe),
9 => Ok(PacketType::SubAck),
10 => Ok(PacketType::Unsubscribe),
11 => Ok(PacketType::UnsubAck),
12 => Ok(PacketType::PingReq),
13 => Ok(PacketType::PingResp),
14 => Ok(PacketType::Disconnect),
_ => Err(Error::InvalidPacketType(num)),
}
}
pub fn frame_length(&self) -> usize {
self.fixed_header_len + self.remaining_len
}
}
fn property(num: u8) -> Result<PropertyType, Error> {
let property = match num {
1 => PropertyType::PayloadFormatIndicator,
2 => PropertyType::MessageExpiryInterval,
3 => PropertyType::ContentType,
8 => PropertyType::ResponseTopic,
9 => PropertyType::CorrelationData,
11 => PropertyType::SubscriptionIdentifier,
17 => PropertyType::SessionExpiryInterval,
18 => PropertyType::AssignedClientIdentifier,
19 => PropertyType::ServerKeepAlive,
21 => PropertyType::AuthenticationMethod,
22 => PropertyType::AuthenticationData,
23 => PropertyType::RequestProblemInformation,
24 => PropertyType::WillDelayInterval,
25 => PropertyType::RequestResponseInformation,
26 => PropertyType::ResponseInformation,
28 => PropertyType::ServerReference,
31 => PropertyType::ReasonString,
33 => PropertyType::ReceiveMaximum,
34 => PropertyType::TopicAliasMaximum,
35 => PropertyType::TopicAlias,
36 => PropertyType::MaximumQos,
37 => PropertyType::RetainAvailable,
38 => PropertyType::UserProperty,
39 => PropertyType::MaximumPacketSize,
40 => PropertyType::WildcardSubscriptionAvailable,
41 => PropertyType::SubscriptionIdentifierAvailable,
42 => PropertyType::SharedSubscriptionAvailable,
num => return Err(Error::InvalidPropertyType(num)),
};
Ok(property)
}
pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
let stream_len = stream.len();
let fixed_header = parse_fixed_header(stream)?;
if fixed_header.remaining_len > max_packet_size {
return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
}
let frame_length = fixed_header.frame_length();
if stream_len < frame_length {
return Err(Error::InsufficientBytes(frame_length - stream_len));
}
Ok(fixed_header)
}
fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
let stream_len = stream.len();
if stream_len < 2 {
return Err(Error::InsufficientBytes(2 - stream_len));
}
let byte1 = stream.next().unwrap();
let (len_len, len) = length(stream)?;
Ok(FixedHeader::new(*byte1, len_len, len))
}
fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
let mut len: usize = 0;
let mut len_len = 0;
let mut done = false;
let mut shift = 0;
for byte in stream {
len_len += 1;
let byte = *byte as usize;
len += (byte & 0x7F) << shift;
done = (byte & 0x80) == 0;
if done {
break;
}
shift += 7;
if shift > 21 {
return Err(Error::MalformedRemainingLength);
}
}
if !done {
return Err(Error::InsufficientBytes(1));
}
Ok((len_len, len))
}
fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
let len = read_u16(stream)? as usize;
if len > stream.len() {
return Err(Error::BoundaryCrossed(len));
}
Ok(stream.split_to(len))
}
fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
let s = read_mqtt_bytes(stream)?;
match String::from_utf8(s.to_vec()) {
Ok(v) => Ok(v),
Err(_e) => Err(Error::TopicNotUtf8),
}
}
fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
stream.put_u16(bytes.len() as u16);
stream.extend_from_slice(bytes);
}
fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
write_mqtt_bytes(stream, string.as_bytes());
}
fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
if len > 268_435_455 {
return Err(Error::PayloadTooLong);
}
let mut done = false;
let mut x = len;
let mut count = 0;
while !done {
let mut byte = (x % 128) as u8;
x /= 128;
if x > 0 {
byte |= 128;
}
stream.put_u8(byte);
count += 1;
done = x == 0;
}
Ok(count)
}
fn len_len(len: usize) -> usize {
if len >= 2_097_152 {
4
} else if len >= 16_384 {
3
} else if len >= 128 {
2
} else {
1
}
}
fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
if stream.len() < 2 {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u16())
}
fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
if stream.is_empty() {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u8())
}
fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
if stream.len() < 4 {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u32())
}
#[derive(Debug, Clone)]
pub struct V5;
impl Protocol for V5 {
fn read_mut(&mut self, stream: &mut BytesMut, max_size: usize) -> Result<Packet, Error> {
let fixed_header = check(stream.iter(), max_size)?;
let packet = stream.split_to(fixed_header.frame_length());
let packet_type = fixed_header.packet_type()?;
if fixed_header.remaining_len == 0 {
return match packet_type {
PacketType::PingReq => Ok(Packet::PingReq(PingReq)),
PacketType::PingResp => Ok(Packet::PingResp(PingResp)),
PacketType::Disconnect => Ok(Packet::Disconnect(
Disconnect {
reason_code: DisconnectReasonCode::NormalDisconnection,
},
None,
)),
_ => Err(Error::PayloadRequired),
};
}
let packet = packet.freeze();
let packet = match packet_type {
PacketType::Connect => {
let (connect, properties, will, willproperties, login) =
connect::read(fixed_header, packet)?;
Packet::Connect(connect, properties, will, willproperties, login)
}
PacketType::Publish => {
let (publish, properties) = publish::read(fixed_header, packet)?;
Packet::Publish(publish, properties)
}
PacketType::PubAck => {
let (puback, properties) = puback::read(fixed_header, packet)?;
Packet::PubAck(puback, properties)
}
PacketType::Subscribe => {
let (subscribe, properties) = subscribe::read(fixed_header, packet)?;
Packet::Subscribe(subscribe, properties)
}
PacketType::SubAck => {
let (suback, properties) = suback::read(fixed_header, packet)?;
Packet::SubAck(suback, properties)
}
PacketType::Unsubscribe => {
let (unsubscribe, properties) = unsubscribe::read(fixed_header, packet)?;
Packet::Unsubscribe(unsubscribe, properties)
}
PacketType::PingReq => Packet::PingReq(PingReq),
PacketType::PingResp => Packet::PingResp(PingResp),
PacketType::Disconnect => {
let (disconnect, properties) = disconnect::read(fixed_header, packet)?;
Packet::Disconnect(disconnect, properties)
}
PacketType::PubRec => {
let (pubrec, properties) = pubrec::read(fixed_header, packet)?;
Packet::PubRec(pubrec, properties)
}
PacketType::PubRel => {
let (pubrel, properties) = pubrel::read(fixed_header, packet)?;
Packet::PubRel(pubrel, properties)
}
PacketType::PubComp => {
let (pubcomp, properties) = pubcomp::read(fixed_header, packet)?;
Packet::PubComp(pubcomp, properties)
}
_ => unreachable!(),
};
Ok(packet)
}
fn write(&self, packet: Packet, buffer: &mut BytesMut) -> Result<usize, Error> {
let size = match packet {
Packet::Connect(
connect,
connect_properties,
last_will,
last_will_properties,
login,
) => connect::write(
&connect,
&connect_properties,
&last_will,
&last_will_properties,
&login,
buffer,
)?,
Packet::ConnAck(connack, properties) => connack::write(&connack, &properties, buffer)?,
Packet::Publish(publish, properties) => publish::write(&publish, &properties, buffer)?,
Packet::PubAck(puback, properties) => puback::write(&puback, &properties, buffer)?,
Packet::Subscribe(subscribe, properties) => {
subscribe::write(&subscribe, &properties, buffer)?
}
Packet::SubAck(suback, properties) => suback::write(&suback, &properties, buffer)?,
Packet::PubRec(pubrec, properties) => pubrec::write(&pubrec, &properties, buffer)?,
Packet::PubRel(pubrel, properties) => pubrel::write(&pubrel, &properties, buffer)?,
Packet::PubComp(pubcomp, properties) => pubcomp::write(&pubcomp, &properties, buffer)?,
Packet::Unsubscribe(unsubscribe, properties) => {
unsubscribe::write(&unsubscribe, &properties, buffer)?
}
Packet::UnsubAck(unsuback, properties) => {
unsuback::write(&unsuback, &properties, buffer)?
}
Packet::Disconnect(disconnect, properties) => {
disconnect::write(&disconnect, &properties, buffer)?
}
Packet::PingReq(pingreq) => ping::pingreq::write(buffer)?,
Packet::PingResp(pingresp) => ping::pingresp::write(buffer)?,
_ => unreachable!(),
};
Ok(size)
}
}