use crate::notice::{
PublishNoticeTx, PublishResult, SubscribeNoticeTx, TrackedNoticeTx, UnsubscribeNoticeTx,
};
use crate::{Event, Incoming, NoticeFailureReason, Outgoing, PublishNoticeError, Request};
use crate::mqttbytes::v4::{
Packet, PubAck, PubComp, PubRec, PubRel, Publish, SubAck, Subscribe, UnsubAck, Unsubscribe,
};
use crate::mqttbytes::{self, QoS};
use fixedbitset::FixedBitSet;
use std::collections::{BTreeMap, VecDeque};
use std::{io, time::Instant};
#[derive(Debug, thiserror::Error)]
pub enum StateError {
#[error("Io error: {0:?}")]
Io(#[from] io::Error),
#[error("Invalid state for a given operation")]
InvalidState,
#[error("Received unsolicited ack pkid: {0}")]
Unsolicited(u16),
#[error("Last pingreq isn't acked")]
AwaitPingResp,
#[error("Received a wrong packet while waiting for another packet")]
WrongPacket,
#[error("Timeout while waiting to resolve collision")]
CollisionTimeout,
#[error("A Subscribe packet must contain atleast one filter")]
EmptySubscription,
#[error("Mqtt serialization/deserialization error: {0}")]
Deserialization(#[from] mqttbytes::Error),
#[error("Connection closed by peer abruptly")]
ConnectionAborted,
}
#[derive(Debug)]
pub struct MqttState {
pub await_pingresp: bool,
pub collision_ping_count: usize,
last_incoming: Instant,
last_outgoing: Instant,
pub(crate) last_pkid: u16,
pub(crate) last_puback: u16,
pub(crate) inflight: u16,
pub(crate) max_inflight: u16,
pub(crate) outgoing_pub: Vec<Option<Publish>>,
pub(crate) outgoing_pub_notice: Vec<Option<PublishNoticeTx>>,
pub(crate) outgoing_pub_ack: FixedBitSet,
pub(crate) outgoing_rel: FixedBitSet,
pub(crate) outgoing_rel_notice: Vec<Option<PublishNoticeTx>>,
pub(crate) incoming_pub: FixedBitSet,
pub collision: Option<Publish>,
pub(crate) collision_notice: Option<PublishNoticeTx>,
pub(crate) tracked_subscribe: BTreeMap<u16, (Subscribe, SubscribeNoticeTx)>,
pub(crate) tracked_unsubscribe: BTreeMap<u16, (Unsubscribe, UnsubscribeNoticeTx)>,
pub events: VecDeque<Event>,
pub manual_acks: bool,
}
#[derive(Debug)]
pub struct MqttStateBuilder {
max_inflight: u16,
manual_acks: bool,
}
impl MqttStateBuilder {
#[must_use]
pub const fn new(max_inflight: u16) -> Self {
Self {
max_inflight,
manual_acks: false,
}
}
#[must_use]
pub const fn manual_acks(mut self, manual_acks: bool) -> Self {
self.manual_acks = manual_acks;
self
}
#[must_use]
pub fn build(self) -> MqttState {
MqttState::new_internal(self.max_inflight, self.manual_acks)
}
}
impl MqttState {
const WARM_TRACKING_SLOTS: usize = 32;
const fn initial_events_capacity() -> usize {
128
}
const fn outgoing_tracking_len(max_inflight: u16) -> usize {
max_inflight as usize + 1
}
const fn warm_tracking_len(max_inflight: u16) -> usize {
let full_len = Self::outgoing_tracking_len(max_inflight);
let warm_len = Self::WARM_TRACKING_SLOTS + 1;
if full_len < warm_len {
full_len
} else {
warm_len
}
}
fn new_notice_slots_with_len(len: usize) -> Vec<Option<PublishNoticeTx>> {
std::iter::repeat_with(|| None).take(len).collect()
}
fn ensure_outgoing_tracking_capacity(&mut self, target_len: usize) {
if self.outgoing_pub.len() < target_len {
self.outgoing_pub.resize_with(target_len, || None);
}
if self.outgoing_pub_notice.len() < target_len {
self.outgoing_pub_notice.resize_with(target_len, || None);
}
if self.outgoing_rel_notice.len() < target_len {
self.outgoing_rel_notice.resize_with(target_len, || None);
}
if self.outgoing_pub_ack.len() < target_len {
self.outgoing_pub_ack.grow(target_len);
}
if self.outgoing_rel.len() < target_len {
self.outgoing_rel.grow(target_len);
}
}
pub(crate) fn outbound_requests_drained(&self) -> bool {
self.inflight == 0
&& self.collision.is_none()
&& self.collision_notice.is_none()
&& self.tracked_subscribe.is_empty()
&& self.tracked_unsubscribe.is_empty()
&& self.outgoing_pub.iter().all(Option::is_none)
&& self.outgoing_pub_notice.iter().all(Option::is_none)
&& self.outgoing_rel_notice.iter().all(Option::is_none)
&& self.outgoing_pub_ack.ones().next().is_none()
&& self.outgoing_rel.ones().next().is_none()
}
fn maybe_shrink_outgoing_tracking_capacity(&mut self) {
let target_len = Self::warm_tracking_len(self.max_inflight);
if self.outgoing_pub.len() <= target_len || !self.outbound_requests_drained() {
return;
}
self.outgoing_pub.truncate(target_len);
self.outgoing_pub_notice.truncate(target_len);
self.outgoing_rel_notice.truncate(target_len);
self.outgoing_pub_ack = FixedBitSet::with_capacity(target_len);
self.outgoing_rel = FixedBitSet::with_capacity(target_len);
self.last_pkid = 0;
self.last_puback = 0;
}
const fn validate_outgoing_pkid_bound(&self, pkid: u16) -> Result<(), StateError> {
if pkid == 0 || pkid > self.max_inflight {
return Err(StateError::Unsolicited(pkid));
}
Ok(())
}
const fn next_publish_pkid_after(&self, pkid: u16) -> u16 {
if pkid >= self.max_inflight {
1
} else {
pkid + 1
}
}
fn packet_identifier_in_use(&self, pkid: u16) -> bool {
let index = usize::from(pkid);
self.outgoing_pub.get(index).is_some_and(Option::is_some)
|| self.outgoing_rel.contains(index)
|| self.tracked_subscribe.contains_key(&pkid)
|| self.tracked_unsubscribe.contains_key(&pkid)
}
pub(crate) fn can_send_publish(&self, publish: &Publish) -> bool {
if publish.qos == QoS::AtMostOnce {
return true;
}
if self.inflight >= self.max_inflight || self.collision.is_some() {
return false;
}
if publish.pkid == 0 {
return self.next_publish_pkid().is_some();
}
self.validate_outgoing_pkid_bound(publish.pkid).is_ok()
&& !self.packet_identifier_in_use(publish.pkid)
}
pub(crate) fn control_packet_identifier_available(&self) -> bool {
(1..=u16::MAX).any(|pkid| !self.packet_identifier_in_use(pkid))
}
fn clean_pending_capacity(&self) -> usize {
self.outgoing_pub
.iter()
.filter(|publish| publish.is_some())
.count()
+ self.outgoing_rel.ones().count()
+ self.tracked_subscribe.len()
+ self.tracked_unsubscribe.len()
}
#[must_use]
pub const fn builder(max_inflight: u16) -> MqttStateBuilder {
MqttStateBuilder::new(max_inflight)
}
#[must_use]
pub(crate) fn new_internal(max_inflight: u16, manual_acks: bool) -> Self {
let tracking_len = Self::warm_tracking_len(max_inflight);
Self {
await_pingresp: false,
collision_ping_count: 0,
last_incoming: Instant::now(),
last_outgoing: Instant::now(),
last_pkid: 0,
last_puback: 0,
inflight: 0,
max_inflight,
outgoing_pub: std::iter::repeat_with(|| None).take(tracking_len).collect(),
outgoing_pub_notice: Self::new_notice_slots_with_len(tracking_len),
outgoing_pub_ack: FixedBitSet::with_capacity(tracking_len),
outgoing_rel: FixedBitSet::with_capacity(tracking_len),
outgoing_rel_notice: Self::new_notice_slots_with_len(tracking_len),
incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1),
collision: None,
collision_notice: None,
tracked_subscribe: BTreeMap::new(),
tracked_unsubscribe: BTreeMap::new(),
events: VecDeque::with_capacity(Self::initial_events_capacity()),
manual_acks,
}
}
pub(crate) fn clean_with_notices(&mut self) -> Vec<(Request, Option<TrackedNoticeTx>)> {
let mut pending = Vec::with_capacity(self.clean_pending_capacity());
let (first_half, second_half) = self
.outgoing_pub
.split_at_mut(self.last_puback as usize + 1);
let (notice_first_half, notice_second_half) = self
.outgoing_pub_notice
.split_at_mut(self.last_puback as usize + 1);
for (publish, notice) in second_half
.iter_mut()
.zip(notice_second_half.iter_mut())
.chain(first_half.iter_mut().zip(notice_first_half.iter_mut()))
{
if let Some(publish) = publish.take() {
let request = Request::Publish(publish);
pending.push((request, notice.take().map(TrackedNoticeTx::Publish)));
} else {
_ = notice.take();
}
}
for pkid in self.outgoing_rel.ones() {
let pkid = u16::try_from(pkid).expect("fixedbitset index always fits in u16");
let request = Request::PubRel(PubRel::new(pkid));
pending.push((
request,
self.outgoing_rel_notice[pkid as usize]
.take()
.map(TrackedNoticeTx::Publish),
));
}
self.outgoing_rel.clear();
self.outgoing_pub_ack.clear();
for (pkid, (mut subscribe, notice)) in std::mem::take(&mut self.tracked_subscribe) {
subscribe.pkid = pkid;
pending.push((
Request::Subscribe(subscribe),
Some(TrackedNoticeTx::Subscribe(notice)),
));
}
for (pkid, (mut unsubscribe, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
unsubscribe.pkid = pkid;
pending.push((
Request::Unsubscribe(unsubscribe),
Some(TrackedNoticeTx::Unsubscribe(notice)),
));
}
self.incoming_pub.clear();
self.await_pingresp = false;
self.collision_ping_count = 0;
self.inflight = 0;
if pending.is_empty() {
self.maybe_shrink_outgoing_tracking_capacity();
}
pending
}
pub fn clean(&mut self) -> Vec<Request> {
self.clean_with_notices()
.into_iter()
.map(|(request, _)| request)
.collect()
}
pub const fn inflight(&self) -> u16 {
self.inflight
}
pub fn tracked_subscribe_len(&self) -> usize {
self.tracked_subscribe.len()
}
pub fn tracked_unsubscribe_len(&self) -> usize {
self.tracked_unsubscribe.len()
}
pub fn tracked_requests_is_empty(&self) -> bool {
self.tracked_subscribe.is_empty() && self.tracked_unsubscribe.is_empty()
}
pub fn drain_tracked_requests_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
let mut drained = 0;
for (_, (_, notice)) in std::mem::take(&mut self.tracked_subscribe) {
drained += 1;
notice.error(reason.subscribe_error());
}
for (_, (_, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
drained += 1;
notice.error(reason.unsubscribe_error());
}
self.maybe_shrink_outgoing_tracking_capacity();
drained
}
pub(crate) fn fail_pending_notices(&mut self) {
for notice in &mut self.outgoing_pub_notice {
if let Some(tx) = notice.take() {
tx.error(PublishNoticeError::SessionReset);
}
}
for notice in &mut self.outgoing_rel_notice {
if let Some(tx) = notice.take() {
tx.error(PublishNoticeError::SessionReset);
}
}
if let Some(tx) = self.collision_notice.take() {
tx.error(PublishNoticeError::SessionReset);
}
self.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
self.clear_collision();
self.maybe_shrink_outgoing_tracking_capacity();
}
pub fn handle_outgoing_packet(
&mut self,
request: Request,
) -> Result<Option<Packet>, StateError> {
let (packet, flush_notice) = self.handle_outgoing_packet_with_notice(request, None)?;
if let Some(tx) = flush_notice {
tx.success(PublishResult::Qos0Flushed);
}
self.last_outgoing = Instant::now();
Ok(packet)
}
pub(crate) fn handle_outgoing_packet_with_notice(
&mut self,
request: Request,
notice: Option<TrackedNoticeTx>,
) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
let result =
match request {
Request::Publish(publish) => {
let publish_notice = match notice {
Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
Some(TrackedNoticeTx::Subscribe(_) | TrackedNoticeTx::Unsubscribe(_))
| None => None,
};
self.outgoing_publish_with_notice(publish, publish_notice)?
}
Request::PubRel(pubrel) => {
let publish_notice = match notice {
Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
Some(TrackedNoticeTx::Subscribe(_) | TrackedNoticeTx::Unsubscribe(_))
| None => None,
};
self.outgoing_pubrel_with_notice(pubrel, publish_notice)?
}
Request::Subscribe(subscribe) => {
let request_notice = match notice {
Some(TrackedNoticeTx::Subscribe(notice)) => Some(notice),
Some(TrackedNoticeTx::Publish(_) | TrackedNoticeTx::Unsubscribe(_))
| None => None,
};
(self.outgoing_subscribe(subscribe, request_notice)?, None)
}
Request::Unsubscribe(unsubscribe) => {
let request_notice = match notice {
Some(TrackedNoticeTx::Unsubscribe(notice)) => Some(notice),
Some(TrackedNoticeTx::Publish(_) | TrackedNoticeTx::Subscribe(_))
| None => None,
};
(
Some(self.outgoing_unsubscribe(unsubscribe, request_notice)?),
None,
)
}
Request::PingReq(_) => (self.outgoing_ping()?, None),
Request::Disconnect(_) | Request::DisconnectWithTimeout(_, _) => {
unreachable!("graceful disconnect requests are handled by the event loop")
}
Request::DisconnectNow(_) => (Some(self.outgoing_disconnect()), None),
Request::PubAck(puback) => (Some(self.outgoing_puback(puback)), None),
Request::PubRec(pubrec) => (Some(self.outgoing_pubrec(pubrec)), None),
_ => unimplemented!(),
};
self.last_outgoing = Instant::now();
Ok(result)
}
pub fn handle_incoming_packet(
&mut self,
packet: Incoming,
) -> Result<Option<Packet>, StateError> {
let events_len_before = self.events.len();
let outgoing = match &packet {
Incoming::PingResp => Ok(self.handle_incoming_pingresp()),
Incoming::Publish(publish) => Ok(self.handle_incoming_publish(publish)),
Incoming::SubAck(suback) => Ok(self.handle_incoming_suback(suback)),
Incoming::UnsubAck(unsuback) => Ok(self.handle_incoming_unsuback(unsuback)),
Incoming::PubAck(puback) => self.handle_incoming_puback(puback),
Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec),
Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel),
Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp),
_ => {
error!("Invalid incoming packet = {packet:?}");
Err(StateError::WrongPacket)
}
};
self.events
.insert(events_len_before, Event::Incoming(packet));
let outgoing = outgoing?;
self.last_incoming = Instant::now();
Ok(outgoing)
}
pub fn clear_collision(&mut self) {
self.collision = None;
self.collision_notice = None;
self.collision_ping_count = 0;
}
fn handle_incoming_suback(&mut self, suback: &SubAck) -> Option<Packet> {
if let Some((_, notice)) = self.tracked_subscribe.remove(&suback.pkid) {
notice.success(suback.clone());
}
None
}
fn handle_incoming_unsuback(&mut self, unsuback: &UnsubAck) -> Option<Packet> {
if let Some((_, notice)) = self.tracked_unsubscribe.remove(&unsuback.pkid) {
notice.success(unsuback.clone());
}
None
}
fn handle_incoming_publish(&mut self, publish: &Publish) -> Option<Packet> {
let qos = publish.qos;
match qos {
QoS::AtMostOnce => None,
QoS::AtLeastOnce => {
if !self.manual_acks {
let puback = PubAck::new(publish.pkid);
return Some(self.outgoing_puback(puback));
}
None
}
QoS::ExactlyOnce => {
let pkid = publish.pkid;
self.incoming_pub.insert(pkid as usize);
if !self.manual_acks {
let pubrec = PubRec::new(pkid);
return Some(self.outgoing_pubrec(pubrec));
}
None
}
}
}
fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<Option<Packet>, StateError> {
let publish = self
.outgoing_pub
.get_mut(puback.pkid as usize)
.ok_or(StateError::Unsolicited(puback.pkid))?;
if publish.take().is_none() {
error!("Unsolicited puback packet: {:?}", puback.pkid);
return Err(StateError::Unsolicited(puback.pkid));
}
self.mark_outgoing_packet_id_complete(puback.pkid);
if let Some(tx) = self.outgoing_pub_notice[puback.pkid as usize].take() {
tx.success(PublishResult::Qos1(puback.clone()));
}
self.inflight -= 1;
let packet = self.replay_collision_publish(puback.pkid);
if packet.is_none() {
self.maybe_shrink_outgoing_tracking_capacity();
}
Ok(packet)
}
fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<Option<Packet>, StateError> {
let publish = self
.outgoing_pub
.get_mut(pubrec.pkid as usize)
.ok_or(StateError::Unsolicited(pubrec.pkid))?;
if publish.take().is_none() {
error!("Unsolicited pubrec packet: {:?}", pubrec.pkid);
return Err(StateError::Unsolicited(pubrec.pkid));
}
let notice = self.outgoing_pub_notice[pubrec.pkid as usize].take();
self.outgoing_rel.insert(pubrec.pkid as usize);
self.outgoing_rel_notice[pubrec.pkid as usize] = notice;
let release = PubRel { pkid: pubrec.pkid };
let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
self.events.push_back(event);
Ok(Some(Packet::PubRel(release)))
}
fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<Option<Packet>, StateError> {
if !self.incoming_pub.contains(pubrel.pkid as usize) {
error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
return Err(StateError::Unsolicited(pubrel.pkid));
}
self.incoming_pub.set(pubrel.pkid as usize, false);
let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
let pubcomp = PubComp { pkid: pubrel.pkid };
self.events.push_back(event);
Ok(Some(Packet::PubComp(pubcomp)))
}
fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<Option<Packet>, StateError> {
if !self.outgoing_rel.contains(pubcomp.pkid as usize) {
error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
return Err(StateError::Unsolicited(pubcomp.pkid));
}
self.outgoing_rel.set(pubcomp.pkid as usize, false);
self.mark_outgoing_packet_id_complete(pubcomp.pkid);
if let Some(tx) = self.outgoing_rel_notice[pubcomp.pkid as usize].take() {
tx.success(PublishResult::Qos2Completed(pubcomp.clone()));
}
self.inflight -= 1;
let packet = self.replay_collision_publish(pubcomp.pkid);
if packet.is_none() {
self.maybe_shrink_outgoing_tracking_capacity();
}
Ok(packet)
}
const fn handle_incoming_pingresp(&mut self) -> Option<Packet> {
self.await_pingresp = false;
None
}
#[cfg(test)]
fn outgoing_publish(&mut self, publish: Publish) -> Result<Option<Packet>, StateError> {
let (packet, flush_notice) = self.outgoing_publish_with_notice(publish, None)?;
if let Some(tx) = flush_notice {
tx.success(PublishResult::Qos0Flushed);
}
Ok(packet)
}
fn outgoing_publish_with_notice(
&mut self,
mut publish: Publish,
notice: Option<PublishNoticeTx>,
) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
let mut notice = notice;
if publish.qos != QoS::AtMostOnce {
if publish.pkid == 0 {
publish.pkid = self.next_pkid();
}
let pkid = publish.pkid;
self.validate_outgoing_pkid_bound(pkid)?;
self.ensure_outgoing_tracking_capacity(pkid as usize + 1);
if self
.outgoing_pub
.get(publish.pkid as usize)
.ok_or(StateError::Unsolicited(publish.pkid))?
.is_some()
{
info!("Collision on packet id = {:?}", publish.pkid);
self.collision = Some(publish);
self.collision_notice = notice.take();
let event = Event::Outgoing(Outgoing::AwaitAck(pkid));
self.events.push_back(event);
return Ok((None, None));
}
self.outgoing_pub[pkid as usize] = Some(publish.clone());
self.outgoing_pub_notice[pkid as usize] = notice.take();
self.outgoing_pub_ack.set(pkid as usize, false);
self.inflight += 1;
}
debug!(
"Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}",
String::from_utf8_lossy(&publish.topic),
publish.pkid,
publish.payload.len()
);
let event = Event::Outgoing(Outgoing::Publish(publish.pkid));
self.events.push_back(event);
if publish.qos == QoS::AtMostOnce {
Ok((Some(Packet::Publish(publish)), notice.take()))
} else {
Ok((Some(Packet::Publish(publish)), None))
}
}
fn outgoing_pubrel_with_notice(
&mut self,
pubrel: PubRel,
notice: Option<PublishNoticeTx>,
) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
let pubrel = self.save_pubrel_with_notice(pubrel, notice)?;
debug!("Pubrel. Pkid = {}", pubrel.pkid);
let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid));
self.events.push_back(event);
Ok((Some(Packet::PubRel(pubrel)), None))
}
fn outgoing_puback(&mut self, puback: PubAck) -> Packet {
let event = Event::Outgoing(Outgoing::PubAck(puback.pkid));
self.events.push_back(event);
Packet::PubAck(puback)
}
fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Packet {
let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid));
self.events.push_back(event);
Packet::PubRec(pubrec)
}
fn outgoing_ping(&mut self) -> Result<Option<Packet>, StateError> {
let elapsed_in = self.last_incoming.elapsed();
let elapsed_out = self.last_outgoing.elapsed();
if self.collision.is_some() {
self.collision_ping_count += 1;
if self.collision_ping_count >= 2 {
return Err(StateError::CollisionTimeout);
}
}
if self.await_pingresp {
return Err(StateError::AwaitPingResp);
}
self.await_pingresp = true;
debug!(
"Pingreq,
last incoming packet before {} millisecs,
last outgoing request before {} millisecs",
elapsed_in.as_millis(),
elapsed_out.as_millis()
);
let event = Event::Outgoing(Outgoing::PingReq);
self.events.push_back(event);
Ok(Some(Packet::PingReq))
}
fn outgoing_subscribe(
&mut self,
mut subscription: Subscribe,
notice: Option<SubscribeNoticeTx>,
) -> Result<Option<Packet>, StateError> {
if subscription.filters.is_empty() {
return Err(StateError::EmptySubscription);
}
let pkid = self.next_control_pkid()?;
subscription.pkid = pkid;
debug!(
"Subscribe. Topics = {:?}, Pkid = {:?}",
subscription.filters, subscription.pkid
);
let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid));
self.events.push_back(event);
if let Some(notice) = notice {
self.tracked_subscribe
.insert(subscription.pkid, (subscription.clone(), notice));
}
Ok(Some(Packet::Subscribe(subscription)))
}
fn outgoing_unsubscribe(
&mut self,
mut unsub: Unsubscribe,
notice: Option<UnsubscribeNoticeTx>,
) -> Result<Packet, StateError> {
let pkid = self.next_control_pkid()?;
unsub.pkid = pkid;
debug!(
"Unsubscribe. Topics = {:?}, Pkid = {:?}",
unsub.topics, unsub.pkid
);
let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid));
self.events.push_back(event);
if let Some(notice) = notice {
self.tracked_unsubscribe
.insert(unsub.pkid, (unsub.clone(), notice));
}
Ok(Packet::Unsubscribe(unsub))
}
fn outgoing_disconnect(&mut self) -> Packet {
debug!("Disconnect");
let event = Event::Outgoing(Outgoing::Disconnect);
self.events.push_back(event);
Packet::Disconnect
}
fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option<PublishNoticeTx>)> {
if let Some(publish) = &self.collision
&& publish.pkid == pkid
{
return self
.collision
.take()
.map(|publish| (publish, self.collision_notice.take()));
}
None
}
fn save_pubrel_with_notice(
&mut self,
mut pubrel: PubRel,
notice: Option<PublishNoticeTx>,
) -> Result<PubRel, StateError> {
let pubrel = match pubrel.pkid {
0 => {
pubrel.pkid = self.next_pkid();
pubrel
}
_ => pubrel,
};
self.validate_outgoing_pkid_bound(pubrel.pkid)?;
self.ensure_outgoing_tracking_capacity(pubrel.pkid as usize + 1);
self.outgoing_rel.insert(pubrel.pkid as usize);
self.outgoing_rel_notice[pubrel.pkid as usize] = notice;
self.inflight += 1;
Ok(pubrel)
}
fn replay_collision_publish(&mut self, pkid: u16) -> Option<Packet> {
self.check_collision(pkid).map(|(publish, notice)| {
let publish_pkid = publish.pkid;
self.ensure_outgoing_tracking_capacity(publish_pkid as usize + 1);
self.outgoing_pub[publish_pkid as usize] = Some(publish.clone());
self.outgoing_pub_notice[publish_pkid as usize] = notice;
self.inflight += 1;
let event = Event::Outgoing(Outgoing::Publish(publish_pkid));
self.events.push_back(event);
self.collision_ping_count = 0;
Packet::Publish(publish)
})
}
fn mark_outgoing_packet_id_complete(&mut self, pkid: u16) {
self.outgoing_pub_ack.set(pkid as usize, true);
self.advance_last_puback_frontier();
}
fn advance_last_puback_frontier(&mut self) {
let mut next = self.next_puback_boundary_pkid(self.last_puback);
while next != 0 && self.outgoing_pub_ack.contains(next as usize) {
self.outgoing_pub_ack.set(next as usize, false);
self.last_puback = next;
next = self.next_puback_boundary_pkid(self.last_puback);
}
}
const fn next_puback_boundary_pkid(&self, pkid: u16) -> u16 {
if self.max_inflight == 0 {
return 0;
}
if pkid >= self.max_inflight {
1
} else {
pkid + 1
}
}
fn next_publish_pkid(&self) -> Option<u16> {
let mut pkid = self.next_publish_pkid_after(self.last_pkid);
for _ in 0..usize::from(self.max_inflight) {
if !self.packet_identifier_in_use(pkid) {
return Some(pkid);
}
pkid = self.next_publish_pkid_after(pkid);
}
None
}
fn next_pkid(&mut self) -> u16 {
let pkid = self
.next_publish_pkid()
.unwrap_or_else(|| self.next_publish_pkid_after(self.last_pkid));
if pkid == self.max_inflight {
self.last_pkid = 0;
} else {
self.last_pkid = pkid;
}
pkid
}
fn next_control_pkid(&mut self) -> Result<u16, StateError> {
for offset in 1..=u16::MAX {
let pkid = self.last_pkid.wrapping_add(offset);
if pkid != 0 && !self.packet_identifier_in_use(pkid) {
self.last_pkid = pkid;
return Ok(pkid);
}
}
Err(StateError::InvalidState)
}
}
impl Clone for MqttState {
fn clone(&self) -> Self {
let tracking_len = self.outgoing_pub_notice.len();
Self {
await_pingresp: self.await_pingresp,
collision_ping_count: self.collision_ping_count,
last_incoming: self.last_incoming,
last_outgoing: self.last_outgoing,
last_pkid: self.last_pkid,
last_puback: self.last_puback,
inflight: self.inflight,
max_inflight: self.max_inflight,
outgoing_pub: self.outgoing_pub.clone(),
outgoing_pub_notice: Self::new_notice_slots_with_len(tracking_len),
outgoing_pub_ack: self.outgoing_pub_ack.clone(),
outgoing_rel: self.outgoing_rel.clone(),
outgoing_rel_notice: Self::new_notice_slots_with_len(self.outgoing_rel_notice.len()),
incoming_pub: self.incoming_pub.clone(),
collision: self.collision.clone(),
collision_notice: None,
tracked_subscribe: BTreeMap::new(),
tracked_unsubscribe: BTreeMap::new(),
events: self.events.clone(),
manual_acks: self.manual_acks,
}
}
}
#[cfg(test)]
mod test {
use super::{MqttState, StateError};
use crate::mqttbytes::v4::*;
use crate::mqttbytes::*;
use crate::notice::{
PublishNoticeTx, PublishResult, SubscribeNoticeError, SubscribeNoticeTx,
UnsubscribeNoticeError, UnsubscribeNoticeTx,
};
use crate::{Event, Incoming, NoticeFailureReason, Outgoing, Request};
use bytes::Bytes;
fn build_outgoing_publish(qos: QoS) -> Publish {
let topic = "hello/world".to_owned();
let payload = vec![1, 2, 3];
let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload);
publish.qos = qos;
publish
}
fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish {
let topic = "hello/world".to_owned();
let payload = vec![1, 2, 3];
let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload);
publish.pkid = pkid;
publish.qos = qos;
publish
}
fn build_mqttstate() -> MqttState {
MqttState::builder(100).build()
}
fn queue_publish_with_notice(mqtt: &mut MqttState, publish: Publish) -> crate::PublishNotice {
let (tx, notice) = PublishNoticeTx::new();
let (packet, flush_notice) = mqtt
.outgoing_publish_with_notice(publish, Some(tx))
.unwrap();
assert!(packet.is_some());
assert!(flush_notice.is_none());
notice
}
#[test]
fn new_state_preallocates_event_queue_for_read_batch_bursts() {
let mqtt = MqttState::builder(10).build();
assert!(mqtt.events.capacity() >= MqttState::initial_events_capacity());
}
#[test]
fn new_state_uses_warm_tracking_floor() {
let mqtt = MqttState::builder(100).build();
assert_eq!(mqtt.outgoing_pub.len(), 33);
assert_eq!(mqtt.outgoing_pub_notice.len(), 33);
assert_eq!(mqtt.outgoing_rel_notice.len(), 33);
assert_eq!(mqtt.outgoing_pub_ack.len(), 33);
assert_eq!(mqtt.outgoing_rel.len(), 33);
}
#[test]
fn new_state_uses_full_tracking_len_when_max_inflight_is_below_warm_floor() {
let mqtt = MqttState::builder(10).build();
assert_eq!(mqtt.outgoing_pub.len(), 11);
assert_eq!(mqtt.outgoing_pub_notice.len(), 11);
assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
assert_eq!(mqtt.outgoing_pub_ack.len(), 11);
assert_eq!(mqtt.outgoing_rel.len(), 11);
}
#[test]
fn clean_pending_capacity_counts_publish_rel_and_tracked_requests() {
let mut mqtt = MqttState::builder(10).build();
mqtt.outgoing_pub[1] = Some(build_outgoing_publish(QoS::AtLeastOnce));
mqtt.outgoing_pub[2] = Some(build_outgoing_publish(QoS::ExactlyOnce));
mqtt.outgoing_rel.insert(3);
mqtt.outgoing_rel.insert(4);
let (sub_notice, _) = SubscribeNoticeTx::new();
mqtt.tracked_subscribe
.insert(5, (Subscribe::new("a/b", QoS::AtMostOnce), sub_notice));
let (unsub_notice, _) = UnsubscribeNoticeTx::new();
mqtt.tracked_unsubscribe
.insert(6, (Unsubscribe::new("a/b"), unsub_notice));
assert_eq!(mqtt.clean_pending_capacity(), 6);
}
#[test]
fn tracked_request_len_helpers_report_counts() {
let mut mqtt = MqttState::builder(10).build();
let (sub_notice, _) = SubscribeNoticeTx::new();
mqtt.tracked_subscribe
.insert(5, (Subscribe::new("a/b", QoS::AtMostOnce), sub_notice));
let (unsub_notice, _) = UnsubscribeNoticeTx::new();
mqtt.tracked_unsubscribe
.insert(6, (Unsubscribe::new("a/b"), unsub_notice));
assert_eq!(mqtt.tracked_subscribe_len(), 1);
assert_eq!(mqtt.tracked_unsubscribe_len(), 1);
assert!(!mqtt.tracked_requests_is_empty());
mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
assert!(mqtt.tracked_requests_is_empty());
}
#[test]
fn drain_tracked_requests_as_failed_reports_session_reset_and_returns_count() {
let mut mqtt = MqttState::builder(10).build();
let (sub_notice_tx, sub_notice) = SubscribeNoticeTx::new();
mqtt.tracked_subscribe
.insert(5, (Subscribe::new("a/b", QoS::AtMostOnce), sub_notice_tx));
let (unsub_notice_tx, unsub_notice) = UnsubscribeNoticeTx::new();
mqtt.tracked_unsubscribe
.insert(6, (Unsubscribe::new("a/b"), unsub_notice_tx));
let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
assert_eq!(drained, 2);
assert!(mqtt.tracked_requests_is_empty());
assert_eq!(
sub_notice.wait().unwrap_err(),
SubscribeNoticeError::SessionReset
);
assert_eq!(
unsub_notice.wait().unwrap_err(),
UnsubscribeNoticeError::SessionReset
);
}
#[test]
fn drain_tracked_requests_as_failed_is_noop_when_empty() {
let mut mqtt = MqttState::builder(10).build();
let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
assert_eq!(drained, 0);
assert!(mqtt.tracked_requests_is_empty());
}
#[test]
fn tracked_puback_returns_ack_and_preserves_incoming_event() {
let mut mqtt = build_mqttstate();
let notice = queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::AtLeastOnce));
mqtt.events.clear();
let puback = PubAck::new(1);
assert!(
mqtt.handle_incoming_packet(Incoming::PubAck(puback.clone()))
.unwrap()
.is_none()
);
assert_eq!(notice.wait(), Ok(PublishResult::Qos1(puback.clone())));
assert_eq!(
mqtt.events.pop_front(),
Some(Event::Incoming(Packet::PubAck(puback)))
);
}
#[test]
fn tracked_suback_returns_ack_and_preserves_incoming_event() {
let mut mqtt = build_mqttstate();
let (tx, notice) = SubscribeNoticeTx::new();
mqtt.outgoing_subscribe(Subscribe::new("a/b", QoS::AtMostOnce), Some(tx))
.unwrap();
mqtt.events.clear();
let suback = SubAck::new(1, vec![SubscribeReasonCode::Failure]);
assert!(
mqtt.handle_incoming_packet(Incoming::SubAck(suback.clone()))
.unwrap()
.is_none()
);
assert_eq!(notice.wait(), Ok(suback.clone()));
assert_eq!(
mqtt.events.pop_front(),
Some(Event::Incoming(Packet::SubAck(suback)))
);
}
#[test]
fn tracked_unsuback_returns_ack_and_preserves_incoming_event() {
let mut mqtt = build_mqttstate();
let (tx, notice) = UnsubscribeNoticeTx::new();
mqtt.outgoing_unsubscribe(Unsubscribe::new("a/b"), Some(tx))
.unwrap();
mqtt.events.clear();
let unsuback = UnsubAck::new(1);
assert!(
mqtt.handle_incoming_packet(Incoming::UnsubAck(unsuback.clone()))
.unwrap()
.is_none()
);
assert_eq!(notice.wait(), Ok(unsuback.clone()));
assert_eq!(
mqtt.events.pop_front(),
Some(Event::Incoming(Packet::UnsubAck(unsuback)))
);
}
#[test]
fn outgoing_publish_grows_tracking_capacity_on_demand() {
let mut mqtt = build_mqttstate();
mqtt.last_pkid = 32;
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 34);
assert_eq!(mqtt.outgoing_pub_notice.len(), 34);
assert_eq!(mqtt.outgoing_rel_notice.len(), 34);
assert_eq!(mqtt.outgoing_pub_ack.len(), 34);
assert_eq!(mqtt.outgoing_rel.len(), 34);
assert!(mqtt.outgoing_pub[33].is_some());
}
#[test]
fn incoming_puback_shrinks_tracking_when_state_becomes_empty() {
let mut mqtt = build_mqttstate();
mqtt.last_pkid = 32;
mqtt.last_puback = 32;
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 34);
mqtt.handle_incoming_puback(&PubAck::new(33)).unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 33);
assert_eq!(mqtt.outgoing_pub_notice.len(), 33);
assert_eq!(mqtt.outgoing_rel_notice.len(), 33);
assert_eq!(mqtt.outgoing_pub_ack.len(), 33);
assert_eq!(mqtt.outgoing_rel.len(), 33);
assert_eq!(mqtt.last_pkid, 0);
assert_eq!(mqtt.last_puback, 0);
}
#[test]
fn incoming_puback_does_not_shrink_tracking_when_state_is_non_empty() {
let mut mqtt = build_mqttstate();
mqtt.last_pkid = 32;
mqtt.last_puback = 32;
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 35);
mqtt.handle_incoming_puback(&PubAck::new(33)).unwrap();
assert_eq!(mqtt.outgoing_pub.len(), 35);
assert_eq!(mqtt.inflight, 1);
}
#[test]
fn clean_preserves_packet_id_frontier_when_pending_state_is_exported() {
let mut mqtt = build_mqttstate();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
assert_eq!(mqtt.last_pkid, 2);
let pending = mqtt.clean();
assert_eq!(pending.len(), 2);
assert_eq!(mqtt.last_pkid, 2);
assert_eq!(mqtt.last_puback, 0);
for request in pending {
let packet = mqtt.handle_outgoing_packet(request).unwrap().unwrap();
match packet {
Packet::Publish(publish) => assert!(matches!(publish.pkid, 1 | 2)),
packet => panic!("Unexpected replay packet: {packet:?}"),
}
}
let packet = mqtt
.handle_outgoing_packet(Request::Publish(build_outgoing_publish(QoS::AtLeastOnce)))
.unwrap()
.unwrap();
match packet {
Packet::Publish(publish) => assert_eq!(publish.pkid, 3),
packet => panic!("Unexpected fresh packet after replay: {packet:?}"),
}
assert!(mqtt.collision.is_none());
}
#[test]
fn clone_preserves_current_tracking_lengths_after_shrink() {
let mut mqtt = build_mqttstate();
mqtt.last_pkid = 32;
mqtt.last_puback = 32;
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.handle_incoming_puback(&PubAck::new(33)).unwrap();
let cloned = mqtt.clone();
assert_eq!(cloned.outgoing_pub.len(), 33);
assert_eq!(cloned.outgoing_pub_notice.len(), 33);
assert_eq!(cloned.outgoing_rel_notice.len(), 33);
assert_eq!(cloned.outgoing_pub_ack.len(), 33);
assert_eq!(cloned.outgoing_rel.len(), 33);
}
#[test]
fn next_pkid_increments_as_expected() {
let mut mqtt = build_mqttstate();
for i in 1..=100 {
let pkid = mqtt.next_pkid();
let expected = i % 100;
if expected == 0 {
break;
}
assert_eq!(expected, pkid);
}
}
#[test]
fn can_send_publish_searches_free_pkid_after_control_ids_pass_inflight_limit() {
let mut mqtt = MqttState::builder(4).build();
let mut active_publish = build_outgoing_publish(QoS::AtLeastOnce);
active_publish.pkid = 1;
mqtt.outgoing_pub[1] = Some(active_publish);
mqtt.inflight = 1;
mqtt.last_pkid = 5;
assert!(mqtt.can_send_publish(&build_outgoing_publish(QoS::AtLeastOnce)));
let packet = mqtt
.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap()
.unwrap();
match packet {
Packet::Publish(publish) => assert_eq!(publish.pkid, 2),
packet => panic!("Unexpected packet: {packet:?}"),
}
}
#[test]
fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::AtMostOnce);
mqtt.outgoing_publish(publish).unwrap();
assert_eq!(mqtt.last_pkid, 0);
assert_eq!(mqtt.inflight, 0);
let publish = build_outgoing_publish(QoS::AtLeastOnce);
mqtt.outgoing_publish(publish.clone()).unwrap();
assert_eq!(mqtt.last_pkid, 1);
assert_eq!(mqtt.inflight, 1);
mqtt.outgoing_publish(publish).unwrap();
assert_eq!(mqtt.last_pkid, 2);
assert_eq!(mqtt.inflight, 2);
let publish = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.outgoing_publish(publish.clone()).unwrap();
assert_eq!(mqtt.last_pkid, 3);
assert_eq!(mqtt.inflight, 3);
mqtt.outgoing_publish(publish).unwrap();
assert_eq!(mqtt.last_pkid, 4);
assert_eq!(mqtt.inflight, 4);
}
#[test]
fn incoming_publish_should_be_added_to_queue_correctly() {
let mut mqtt = build_mqttstate();
let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
let _ = mqtt.handle_incoming_publish(&publish1);
let _ = mqtt.handle_incoming_publish(&publish2);
let _ = mqtt.handle_incoming_publish(&publish3);
assert!(mqtt.incoming_pub.contains(3));
}
#[test]
fn incoming_publish_should_be_acked() {
let mut mqtt = build_mqttstate();
let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
assert!(mqtt.handle_incoming_publish(&publish1).is_none());
let packet = mqtt.handle_incoming_publish(&publish2).unwrap();
if let Packet::PubAck(puback) = packet {
let pkid = puback.pkid;
assert_eq!(pkid, 2);
} else {
panic!("missing puback");
}
let packet = mqtt.handle_incoming_publish(&publish3).unwrap();
if let Packet::PubRec(pubrec) = packet {
let pkid = pubrec.pkid;
assert_eq!(pkid, 3);
} else {
panic!("missing PubRec");
}
}
#[test]
fn incoming_publish_should_not_be_acked_with_manual_acks() {
let mut mqtt = build_mqttstate();
mqtt.manual_acks = true;
let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
assert!(mqtt.handle_incoming_publish(&publish1).is_none());
assert!(mqtt.handle_incoming_publish(&publish2).is_none());
assert!(mqtt.handle_incoming_publish(&publish3).is_none());
assert!(mqtt.incoming_pub.contains(3));
assert!(mqtt.events.is_empty());
}
#[test]
fn handle_incoming_packet_should_emit_incoming_before_derived_qos1_ack() {
let mut mqtt = build_mqttstate();
let publish = build_incoming_publish(QoS::AtLeastOnce, 42);
mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
.unwrap();
assert_eq!(mqtt.events.len(), 2);
assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubAck(42)));
}
#[test]
fn handle_incoming_packet_should_emit_incoming_before_derived_qos2_ack() {
let mut mqtt = build_mqttstate();
let publish = build_incoming_publish(QoS::ExactlyOnce, 43);
mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
.unwrap();
assert_eq!(mqtt.events.len(), 2);
assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubRec(43)));
}
#[test]
fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
let mut mqtt = build_mqttstate();
let publish = build_incoming_publish(QoS::ExactlyOnce, 1);
let packet = mqtt.handle_incoming_publish(&publish).unwrap();
match packet {
Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
_ => panic!("Invalid network request: {packet:?}"),
}
}
#[test]
fn incoming_puback_should_remove_correct_publish_from_queue() {
let mut mqtt = build_mqttstate();
let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.outgoing_publish(publish1).unwrap();
mqtt.outgoing_publish(publish2).unwrap();
assert_eq!(mqtt.inflight, 2);
mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap();
assert_eq!(mqtt.inflight, 1);
mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
assert_eq!(mqtt.inflight, 0);
assert!(mqtt.outgoing_pub[1].is_none());
assert!(mqtt.outgoing_pub[2].is_none());
}
#[test]
fn incoming_puback_advances_last_puback_only_on_contiguous_boundary() {
let mut mqtt = build_mqttstate();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
assert_eq!(mqtt.last_puback, 0);
mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
assert_eq!(mqtt.last_puback, 0);
mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap();
assert_eq!(mqtt.last_puback, 2);
mqtt.handle_incoming_puback(&PubAck::new(3)).unwrap();
assert_eq!(mqtt.last_puback, 3);
}
#[test]
fn mixed_qos_completion_clears_outbound_drain_state() {
let mut mqtt = build_mqttstate();
mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
.unwrap();
mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap();
mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
mqtt.handle_incoming_puback(&PubAck::new(3)).unwrap();
mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap();
mqtt.handle_incoming_pubrec(&PubRec::new(4)).unwrap();
mqtt.handle_incoming_pubcomp(&PubComp::new(4)).unwrap();
assert_eq!(mqtt.inflight, 0);
assert!(mqtt.outbound_requests_drained());
assert!(mqtt.outgoing_pub_ack.ones().next().is_none());
assert!(mqtt.outgoing_rel.ones().next().is_none());
}
#[test]
fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() {
let mut mqtt = build_mqttstate();
let got = mqtt.handle_incoming_puback(&PubAck::new(101)).unwrap_err();
match got {
StateError::Unsolicited(pkid) => assert_eq!(pkid, 101),
e => panic!("Unexpected error: {e}"),
}
}
#[test]
fn incoming_puback_with_pkid_beyond_allocated_tracking_is_unsolicited() {
let mut mqtt = build_mqttstate();
let got = mqtt.handle_incoming_puback(&PubAck::new(50)).unwrap_err();
match got {
StateError::Unsolicited(pkid) => assert_eq!(pkid, 50),
e => panic!("Unexpected error: {e}"),
}
}
#[test]
fn outgoing_publish_with_pkid_above_max_inflight_is_unsolicited_and_does_not_grow_tracking() {
let mut mqtt = MqttState::builder(10).build();
let mut publish = build_outgoing_publish(QoS::AtLeastOnce);
publish.pkid = 50;
let got = mqtt
.handle_outgoing_packet(Request::Publish(publish))
.unwrap_err();
match got {
StateError::Unsolicited(pkid) => assert_eq!(pkid, 50),
e => panic!("Unexpected error: {e}"),
}
assert_eq!(mqtt.outgoing_pub.len(), 11);
assert_eq!(mqtt.outgoing_pub_notice.len(), 11);
assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
assert_eq!(mqtt.inflight, 0);
}
#[test]
fn outgoing_pubrel_with_pkid_above_max_inflight_is_unsolicited_and_does_not_grow_tracking() {
let mut mqtt = MqttState::builder(10).build();
let got = mqtt
.handle_outgoing_packet(Request::PubRel(PubRel::new(50)))
.unwrap_err();
match got {
StateError::Unsolicited(pkid) => assert_eq!(pkid, 50),
e => panic!("Unexpected error: {e}"),
}
assert_eq!(mqtt.outgoing_pub.len(), 11);
assert_eq!(mqtt.outgoing_pub_notice.len(), 11);
assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
assert_eq!(mqtt.inflight, 0);
}
#[test]
fn clean_keeps_oldest_unacked_publish_first_after_out_of_order_puback() {
let mut mqtt = build_mqttstate();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
.unwrap();
mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
let requests = mqtt.clean();
let pending_pkids: Vec<u16> = requests
.iter()
.map(|req| match req {
Request::Publish(publish) => publish.pkid,
req => panic!("Unexpected request while cleaning: {req:?}"),
})
.collect();
assert_eq!(pending_pkids, vec![1, 3]);
}
#[test]
fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() {
let mut mqtt = build_mqttstate();
let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
let _publish_out = mqtt.outgoing_publish(publish1);
let _publish_out = mqtt.outgoing_publish(publish2);
mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap();
assert_eq!(mqtt.inflight, 2);
let backup = mqtt.outgoing_pub[1].clone();
assert_eq!(backup.unwrap().pkid, 1);
assert!(mqtt.outgoing_rel.contains(2));
}
#[test]
fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::ExactlyOnce);
let packet = mqtt.outgoing_publish(publish).unwrap().unwrap();
match packet {
Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
packet => panic!("Invalid network request: {packet:?}"),
}
let packet = mqtt
.handle_incoming_pubrec(&PubRec::new(1))
.unwrap()
.unwrap();
match packet {
Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
packet => panic!("Invalid network request: {packet:?}"),
}
}
#[test]
fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() {
let mut mqtt = build_mqttstate();
let publish = build_incoming_publish(QoS::ExactlyOnce, 1);
let packet = mqtt.handle_incoming_publish(&publish).unwrap();
match packet {
Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
packet => panic!("Invalid network request: {packet:?}"),
}
let packet = mqtt
.handle_incoming_pubrel(&PubRel::new(1))
.unwrap()
.unwrap();
match packet {
Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1),
packet => panic!("Invalid network request: {packet:?}"),
}
}
#[test]
fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.outgoing_publish(publish).unwrap();
mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap();
mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap();
assert_eq!(mqtt.inflight, 0);
}
#[test]
fn incoming_pubcomp_collision_replay_should_restore_qos2_tracking() {
let mut mqtt = build_mqttstate();
let publish = build_outgoing_publish(QoS::ExactlyOnce);
mqtt.outgoing_publish(publish).unwrap();
let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
collided.pkid = 1;
assert!(mqtt.outgoing_publish(collided).unwrap().is_none());
assert!(mqtt.collision.is_some());
mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap();
let packet = mqtt
.handle_incoming_pubcomp(&PubComp::new(1))
.unwrap()
.unwrap();
match packet {
Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
packet => panic!("Invalid network request: {packet:?}"),
}
assert!(mqtt.outgoing_pub[1].is_some());
assert_eq!(mqtt.inflight, 1);
let packet = mqtt
.handle_incoming_pubrec(&PubRec::new(1))
.unwrap()
.unwrap();
match packet {
Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
packet => panic!("Invalid network request: {packet:?}"),
}
}
#[test]
fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() {
let mut mqtt = build_mqttstate();
mqtt.outgoing_ping().unwrap();
let publish = build_outgoing_publish(QoS::AtLeastOnce);
mqtt.handle_outgoing_packet(Request::Publish(publish))
.unwrap();
mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1)))
.unwrap();
match mqtt.outgoing_ping() {
Ok(_) => panic!("Should throw pingresp await error"),
Err(StateError::AwaitPingResp) => (),
Err(e) => panic!("Should throw pingresp await error. Error = {e:?}"),
}
}
#[test]
fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() {
let mut mqtt = build_mqttstate();
mqtt.outgoing_ping().unwrap();
mqtt.handle_incoming_packet(Incoming::PingResp).unwrap();
mqtt.outgoing_ping().unwrap();
}
#[test]
fn clean_is_calculating_pending_correctly() {
fn build_outgoing_pub() -> Vec<Option<Publish>> {
vec![
None,
Some(Publish {
dup: false,
qos: QoS::AtMostOnce,
retain: false,
topic: Bytes::from_static(b"test"),
pkid: 1,
payload: "".into(),
}),
Some(Publish {
dup: false,
qos: QoS::AtMostOnce,
retain: false,
topic: Bytes::from_static(b"test"),
pkid: 2,
payload: "".into(),
}),
Some(Publish {
dup: false,
qos: QoS::AtMostOnce,
retain: false,
topic: Bytes::from_static(b"test"),
pkid: 3,
payload: "".into(),
}),
None,
None,
Some(Publish {
dup: false,
qos: QoS::AtMostOnce,
retain: false,
topic: Bytes::from_static(b"test"),
pkid: 6,
payload: "".into(),
}),
]
}
let mut mqtt = build_mqttstate();
mqtt.outgoing_pub = build_outgoing_pub();
mqtt.last_puback = 3;
let requests = mqtt.clean();
let res = vec![6, 1, 2, 3];
for (req, idx) in requests.iter().zip(res) {
if let Request::Publish(publish) = req {
assert_eq!(publish.pkid, idx);
} else {
unreachable!()
}
}
mqtt.outgoing_pub = build_outgoing_pub();
mqtt.last_puback = 0;
let requests = mqtt.clean();
let res = vec![1, 2, 3, 6];
for (req, idx) in requests.iter().zip(res) {
if let Request::Publish(publish) = req {
assert_eq!(publish.pkid, idx);
} else {
unreachable!()
}
}
mqtt.outgoing_pub = build_outgoing_pub();
mqtt.last_puback = 6;
let requests = mqtt.clean();
let res = vec![1, 2, 3, 6];
for (req, idx) in requests.iter().zip(res) {
if let Request::Publish(publish) = req {
assert_eq!(publish.pkid, idx);
} else {
unreachable!()
}
}
}
}