1use crate::notice::{
2 PublishNoticeTx, PublishResult, SubscribeNoticeTx, TrackedNoticeTx, UnsubscribeNoticeTx,
3};
4use crate::{Event, Incoming, NoticeFailureReason, Outgoing, PublishNoticeError, Request};
5
6use crate::mqttbytes::v4::{
7 Packet, PubAck, PubComp, PubRec, PubRel, Publish, SubAck, Subscribe, UnsubAck, Unsubscribe,
8};
9use crate::mqttbytes::{self, QoS};
10use fixedbitset::FixedBitSet;
11use std::collections::{BTreeMap, VecDeque};
12use std::{io, time::Instant};
13
14#[derive(Debug, thiserror::Error)]
16pub enum StateError {
17 #[error("Io error: {0:?}")]
19 Io(#[from] io::Error),
20 #[error("Invalid state for a given operation")]
22 InvalidState,
23 #[error("Received unsolicited ack pkid: {0}")]
25 Unsolicited(u16),
26 #[error("Last pingreq isn't acked")]
28 AwaitPingResp,
29 #[error("Received a wrong packet while waiting for another packet")]
31 WrongPacket,
32 #[error("Timeout while waiting to resolve collision")]
33 CollisionTimeout,
34 #[error("A Subscribe packet must contain atleast one filter")]
35 EmptySubscription,
36 #[error("Mqtt serialization/deserialization error: {0}")]
37 Deserialization(#[from] mqttbytes::Error),
38 #[error("Connection closed by peer abruptly")]
39 ConnectionAborted,
40}
41
42#[derive(Debug)]
49pub struct MqttState {
50 pub await_pingresp: bool,
52 pub collision_ping_count: usize,
56 last_incoming: Instant,
58 last_outgoing: Instant,
60 pub(crate) last_pkid: u16,
62 pub(crate) last_puback: u16,
64 pub(crate) inflight: u16,
66 pub(crate) max_inflight: u16,
68 pub(crate) outgoing_pub: Vec<Option<Publish>>,
70 pub(crate) outgoing_pub_notice: Vec<Option<PublishNoticeTx>>,
72 pub(crate) outgoing_pub_ack: FixedBitSet,
74 pub(crate) outgoing_rel: FixedBitSet,
76 pub(crate) outgoing_rel_notice: Vec<Option<PublishNoticeTx>>,
78 pub(crate) incoming_pub: FixedBitSet,
80 pub collision: Option<Publish>,
82 pub(crate) collision_notice: Option<PublishNoticeTx>,
84 pub(crate) tracked_subscribe: BTreeMap<u16, (Subscribe, SubscribeNoticeTx)>,
86 pub(crate) tracked_unsubscribe: BTreeMap<u16, (Unsubscribe, UnsubscribeNoticeTx)>,
88 pub events: VecDeque<Event>,
90 pub manual_acks: bool,
92}
93
94#[derive(Debug)]
100pub struct MqttStateBuilder {
101 max_inflight: u16,
102 manual_acks: bool,
103}
104
105impl MqttStateBuilder {
106 #[must_use]
108 pub const fn new(max_inflight: u16) -> Self {
109 Self {
110 max_inflight,
111 manual_acks: false,
112 }
113 }
114
115 #[must_use]
117 pub const fn manual_acks(mut self, manual_acks: bool) -> Self {
118 self.manual_acks = manual_acks;
119 self
120 }
121
122 #[must_use]
124 pub fn build(self) -> MqttState {
125 MqttState::new_internal(self.max_inflight, self.manual_acks)
126 }
127}
128
129impl MqttState {
130 const WARM_TRACKING_SLOTS: usize = 32;
131
132 const fn initial_events_capacity() -> usize {
133 128
134 }
135
136 const fn outgoing_tracking_len(max_inflight: u16) -> usize {
137 max_inflight as usize + 1
138 }
139
140 const fn warm_tracking_len(max_inflight: u16) -> usize {
141 let full_len = Self::outgoing_tracking_len(max_inflight);
142 let warm_len = Self::WARM_TRACKING_SLOTS + 1;
143 if full_len < warm_len {
144 full_len
145 } else {
146 warm_len
147 }
148 }
149
150 fn new_notice_slots_with_len(len: usize) -> Vec<Option<PublishNoticeTx>> {
151 std::iter::repeat_with(|| None).take(len).collect()
152 }
153
154 fn ensure_outgoing_tracking_capacity(&mut self, target_len: usize) {
155 if self.outgoing_pub.len() < target_len {
156 self.outgoing_pub.resize_with(target_len, || None);
157 }
158
159 if self.outgoing_pub_notice.len() < target_len {
160 self.outgoing_pub_notice.resize_with(target_len, || None);
161 }
162
163 if self.outgoing_rel_notice.len() < target_len {
164 self.outgoing_rel_notice.resize_with(target_len, || None);
165 }
166
167 if self.outgoing_pub_ack.len() < target_len {
168 self.outgoing_pub_ack.grow(target_len);
169 }
170
171 if self.outgoing_rel.len() < target_len {
172 self.outgoing_rel.grow(target_len);
173 }
174 }
175
176 pub(crate) fn outbound_requests_drained(&self) -> bool {
177 self.inflight == 0
178 && self.collision.is_none()
179 && self.collision_notice.is_none()
180 && self.tracked_subscribe.is_empty()
181 && self.tracked_unsubscribe.is_empty()
182 && self.outgoing_pub.iter().all(Option::is_none)
183 && self.outgoing_pub_notice.iter().all(Option::is_none)
184 && self.outgoing_rel_notice.iter().all(Option::is_none)
185 && self.outgoing_pub_ack.ones().next().is_none()
186 && self.outgoing_rel.ones().next().is_none()
187 }
188
189 fn maybe_shrink_outgoing_tracking_capacity(&mut self) {
190 let target_len = Self::warm_tracking_len(self.max_inflight);
191 if self.outgoing_pub.len() <= target_len || !self.outbound_requests_drained() {
192 return;
193 }
194
195 self.outgoing_pub.truncate(target_len);
196 self.outgoing_pub_notice.truncate(target_len);
197 self.outgoing_rel_notice.truncate(target_len);
198 self.outgoing_pub_ack = FixedBitSet::with_capacity(target_len);
199 self.outgoing_rel = FixedBitSet::with_capacity(target_len);
200 self.last_pkid = 0;
201 self.last_puback = 0;
202 }
203
204 const fn validate_outgoing_pkid_bound(&self, pkid: u16) -> Result<(), StateError> {
205 if pkid == 0 || pkid > self.max_inflight {
206 return Err(StateError::Unsolicited(pkid));
207 }
208
209 Ok(())
210 }
211
212 const fn next_publish_pkid_after(&self, pkid: u16) -> u16 {
213 if pkid >= self.max_inflight {
214 1
215 } else {
216 pkid + 1
217 }
218 }
219
220 fn packet_identifier_in_use(&self, pkid: u16) -> bool {
221 let index = usize::from(pkid);
222 self.outgoing_pub.get(index).is_some_and(Option::is_some)
223 || self.outgoing_rel.contains(index)
224 || self.tracked_subscribe.contains_key(&pkid)
225 || self.tracked_unsubscribe.contains_key(&pkid)
226 }
227
228 pub(crate) fn can_send_publish(&self, publish: &Publish) -> bool {
229 if publish.qos == QoS::AtMostOnce {
230 return true;
231 }
232
233 if self.inflight >= self.max_inflight || self.collision.is_some() {
234 return false;
235 }
236
237 if publish.pkid == 0 {
238 return self.next_publish_pkid().is_some();
239 }
240
241 self.validate_outgoing_pkid_bound(publish.pkid).is_ok()
242 && !self.packet_identifier_in_use(publish.pkid)
243 }
244
245 pub(crate) fn control_packet_identifier_available(&self) -> bool {
246 (1..=u16::MAX).any(|pkid| !self.packet_identifier_in_use(pkid))
247 }
248
249 fn clean_pending_capacity(&self) -> usize {
250 self.outgoing_pub
251 .iter()
252 .filter(|publish| publish.is_some())
253 .count()
254 + self.outgoing_rel.ones().count()
255 + self.tracked_subscribe.len()
256 + self.tracked_unsubscribe.len()
257 }
258
259 #[must_use]
261 pub const fn builder(max_inflight: u16) -> MqttStateBuilder {
262 MqttStateBuilder::new(max_inflight)
263 }
264
265 #[must_use]
269 pub(crate) fn new_internal(max_inflight: u16, manual_acks: bool) -> Self {
270 let tracking_len = Self::warm_tracking_len(max_inflight);
271 Self {
272 await_pingresp: false,
273 collision_ping_count: 0,
274 last_incoming: Instant::now(),
275 last_outgoing: Instant::now(),
276 last_pkid: 0,
277 last_puback: 0,
278 inflight: 0,
279 max_inflight,
280 outgoing_pub: std::iter::repeat_with(|| None).take(tracking_len).collect(),
282 outgoing_pub_notice: Self::new_notice_slots_with_len(tracking_len),
283 outgoing_pub_ack: FixedBitSet::with_capacity(tracking_len),
284 outgoing_rel: FixedBitSet::with_capacity(tracking_len),
285 outgoing_rel_notice: Self::new_notice_slots_with_len(tracking_len),
286 incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1),
287 collision: None,
288 collision_notice: None,
289 tracked_subscribe: BTreeMap::new(),
290 tracked_unsubscribe: BTreeMap::new(),
291 events: VecDeque::with_capacity(Self::initial_events_capacity()),
292 manual_acks,
293 }
294 }
295
296 pub(crate) fn clean_with_notices(&mut self) -> Vec<(Request, Option<TrackedNoticeTx>)> {
297 let mut pending = Vec::with_capacity(self.clean_pending_capacity());
298 let (first_half, second_half) = self
299 .outgoing_pub
300 .split_at_mut(self.last_puback as usize + 1);
301 let (notice_first_half, notice_second_half) = self
302 .outgoing_pub_notice
303 .split_at_mut(self.last_puback as usize + 1);
304
305 for (publish, notice) in second_half
306 .iter_mut()
307 .zip(notice_second_half.iter_mut())
308 .chain(first_half.iter_mut().zip(notice_first_half.iter_mut()))
309 {
310 if let Some(publish) = publish.take() {
311 let request = Request::Publish(publish);
312 pending.push((request, notice.take().map(TrackedNoticeTx::Publish)));
313 } else {
314 _ = notice.take();
315 }
316 }
317
318 for pkid in self.outgoing_rel.ones() {
320 let pkid = u16::try_from(pkid).expect("fixedbitset index always fits in u16");
321 let request = Request::PubRel(PubRel::new(pkid));
322 pending.push((
323 request,
324 self.outgoing_rel_notice[pkid as usize]
325 .take()
326 .map(TrackedNoticeTx::Publish),
327 ));
328 }
329 self.outgoing_rel.clear();
330 self.outgoing_pub_ack.clear();
331
332 for (pkid, (mut subscribe, notice)) in std::mem::take(&mut self.tracked_subscribe) {
333 subscribe.pkid = pkid;
334 pending.push((
335 Request::Subscribe(subscribe),
336 Some(TrackedNoticeTx::Subscribe(notice)),
337 ));
338 }
339
340 for (pkid, (mut unsubscribe, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
341 unsubscribe.pkid = pkid;
342 pending.push((
343 Request::Unsubscribe(unsubscribe),
344 Some(TrackedNoticeTx::Unsubscribe(notice)),
345 ));
346 }
347
348 self.incoming_pub.clear();
350
351 self.await_pingresp = false;
352 self.collision_ping_count = 0;
353 self.inflight = 0;
354 if pending.is_empty() {
355 self.maybe_shrink_outgoing_tracking_capacity();
356 }
357 pending
358 }
359
360 pub fn clean(&mut self) -> Vec<Request> {
362 self.clean_with_notices()
363 .into_iter()
364 .map(|(request, _)| request)
365 .collect()
366 }
367
368 pub const fn inflight(&self) -> u16 {
369 self.inflight
370 }
371
372 pub fn tracked_subscribe_len(&self) -> usize {
373 self.tracked_subscribe.len()
374 }
375
376 pub fn tracked_unsubscribe_len(&self) -> usize {
377 self.tracked_unsubscribe.len()
378 }
379
380 pub fn tracked_requests_is_empty(&self) -> bool {
381 self.tracked_subscribe.is_empty() && self.tracked_unsubscribe.is_empty()
382 }
383
384 pub fn drain_tracked_requests_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
385 let mut drained = 0;
386 for (_, (_, notice)) in std::mem::take(&mut self.tracked_subscribe) {
387 drained += 1;
388 notice.error(reason.subscribe_error());
389 }
390 for (_, (_, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
391 drained += 1;
392 notice.error(reason.unsubscribe_error());
393 }
394
395 self.maybe_shrink_outgoing_tracking_capacity();
396 drained
397 }
398
399 pub(crate) fn fail_pending_notices(&mut self) {
400 for notice in &mut self.outgoing_pub_notice {
401 if let Some(tx) = notice.take() {
402 tx.error(PublishNoticeError::SessionReset);
403 }
404 }
405
406 for notice in &mut self.outgoing_rel_notice {
407 if let Some(tx) = notice.take() {
408 tx.error(PublishNoticeError::SessionReset);
409 }
410 }
411
412 if let Some(tx) = self.collision_notice.take() {
413 tx.error(PublishNoticeError::SessionReset);
414 }
415
416 self.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
417 self.clear_collision();
418 self.maybe_shrink_outgoing_tracking_capacity();
419 }
420
421 pub fn handle_outgoing_packet(
429 &mut self,
430 request: Request,
431 ) -> Result<Option<Packet>, StateError> {
432 let (packet, flush_notice) = self.handle_outgoing_packet_with_notice(request, None)?;
433 if let Some(tx) = flush_notice {
434 tx.success(PublishResult::Qos0Flushed);
435 }
436
437 self.last_outgoing = Instant::now();
438 Ok(packet)
439 }
440
441 pub(crate) fn handle_outgoing_packet_with_notice(
442 &mut self,
443 request: Request,
444 notice: Option<TrackedNoticeTx>,
445 ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
446 let result =
447 match request {
448 Request::Publish(publish) => {
449 let publish_notice = match notice {
450 Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
451 Some(TrackedNoticeTx::Subscribe(_) | TrackedNoticeTx::Unsubscribe(_))
452 | None => None,
453 };
454 self.outgoing_publish_with_notice(publish, publish_notice)?
455 }
456 Request::PubRel(pubrel) => {
457 let publish_notice = match notice {
458 Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
459 Some(TrackedNoticeTx::Subscribe(_) | TrackedNoticeTx::Unsubscribe(_))
460 | None => None,
461 };
462 self.outgoing_pubrel_with_notice(pubrel, publish_notice)?
463 }
464 Request::Subscribe(subscribe) => {
465 let request_notice = match notice {
466 Some(TrackedNoticeTx::Subscribe(notice)) => Some(notice),
467 Some(TrackedNoticeTx::Publish(_) | TrackedNoticeTx::Unsubscribe(_))
468 | None => None,
469 };
470 (self.outgoing_subscribe(subscribe, request_notice)?, None)
471 }
472 Request::Unsubscribe(unsubscribe) => {
473 let request_notice = match notice {
474 Some(TrackedNoticeTx::Unsubscribe(notice)) => Some(notice),
475 Some(TrackedNoticeTx::Publish(_) | TrackedNoticeTx::Subscribe(_))
476 | None => None,
477 };
478 (
479 Some(self.outgoing_unsubscribe(unsubscribe, request_notice)?),
480 None,
481 )
482 }
483 Request::PingReq(_) => (self.outgoing_ping()?, None),
484 Request::Disconnect(_) | Request::DisconnectWithTimeout(_, _) => {
485 unreachable!("graceful disconnect requests are handled by the event loop")
486 }
487 Request::DisconnectNow(_) => (Some(self.outgoing_disconnect()), None),
488 Request::PubAck(puback) => (Some(self.outgoing_puback(puback)), None),
489 Request::PubRec(pubrec) => (Some(self.outgoing_pubrec(pubrec)), None),
490 _ => unimplemented!(),
491 };
492
493 self.last_outgoing = Instant::now();
494 Ok(result)
495 }
496
497 pub fn handle_incoming_packet(
507 &mut self,
508 packet: Incoming,
509 ) -> Result<Option<Packet>, StateError> {
510 let events_len_before = self.events.len();
511 let outgoing = match &packet {
512 Incoming::PingResp => Ok(self.handle_incoming_pingresp()),
513 Incoming::Publish(publish) => Ok(self.handle_incoming_publish(publish)),
514 Incoming::SubAck(suback) => Ok(self.handle_incoming_suback(suback)),
515 Incoming::UnsubAck(unsuback) => Ok(self.handle_incoming_unsuback(unsuback)),
516 Incoming::PubAck(puback) => self.handle_incoming_puback(puback),
517 Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec),
518 Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel),
519 Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp),
520 _ => {
521 error!("Invalid incoming packet = {packet:?}");
522 Err(StateError::WrongPacket)
523 }
524 };
525
526 self.events
529 .insert(events_len_before, Event::Incoming(packet));
530 let outgoing = outgoing?;
531 self.last_incoming = Instant::now();
532
533 Ok(outgoing)
534 }
535
536 pub fn clear_collision(&mut self) {
537 self.collision = None;
538 self.collision_notice = None;
539 self.collision_ping_count = 0;
540 }
541
542 fn handle_incoming_suback(&mut self, suback: &SubAck) -> Option<Packet> {
543 if let Some((_, notice)) = self.tracked_subscribe.remove(&suback.pkid) {
544 notice.success(suback.clone());
545 }
546 None
547 }
548
549 fn handle_incoming_unsuback(&mut self, unsuback: &UnsubAck) -> Option<Packet> {
550 if let Some((_, notice)) = self.tracked_unsubscribe.remove(&unsuback.pkid) {
551 notice.success(unsuback.clone());
552 }
553 None
554 }
555
556 fn handle_incoming_publish(&mut self, publish: &Publish) -> Option<Packet> {
559 let qos = publish.qos;
560
561 match qos {
562 QoS::AtMostOnce => None,
563 QoS::AtLeastOnce => {
564 if !self.manual_acks {
565 let puback = PubAck::new(publish.pkid);
566 return Some(self.outgoing_puback(puback));
567 }
568 None
569 }
570 QoS::ExactlyOnce => {
571 let pkid = publish.pkid;
572 self.incoming_pub.insert(pkid as usize);
573
574 if !self.manual_acks {
575 let pubrec = PubRec::new(pkid);
576 return Some(self.outgoing_pubrec(pubrec));
577 }
578 None
579 }
580 }
581 }
582
583 fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<Option<Packet>, StateError> {
584 let publish = self
585 .outgoing_pub
586 .get_mut(puback.pkid as usize)
587 .ok_or(StateError::Unsolicited(puback.pkid))?;
588
589 if publish.take().is_none() {
590 error!("Unsolicited puback packet: {:?}", puback.pkid);
591 return Err(StateError::Unsolicited(puback.pkid));
592 }
593 self.mark_outgoing_packet_id_complete(puback.pkid);
594
595 if let Some(tx) = self.outgoing_pub_notice[puback.pkid as usize].take() {
596 tx.success(PublishResult::Qos1(puback.clone()));
597 }
598
599 self.inflight -= 1;
600 let packet = self.replay_collision_publish(puback.pkid);
601 if packet.is_none() {
602 self.maybe_shrink_outgoing_tracking_capacity();
603 }
604
605 Ok(packet)
606 }
607
608 fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<Option<Packet>, StateError> {
609 let publish = self
610 .outgoing_pub
611 .get_mut(pubrec.pkid as usize)
612 .ok_or(StateError::Unsolicited(pubrec.pkid))?;
613
614 if publish.take().is_none() {
615 error!("Unsolicited pubrec packet: {:?}", pubrec.pkid);
616 return Err(StateError::Unsolicited(pubrec.pkid));
617 }
618
619 let notice = self.outgoing_pub_notice[pubrec.pkid as usize].take();
620 self.outgoing_rel.insert(pubrec.pkid as usize);
622 self.outgoing_rel_notice[pubrec.pkid as usize] = notice;
623 let release = PubRel { pkid: pubrec.pkid };
624 let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
625 self.events.push_back(event);
626
627 Ok(Some(Packet::PubRel(release)))
628 }
629
630 fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<Option<Packet>, StateError> {
631 if !self.incoming_pub.contains(pubrel.pkid as usize) {
632 error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
633 return Err(StateError::Unsolicited(pubrel.pkid));
634 }
635
636 self.incoming_pub.set(pubrel.pkid as usize, false);
637 let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
638 let pubcomp = PubComp { pkid: pubrel.pkid };
639 self.events.push_back(event);
640
641 Ok(Some(Packet::PubComp(pubcomp)))
642 }
643
644 fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<Option<Packet>, StateError> {
645 if !self.outgoing_rel.contains(pubcomp.pkid as usize) {
646 error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
647 return Err(StateError::Unsolicited(pubcomp.pkid));
648 }
649
650 self.outgoing_rel.set(pubcomp.pkid as usize, false);
651 self.mark_outgoing_packet_id_complete(pubcomp.pkid);
652 if let Some(tx) = self.outgoing_rel_notice[pubcomp.pkid as usize].take() {
653 tx.success(PublishResult::Qos2Completed(pubcomp.clone()));
654 }
655 self.inflight -= 1;
656 let packet = self.replay_collision_publish(pubcomp.pkid);
657 if packet.is_none() {
658 self.maybe_shrink_outgoing_tracking_capacity();
659 }
660
661 Ok(packet)
662 }
663
664 const fn handle_incoming_pingresp(&mut self) -> Option<Packet> {
665 self.await_pingresp = false;
666
667 None
668 }
669
670 #[cfg(test)]
673 fn outgoing_publish(&mut self, publish: Publish) -> Result<Option<Packet>, StateError> {
674 let (packet, flush_notice) = self.outgoing_publish_with_notice(publish, None)?;
675 if let Some(tx) = flush_notice {
676 tx.success(PublishResult::Qos0Flushed);
677 }
678 Ok(packet)
679 }
680
681 fn outgoing_publish_with_notice(
682 &mut self,
683 mut publish: Publish,
684 notice: Option<PublishNoticeTx>,
685 ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
686 let mut notice = notice;
687 if publish.qos != QoS::AtMostOnce {
688 if publish.pkid == 0 {
689 publish.pkid = self.next_pkid();
690 }
691
692 let pkid = publish.pkid;
693 self.validate_outgoing_pkid_bound(pkid)?;
694 self.ensure_outgoing_tracking_capacity(pkid as usize + 1);
695 if self
696 .outgoing_pub
697 .get(publish.pkid as usize)
698 .ok_or(StateError::Unsolicited(publish.pkid))?
699 .is_some()
700 {
701 info!("Collision on packet id = {:?}", publish.pkid);
702 self.collision = Some(publish);
703 self.collision_notice = notice.take();
704 let event = Event::Outgoing(Outgoing::AwaitAck(pkid));
705 self.events.push_back(event);
706 return Ok((None, None));
707 }
708
709 self.outgoing_pub[pkid as usize] = Some(publish.clone());
712 self.outgoing_pub_notice[pkid as usize] = notice.take();
713 self.outgoing_pub_ack.set(pkid as usize, false);
714 self.inflight += 1;
715 }
716
717 debug!(
718 "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}",
719 String::from_utf8_lossy(&publish.topic),
720 publish.pkid,
721 publish.payload.len()
722 );
723
724 let event = Event::Outgoing(Outgoing::Publish(publish.pkid));
725 self.events.push_back(event);
726
727 if publish.qos == QoS::AtMostOnce {
728 Ok((Some(Packet::Publish(publish)), notice.take()))
729 } else {
730 Ok((Some(Packet::Publish(publish)), None))
731 }
732 }
733
734 fn outgoing_pubrel_with_notice(
735 &mut self,
736 pubrel: PubRel,
737 notice: Option<PublishNoticeTx>,
738 ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
739 let pubrel = self.save_pubrel_with_notice(pubrel, notice)?;
740
741 debug!("Pubrel. Pkid = {}", pubrel.pkid);
742 let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid));
743 self.events.push_back(event);
744
745 Ok((Some(Packet::PubRel(pubrel)), None))
746 }
747
748 fn outgoing_puback(&mut self, puback: PubAck) -> Packet {
749 let event = Event::Outgoing(Outgoing::PubAck(puback.pkid));
750 self.events.push_back(event);
751
752 Packet::PubAck(puback)
753 }
754
755 fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Packet {
756 let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid));
757 self.events.push_back(event);
758
759 Packet::PubRec(pubrec)
760 }
761
762 fn outgoing_ping(&mut self) -> Result<Option<Packet>, StateError> {
766 let elapsed_in = self.last_incoming.elapsed();
767 let elapsed_out = self.last_outgoing.elapsed();
768
769 if self.collision.is_some() {
770 self.collision_ping_count += 1;
771 if self.collision_ping_count >= 2 {
772 return Err(StateError::CollisionTimeout);
773 }
774 }
775
776 if self.await_pingresp {
778 return Err(StateError::AwaitPingResp);
779 }
780
781 self.await_pingresp = true;
782
783 debug!(
784 "Pingreq,
785 last incoming packet before {} millisecs,
786 last outgoing request before {} millisecs",
787 elapsed_in.as_millis(),
788 elapsed_out.as_millis()
789 );
790
791 let event = Event::Outgoing(Outgoing::PingReq);
792 self.events.push_back(event);
793
794 Ok(Some(Packet::PingReq))
795 }
796
797 fn outgoing_subscribe(
798 &mut self,
799 mut subscription: Subscribe,
800 notice: Option<SubscribeNoticeTx>,
801 ) -> Result<Option<Packet>, StateError> {
802 if subscription.filters.is_empty() {
803 return Err(StateError::EmptySubscription);
804 }
805
806 let pkid = self.next_control_pkid()?;
807 subscription.pkid = pkid;
808
809 debug!(
810 "Subscribe. Topics = {:?}, Pkid = {:?}",
811 subscription.filters, subscription.pkid
812 );
813
814 let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid));
815 self.events.push_back(event);
816 if let Some(notice) = notice {
817 self.tracked_subscribe
818 .insert(subscription.pkid, (subscription.clone(), notice));
819 }
820
821 Ok(Some(Packet::Subscribe(subscription)))
822 }
823
824 fn outgoing_unsubscribe(
825 &mut self,
826 mut unsub: Unsubscribe,
827 notice: Option<UnsubscribeNoticeTx>,
828 ) -> Result<Packet, StateError> {
829 let pkid = self.next_control_pkid()?;
830 unsub.pkid = pkid;
831
832 debug!(
833 "Unsubscribe. Topics = {:?}, Pkid = {:?}",
834 unsub.topics, unsub.pkid
835 );
836
837 let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid));
838 self.events.push_back(event);
839 if let Some(notice) = notice {
840 self.tracked_unsubscribe
841 .insert(unsub.pkid, (unsub.clone(), notice));
842 }
843
844 Ok(Packet::Unsubscribe(unsub))
845 }
846
847 fn outgoing_disconnect(&mut self) -> Packet {
848 debug!("Disconnect");
849
850 let event = Event::Outgoing(Outgoing::Disconnect);
851 self.events.push_back(event);
852
853 Packet::Disconnect
854 }
855
856 fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option<PublishNoticeTx>)> {
857 if let Some(publish) = &self.collision
858 && publish.pkid == pkid
859 {
860 return self
861 .collision
862 .take()
863 .map(|publish| (publish, self.collision_notice.take()));
864 }
865
866 None
867 }
868
869 fn save_pubrel_with_notice(
870 &mut self,
871 mut pubrel: PubRel,
872 notice: Option<PublishNoticeTx>,
873 ) -> Result<PubRel, StateError> {
874 let pubrel = match pubrel.pkid {
875 0 => {
877 pubrel.pkid = self.next_pkid();
878 pubrel
879 }
880 _ => pubrel,
881 };
882
883 self.validate_outgoing_pkid_bound(pubrel.pkid)?;
884 self.ensure_outgoing_tracking_capacity(pubrel.pkid as usize + 1);
885 self.outgoing_rel.insert(pubrel.pkid as usize);
886 self.outgoing_rel_notice[pubrel.pkid as usize] = notice;
887 self.inflight += 1;
888 Ok(pubrel)
889 }
890
891 fn replay_collision_publish(&mut self, pkid: u16) -> Option<Packet> {
892 self.check_collision(pkid).map(|(publish, notice)| {
893 let publish_pkid = publish.pkid;
894 self.ensure_outgoing_tracking_capacity(publish_pkid as usize + 1);
895 self.outgoing_pub[publish_pkid as usize] = Some(publish.clone());
896 self.outgoing_pub_notice[publish_pkid as usize] = notice;
897 self.inflight += 1;
898
899 let event = Event::Outgoing(Outgoing::Publish(publish_pkid));
900 self.events.push_back(event);
901 self.collision_ping_count = 0;
902
903 Packet::Publish(publish)
904 })
905 }
906
907 fn mark_outgoing_packet_id_complete(&mut self, pkid: u16) {
908 self.outgoing_pub_ack.set(pkid as usize, true);
909 self.advance_last_puback_frontier();
910 }
911
912 fn advance_last_puback_frontier(&mut self) {
913 let mut next = self.next_puback_boundary_pkid(self.last_puback);
914 while next != 0 && self.outgoing_pub_ack.contains(next as usize) {
915 self.outgoing_pub_ack.set(next as usize, false);
916 self.last_puback = next;
917 next = self.next_puback_boundary_pkid(self.last_puback);
918 }
919 }
920
921 const fn next_puback_boundary_pkid(&self, pkid: u16) -> u16 {
922 if self.max_inflight == 0 {
923 return 0;
924 }
925
926 if pkid >= self.max_inflight {
927 1
928 } else {
929 pkid + 1
930 }
931 }
932
933 fn next_publish_pkid(&self) -> Option<u16> {
937 let mut pkid = self.next_publish_pkid_after(self.last_pkid);
938 for _ in 0..usize::from(self.max_inflight) {
939 if !self.packet_identifier_in_use(pkid) {
940 return Some(pkid);
941 }
942 pkid = self.next_publish_pkid_after(pkid);
943 }
944
945 None
946 }
947
948 fn next_pkid(&mut self) -> u16 {
949 let pkid = self
950 .next_publish_pkid()
951 .unwrap_or_else(|| self.next_publish_pkid_after(self.last_pkid));
952 if pkid == self.max_inflight {
953 self.last_pkid = 0;
954 } else {
955 self.last_pkid = pkid;
956 }
957
958 pkid
959 }
960
961 fn next_control_pkid(&mut self) -> Result<u16, StateError> {
962 for offset in 1..=u16::MAX {
963 let pkid = self.last_pkid.wrapping_add(offset);
964 if pkid != 0 && !self.packet_identifier_in_use(pkid) {
965 self.last_pkid = pkid;
966 return Ok(pkid);
967 }
968 }
969
970 Err(StateError::InvalidState)
971 }
972}
973
974impl Clone for MqttState {
975 fn clone(&self) -> Self {
976 let tracking_len = self.outgoing_pub_notice.len();
977 Self {
978 await_pingresp: self.await_pingresp,
979 collision_ping_count: self.collision_ping_count,
980 last_incoming: self.last_incoming,
981 last_outgoing: self.last_outgoing,
982 last_pkid: self.last_pkid,
983 last_puback: self.last_puback,
984 inflight: self.inflight,
985 max_inflight: self.max_inflight,
986 outgoing_pub: self.outgoing_pub.clone(),
987 outgoing_pub_notice: Self::new_notice_slots_with_len(tracking_len),
988 outgoing_pub_ack: self.outgoing_pub_ack.clone(),
989 outgoing_rel: self.outgoing_rel.clone(),
990 outgoing_rel_notice: Self::new_notice_slots_with_len(self.outgoing_rel_notice.len()),
991 incoming_pub: self.incoming_pub.clone(),
992 collision: self.collision.clone(),
993 collision_notice: None,
994 tracked_subscribe: BTreeMap::new(),
995 tracked_unsubscribe: BTreeMap::new(),
996 events: self.events.clone(),
997 manual_acks: self.manual_acks,
998 }
999 }
1000}
1001
1002#[cfg(test)]
1003mod test {
1004 use super::{MqttState, StateError};
1005 use crate::mqttbytes::v4::*;
1006 use crate::mqttbytes::*;
1007 use crate::notice::{
1008 PublishNoticeTx, PublishResult, SubscribeNoticeError, SubscribeNoticeTx,
1009 UnsubscribeNoticeError, UnsubscribeNoticeTx,
1010 };
1011 use crate::{Event, Incoming, NoticeFailureReason, Outgoing, Request};
1012 use bytes::Bytes;
1013
1014 fn build_outgoing_publish(qos: QoS) -> Publish {
1015 let topic = "hello/world".to_owned();
1016 let payload = vec![1, 2, 3];
1017
1018 let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload);
1019 publish.qos = qos;
1020 publish
1021 }
1022
1023 fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish {
1024 let topic = "hello/world".to_owned();
1025 let payload = vec![1, 2, 3];
1026
1027 let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload);
1028 publish.pkid = pkid;
1029 publish.qos = qos;
1030 publish
1031 }
1032
1033 fn build_mqttstate() -> MqttState {
1034 MqttState::builder(100).build()
1035 }
1036
1037 fn queue_publish_with_notice(mqtt: &mut MqttState, publish: Publish) -> crate::PublishNotice {
1038 let (tx, notice) = PublishNoticeTx::new();
1039 let (packet, flush_notice) = mqtt
1040 .outgoing_publish_with_notice(publish, Some(tx))
1041 .unwrap();
1042 assert!(packet.is_some());
1043 assert!(flush_notice.is_none());
1044 notice
1045 }
1046
1047 #[test]
1048 fn new_state_preallocates_event_queue_for_read_batch_bursts() {
1049 let mqtt = MqttState::builder(10).build();
1050 assert!(mqtt.events.capacity() >= MqttState::initial_events_capacity());
1051 }
1052
1053 #[test]
1054 fn new_state_uses_warm_tracking_floor() {
1055 let mqtt = MqttState::builder(100).build();
1056
1057 assert_eq!(mqtt.outgoing_pub.len(), 33);
1058 assert_eq!(mqtt.outgoing_pub_notice.len(), 33);
1059 assert_eq!(mqtt.outgoing_rel_notice.len(), 33);
1060 assert_eq!(mqtt.outgoing_pub_ack.len(), 33);
1061 assert_eq!(mqtt.outgoing_rel.len(), 33);
1062 }
1063
1064 #[test]
1065 fn new_state_uses_full_tracking_len_when_max_inflight_is_below_warm_floor() {
1066 let mqtt = MqttState::builder(10).build();
1067
1068 assert_eq!(mqtt.outgoing_pub.len(), 11);
1069 assert_eq!(mqtt.outgoing_pub_notice.len(), 11);
1070 assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
1071 assert_eq!(mqtt.outgoing_pub_ack.len(), 11);
1072 assert_eq!(mqtt.outgoing_rel.len(), 11);
1073 }
1074
1075 #[test]
1076 fn clean_pending_capacity_counts_publish_rel_and_tracked_requests() {
1077 let mut mqtt = MqttState::builder(10).build();
1078 mqtt.outgoing_pub[1] = Some(build_outgoing_publish(QoS::AtLeastOnce));
1079 mqtt.outgoing_pub[2] = Some(build_outgoing_publish(QoS::ExactlyOnce));
1080 mqtt.outgoing_rel.insert(3);
1081 mqtt.outgoing_rel.insert(4);
1082
1083 let (sub_notice, _) = SubscribeNoticeTx::new();
1084 mqtt.tracked_subscribe
1085 .insert(5, (Subscribe::new("a/b", QoS::AtMostOnce), sub_notice));
1086
1087 let (unsub_notice, _) = UnsubscribeNoticeTx::new();
1088 mqtt.tracked_unsubscribe
1089 .insert(6, (Unsubscribe::new("a/b"), unsub_notice));
1090
1091 assert_eq!(mqtt.clean_pending_capacity(), 6);
1092 }
1093
1094 #[test]
1095 fn tracked_request_len_helpers_report_counts() {
1096 let mut mqtt = MqttState::builder(10).build();
1097 let (sub_notice, _) = SubscribeNoticeTx::new();
1098 mqtt.tracked_subscribe
1099 .insert(5, (Subscribe::new("a/b", QoS::AtMostOnce), sub_notice));
1100 let (unsub_notice, _) = UnsubscribeNoticeTx::new();
1101 mqtt.tracked_unsubscribe
1102 .insert(6, (Unsubscribe::new("a/b"), unsub_notice));
1103
1104 assert_eq!(mqtt.tracked_subscribe_len(), 1);
1105 assert_eq!(mqtt.tracked_unsubscribe_len(), 1);
1106 assert!(!mqtt.tracked_requests_is_empty());
1107
1108 mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
1109 assert!(mqtt.tracked_requests_is_empty());
1110 }
1111
1112 #[test]
1113 fn drain_tracked_requests_as_failed_reports_session_reset_and_returns_count() {
1114 let mut mqtt = MqttState::builder(10).build();
1115 let (sub_notice_tx, sub_notice) = SubscribeNoticeTx::new();
1116 mqtt.tracked_subscribe
1117 .insert(5, (Subscribe::new("a/b", QoS::AtMostOnce), sub_notice_tx));
1118 let (unsub_notice_tx, unsub_notice) = UnsubscribeNoticeTx::new();
1119 mqtt.tracked_unsubscribe
1120 .insert(6, (Unsubscribe::new("a/b"), unsub_notice_tx));
1121
1122 let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
1123
1124 assert_eq!(drained, 2);
1125 assert!(mqtt.tracked_requests_is_empty());
1126 assert_eq!(
1127 sub_notice.wait().unwrap_err(),
1128 SubscribeNoticeError::SessionReset
1129 );
1130 assert_eq!(
1131 unsub_notice.wait().unwrap_err(),
1132 UnsubscribeNoticeError::SessionReset
1133 );
1134 }
1135
1136 #[test]
1137 fn drain_tracked_requests_as_failed_is_noop_when_empty() {
1138 let mut mqtt = MqttState::builder(10).build();
1139 let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
1140
1141 assert_eq!(drained, 0);
1142 assert!(mqtt.tracked_requests_is_empty());
1143 }
1144
1145 #[test]
1146 fn tracked_puback_returns_ack_and_preserves_incoming_event() {
1147 let mut mqtt = build_mqttstate();
1148 let notice = queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::AtLeastOnce));
1149 mqtt.events.clear();
1150
1151 let puback = PubAck::new(1);
1152 assert!(
1153 mqtt.handle_incoming_packet(Incoming::PubAck(puback.clone()))
1154 .unwrap()
1155 .is_none()
1156 );
1157
1158 assert_eq!(notice.wait(), Ok(PublishResult::Qos1(puback.clone())));
1159 assert_eq!(
1160 mqtt.events.pop_front(),
1161 Some(Event::Incoming(Packet::PubAck(puback)))
1162 );
1163 }
1164
1165 #[test]
1166 fn tracked_suback_returns_ack_and_preserves_incoming_event() {
1167 let mut mqtt = build_mqttstate();
1168 let (tx, notice) = SubscribeNoticeTx::new();
1169 mqtt.outgoing_subscribe(Subscribe::new("a/b", QoS::AtMostOnce), Some(tx))
1170 .unwrap();
1171 mqtt.events.clear();
1172
1173 let suback = SubAck::new(1, vec![SubscribeReasonCode::Failure]);
1174 assert!(
1175 mqtt.handle_incoming_packet(Incoming::SubAck(suback.clone()))
1176 .unwrap()
1177 .is_none()
1178 );
1179
1180 assert_eq!(notice.wait(), Ok(suback.clone()));
1181 assert_eq!(
1182 mqtt.events.pop_front(),
1183 Some(Event::Incoming(Packet::SubAck(suback)))
1184 );
1185 }
1186
1187 #[test]
1188 fn tracked_unsuback_returns_ack_and_preserves_incoming_event() {
1189 let mut mqtt = build_mqttstate();
1190 let (tx, notice) = UnsubscribeNoticeTx::new();
1191 mqtt.outgoing_unsubscribe(Unsubscribe::new("a/b"), Some(tx))
1192 .unwrap();
1193 mqtt.events.clear();
1194
1195 let unsuback = UnsubAck::new(1);
1196 assert!(
1197 mqtt.handle_incoming_packet(Incoming::UnsubAck(unsuback.clone()))
1198 .unwrap()
1199 .is_none()
1200 );
1201
1202 assert_eq!(notice.wait(), Ok(unsuback.clone()));
1203 assert_eq!(
1204 mqtt.events.pop_front(),
1205 Some(Event::Incoming(Packet::UnsubAck(unsuback)))
1206 );
1207 }
1208
1209 #[test]
1210 fn outgoing_publish_grows_tracking_capacity_on_demand() {
1211 let mut mqtt = build_mqttstate();
1212 mqtt.last_pkid = 32;
1213
1214 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1215 .unwrap();
1216
1217 assert_eq!(mqtt.outgoing_pub.len(), 34);
1218 assert_eq!(mqtt.outgoing_pub_notice.len(), 34);
1219 assert_eq!(mqtt.outgoing_rel_notice.len(), 34);
1220 assert_eq!(mqtt.outgoing_pub_ack.len(), 34);
1221 assert_eq!(mqtt.outgoing_rel.len(), 34);
1222 assert!(mqtt.outgoing_pub[33].is_some());
1223 }
1224
1225 #[test]
1226 fn incoming_puback_shrinks_tracking_when_state_becomes_empty() {
1227 let mut mqtt = build_mqttstate();
1228 mqtt.last_pkid = 32;
1229 mqtt.last_puback = 32;
1230
1231 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1232 .unwrap();
1233 assert_eq!(mqtt.outgoing_pub.len(), 34);
1234
1235 mqtt.handle_incoming_puback(&PubAck::new(33)).unwrap();
1236
1237 assert_eq!(mqtt.outgoing_pub.len(), 33);
1238 assert_eq!(mqtt.outgoing_pub_notice.len(), 33);
1239 assert_eq!(mqtt.outgoing_rel_notice.len(), 33);
1240 assert_eq!(mqtt.outgoing_pub_ack.len(), 33);
1241 assert_eq!(mqtt.outgoing_rel.len(), 33);
1242 assert_eq!(mqtt.last_pkid, 0);
1243 assert_eq!(mqtt.last_puback, 0);
1244 }
1245
1246 #[test]
1247 fn incoming_puback_does_not_shrink_tracking_when_state_is_non_empty() {
1248 let mut mqtt = build_mqttstate();
1249 mqtt.last_pkid = 32;
1250 mqtt.last_puback = 32;
1251
1252 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1253 .unwrap();
1254 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1255 .unwrap();
1256 assert_eq!(mqtt.outgoing_pub.len(), 35);
1257
1258 mqtt.handle_incoming_puback(&PubAck::new(33)).unwrap();
1259
1260 assert_eq!(mqtt.outgoing_pub.len(), 35);
1261 assert_eq!(mqtt.inflight, 1);
1262 }
1263
1264 #[test]
1265 fn clean_preserves_packet_id_frontier_when_pending_state_is_exported() {
1266 let mut mqtt = build_mqttstate();
1267 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1268 .unwrap();
1269 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1270 .unwrap();
1271 assert_eq!(mqtt.last_pkid, 2);
1272
1273 let pending = mqtt.clean();
1274 assert_eq!(pending.len(), 2);
1275 assert_eq!(mqtt.last_pkid, 2);
1276 assert_eq!(mqtt.last_puback, 0);
1277
1278 for request in pending {
1279 let packet = mqtt.handle_outgoing_packet(request).unwrap().unwrap();
1280 match packet {
1281 Packet::Publish(publish) => assert!(matches!(publish.pkid, 1 | 2)),
1282 packet => panic!("Unexpected replay packet: {packet:?}"),
1283 }
1284 }
1285
1286 let packet = mqtt
1287 .handle_outgoing_packet(Request::Publish(build_outgoing_publish(QoS::AtLeastOnce)))
1288 .unwrap()
1289 .unwrap();
1290 match packet {
1291 Packet::Publish(publish) => assert_eq!(publish.pkid, 3),
1292 packet => panic!("Unexpected fresh packet after replay: {packet:?}"),
1293 }
1294
1295 assert!(mqtt.collision.is_none());
1296 }
1297
1298 #[test]
1299 fn clone_preserves_current_tracking_lengths_after_shrink() {
1300 let mut mqtt = build_mqttstate();
1301 mqtt.last_pkid = 32;
1302 mqtt.last_puback = 32;
1303
1304 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1305 .unwrap();
1306 mqtt.handle_incoming_puback(&PubAck::new(33)).unwrap();
1307
1308 let cloned = mqtt.clone();
1309 assert_eq!(cloned.outgoing_pub.len(), 33);
1310 assert_eq!(cloned.outgoing_pub_notice.len(), 33);
1311 assert_eq!(cloned.outgoing_rel_notice.len(), 33);
1312 assert_eq!(cloned.outgoing_pub_ack.len(), 33);
1313 assert_eq!(cloned.outgoing_rel.len(), 33);
1314 }
1315
1316 #[test]
1317 fn next_pkid_increments_as_expected() {
1318 let mut mqtt = build_mqttstate();
1319
1320 for i in 1..=100 {
1321 let pkid = mqtt.next_pkid();
1322
1323 let expected = i % 100;
1325 if expected == 0 {
1326 break;
1327 }
1328
1329 assert_eq!(expected, pkid);
1330 }
1331 }
1332
1333 #[test]
1334 fn can_send_publish_searches_free_pkid_after_control_ids_pass_inflight_limit() {
1335 let mut mqtt = MqttState::builder(4).build();
1336 let mut active_publish = build_outgoing_publish(QoS::AtLeastOnce);
1337 active_publish.pkid = 1;
1338 mqtt.outgoing_pub[1] = Some(active_publish);
1339 mqtt.inflight = 1;
1340 mqtt.last_pkid = 5;
1341
1342 assert!(mqtt.can_send_publish(&build_outgoing_publish(QoS::AtLeastOnce)));
1343
1344 let packet = mqtt
1345 .outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1346 .unwrap()
1347 .unwrap();
1348 match packet {
1349 Packet::Publish(publish) => assert_eq!(publish.pkid, 2),
1350 packet => panic!("Unexpected packet: {packet:?}"),
1351 }
1352 }
1353
1354 #[test]
1355 fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() {
1356 let mut mqtt = build_mqttstate();
1357
1358 let publish = build_outgoing_publish(QoS::AtMostOnce);
1360
1361 mqtt.outgoing_publish(publish).unwrap();
1363 assert_eq!(mqtt.last_pkid, 0);
1364 assert_eq!(mqtt.inflight, 0);
1365
1366 let publish = build_outgoing_publish(QoS::AtLeastOnce);
1368
1369 mqtt.outgoing_publish(publish.clone()).unwrap();
1371 assert_eq!(mqtt.last_pkid, 1);
1372 assert_eq!(mqtt.inflight, 1);
1373
1374 mqtt.outgoing_publish(publish).unwrap();
1376 assert_eq!(mqtt.last_pkid, 2);
1377 assert_eq!(mqtt.inflight, 2);
1378
1379 let publish = build_outgoing_publish(QoS::ExactlyOnce);
1381
1382 mqtt.outgoing_publish(publish.clone()).unwrap();
1384 assert_eq!(mqtt.last_pkid, 3);
1385 assert_eq!(mqtt.inflight, 3);
1386
1387 mqtt.outgoing_publish(publish).unwrap();
1389 assert_eq!(mqtt.last_pkid, 4);
1390 assert_eq!(mqtt.inflight, 4);
1391 }
1392
1393 #[test]
1394 fn incoming_publish_should_be_added_to_queue_correctly() {
1395 let mut mqtt = build_mqttstate();
1396
1397 let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
1399 let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
1400 let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
1401
1402 let _ = mqtt.handle_incoming_publish(&publish1);
1403 let _ = mqtt.handle_incoming_publish(&publish2);
1404 let _ = mqtt.handle_incoming_publish(&publish3);
1405
1406 assert!(mqtt.incoming_pub.contains(3));
1408 }
1409
1410 #[test]
1411 fn incoming_publish_should_be_acked() {
1412 let mut mqtt = build_mqttstate();
1413
1414 let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
1416 let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
1417 let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
1418
1419 assert!(mqtt.handle_incoming_publish(&publish1).is_none());
1420
1421 let packet = mqtt.handle_incoming_publish(&publish2).unwrap();
1422 if let Packet::PubAck(puback) = packet {
1423 let pkid = puback.pkid;
1424 assert_eq!(pkid, 2);
1425 } else {
1426 panic!("missing puback");
1427 }
1428
1429 let packet = mqtt.handle_incoming_publish(&publish3).unwrap();
1430 if let Packet::PubRec(pubrec) = packet {
1431 let pkid = pubrec.pkid;
1432 assert_eq!(pkid, 3);
1433 } else {
1434 panic!("missing PubRec");
1435 }
1436 }
1437
1438 #[test]
1439 fn incoming_publish_should_not_be_acked_with_manual_acks() {
1440 let mut mqtt = build_mqttstate();
1441 mqtt.manual_acks = true;
1442
1443 let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
1445 let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
1446 let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
1447
1448 assert!(mqtt.handle_incoming_publish(&publish1).is_none());
1449 assert!(mqtt.handle_incoming_publish(&publish2).is_none());
1450 assert!(mqtt.handle_incoming_publish(&publish3).is_none());
1451
1452 assert!(mqtt.incoming_pub.contains(3));
1453
1454 assert!(mqtt.events.is_empty());
1455 }
1456
1457 #[test]
1458 fn handle_incoming_packet_should_emit_incoming_before_derived_qos1_ack() {
1459 let mut mqtt = build_mqttstate();
1460 let publish = build_incoming_publish(QoS::AtLeastOnce, 42);
1461
1462 mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
1463 .unwrap();
1464
1465 assert_eq!(mqtt.events.len(), 2);
1466 assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
1467 assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubAck(42)));
1468 }
1469
1470 #[test]
1471 fn handle_incoming_packet_should_emit_incoming_before_derived_qos2_ack() {
1472 let mut mqtt = build_mqttstate();
1473 let publish = build_incoming_publish(QoS::ExactlyOnce, 43);
1474
1475 mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
1476 .unwrap();
1477
1478 assert_eq!(mqtt.events.len(), 2);
1479 assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
1480 assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubRec(43)));
1481 }
1482
1483 #[test]
1484 fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
1485 let mut mqtt = build_mqttstate();
1486 let publish = build_incoming_publish(QoS::ExactlyOnce, 1);
1487
1488 let packet = mqtt.handle_incoming_publish(&publish).unwrap();
1489 match packet {
1490 Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
1491 _ => panic!("Invalid network request: {packet:?}"),
1492 }
1493 }
1494
1495 #[test]
1496 fn incoming_puback_should_remove_correct_publish_from_queue() {
1497 let mut mqtt = build_mqttstate();
1498
1499 let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
1500 let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
1501
1502 mqtt.outgoing_publish(publish1).unwrap();
1503 mqtt.outgoing_publish(publish2).unwrap();
1504 assert_eq!(mqtt.inflight, 2);
1505
1506 mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap();
1507 assert_eq!(mqtt.inflight, 1);
1508
1509 mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
1510 assert_eq!(mqtt.inflight, 0);
1511
1512 assert!(mqtt.outgoing_pub[1].is_none());
1513 assert!(mqtt.outgoing_pub[2].is_none());
1514 }
1515
1516 #[test]
1517 fn incoming_puback_advances_last_puback_only_on_contiguous_boundary() {
1518 let mut mqtt = build_mqttstate();
1519
1520 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1521 .unwrap();
1522 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1523 .unwrap();
1524 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1525 .unwrap();
1526 assert_eq!(mqtt.last_puback, 0);
1527
1528 mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
1529 assert_eq!(mqtt.last_puback, 0);
1530
1531 mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap();
1532 assert_eq!(mqtt.last_puback, 2);
1533
1534 mqtt.handle_incoming_puback(&PubAck::new(3)).unwrap();
1535 assert_eq!(mqtt.last_puback, 3);
1536 }
1537
1538 #[test]
1539 fn mixed_qos_completion_clears_outbound_drain_state() {
1540 let mut mqtt = build_mqttstate();
1541
1542 mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
1543 .unwrap();
1544 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1545 .unwrap();
1546 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1547 .unwrap();
1548 mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
1549 .unwrap();
1550
1551 mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap();
1552 mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
1553 mqtt.handle_incoming_puback(&PubAck::new(3)).unwrap();
1554 mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap();
1555 mqtt.handle_incoming_pubrec(&PubRec::new(4)).unwrap();
1556 mqtt.handle_incoming_pubcomp(&PubComp::new(4)).unwrap();
1557
1558 assert_eq!(mqtt.inflight, 0);
1559 assert!(mqtt.outbound_requests_drained());
1560 assert!(mqtt.outgoing_pub_ack.ones().next().is_none());
1561 assert!(mqtt.outgoing_rel.ones().next().is_none());
1562 }
1563
1564 #[test]
1565 fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() {
1566 let mut mqtt = build_mqttstate();
1567
1568 let got = mqtt.handle_incoming_puback(&PubAck::new(101)).unwrap_err();
1569
1570 match got {
1571 StateError::Unsolicited(pkid) => assert_eq!(pkid, 101),
1572 e => panic!("Unexpected error: {e}"),
1573 }
1574 }
1575
1576 #[test]
1577 fn incoming_puback_with_pkid_beyond_allocated_tracking_is_unsolicited() {
1578 let mut mqtt = build_mqttstate();
1579
1580 let got = mqtt.handle_incoming_puback(&PubAck::new(50)).unwrap_err();
1581
1582 match got {
1583 StateError::Unsolicited(pkid) => assert_eq!(pkid, 50),
1584 e => panic!("Unexpected error: {e}"),
1585 }
1586 }
1587
1588 #[test]
1589 fn outgoing_publish_with_pkid_above_max_inflight_is_unsolicited_and_does_not_grow_tracking() {
1590 let mut mqtt = MqttState::builder(10).build();
1591 let mut publish = build_outgoing_publish(QoS::AtLeastOnce);
1592 publish.pkid = 50;
1593
1594 let got = mqtt
1595 .handle_outgoing_packet(Request::Publish(publish))
1596 .unwrap_err();
1597
1598 match got {
1599 StateError::Unsolicited(pkid) => assert_eq!(pkid, 50),
1600 e => panic!("Unexpected error: {e}"),
1601 }
1602 assert_eq!(mqtt.outgoing_pub.len(), 11);
1603 assert_eq!(mqtt.outgoing_pub_notice.len(), 11);
1604 assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
1605 assert_eq!(mqtt.inflight, 0);
1606 }
1607
1608 #[test]
1609 fn outgoing_pubrel_with_pkid_above_max_inflight_is_unsolicited_and_does_not_grow_tracking() {
1610 let mut mqtt = MqttState::builder(10).build();
1611
1612 let got = mqtt
1613 .handle_outgoing_packet(Request::PubRel(PubRel::new(50)))
1614 .unwrap_err();
1615
1616 match got {
1617 StateError::Unsolicited(pkid) => assert_eq!(pkid, 50),
1618 e => panic!("Unexpected error: {e}"),
1619 }
1620 assert_eq!(mqtt.outgoing_pub.len(), 11);
1621 assert_eq!(mqtt.outgoing_pub_notice.len(), 11);
1622 assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
1623 assert_eq!(mqtt.inflight, 0);
1624 }
1625
1626 #[test]
1627 fn clean_keeps_oldest_unacked_publish_first_after_out_of_order_puback() {
1628 let mut mqtt = build_mqttstate();
1629
1630 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1631 .unwrap();
1632 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1633 .unwrap();
1634 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
1635 .unwrap();
1636
1637 mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap();
1638 let requests = mqtt.clean();
1639
1640 let pending_pkids: Vec<u16> = requests
1641 .iter()
1642 .map(|req| match req {
1643 Request::Publish(publish) => publish.pkid,
1644 req => panic!("Unexpected request while cleaning: {req:?}"),
1645 })
1646 .collect();
1647
1648 assert_eq!(pending_pkids, vec![1, 3]);
1649 }
1650
1651 #[test]
1652 fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() {
1653 let mut mqtt = build_mqttstate();
1654
1655 let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
1656 let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
1657
1658 let _publish_out = mqtt.outgoing_publish(publish1);
1659 let _publish_out = mqtt.outgoing_publish(publish2);
1660
1661 mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap();
1662 assert_eq!(mqtt.inflight, 2);
1663
1664 let backup = mqtt.outgoing_pub[1].clone();
1666 assert_eq!(backup.unwrap().pkid, 1);
1667
1668 assert!(mqtt.outgoing_rel.contains(2));
1670 }
1671
1672 #[test]
1673 fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() {
1674 let mut mqtt = build_mqttstate();
1675
1676 let publish = build_outgoing_publish(QoS::ExactlyOnce);
1677 let packet = mqtt.outgoing_publish(publish).unwrap().unwrap();
1678 match packet {
1679 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
1680 packet => panic!("Invalid network request: {packet:?}"),
1681 }
1682
1683 let packet = mqtt
1684 .handle_incoming_pubrec(&PubRec::new(1))
1685 .unwrap()
1686 .unwrap();
1687 match packet {
1688 Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
1689 packet => panic!("Invalid network request: {packet:?}"),
1690 }
1691 }
1692
1693 #[test]
1694 fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() {
1695 let mut mqtt = build_mqttstate();
1696 let publish = build_incoming_publish(QoS::ExactlyOnce, 1);
1697
1698 let packet = mqtt.handle_incoming_publish(&publish).unwrap();
1699 match packet {
1700 Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
1701 packet => panic!("Invalid network request: {packet:?}"),
1702 }
1703
1704 let packet = mqtt
1705 .handle_incoming_pubrel(&PubRel::new(1))
1706 .unwrap()
1707 .unwrap();
1708 match packet {
1709 Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1),
1710 packet => panic!("Invalid network request: {packet:?}"),
1711 }
1712 }
1713
1714 #[test]
1715 fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() {
1716 let mut mqtt = build_mqttstate();
1717 let publish = build_outgoing_publish(QoS::ExactlyOnce);
1718
1719 mqtt.outgoing_publish(publish).unwrap();
1720 mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap();
1721
1722 mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap();
1723 assert_eq!(mqtt.inflight, 0);
1724 }
1725
1726 #[test]
1727 fn incoming_pubcomp_collision_replay_should_restore_qos2_tracking() {
1728 let mut mqtt = build_mqttstate();
1729 let publish = build_outgoing_publish(QoS::ExactlyOnce);
1730 mqtt.outgoing_publish(publish).unwrap();
1731
1732 let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
1733 collided.pkid = 1;
1734 assert!(mqtt.outgoing_publish(collided).unwrap().is_none());
1735 assert!(mqtt.collision.is_some());
1736
1737 mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap();
1738 let packet = mqtt
1739 .handle_incoming_pubcomp(&PubComp::new(1))
1740 .unwrap()
1741 .unwrap();
1742 match packet {
1743 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
1744 packet => panic!("Invalid network request: {packet:?}"),
1745 }
1746
1747 assert!(mqtt.outgoing_pub[1].is_some());
1748 assert_eq!(mqtt.inflight, 1);
1749
1750 let packet = mqtt
1751 .handle_incoming_pubrec(&PubRec::new(1))
1752 .unwrap()
1753 .unwrap();
1754 match packet {
1755 Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
1756 packet => panic!("Invalid network request: {packet:?}"),
1757 }
1758 }
1759
1760 #[test]
1761 fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() {
1762 let mut mqtt = build_mqttstate();
1763 mqtt.outgoing_ping().unwrap();
1764
1765 let publish = build_outgoing_publish(QoS::AtLeastOnce);
1767 mqtt.handle_outgoing_packet(Request::Publish(publish))
1768 .unwrap();
1769 mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1)))
1770 .unwrap();
1771
1772 match mqtt.outgoing_ping() {
1774 Ok(_) => panic!("Should throw pingresp await error"),
1775 Err(StateError::AwaitPingResp) => (),
1776 Err(e) => panic!("Should throw pingresp await error. Error = {e:?}"),
1777 }
1778 }
1779
1780 #[test]
1781 fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() {
1782 let mut mqtt = build_mqttstate();
1783
1784 mqtt.outgoing_ping().unwrap();
1786 mqtt.handle_incoming_packet(Incoming::PingResp).unwrap();
1787
1788 mqtt.outgoing_ping().unwrap();
1790 }
1791
1792 #[test]
1793 fn clean_is_calculating_pending_correctly() {
1794 fn build_outgoing_pub() -> Vec<Option<Publish>> {
1795 vec![
1796 None,
1797 Some(Publish {
1798 dup: false,
1799 qos: QoS::AtMostOnce,
1800 retain: false,
1801 topic: Bytes::from_static(b"test"),
1802 pkid: 1,
1803 payload: "".into(),
1804 }),
1805 Some(Publish {
1806 dup: false,
1807 qos: QoS::AtMostOnce,
1808 retain: false,
1809 topic: Bytes::from_static(b"test"),
1810 pkid: 2,
1811 payload: "".into(),
1812 }),
1813 Some(Publish {
1814 dup: false,
1815 qos: QoS::AtMostOnce,
1816 retain: false,
1817 topic: Bytes::from_static(b"test"),
1818 pkid: 3,
1819 payload: "".into(),
1820 }),
1821 None,
1822 None,
1823 Some(Publish {
1824 dup: false,
1825 qos: QoS::AtMostOnce,
1826 retain: false,
1827 topic: Bytes::from_static(b"test"),
1828 pkid: 6,
1829 payload: "".into(),
1830 }),
1831 ]
1832 }
1833
1834 let mut mqtt = build_mqttstate();
1835 mqtt.outgoing_pub = build_outgoing_pub();
1836 mqtt.last_puback = 3;
1837 let requests = mqtt.clean();
1838 let res = vec![6, 1, 2, 3];
1839 for (req, idx) in requests.iter().zip(res) {
1840 if let Request::Publish(publish) = req {
1841 assert_eq!(publish.pkid, idx);
1842 } else {
1843 unreachable!()
1844 }
1845 }
1846
1847 mqtt.outgoing_pub = build_outgoing_pub();
1848 mqtt.last_puback = 0;
1849 let requests = mqtt.clean();
1850 let res = vec![1, 2, 3, 6];
1851 for (req, idx) in requests.iter().zip(res) {
1852 if let Request::Publish(publish) = req {
1853 assert_eq!(publish.pkid, idx);
1854 } else {
1855 unreachable!()
1856 }
1857 }
1858
1859 mqtt.outgoing_pub = build_outgoing_pub();
1860 mqtt.last_puback = 6;
1861 let requests = mqtt.clean();
1862 let res = vec![1, 2, 3, 6];
1863 for (req, idx) in requests.iter().zip(res) {
1864 if let Request::Publish(publish) = req {
1865 assert_eq!(publish.pkid, idx);
1866 } else {
1867 unreachable!()
1868 }
1869 }
1870 }
1871}