use crate::Notification;
use std::{collections::VecDeque, result::Result, time::Instant};
use rumq_core::mqtt4::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MqttConnectionStatus {
Handshake,
Connected,
Disconnecting,
Disconnected,
}
#[derive(Debug, thiserror::Error)]
pub enum StateError {
#[error("Connect return code `{0:?}`")]
Connect(ConnectReturnCode),
#[error("Invalid state for a given operation")]
InvalidState,
#[error("Received a packet (ack) which isn't asked for")]
Unsolicited,
#[error("Last pingreq isn't acked")]
AwaitPingResp,
#[error("Received a wrong packet while waiting for another packet")]
WrongPacket,
}
#[derive(Debug, Clone)]
pub struct MqttState {
pub connection_status: MqttConnectionStatus,
pub await_pingresp: bool,
pub last_incoming: Instant,
pub last_outgoing: Instant,
pub last_pkid: PacketIdentifier,
pub outgoing_pub: VecDeque<Publish>,
pub outgoing_rel: VecDeque<PacketIdentifier>,
pub incoming_pub: VecDeque<PacketIdentifier>,
}
impl MqttState {
pub fn new() -> Self {
MqttState {
connection_status: MqttConnectionStatus::Disconnected,
await_pingresp: false,
last_incoming: Instant::now(),
last_outgoing: Instant::now(),
last_pkid: PacketIdentifier(0),
outgoing_pub: VecDeque::new(),
outgoing_rel: VecDeque::new(),
incoming_pub: VecDeque::new(),
}
}
pub(crate) fn handle_outgoing_packet(&mut self, packet: Packet) -> Result<(Option<Notification>, Option<Packet>), StateError> {
let out = match packet {
Packet::Publish(publish) => self.handle_outgoing_publish(publish)?,
Packet::Subscribe(subscribe) => self.handle_outgoing_subscribe(subscribe)?,
Packet::Pingreq => self.handle_outgoing_ping()?,
_ => unimplemented!(),
};
self.last_outgoing = Instant::now();
let request = Some(out);
let notification = None;
Ok((notification, request))
}
pub(crate) fn handle_incoming_packet(&mut self, packet: Packet) -> Result<(Option<Notification>, Option<Packet>), StateError> {
let out = match packet {
Packet::Pingresp => self.handle_incoming_pingresp(),
Packet::Publish(publish) => self.handle_incoming_publish(publish.clone()),
Packet::Suback(suback) => self.handle_incoming_suback(suback),
Packet::Unsuback(pkid) => self.handle_incoming_unsuback(pkid),
Packet::Puback(pkid) => self.handle_incoming_puback(pkid),
Packet::Pubrec(pkid) => self.handle_incoming_pubrec(pkid),
Packet::Pubrel(pkid) => self.handle_incoming_pubrel(pkid),
Packet::Pubcomp(pkid) => self.handle_incoming_pubcomp(pkid),
_ => {
error!("Invalid incoming paket = {:?}", packet);
Ok((None, None))
}
};
self.last_incoming = Instant::now();
out
}
fn handle_outgoing_publish(&mut self, publish: Publish) -> Result<Packet, StateError> {
let publish = match publish.qos {
QoS::AtMostOnce => publish,
QoS::AtLeastOnce | QoS::ExactlyOnce => self.add_packet_id_and_save(publish),
};
debug!(
"Publish. Topic = {:?}, Pkid = {:?}, Payload Size = {:?}",
publish.topic_name,
publish.pkid,
publish.payload.len()
);
Ok(Packet::Publish(publish))
}
fn handle_incoming_puback(&mut self, pkid: PacketIdentifier) -> Result<(Option<Notification>, Option<Packet>), StateError> {
match self.outgoing_pub.iter().position(|x| x.pkid == Some(pkid)) {
Some(index) => {
let _publish = self.outgoing_pub.remove(index).expect("Wrong index");
let request = None;
let notification = Some(Notification::Puback(pkid));
Ok((notification, request))
}
None => {
error!("Unsolicited puback packet: {:?}", pkid);
Err(StateError::Unsolicited)
}
}
}
fn handle_incoming_suback(&mut self, suback: Suback) -> Result<(Option<Notification>, Option<Packet>), StateError> {
let request = None;
let notification = Some(Notification::Suback(suback));
Ok((notification, request))
}
fn handle_incoming_unsuback(&mut self, pkid: PacketIdentifier) -> Result<(Option<Notification>, Option<Packet>), StateError> {
let request = None;
let notification = Some(Notification::Unsuback(pkid));
Ok((notification, request))
}
fn handle_incoming_pubrec(&mut self, pkid: PacketIdentifier) -> Result<(Option<Notification>, Option<Packet>), StateError> {
match self.outgoing_pub.iter().position(|x| x.pkid == Some(pkid)) {
Some(index) => {
let _ = self.outgoing_pub.remove(index);
self.outgoing_rel.push_back(pkid);
let reply = Some(Packet::Pubrel(pkid));
let notification = Some(Notification::Pubrec(pkid));
Ok((notification, reply))
}
None => {
error!("Unsolicited pubrec packet: {:?}", pkid);
Err(StateError::Unsolicited)
}
}
}
fn handle_incoming_publish(&mut self, publish: Publish) -> Result<(Option<Notification>, Option<Packet>), StateError> {
let qos = publish.qos;
match qos {
QoS::AtMostOnce => {
let notification = Notification::Publish(publish);
Ok((Some(notification), None))
}
QoS::AtLeastOnce => {
let pkid = publish.pkid.unwrap();
let request = Packet::Puback(pkid);
let notification = Notification::Publish(publish);
Ok((Some(notification), Some(request)))
}
QoS::ExactlyOnce => {
let pkid = publish.pkid.unwrap();
let reply = Packet::Pubrec(pkid);
let notification = Notification::Publish(publish);
self.incoming_pub.push_back(pkid);
Ok((Some(notification), Some(reply)))
}
}
}
fn handle_incoming_pubrel(&mut self, pkid: PacketIdentifier) -> Result<(Option<Notification>, Option<Packet>), StateError> {
match self.incoming_pub.iter().position(|x| *x == pkid) {
Some(index) => {
let _ = self.incoming_pub.remove(index);
let reply = Packet::Pubcomp(pkid);
Ok((None, Some(reply)))
}
None => {
error!("Unsolicited pubrel packet: {:?}", pkid);
Err(StateError::Unsolicited)
}
}
}
fn handle_incoming_pubcomp(&mut self, pkid: PacketIdentifier) -> Result<(Option<Notification>, Option<Packet>), StateError> {
match self.outgoing_rel.iter().position(|x| *x == pkid) {
Some(index) => {
self.outgoing_rel.remove(index).expect("Wrong index");
let notification = Some(Notification::Pubcomp(pkid));
let reply = None;
Ok((notification, reply))
}
_ => {
error!("Unsolicited pubcomp packet: {:?}", pkid);
Err(StateError::Unsolicited)
}
}
}
fn handle_outgoing_ping(&mut self) -> Result<Packet, StateError> {
let elapsed_in = self.last_incoming.elapsed();
let elapsed_out = self.last_outgoing.elapsed();
if self.await_pingresp {
error!("Error awaiting for last ping response");
return Err(StateError::AwaitPingResp);
}
self.await_pingresp = true;
debug!(
"Pingreq,
last incoming packet before {} millisecs,
last outgoing request before {} millisecs",
elapsed_in.as_millis(),
elapsed_out.as_millis()
);
Ok(Packet::Pingreq)
}
fn handle_incoming_pingresp(&mut self) -> Result<(Option<Notification>, Option<Packet>), StateError> {
self.await_pingresp = false;
Ok((None, None))
}
fn handle_outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<Packet, StateError> {
let pkid = self.next_pkid();
subscription.pkid = pkid;
debug!("Subscribe. Topics = {:?}, Pkid = {:?}", subscription.topics, subscription.pkid);
Ok(Packet::Subscribe(subscription))
}
pub fn handle_outgoing_connect(&mut self) -> Result<(), StateError> {
self.connection_status = MqttConnectionStatus::Handshake;
Ok(())
}
pub fn handle_incoming_connack(&mut self, packet: Packet) -> Result<(), StateError> {
let connack = match packet {
Packet::Connack(connack) => connack,
packet => {
error!("Invalid packet. Expecting connack. Received = {:?}", packet);
self.connection_status = MqttConnectionStatus::Disconnected;
return Err(StateError::WrongPacket);
}
};
match connack.code {
ConnectReturnCode::Accepted if self.connection_status == MqttConnectionStatus::Handshake => {
self.connection_status = MqttConnectionStatus::Connected;
Ok(())
}
ConnectReturnCode::Accepted if self.connection_status != MqttConnectionStatus::Handshake => {
error!(
"Invalid state. Expected = {:?}, Current = {:?}",
MqttConnectionStatus::Handshake,
self.connection_status
);
self.connection_status = MqttConnectionStatus::Disconnected;
Err(StateError::InvalidState)
}
code => {
error!("Connection failed. Connection error = {:?}", code);
self.connection_status = MqttConnectionStatus::Disconnected;
Err(StateError::Connect(code))
}
}
}
fn add_packet_id_and_save(&mut self, mut publish: Publish) -> Publish {
let publish = match publish.pkid {
Some(PacketIdentifier(0)) | None => {
let pkid = self.next_pkid();
publish.set_pkid(pkid);
publish
}
_ => publish,
};
self.outgoing_pub.push_back(publish.clone());
publish
}
fn next_pkid(&mut self) -> PacketIdentifier {
let PacketIdentifier(mut pkid) = self.last_pkid;
if pkid == 65_535 {
pkid = 0;
}
self.last_pkid = PacketIdentifier(pkid + 1);
self.last_pkid
}
}
#[cfg(test)]
mod test {
use super::{MqttConnectionStatus, MqttState, Packet, StateError};
use crate::{MqttOptions, Notification};
use rumq_core::mqtt4::*;
fn build_outgoing_publish(qos: QoS) -> Publish {
let topic = "hello/world".to_owned();
let payload = vec![1, 2, 3];
let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload);
publish.qos = qos;
publish
}
fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish {
let topic = "hello/world".to_owned();
let payload = vec![1, 2, 3];
let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload);
publish.pkid = Some(PacketIdentifier(pkid));
publish.qos = qos;
publish
}
fn build_mqttstate() -> MqttState {
MqttState::new()
}
#[test]
fn next_pkid_roll() {
let mut mqtt = build_mqttstate();
let mut pkt_id = PacketIdentifier(0);
for _ in 0..65536 {
pkt_id = mqtt.next_pkid();
}
assert_eq!(PacketIdentifier(1), pkt_id);
}
#[test]
fn outgoing_publish_handle_should_set_pkid_correctly_and_add_publish_to_queue_correctly() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::AtMostOnce);
let publish_out = match mqtt.handle_outgoing_publish(publish) {
Ok(Packet::Publish(p)) => p,
_ => panic!("Invalid packet. Should've been a publish packet"),
};
assert_eq!(publish_out.pkid, None);
assert_eq!(mqtt.outgoing_pub.len(), 0);
let publish = build_outgoing_publish(QoS::AtLeastOnce);
let publish_out = match mqtt.handle_outgoing_publish(publish.clone()) {
Ok(Packet::Publish(p)) => p,
_ => panic!("Invalid packet. Should've been a publish packet"),
};
assert_eq!(publish_out.pkid, Some(PacketIdentifier(1)));
assert_eq!(mqtt.outgoing_pub.len(), 1);
let publish_out = match mqtt.handle_outgoing_publish(publish.clone()) {
Ok(Packet::Publish(p)) => p,
_ => panic!("Invalid packet. Should've been a publish packet"),
};
assert_eq!(publish_out.pkid, Some(PacketIdentifier(2)));
assert_eq!(mqtt.outgoing_pub.len(), 2);
let publish = build_outgoing_publish(QoS::ExactlyOnce);
let publish_out = match mqtt.handle_outgoing_publish(publish.clone()) {
Ok(Packet::Publish(p)) => p,
_ => panic!("Invalid packet. Should've been a publish packet"),
};
assert_eq!(publish_out.pkid, Some(PacketIdentifier(3)));
assert_eq!(mqtt.outgoing_pub.len(), 3);
let publish_out = match mqtt.handle_outgoing_publish(publish.clone()) {
Ok(Packet::Publish(p)) => p,
_ => panic!("Invalid packet. Should've been a publish packet"),
};
assert_eq!(publish_out.pkid, Some(PacketIdentifier(4)));
assert_eq!(mqtt.outgoing_pub.len(), 4);
}
#[test]
fn incoming_publish_should_be_added_to_queue_correctly() {
let mut mqtt = build_mqttstate();
let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
mqtt.handle_incoming_publish(publish1).unwrap();
mqtt.handle_incoming_publish(publish2).unwrap();
mqtt.handle_incoming_publish(publish3).unwrap();
let pkid = *mqtt.incoming_pub.get(0).unwrap();
assert_eq!(mqtt.incoming_pub.len(), 1);
assert_eq!(pkid, PacketIdentifier(3));
}
#[test]
fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
let mut mqtt = build_mqttstate();
let publish = build_incoming_publish(QoS::ExactlyOnce, 1);
let (notification, request) = mqtt.handle_incoming_publish(publish).unwrap();
match notification {
Some(Notification::Publish(publish)) => assert_eq!(publish.pkid.unwrap(), PacketIdentifier(1)),
_ => panic!("Invalid notification: {:?}", notification),
}
match request {
Some(Packet::Pubrec(PacketIdentifier(pkid))) => assert_eq!(pkid, 1),
_ => panic!("Invalid network request: {:?}", request),
}
}
#[test]
fn incoming_puback_should_remove_correct_publish_from_queue() {
let mut mqtt = build_mqttstate();
let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.handle_outgoing_publish(publish1).unwrap();
mqtt.handle_outgoing_publish(publish2).unwrap();
mqtt.handle_incoming_puback(PacketIdentifier(1)).unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 1);
let backup = mqtt.outgoing_pub.get(0).clone();
assert_eq!(backup.unwrap().pkid, Some(PacketIdentifier(2)));
mqtt.handle_incoming_puback(PacketIdentifier(2)).unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 0);
}
#[test]
fn incoming_pubrec_should_release_correct_publish_from_queue_and_add_releaseid_to_rel_queue() {
let mut mqtt = build_mqttstate();
let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
let _publish_out = mqtt.handle_outgoing_publish(publish1);
let _publish_out = mqtt.handle_outgoing_publish(publish2);
mqtt.handle_incoming_pubrec(PacketIdentifier(2)).unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 1);
let backup = mqtt.outgoing_pub.get(0).clone();
assert_eq!(backup.unwrap().pkid, Some(PacketIdentifier(1)));
assert_eq!(mqtt.outgoing_rel.len(), 1);
let pkid = *mqtt.outgoing_rel.get(0).unwrap();
assert_eq!(pkid, PacketIdentifier(2));
}
#[test]
fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.handle_outgoing_publish(publish).unwrap();
let (notification, request) = mqtt.handle_incoming_pubrec(PacketIdentifier(1)).unwrap();
match notification {
Some(Notification::Pubrec(PacketIdentifier(id))) => assert_eq!(id, 1),
_ => panic!("Invalid notification"),
}
match request {
Some(Packet::Pubrel(PacketIdentifier(pkid))) => assert_eq!(pkid, 1),
_ => panic!("Invalid network request: {:?}", request),
}
}
#[test]
fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() {
let mut mqtt = build_mqttstate();
let publish = build_incoming_publish(QoS::ExactlyOnce, 1);
mqtt.handle_incoming_publish(publish).unwrap();
println!("{:?}", mqtt);
let (notification, request) = mqtt.handle_incoming_pubrel(PacketIdentifier(1)).unwrap();
match notification {
None => assert!(true),
_ => panic!("Invalid notification: {:?}", notification),
}
match request {
Some(Packet::Pubcomp(PacketIdentifier(pkid))) => assert_eq!(pkid, 1),
_ => panic!("Invalid network request: {:?}", request),
}
}
#[test]
fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.handle_outgoing_publish(publish).unwrap();
mqtt.handle_incoming_pubrec(PacketIdentifier(1)).unwrap();
println!("{:?}", mqtt);
mqtt.handle_incoming_pubcomp(PacketIdentifier(1)).unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 0);
}
#[test]
fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() {
let mut mqtt = build_mqttstate();
let mut opts = MqttOptions::new("test", "localhost", 1883);
opts.set_keep_alive(10);
mqtt.connection_status = MqttConnectionStatus::Connected;
mqtt.handle_outgoing_ping().unwrap();
let publish = build_outgoing_publish(QoS::AtLeastOnce);
mqtt.handle_outgoing_packet(Packet::Publish(publish)).unwrap();
mqtt.handle_incoming_packet(Packet::Puback(PacketIdentifier(1))).unwrap();
match mqtt.handle_outgoing_ping() {
Ok(_) => panic!("Should throw pingresp await error"),
Err(StateError::AwaitPingResp) => (),
Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e),
}
}
#[test]
fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() {
let mut mqtt = build_mqttstate();
let mut opts = MqttOptions::new("test", "localhost", 1883);
opts.set_keep_alive(10);
mqtt.connection_status = MqttConnectionStatus::Connected;
mqtt.handle_outgoing_ping().unwrap();
mqtt.handle_incoming_packet(Packet::Pingresp).unwrap();
mqtt.handle_outgoing_ping().unwrap();
}
}