1use super::mqttbytes::v5::{
2 Auth, ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck,
3 PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish,
4 PublishProperties, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason,
5 Unsubscribe,
6};
7use super::mqttbytes::{self, Error as MqttError, QoS};
8use crate::auth::{AuthLifecycle, IncomingAuthEffect};
9use crate::notice::{
10 AuthNoticeError, PublishNoticeTx, PublishResult, SubscribeNoticeTx, TrackedNoticeTx,
11 UnsubscribeNoticeTx,
12};
13use crate::{
14 AuthContext, AuthError, AuthExchangeKind, Authenticator, NoticeFailureReason,
15 PublishNoticeError, TopicAliasPolicy,
16};
17
18use super::{Event, Incoming, Outgoing, Request};
19
20use bytes::Bytes;
21use fixedbitset::FixedBitSet;
22use std::collections::{BTreeMap, HashMap, VecDeque};
23use std::sync::{Arc, Mutex};
24use std::{io, time::Instant};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
27struct PendingAutoTopicAlias {
28 topic: Bytes,
29 alias: u16,
30 previous_topic: Option<Bytes>,
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
34enum AutoTopicAliasAction {
35 Existing { original_topic: Bytes, alias: u16 },
36 New(PendingAutoTopicAlias),
37}
38
39impl AutoTopicAliasAction {
40 const fn pending_alias(&self) -> Option<&PendingAutoTopicAlias> {
41 match self {
42 Self::Existing { .. } => None,
43 Self::New(pending) => Some(pending),
44 }
45 }
46
47 const fn existing_alias(&self) -> Option<u16> {
48 match self {
49 Self::Existing { alias, .. } => Some(*alias),
50 Self::New(_) => None,
51 }
52 }
53
54 fn restore_for_replay(self, publish: &mut Publish) {
55 match self {
56 Self::Existing { original_topic, .. } => {
57 publish.topic = original_topic;
58 MqttState::strip_publish_topic_alias(publish);
59 }
60 Self::New(_) => {
61 MqttState::strip_publish_topic_alias(publish);
62 }
63 }
64 }
65}
66
67#[derive(Debug, thiserror::Error)]
69pub enum StateError {
70 #[error("Io error: {0:?}")]
72 Io(#[from] io::Error),
73 #[error("Conversion error {0:?}")]
74 Coversion(#[from] core::num::TryFromIntError),
75 #[error("Invalid state for a given operation")]
77 InvalidState,
78 #[error("Received unsolicited ack pkid: {0}")]
80 Unsolicited(u16),
81 #[error("Last pingreq isn't acked")]
83 AwaitPingResp,
84 #[error("Received a wrong packet while waiting for another packet")]
86 WrongPacket,
87 #[error("Timeout while waiting to resolve collision")]
88 CollisionTimeout,
89 #[error("A Subscribe packet must contain atleast one filter")]
90 EmptySubscription,
91 #[error("Mqtt serialization/deserialization error: {0}")]
92 Deserialization(MqttError),
93 #[error(
94 "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'."
95 )]
96 InvalidAlias { alias: u16, max: u16 },
97 #[error(
98 "Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'"
99 )]
100 OutgoingPacketTooLarge { pkt_size: u32, max: u32 },
101 #[error(
102 "Cannot receive packet of size '{pkt_size:?}'. It's greater than the client's maximum packet size of: '{max:?}'"
103 )]
104 IncomingPacketTooLarge { pkt_size: usize, max: usize },
105 #[error("Server sent disconnect with reason `{reason_string:?}` and code '{reason_code:?}' ")]
106 ServerDisconnect {
107 reason_code: DisconnectReasonCode,
108 reason_string: Option<String>,
109 },
110 #[error("Connection failed with reason '{reason:?}' ")]
111 ConnFail { reason: ConnectReturnCode },
112 #[error("Connection closed by peer abruptly")]
113 ConnectionAborted,
114 #[error("Authentication error: {0}")]
115 AuthError(String),
116 #[error("Authenticator not set")]
117 AuthenticatorNotSet,
118}
119
120impl From<mqttbytes::Error> for StateError {
121 fn from(value: MqttError) -> Self {
122 match value {
123 MqttError::OutgoingPacketTooLarge { pkt_size, max } => {
124 Self::OutgoingPacketTooLarge { pkt_size, max }
125 }
126 e => Self::Deserialization(e),
127 }
128 }
129}
130
131#[derive(Debug)]
138pub struct MqttState {
139 pub await_pingresp: bool,
141 pub collision_ping_count: usize,
145 last_incoming: Instant,
147 last_outgoing: Instant,
149 pub(crate) last_pkid: u16,
151 pub(crate) last_puback: u16,
153 pub(crate) inflight: u16,
155 pub(crate) outgoing_pub: Vec<Option<Publish>>,
157 pub(crate) outgoing_pub_notice: Vec<Option<PublishNoticeTx>>,
159 pub(crate) outgoing_pub_ack: FixedBitSet,
161 pub(crate) outgoing_rel: FixedBitSet,
163 pub(crate) outgoing_rel_notice: Vec<Option<PublishNoticeTx>>,
165 pub(crate) incoming_pub: FixedBitSet,
167 pub collision: Option<Publish>,
169 pub(crate) collision_notice: Option<PublishNoticeTx>,
171 pub(crate) tracked_subscribe: BTreeMap<u16, (Subscribe, SubscribeNoticeTx)>,
173 pub(crate) tracked_unsubscribe: BTreeMap<u16, (Unsubscribe, UnsubscribeNoticeTx)>,
175 pub events: VecDeque<Event>,
177 pub manual_acks: bool,
179 incoming_topic_aliases: HashMap<u16, Bytes>,
181 outgoing_topic_aliases: HashMap<u16, Bytes>,
183 auto_outgoing_topic_aliases: HashMap<Bytes, u16>,
185 next_auto_topic_alias: Option<u16>,
186 auto_topic_aliases: bool,
187 auto_topic_alias_policy: TopicAliasPolicy,
188 auto_topic_alias_lru: VecDeque<u16>,
189 pub broker_topic_alias_max: u16,
191 pub(crate) max_outgoing_inflight: u16,
193 max_outgoing_inflight_upper_limit: u16,
195 authenticator: Option<Arc<Mutex<dyn Authenticator>>>,
197 auth: AuthLifecycle,
199}
200
201#[derive(Debug)]
207pub struct MqttStateBuilder {
208 max_inflight: u16,
209 manual_acks: bool,
210 auto_topic_aliases: bool,
211 auto_topic_alias_policy: TopicAliasPolicy,
212 authenticator: Option<Arc<Mutex<dyn Authenticator>>>,
213 authentication_method: Option<String>,
214}
215
216impl MqttStateBuilder {
217 #[must_use]
219 pub const fn new(max_inflight: u16) -> Self {
220 Self {
221 max_inflight,
222 manual_acks: false,
223 auto_topic_aliases: false,
224 auto_topic_alias_policy: TopicAliasPolicy::Monotonic,
225 authenticator: None,
226 authentication_method: None,
227 }
228 }
229
230 #[must_use]
232 pub const fn manual_acks(mut self, manual_acks: bool) -> Self {
233 self.manual_acks = manual_acks;
234 self
235 }
236
237 #[must_use]
239 pub const fn auto_topic_aliases(mut self, auto_topic_aliases: bool) -> Self {
240 self.auto_topic_aliases = auto_topic_aliases;
241 self
242 }
243
244 #[must_use]
246 pub const fn topic_alias_policy(mut self, auto_topic_alias_policy: TopicAliasPolicy) -> Self {
247 self.auto_topic_alias_policy = auto_topic_alias_policy;
248 self
249 }
250
251 #[must_use]
253 pub fn authentication_method(mut self, authentication_method: Option<String>) -> Self {
254 self.authentication_method = authentication_method;
255 self
256 }
257
258 #[must_use]
260 pub fn authenticator(mut self, authenticator: Arc<Mutex<dyn Authenticator>>) -> Self {
261 self.authenticator = Some(authenticator);
262 self
263 }
264
265 #[must_use]
267 pub fn auth_manager(mut self, authenticator: Arc<Mutex<dyn Authenticator>>) -> Self {
268 self.authenticator = Some(authenticator);
269 self
270 }
271
272 #[must_use]
274 pub fn build(self) -> MqttState {
275 MqttState::new_internal(
276 self.max_inflight,
277 self.manual_acks,
278 self.auto_topic_aliases,
279 self.auto_topic_alias_policy,
280 self.authentication_method,
281 self.authenticator,
282 )
283 }
284}
285
286impl MqttState {
287 const fn initial_events_capacity() -> usize {
288 128
289 }
290
291 fn outgoing_tracking_len(max_inflight: u16) -> usize {
292 usize::from(max_inflight) + 1
293 }
294
295 fn new_notice_slots_with_len(size: usize) -> Vec<Option<PublishNoticeTx>> {
296 std::iter::repeat_with(|| None).take(size).collect()
297 }
298
299 fn new_notice_slots(max_inflight: u16) -> Vec<Option<PublishNoticeTx>> {
300 Self::new_notice_slots_with_len(Self::outgoing_tracking_len(max_inflight))
301 }
302
303 fn clean_pending_capacity(&self) -> usize {
304 self.outgoing_pub
305 .iter()
306 .filter(|publish| publish.is_some())
307 .count()
308 + self.outgoing_rel.ones().count()
309 + self.tracked_subscribe.len()
310 + self.tracked_unsubscribe.len()
311 }
312
313 const fn next_publish_pkid_after(&self, pkid: u16) -> u16 {
314 if pkid >= self.max_outgoing_inflight {
315 1
316 } else {
317 pkid + 1
318 }
319 }
320
321 fn packet_identifier_in_use(&self, pkid: u16) -> bool {
322 let index = usize::from(pkid);
323 self.outgoing_pub.get(index).is_some_and(Option::is_some)
324 || self.outgoing_rel.contains(index)
325 || self.tracked_subscribe.contains_key(&pkid)
326 || self.tracked_unsubscribe.contains_key(&pkid)
327 }
328
329 pub(crate) fn can_send_publish(&self, publish: &Publish) -> bool {
330 if publish.qos == QoS::AtMostOnce {
331 return true;
332 }
333
334 if self.inflight >= self.max_outgoing_inflight || self.collision.is_some() {
335 return false;
336 }
337
338 if publish.pkid == 0 {
339 return self.next_publish_pkid().is_some();
340 }
341
342 publish.pkid != 0
343 && publish.pkid <= self.max_outgoing_inflight
344 && !self.packet_identifier_in_use(publish.pkid)
345 }
346
347 pub(crate) fn control_packet_identifier_available(&self) -> bool {
348 (1..=u16::MAX).any(|pkid| !self.packet_identifier_in_use(pkid))
349 }
350
351 #[must_use]
353 pub const fn builder(max_inflight: u16) -> MqttStateBuilder {
354 MqttStateBuilder::new(max_inflight)
355 }
356
357 #[must_use]
361 pub(crate) fn new_internal(
362 max_inflight: u16,
363 manual_acks: bool,
364 auto_topic_aliases: bool,
365 auto_topic_alias_policy: TopicAliasPolicy,
366 authentication_method: Option<String>,
367 authenticator: Option<Arc<Mutex<dyn Authenticator>>>,
368 ) -> Self {
369 Self {
370 await_pingresp: false,
371 collision_ping_count: 0,
372 last_incoming: Instant::now(),
373 last_outgoing: Instant::now(),
374 last_pkid: 0,
375 last_puback: 0,
376 inflight: 0,
377 outgoing_pub: vec![None; max_inflight as usize + 1],
379 outgoing_pub_notice: Self::new_notice_slots(max_inflight),
380 outgoing_pub_ack: FixedBitSet::with_capacity(max_inflight as usize + 1),
381 outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1),
382 outgoing_rel_notice: Self::new_notice_slots(max_inflight),
383 incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1),
384 collision: None,
385 collision_notice: None,
386 tracked_subscribe: BTreeMap::new(),
387 tracked_unsubscribe: BTreeMap::new(),
388 events: VecDeque::with_capacity(Self::initial_events_capacity()),
389 manual_acks,
390 incoming_topic_aliases: HashMap::new(),
391 outgoing_topic_aliases: HashMap::new(),
392 auto_outgoing_topic_aliases: HashMap::new(),
393 next_auto_topic_alias: Some(1),
394 auto_topic_aliases,
395 auto_topic_alias_policy,
396 auto_topic_alias_lru: VecDeque::new(),
397 broker_topic_alias_max: 0,
399 max_outgoing_inflight: max_inflight,
400 max_outgoing_inflight_upper_limit: max_inflight,
401 authenticator,
402 auth: AuthLifecycle::new(authentication_method),
403 }
404 }
405
406 pub fn set_authentication_method(&mut self, authentication_method: Option<String>) {
413 self.auth.set_method(authentication_method);
414 }
415
416 pub(crate) fn begin_authentication_connect(
417 &mut self,
418 authentication_method: Option<String>,
419 ) -> Result<Option<crate::mqttbytes::v5::AuthProperties>, StateError> {
420 self.auth
421 .begin_connect(authentication_method, &mut self.events);
422 let Some(method) = self.auth.method().map(str::to_owned) else {
423 return Ok(None);
424 };
425 let Some(authenticator) = self.authenticator.clone() else {
426 return Ok(None);
427 };
428 let context = AuthContext {
429 kind: AuthExchangeKind::InitialConnect,
430 method: &method,
431 };
432 let start_result = authenticator.lock().unwrap().start(context);
433 let properties = match start_result {
434 Ok(properties) => properties,
435 Err(err) => return Err(self.fail_authenticator(&err)),
436 };
437 properties
438 .map(|properties| crate::auth::normalize_auth_properties(&method, Some(properties)))
439 .transpose()
440 }
441
442 pub(crate) fn validate_successful_connack_authentication_method(
443 &self,
444 connack: &ConnAck,
445 ) -> Result<(), StateError> {
446 self.auth.validate_successful_connack(connack)
447 }
448
449 fn ensure_outgoing_tracking_capacity(&mut self, target_len: usize) {
450 if self.outgoing_pub.len() < target_len {
451 self.outgoing_pub.resize_with(target_len, || None);
452 }
453
454 if self.outgoing_pub_notice.len() < target_len {
455 self.outgoing_pub_notice.resize_with(target_len, || None);
456 }
457
458 if self.outgoing_rel_notice.len() < target_len {
459 self.outgoing_rel_notice.resize_with(target_len, || None);
460 }
461
462 if self.outgoing_pub_ack.len() < target_len {
463 self.outgoing_pub_ack.grow(target_len);
464 }
465
466 if self.outgoing_rel.len() < target_len {
467 self.outgoing_rel.grow(target_len);
468 }
469 }
470
471 pub(crate) fn outbound_requests_drained(&self) -> bool {
472 self.inflight == 0
473 && self.collision.is_none()
474 && self.collision_notice.is_none()
475 && self.tracked_subscribe.is_empty()
476 && self.tracked_unsubscribe.is_empty()
477 && self.outgoing_pub.iter().all(Option::is_none)
478 && self.outgoing_pub_notice.iter().all(Option::is_none)
479 && self.outgoing_rel_notice.iter().all(Option::is_none)
480 && self.outgoing_pub_ack.ones().next().is_none()
481 && self.outgoing_rel.ones().next().is_none()
482 }
483
484 fn maybe_shrink_outgoing_tracking_capacity(&mut self, target_len: usize, pending_empty: bool) {
485 if !pending_empty
486 || self.outgoing_pub.len() <= target_len
487 || !self.outbound_requests_drained()
488 {
489 return;
490 }
491
492 self.outgoing_pub.truncate(target_len);
493 self.outgoing_pub_notice.truncate(target_len);
494 self.outgoing_rel_notice.truncate(target_len);
495 self.outgoing_pub_ack = FixedBitSet::with_capacity(target_len);
496 self.outgoing_rel = FixedBitSet::with_capacity(target_len);
497 self.last_pkid = 0;
499 self.last_puback = 0;
500 }
501
502 pub(crate) fn reconcile_outgoing_tracking_capacity(&mut self, pending_empty: bool) {
503 let target_len = Self::outgoing_tracking_len(self.max_outgoing_inflight);
504 self.ensure_outgoing_tracking_capacity(target_len);
505 self.maybe_shrink_outgoing_tracking_capacity(target_len, pending_empty);
506 }
507
508 pub(crate) fn reset_connection_scoped_state(&mut self) {
509 self.incoming_topic_aliases.clear();
510 self.outgoing_topic_aliases.clear();
511 self.auto_outgoing_topic_aliases.clear();
512 self.next_auto_topic_alias = Some(1);
513 self.auto_topic_alias_lru.clear();
514 self.broker_topic_alias_max = 0;
515 }
516
517 pub(crate) fn replay_topic_aliases(&self) -> HashMap<u16, Bytes> {
518 self.outgoing_topic_aliases.clone()
519 }
520
521 pub(crate) fn prepare_publish_for_replay_with_aliases(
522 publish: &mut Publish,
523 topic_aliases: &mut HashMap<u16, Bytes>,
524 ) -> Result<(), PublishNoticeError> {
525 let Some(alias) = Self::publish_topic_alias(publish) else {
526 return Ok(());
527 };
528
529 if !publish.topic.is_empty() {
530 topic_aliases.insert(alias, publish.topic.clone());
531 Self::strip_publish_topic_alias(publish);
532 return Ok(());
533 }
534
535 if let Some(topic) = topic_aliases.get(&alias) {
536 topic.clone_into(&mut publish.topic);
537 topic_aliases.insert(alias, publish.topic.clone());
538 Self::strip_publish_topic_alias(publish);
539 return Ok(());
540 }
541
542 Err(PublishNoticeError::TopicAliasReplayUnavailable(alias))
543 }
544
545 pub(crate) fn prepare_request_for_replay_with_aliases(
546 request: &mut Request,
547 topic_aliases: &mut HashMap<u16, Bytes>,
548 ) -> Result<(), PublishNoticeError> {
549 if let Request::Publish(publish) = request {
550 Self::prepare_publish_for_replay_with_aliases(publish, topic_aliases)?;
551 }
552
553 Ok(())
554 }
555
556 pub(crate) fn clean_with_notices(&mut self) -> Vec<(Request, Option<TrackedNoticeTx>)> {
557 let mut pending = Vec::with_capacity(self.clean_pending_capacity());
558 let (first_half, second_half) = self
559 .outgoing_pub
560 .split_at_mut(self.last_puback as usize + 1);
561 let (notice_first_half, notice_second_half) = self
562 .outgoing_pub_notice
563 .split_at_mut(self.last_puback as usize + 1);
564
565 for (publish, notice) in second_half
566 .iter_mut()
567 .zip(notice_second_half.iter_mut())
568 .chain(first_half.iter_mut().zip(notice_first_half.iter_mut()))
569 {
570 if let Some(publish) = publish.take() {
571 let request = Request::Publish(publish);
572 pending.push((request, notice.take().map(TrackedNoticeTx::Publish)));
573 } else {
574 _ = notice.take();
575 }
576 }
577
578 for pkid in self.outgoing_rel.ones() {
580 let pkid = u16::try_from(pkid).expect("fixedbitset index always fits in u16");
581 let request = Request::PubRel(PubRel::new(pkid, None));
582 pending.push((
583 request,
584 self.outgoing_rel_notice[pkid as usize]
585 .take()
586 .map(TrackedNoticeTx::Publish),
587 ));
588 }
589 self.outgoing_rel.clear();
590 self.outgoing_pub_ack.clear();
591
592 for (pkid, (mut subscribe, notice)) in std::mem::take(&mut self.tracked_subscribe) {
593 subscribe.pkid = pkid;
594 pending.push((
595 Request::Subscribe(subscribe),
596 Some(TrackedNoticeTx::Subscribe(notice)),
597 ));
598 }
599 for (pkid, (mut unsubscribe, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
600 unsubscribe.pkid = pkid;
601 pending.push((
602 Request::Unsubscribe(unsubscribe),
603 Some(TrackedNoticeTx::Unsubscribe(notice)),
604 ));
605 }
606
607 self.incoming_pub.clear();
609
610 self.await_pingresp = false;
611 self.collision_ping_count = 0;
612 self.inflight = 0;
613 pending
614 }
615
616 pub fn clean(&mut self) -> Vec<Request> {
624 let mut replay_topic_aliases = self.replay_topic_aliases();
625 let mut pending = Vec::with_capacity(self.clean_pending_capacity());
626
627 for (mut request, _) in self.clean_with_notices() {
628 if Self::prepare_request_for_replay_with_aliases(
629 &mut request,
630 &mut replay_topic_aliases,
631 )
632 .is_ok()
633 {
634 pending.push(request);
635 }
636 }
637
638 self.reset_connection_scoped_state();
639 pending
640 }
641
642 pub const fn inflight(&self) -> u16 {
643 self.inflight
644 }
645
646 pub fn tracked_subscribe_len(&self) -> usize {
647 self.tracked_subscribe.len()
648 }
649
650 pub fn tracked_unsubscribe_len(&self) -> usize {
651 self.tracked_unsubscribe.len()
652 }
653
654 pub fn tracked_requests_is_empty(&self) -> bool {
655 self.tracked_subscribe.is_empty() && self.tracked_unsubscribe.is_empty()
656 }
657
658 pub fn drain_tracked_requests_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
659 let mut drained = 0;
660 for (_, (_, notice)) in std::mem::take(&mut self.tracked_subscribe) {
661 drained += 1;
662 notice.error(reason.subscribe_error());
663 }
664 for (_, (_, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
665 drained += 1;
666 notice.error(reason.unsubscribe_error());
667 }
668
669 drained
670 }
671
672 pub(crate) fn fail_pending_notices(&mut self) {
673 for notice in &mut self.outgoing_pub_notice {
674 if let Some(tx) = notice.take() {
675 tx.error(PublishNoticeError::SessionReset);
676 }
677 }
678
679 for notice in &mut self.outgoing_rel_notice {
680 if let Some(tx) = notice.take() {
681 tx.error(PublishNoticeError::SessionReset);
682 }
683 }
684
685 if let Some(tx) = self.collision_notice.take() {
686 tx.error(PublishNoticeError::SessionReset);
687 }
688 self.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
689 self.clear_collision();
690 }
691
692 pub(crate) fn fail_auth_exchange_due_to_session_reset(&mut self) {
693 self.fail_auth_exchange(
694 AuthNoticeError::SessionReset,
695 AuthError::Failed("authentication exchange was reset with the session".to_owned()),
696 );
697 }
698
699 pub(crate) fn fail_reauth_exchange_due_to_session_reset(&mut self) {
700 let Some((method, notice_error)) = self
701 .auth
702 .reset_reauth(AuthNoticeError::SessionReset, &mut self.events)
703 else {
704 return;
705 };
706
707 if let Some(authenticator) = self.authenticator.clone() {
708 authenticator.lock().unwrap().failure(
709 AuthContext {
710 kind: AuthExchangeKind::Reauthentication,
711 method: &method,
712 },
713 AuthError::Failed(notice_error.to_string()),
714 );
715 }
716 }
717
718 pub(crate) fn fail_auth_exchange_due_to_connection_closed(&mut self) {
719 self.fail_auth_exchange(
720 AuthNoticeError::ConnectionClosed,
721 AuthError::Failed("connection closed before authentication completed".to_owned()),
722 );
723 }
724
725 pub(crate) fn fail_auth_exchange_due_to_client_disconnect(&mut self) {
726 self.fail_auth_exchange(
727 AuthNoticeError::ConnectionClosed,
728 AuthError::Failed("authentication aborted by client disconnect".to_owned()),
729 );
730 }
731
732 pub fn handle_outgoing_packet(
740 &mut self,
741 request: Request,
742 ) -> Result<Option<Packet>, StateError> {
743 let (packet, flush_notice) = self.handle_outgoing_packet_with_notice(request, None)?;
744 if let Some(tx) = flush_notice {
745 tx.success(PublishResult::Qos0Flushed);
746 }
747 Ok(packet)
748 }
749
750 pub(crate) fn handle_outgoing_packet_with_notice(
751 &mut self,
752 request: Request,
753 notice: Option<TrackedNoticeTx>,
754 ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
755 let result = match request {
756 Request::Publish(publish) => {
757 let publish_notice = match notice {
758 Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
759 Some(
760 TrackedNoticeTx::Subscribe(_)
761 | TrackedNoticeTx::Unsubscribe(_)
762 | TrackedNoticeTx::Auth(_),
763 )
764 | None => None,
765 };
766 self.outgoing_publish_with_notice(publish, publish_notice)?
767 }
768 Request::PubRel(pubrel) => {
769 let publish_notice = match notice {
770 Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
771 Some(
772 TrackedNoticeTx::Subscribe(_)
773 | TrackedNoticeTx::Unsubscribe(_)
774 | TrackedNoticeTx::Auth(_),
775 )
776 | None => None,
777 };
778 self.outgoing_pubrel_with_notice(pubrel, publish_notice)
779 }
780 Request::Subscribe(subscribe) => {
781 let request_notice = match notice {
782 Some(TrackedNoticeTx::Subscribe(notice)) => Some(notice),
783 Some(
784 TrackedNoticeTx::Publish(_)
785 | TrackedNoticeTx::Unsubscribe(_)
786 | TrackedNoticeTx::Auth(_),
787 )
788 | None => None,
789 };
790 (self.outgoing_subscribe(subscribe, request_notice)?, None)
791 }
792 Request::Unsubscribe(unsubscribe) => {
793 let request_notice = match notice {
794 Some(TrackedNoticeTx::Unsubscribe(notice)) => Some(notice),
795 Some(
796 TrackedNoticeTx::Publish(_)
797 | TrackedNoticeTx::Subscribe(_)
798 | TrackedNoticeTx::Auth(_),
799 )
800 | None => None,
801 };
802 (
803 Some(self.outgoing_unsubscribe(unsubscribe, request_notice)?),
804 None,
805 )
806 }
807 Request::PingReq => (self.outgoing_ping()?, None),
808 Request::Disconnect(_) | Request::DisconnectWithTimeout(_, _) => {
809 unreachable!("graceful disconnect requests are handled by the event loop")
810 }
811 Request::DisconnectNow(disconnect) => {
812 (Some(self.outgoing_disconnect(disconnect)), None)
813 }
814 Request::PubAck(puback) => (Some(self.outgoing_puback(puback)), None),
815 Request::PubRec(pubrec) => (Some(self.outgoing_pubrec(pubrec)), None),
816 Request::Auth(auth) => {
817 let auth_notice = match notice {
818 Some(TrackedNoticeTx::Auth(notice)) => Some(notice),
819 Some(
820 TrackedNoticeTx::Publish(_)
821 | TrackedNoticeTx::Subscribe(_)
822 | TrackedNoticeTx::Unsubscribe(_),
823 )
824 | None => None,
825 };
826 (Some(self.outgoing_auth(auth, auth_notice)?), None)
827 }
828 _ => unimplemented!(),
829 };
830
831 self.last_outgoing = Instant::now();
832 Ok(result)
833 }
834
835 pub fn handle_incoming_packet(
845 &mut self,
846 mut packet: Incoming,
847 ) -> Result<Option<Packet>, StateError> {
848 let events_len_before = self.events.len();
849 let outgoing = match &mut packet {
850 Incoming::PingResp(_) => Ok(self.handle_incoming_pingresp()),
851 Incoming::Publish(publish) => self.handle_incoming_publish(publish),
852 Incoming::SubAck(suback) => Ok(self.handle_incoming_suback(suback)),
853 Incoming::UnsubAck(unsuback) => Ok(self.handle_incoming_unsuback(unsuback)),
854 Incoming::PubAck(puback) => self.handle_incoming_puback(puback),
855 Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec),
856 Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel),
857 Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp),
858 Incoming::ConnAck(connack) => self.handle_incoming_connack(connack),
859 Incoming::Disconnect(disconn) => Self::handle_incoming_disconn(disconn),
860 Incoming::Auth(auth) => self.handle_incoming_auth(auth),
861 _ => {
862 error!("Invalid incoming packet = {packet:?}");
863 Err(StateError::WrongPacket)
864 }
865 };
866
867 let skip_incoming_event = matches!(
868 (&packet, &outgoing),
869 (Incoming::Publish(_), Ok(Some(Packet::Disconnect(_))))
870 );
871
872 if !skip_incoming_event {
875 self.events
876 .insert(events_len_before, Event::Incoming(packet));
877 }
878
879 let outgoing = outgoing?;
880 self.last_incoming = Instant::now();
881 Ok(outgoing)
882 }
883
884 pub fn handle_protocol_error(&mut self) -> Result<Option<Packet>, StateError> {
891 let disconnect = Disconnect::new(DisconnectReasonCode::ProtocolError);
893 Ok(Some(self.outgoing_disconnect(disconnect)))
894 }
895
896 pub fn clear_collision(&mut self) {
897 self.collision = None;
898 self.collision_notice = None;
899 self.collision_ping_count = 0;
900 }
901
902 fn handle_incoming_suback(&mut self, suback: &SubAck) -> Option<Packet> {
903 for reason in &suback.return_codes {
904 match reason {
905 SubscribeReasonCode::Success(qos) => {
906 debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos);
907 }
908 _ => {
909 warn!("SubAck Pkid = {:?}, Reason = {:?}", suback.pkid, reason);
910 }
911 }
912 }
913 if let Some((_, notice)) = self.tracked_subscribe.remove(&suback.pkid) {
914 notice.success(suback.clone());
915 }
916 None
917 }
918
919 fn handle_incoming_unsuback(&mut self, unsuback: &UnsubAck) -> Option<Packet> {
920 for reason in &unsuback.reasons {
921 if reason != &UnsubAckReason::Success {
922 warn!("UnsubAck Pkid = {:?}, Reason = {:?}", unsuback.pkid, reason);
923 }
924 }
925 if let Some((_, notice)) = self.tracked_unsubscribe.remove(&unsuback.pkid) {
926 notice.success(unsuback.clone());
927 }
928 None
929 }
930
931 fn handle_incoming_connack(&mut self, connack: &ConnAck) -> Result<Option<Packet>, StateError> {
932 if connack.code != ConnectReturnCode::Success {
933 return Err(StateError::ConnFail {
934 reason: connack.code,
935 });
936 }
937
938 self.auth.validate_successful_connack(connack)?;
939 self.reset_connection_scoped_state();
940 self.auth.complete_initial_connack(&mut self.events);
941
942 if let Some(props) = &connack.properties
943 && let Some(topic_alias_max) = props.topic_alias_max
944 {
945 self.broker_topic_alias_max = topic_alias_max;
946 }
947
948 if let Some(props) = &connack.properties
949 && let Some(max_inflight) = props.receive_max
950 {
951 self.max_outgoing_inflight = max_inflight.min(self.max_outgoing_inflight_upper_limit);
952 self.reconcile_outgoing_tracking_capacity(false);
955 }
956 Ok(None)
957 }
958
959 fn handle_incoming_disconn(disconn: &Disconnect) -> Result<Option<Packet>, StateError> {
960 let reason_code = disconn.reason_code;
961 let reason_string = disconn
962 .properties
963 .as_ref()
964 .and_then(|props| props.reason_string.clone());
965 Err(StateError::ServerDisconnect {
966 reason_code,
967 reason_string,
968 })
969 }
970
971 fn handle_incoming_publish(
974 &mut self,
975 publish: &mut Publish,
976 ) -> Result<Option<Packet>, StateError> {
977 let qos = publish.qos;
978
979 let topic_alias = publish
980 .properties
981 .as_ref()
982 .and_then(|props| props.topic_alias);
983
984 if !publish.topic.is_empty() {
985 if let Some(alias) = topic_alias {
986 self.incoming_topic_aliases
987 .insert(alias, publish.topic.clone());
988 }
989 } else if let Some(alias) = topic_alias
990 && let Some(topic) = self.incoming_topic_aliases.get(&alias)
991 {
992 topic.clone_into(&mut publish.topic);
993 } else if topic_alias.is_some() {
994 return self.handle_protocol_error();
995 }
996
997 match qos {
998 QoS::AtMostOnce => Ok(None),
999 QoS::AtLeastOnce => {
1000 if !self.manual_acks {
1001 let puback = PubAck::new(publish.pkid, None);
1002 return Ok(Some(self.outgoing_puback(puback)));
1003 }
1004 Ok(None)
1005 }
1006 QoS::ExactlyOnce => {
1007 let pkid = publish.pkid;
1008 self.incoming_pub.insert(pkid as usize);
1009
1010 if !self.manual_acks {
1011 let pubrec = PubRec::new(pkid, None);
1012 return Ok(Some(self.outgoing_pubrec(pubrec)));
1013 }
1014 Ok(None)
1015 }
1016 }
1017 }
1018
1019 fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<Option<Packet>, StateError> {
1020 let publish = self
1021 .outgoing_pub
1022 .get_mut(puback.pkid as usize)
1023 .ok_or(StateError::Unsolicited(puback.pkid))?;
1024
1025 if publish.take().is_none() {
1026 error!("Unsolicited puback packet: {:?}", puback.pkid);
1027 return Err(StateError::Unsolicited(puback.pkid));
1028 }
1029 self.mark_outgoing_packet_id_complete(puback.pkid);
1030
1031 let notice = self.outgoing_pub_notice[puback.pkid as usize].take();
1032 self.inflight -= 1;
1033
1034 if puback.reason != PubAckReason::Success
1035 && puback.reason != PubAckReason::NoMatchingSubscribers
1036 {
1037 warn!(
1038 "PubAck Pkid = {:?}, reason: {:?}",
1039 puback.pkid, puback.reason
1040 );
1041 }
1042 if let Some(tx) = notice {
1043 tx.success(PublishResult::Qos1(puback.clone()));
1044 }
1045
1046 Ok(self.replay_collision_publish(puback.pkid))
1047 }
1048
1049 fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<Option<Packet>, StateError> {
1050 let publish = self
1051 .outgoing_pub
1052 .get_mut(pubrec.pkid as usize)
1053 .ok_or(StateError::Unsolicited(pubrec.pkid))?;
1054
1055 if publish.take().is_none() {
1056 error!("Unsolicited pubrec packet: {:?}", pubrec.pkid);
1057 return Err(StateError::Unsolicited(pubrec.pkid));
1058 }
1059
1060 let notice = self.outgoing_pub_notice[pubrec.pkid as usize].take();
1061 if pubrec.reason != PubRecReason::Success
1062 && pubrec.reason != PubRecReason::NoMatchingSubscribers
1063 {
1064 warn!(
1065 "PubRec Pkid = {:?}, reason: {:?}",
1066 pubrec.pkid, pubrec.reason
1067 );
1068 if let Some(tx) = notice {
1069 tx.success(PublishResult::Qos2PubRecRejected(pubrec.clone()));
1070 }
1071 self.mark_outgoing_packet_id_complete(pubrec.pkid);
1072 self.inflight -= 1;
1073 return Ok(self.replay_collision_publish(pubrec.pkid));
1074 }
1075
1076 self.outgoing_rel.insert(pubrec.pkid as usize);
1078 self.outgoing_rel_notice[pubrec.pkid as usize] = notice;
1079 let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
1080 self.events.push_back(event);
1081
1082 Ok(Some(Packet::PubRel(PubRel::new(pubrec.pkid, None))))
1083 }
1084
1085 fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<Option<Packet>, StateError> {
1086 if !self.incoming_pub.contains(pubrel.pkid as usize) {
1087 error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
1088 return Err(StateError::Unsolicited(pubrel.pkid));
1089 }
1090 self.incoming_pub.set(pubrel.pkid as usize, false);
1091
1092 if pubrel.reason != PubRelReason::Success {
1093 warn!(
1094 "PubRel Pkid = {:?}, reason: {:?}",
1095 pubrel.pkid, pubrel.reason
1096 );
1097 return Ok(None);
1098 }
1099
1100 let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
1101 self.events.push_back(event);
1102
1103 Ok(Some(Packet::PubComp(PubComp::new(pubrel.pkid, None))))
1104 }
1105
1106 fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<Option<Packet>, StateError> {
1107 if !self.outgoing_rel.contains(pubcomp.pkid as usize) {
1108 error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
1109 return Err(StateError::Unsolicited(pubcomp.pkid));
1110 }
1111 self.outgoing_rel.set(pubcomp.pkid as usize, false);
1112 let notice = self.outgoing_rel_notice[pubcomp.pkid as usize].take();
1113 self.mark_outgoing_packet_id_complete(pubcomp.pkid);
1114 self.inflight -= 1;
1115
1116 if pubcomp.reason != PubCompReason::Success {
1117 warn!(
1118 "PubComp Pkid = {:?}, reason: {:?}",
1119 pubcomp.pkid, pubcomp.reason
1120 );
1121 }
1122 if let Some(tx) = notice {
1123 tx.success(PublishResult::Qos2Completed(pubcomp.clone()));
1124 }
1125
1126 Ok(self.replay_collision_publish(pubcomp.pkid))
1127 }
1128
1129 fn replay_collision_publish(&mut self, pkid: u16) -> Option<Packet> {
1130 self.check_collision(pkid).map(|(publish, notice)| {
1131 let pkid = publish.pkid;
1132 let replay_publish = self.publish_for_replay_tracking(&publish);
1133 self.outgoing_pub[pkid as usize] = Some(replay_publish);
1134 self.outgoing_pub_notice[pkid as usize] = notice;
1135 self.inflight += 1;
1136 self.record_outgoing_topic_alias(&publish);
1137
1138 let event = Event::Outgoing(Outgoing::Publish(pkid));
1139 self.events.push_back(event);
1140 self.collision_ping_count = 0;
1141
1142 Packet::Publish(publish)
1143 })
1144 }
1145
1146 const fn handle_incoming_pingresp(&mut self) -> Option<Packet> {
1147 self.await_pingresp = false;
1148 None
1149 }
1150
1151 fn handle_incoming_auth(&mut self, auth: &Auth) -> Result<Option<Packet>, StateError> {
1152 let effect = match self.auth.incoming_auth(auth, &mut self.events) {
1153 Ok(effect) => effect,
1154 Err(err @ StateError::Deserialization(mqttbytes::Error::ProtocolError)) => {
1155 self.fail_auth_exchange(
1156 AuthNoticeError::ProtocolError,
1157 AuthError::Failed("authentication protocol error".to_owned()),
1158 );
1159 return Err(err);
1160 }
1161 Err(err) => return Err(err),
1162 };
1163
1164 match effect {
1165 IncomingAuthEffect::Success { kind, method } => {
1166 if let Some(authenticator) = self.authenticator.clone() {
1167 let context = AuthContext {
1168 kind,
1169 method: &method,
1170 };
1171 let auth_result = authenticator
1172 .lock()
1173 .unwrap()
1174 .success(context, auth.properties.clone());
1175 if let Err(err) = auth_result {
1176 return Err(self.fail_authenticator(&err));
1177 }
1178 }
1179 self.auth.complete_success(kind, method, &mut self.events);
1180 Ok(None)
1181 }
1182 IncomingAuthEffect::Continue { kind } => {
1183 let authenticator = self
1184 .authenticator
1185 .clone()
1186 .ok_or(StateError::AuthenticatorNotSet)?;
1187 let method = auth
1188 .properties
1189 .as_ref()
1190 .and_then(|props| props.method.as_deref())
1191 .unwrap_or_default();
1192 let context = AuthContext { kind, method };
1193 let continue_result = authenticator
1194 .lock()
1195 .unwrap()
1196 .continue_auth(context, auth.properties.clone());
1197 let action = match continue_result {
1198 Ok(action) => action,
1199 Err(err) => return Err(self.fail_authenticator(&err)),
1200 };
1201
1202 let out_auth_props = match action.into_continue_properties() {
1203 Ok(properties) => properties,
1204 Err(err) => return Err(self.fail_authenticator(&err)),
1205 };
1206 let client_auth = self.auth.outgoing_continue(out_auth_props)?;
1207 Ok(Some(self.outgoing_auth_packet(client_auth)))
1208 }
1209 }
1210 }
1211
1212 fn fail_authenticator(&mut self, error: &AuthError) -> StateError {
1213 self.fail_auth_exchange(
1214 AuthNoticeError::AuthenticationFailed(error.to_string()),
1215 error.clone(),
1216 );
1217 StateError::AuthError(error.to_string())
1218 }
1219
1220 fn fail_auth_exchange(&mut self, notice_error: AuthNoticeError, callback_error: AuthError) {
1221 if let Some((kind, method)) = self.auth.active_exchange()
1222 && let Some(authenticator) = self.authenticator.clone()
1223 {
1224 authenticator.lock().unwrap().failure(
1225 AuthContext {
1226 kind,
1227 method: &method,
1228 },
1229 callback_error,
1230 );
1231 }
1232 self.auth.reset(notice_error, &mut self.events);
1233 }
1234
1235 #[cfg(test)]
1238 fn outgoing_publish(&mut self, publish: Publish) -> Result<Option<Packet>, StateError> {
1239 let (packet, flush_notice) = self.outgoing_publish_with_notice(publish, None)?;
1240 if let Some(tx) = flush_notice {
1241 tx.success(PublishResult::Qos0Flushed);
1242 }
1243 Ok(packet)
1244 }
1245
1246 fn outgoing_publish_with_notice(
1247 &mut self,
1248 mut publish: Publish,
1249 notice: Option<PublishNoticeTx>,
1250 ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
1251 let mut notice = notice;
1252 let auto_topic_alias_action = self.apply_auto_topic_alias(&mut publish);
1253 self.validate_outgoing_topic_alias(&publish)?;
1254
1255 if publish.qos != QoS::AtMostOnce {
1256 if publish.pkid == 0 {
1257 publish.pkid = self.next_pkid();
1258 }
1259
1260 let pkid = publish.pkid;
1261 if self
1262 .outgoing_pub
1263 .get(publish.pkid as usize)
1264 .ok_or(StateError::Unsolicited(publish.pkid))?
1265 .is_some()
1266 {
1267 info!("Collision on packet id = {:?}", publish.pkid);
1268 if let Some(action) = auto_topic_alias_action {
1269 action.restore_for_replay(&mut publish);
1270 }
1271 self.collision = Some(publish);
1272 self.collision_notice = notice.take();
1273 let event = Event::Outgoing(Outgoing::AwaitAck(pkid));
1274 self.events.push_back(event);
1275 return Ok((None, None));
1276 }
1277
1278 let replay_publish = self.publish_for_replay_tracking(&publish);
1281 self.outgoing_pub[pkid as usize] = Some(replay_publish);
1282 self.outgoing_pub_notice[pkid as usize] = notice.take();
1283 self.outgoing_pub_ack.set(pkid as usize, false);
1284 self.inflight += 1;
1285 }
1286
1287 debug!(
1288 "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}",
1289 String::from_utf8_lossy(&publish.topic),
1290 publish.pkid,
1291 publish.payload.len()
1292 );
1293
1294 let pkid = publish.pkid;
1295 if let Some(pending_auto_topic_alias) = auto_topic_alias_action
1296 .as_ref()
1297 .and_then(AutoTopicAliasAction::pending_alias)
1298 {
1299 self.record_auto_topic_alias(pending_auto_topic_alias.clone());
1300 } else if let Some(alias) = auto_topic_alias_action
1301 .as_ref()
1302 .and_then(AutoTopicAliasAction::existing_alias)
1303 {
1304 self.record_auto_topic_alias_use(alias);
1305 }
1306 self.record_outgoing_topic_alias(&publish);
1307
1308 let event = Event::Outgoing(Outgoing::Publish(pkid));
1309 self.events.push_back(event);
1310
1311 if publish.qos == QoS::AtMostOnce {
1312 Ok((Some(Packet::Publish(publish)), notice.take()))
1313 } else {
1314 Ok((Some(Packet::Publish(publish)), None))
1315 }
1316 }
1317
1318 fn outgoing_pubrel_with_notice(
1319 &mut self,
1320 pubrel: PubRel,
1321 notice: Option<PublishNoticeTx>,
1322 ) -> (Option<Packet>, Option<PublishNoticeTx>) {
1323 let pubrel = self.save_pubrel_with_notice(pubrel, notice);
1324
1325 debug!("Pubrel. Pkid = {}", pubrel.pkid);
1326
1327 let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid));
1328 self.events.push_back(event);
1329
1330 (Some(Packet::PubRel(PubRel::new(pubrel.pkid, None))), None)
1331 }
1332
1333 fn outgoing_puback(&mut self, puback: PubAck) -> Packet {
1334 let pkid = puback.pkid;
1335 let event = Event::Outgoing(Outgoing::PubAck(pkid));
1336 self.events.push_back(event);
1337
1338 Packet::PubAck(puback)
1339 }
1340
1341 fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Packet {
1342 let pkid = pubrec.pkid;
1343 let event = Event::Outgoing(Outgoing::PubRec(pkid));
1344 self.events.push_back(event);
1345
1346 Packet::PubRec(pubrec)
1347 }
1348
1349 fn outgoing_ping(&mut self) -> Result<Option<Packet>, StateError> {
1353 let elapsed_in = self.last_incoming.elapsed();
1354 let elapsed_out = self.last_outgoing.elapsed();
1355
1356 if self.collision.is_some() {
1357 self.collision_ping_count += 1;
1358 if self.collision_ping_count >= 2 {
1359 return Err(StateError::CollisionTimeout);
1360 }
1361 }
1362
1363 if self.await_pingresp {
1365 return Err(StateError::AwaitPingResp);
1366 }
1367
1368 self.await_pingresp = true;
1369
1370 debug!(
1371 "Pingreq, last incoming packet before {elapsed_in:?}, last outgoing request before {elapsed_out:?}",
1372 );
1373
1374 let event = Event::Outgoing(Outgoing::PingReq);
1375 self.events.push_back(event);
1376
1377 Ok(Some(Packet::PingReq(PingReq)))
1378 }
1379
1380 fn outgoing_subscribe(
1381 &mut self,
1382 mut subscription: Subscribe,
1383 notice: Option<SubscribeNoticeTx>,
1384 ) -> Result<Option<Packet>, StateError> {
1385 if subscription.filters.is_empty() {
1386 return Err(StateError::EmptySubscription);
1387 }
1388
1389 let pkid = self.next_control_pkid()?;
1390 subscription.pkid = pkid;
1391
1392 debug!(
1393 "Subscribe. Topics = {:?}, Pkid = {:?}",
1394 subscription.filters, subscription.pkid
1395 );
1396
1397 let pkid = subscription.pkid;
1398 let event = Event::Outgoing(Outgoing::Subscribe(pkid));
1399 self.events.push_back(event);
1400 if let Some(notice) = notice {
1401 self.tracked_subscribe
1402 .insert(subscription.pkid, (subscription.clone(), notice));
1403 }
1404
1405 Ok(Some(Packet::Subscribe(subscription)))
1406 }
1407
1408 fn outgoing_unsubscribe(
1409 &mut self,
1410 mut unsub: Unsubscribe,
1411 notice: Option<UnsubscribeNoticeTx>,
1412 ) -> Result<Packet, StateError> {
1413 let pkid = self.next_control_pkid()?;
1414 unsub.pkid = pkid;
1415
1416 debug!(
1417 "Unsubscribe. Topics = {:?}, Pkid = {:?}",
1418 unsub.filters, unsub.pkid
1419 );
1420
1421 let pkid = unsub.pkid;
1422 let event = Event::Outgoing(Outgoing::Unsubscribe(pkid));
1423 self.events.push_back(event);
1424 if let Some(notice) = notice {
1425 self.tracked_unsubscribe
1426 .insert(unsub.pkid, (unsub.clone(), notice));
1427 }
1428
1429 Ok(Packet::Unsubscribe(unsub))
1430 }
1431
1432 fn outgoing_disconnect(&mut self, disconnect: Disconnect) -> Packet {
1433 self.fail_auth_exchange_due_to_client_disconnect();
1434 let reason = disconnect.reason_code;
1435 debug!("Disconnect with {reason:?}");
1436 let event = Event::Outgoing(Outgoing::Disconnect);
1437 self.events.push_back(event);
1438
1439 Packet::Disconnect(disconnect)
1440 }
1441
1442 fn outgoing_auth(
1443 &mut self,
1444 mut auth: Auth,
1445 mut notice: Option<crate::notice::AuthNoticeTx>,
1446 ) -> Result<Packet, StateError> {
1447 let method = match self.auth.reauth_method() {
1448 Ok(method) => method.to_owned(),
1449 Err(err) => {
1450 if let Some(notice) = notice.take() {
1451 notice.error(err.clone());
1452 }
1453 return Err(StateError::AuthError(err.to_string()));
1454 }
1455 };
1456
1457 if let Some(authenticator) = self.authenticator.clone() {
1458 let context = AuthContext {
1459 kind: AuthExchangeKind::Reauthentication,
1460 method: &method,
1461 };
1462 let start_result = authenticator.lock().unwrap().start(context);
1463 match start_result {
1464 Ok(Some(properties)) if auth.properties.is_none() => {
1465 auth.properties = Some(properties);
1466 }
1467 Ok(_) => {}
1468 Err(err) => {
1469 if let Some(notice) = notice.take() {
1470 notice.error(AuthNoticeError::AuthenticationFailed(err.to_string()));
1471 }
1472 return Err(StateError::AuthError(err.to_string()));
1473 }
1474 }
1475 }
1476 let auth = self
1477 .auth
1478 .begin_reauth(auth.properties, notice, &mut self.events)?;
1479 Ok(self.outgoing_auth_packet(auth))
1480 }
1481
1482 fn outgoing_auth_packet(&mut self, auth: Auth) -> Packet {
1483 let props = auth
1484 .properties
1485 .as_ref()
1486 .expect("AUTH packets created by state always contain properties");
1487 debug!(
1488 "Auth packet sent. Auth Method: {:?}. Auth Data: {:?}",
1489 props.method, props.data
1490 );
1491 let event = Event::Outgoing(Outgoing::Auth);
1492 self.events.push_back(event);
1493 Packet::Auth(auth)
1494 }
1495
1496 fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option<PublishNoticeTx>)> {
1497 if let Some(publish) = &self.collision
1498 && publish.pkid == pkid
1499 {
1500 return self
1501 .collision
1502 .take()
1503 .map(|publish| (publish, self.collision_notice.take()));
1504 }
1505
1506 None
1507 }
1508
1509 fn save_pubrel_with_notice(
1510 &mut self,
1511 mut pubrel: PubRel,
1512 notice: Option<PublishNoticeTx>,
1513 ) -> PubRel {
1514 let pubrel = match pubrel.pkid {
1515 0 => {
1517 pubrel.pkid = self.next_pkid();
1518 pubrel
1519 }
1520 _ => pubrel,
1521 };
1522
1523 self.outgoing_rel.insert(pubrel.pkid as usize);
1524 self.outgoing_rel_notice[pubrel.pkid as usize] = notice;
1525 self.inflight += 1;
1526 pubrel
1527 }
1528
1529 fn mark_outgoing_packet_id_complete(&mut self, pkid: u16) {
1530 self.outgoing_pub_ack.set(pkid as usize, true);
1531 self.advance_last_puback_frontier();
1532 }
1533
1534 fn advance_last_puback_frontier(&mut self) {
1535 let mut next = self.next_puback_boundary_pkid(self.last_puback);
1536 while next != 0 && self.outgoing_pub_ack.contains(next as usize) {
1537 self.outgoing_pub_ack.set(next as usize, false);
1538 self.last_puback = next;
1539 next = self.next_puback_boundary_pkid(self.last_puback);
1540 }
1541 }
1542
1543 const fn next_puback_boundary_pkid(&self, pkid: u16) -> u16 {
1544 if self.max_outgoing_inflight == 0 {
1545 return 0;
1546 }
1547
1548 if pkid >= self.max_outgoing_inflight {
1549 1
1550 } else {
1551 pkid + 1
1552 }
1553 }
1554
1555 fn next_publish_pkid(&self) -> Option<u16> {
1559 let mut pkid = self.next_publish_pkid_after(self.last_pkid);
1560 for _ in 0..usize::from(self.max_outgoing_inflight) {
1561 if !self.packet_identifier_in_use(pkid) {
1562 return Some(pkid);
1563 }
1564 pkid = self.next_publish_pkid_after(pkid);
1565 }
1566
1567 None
1568 }
1569
1570 fn next_pkid(&mut self) -> u16 {
1571 let next_pkid = self
1572 .next_publish_pkid()
1573 .unwrap_or_else(|| self.next_publish_pkid_after(self.last_pkid));
1574
1575 if next_pkid == self.max_outgoing_inflight {
1580 self.last_pkid = 0;
1581 return next_pkid;
1582 }
1583
1584 self.last_pkid = next_pkid;
1585 next_pkid
1586 }
1587
1588 fn next_control_pkid(&mut self) -> Result<u16, StateError> {
1589 for offset in 1..=u16::MAX {
1590 let pkid = self.last_pkid.wrapping_add(offset);
1591 if pkid != 0 && !self.packet_identifier_in_use(pkid) {
1592 self.last_pkid = pkid;
1593 return Ok(pkid);
1594 }
1595 }
1596
1597 Err(StateError::InvalidState)
1598 }
1599
1600 fn publish_topic_alias(publish: &Publish) -> Option<u16> {
1601 publish
1602 .properties
1603 .as_ref()
1604 .and_then(|props| props.topic_alias)
1605 }
1606
1607 fn set_publish_topic_alias(publish: &mut Publish, alias: u16) {
1608 publish
1609 .properties
1610 .get_or_insert_with(PublishProperties::default)
1611 .topic_alias = Some(alias);
1612 }
1613
1614 fn apply_auto_topic_alias(&self, publish: &mut Publish) -> Option<AutoTopicAliasAction> {
1615 if !self.auto_topic_aliases
1616 || self.broker_topic_alias_max == 0
1617 || publish.topic.is_empty()
1618 || Self::publish_topic_alias(publish).is_some()
1619 {
1620 return None;
1621 }
1622
1623 if let Some(alias) = self
1624 .auto_outgoing_topic_aliases
1625 .get(&publish.topic)
1626 .copied()
1627 {
1628 let original_topic = publish.topic.clone();
1629 Self::set_publish_topic_alias(publish, alias);
1630 publish.topic = Bytes::new();
1631 return Some(AutoTopicAliasAction::Existing {
1632 original_topic,
1633 alias,
1634 });
1635 }
1636
1637 let (alias, previous_topic) = self.next_auto_topic_alias_assignment()?;
1638
1639 let pending = PendingAutoTopicAlias {
1640 topic: publish.topic.clone(),
1641 alias,
1642 previous_topic,
1643 };
1644 Self::set_publish_topic_alias(publish, alias);
1645 Some(AutoTopicAliasAction::New(pending))
1646 }
1647
1648 fn next_auto_topic_alias_assignment(&self) -> Option<(u16, Option<Bytes>)> {
1649 if let Some(alias) = self.next_available_auto_topic_alias() {
1650 return Some((alias, None));
1651 }
1652
1653 if self.auto_topic_alias_policy != TopicAliasPolicy::Lru {
1654 return None;
1655 }
1656
1657 self.least_recent_auto_topic_alias()
1658 }
1659
1660 fn next_available_auto_topic_alias(&self) -> Option<u16> {
1661 let next_auto_topic_alias = self.next_auto_topic_alias?;
1662 (next_auto_topic_alias..=self.broker_topic_alias_max)
1663 .find(|&alias| !self.outgoing_topic_aliases.contains_key(&alias))
1664 }
1665
1666 fn record_auto_topic_alias(&mut self, pending: PendingAutoTopicAlias) {
1667 if let Some(previous_topic) = pending.previous_topic {
1668 self.auto_outgoing_topic_aliases.remove(&previous_topic);
1669 }
1670 self.auto_outgoing_topic_aliases
1671 .insert(pending.topic.clone(), pending.alias);
1672 self.outgoing_topic_aliases
1673 .insert(pending.alias, pending.topic.clone());
1674 self.record_auto_topic_alias_use(pending.alias);
1675 self.advance_next_auto_topic_alias();
1676 }
1677
1678 fn record_auto_topic_alias_use(&mut self, alias: u16) {
1679 if self.auto_topic_alias_policy != TopicAliasPolicy::Lru {
1680 return;
1681 }
1682
1683 self.auto_topic_alias_lru.retain(|entry| *entry != alias);
1684 self.auto_topic_alias_lru.push_back(alias);
1685 }
1686
1687 fn least_recent_auto_topic_alias(&self) -> Option<(u16, Option<Bytes>)> {
1688 for &alias in &self.auto_topic_alias_lru {
1689 if alias == 0 || alias > self.broker_topic_alias_max {
1690 continue;
1691 }
1692
1693 let Some(topic) = self.outgoing_topic_aliases.get(&alias) else {
1694 continue;
1695 };
1696
1697 if self
1698 .auto_outgoing_topic_aliases
1699 .get(topic)
1700 .is_some_and(|mapped_alias| *mapped_alias == alias)
1701 {
1702 return Some((alias, Some(topic.clone())));
1703 }
1704 }
1705
1706 None
1707 }
1708
1709 fn advance_next_auto_topic_alias(&mut self) {
1710 let Some(mut alias) = self.next_auto_topic_alias else {
1711 return;
1712 };
1713
1714 while alias <= self.broker_topic_alias_max
1715 && self.outgoing_topic_aliases.contains_key(&alias)
1716 {
1717 let Some(next_alias) = alias.checked_add(1) else {
1718 self.next_auto_topic_alias = None;
1719 return;
1720 };
1721 alias = next_alias;
1722 }
1723
1724 self.next_auto_topic_alias = (alias <= self.broker_topic_alias_max).then_some(alias);
1725 }
1726
1727 const fn strip_publish_topic_alias(publish: &mut Publish) {
1728 if let Some(props) = &mut publish.properties {
1729 props.topic_alias = None;
1730 }
1731 }
1732
1733 fn publish_for_replay_tracking(&self, publish: &Publish) -> Publish {
1734 let mut replay_publish = publish.clone();
1735 if replay_publish.topic.is_empty()
1736 && let Some(alias) = Self::publish_topic_alias(&replay_publish)
1737 && let Some(topic) = self.outgoing_topic_aliases.get(&alias)
1738 {
1739 topic.clone_into(&mut replay_publish.topic);
1740 }
1741
1742 replay_publish
1743 }
1744
1745 fn validate_outgoing_topic_alias(&self, publish: &Publish) -> Result<(), StateError> {
1746 if let Some(alias) = Self::publish_topic_alias(publish)
1747 && alias > self.broker_topic_alias_max
1748 {
1749 return Err(StateError::InvalidAlias {
1752 alias,
1753 max: self.broker_topic_alias_max,
1754 });
1755 }
1756
1757 Ok(())
1758 }
1759
1760 fn record_outgoing_topic_alias(&mut self, publish: &Publish) {
1761 if !publish.topic.is_empty()
1762 && let Some(alias) = Self::publish_topic_alias(publish)
1763 {
1764 if let Some(previous_topic) = self
1765 .outgoing_topic_aliases
1766 .insert(alias, publish.topic.clone())
1767 && previous_topic != publish.topic
1768 {
1769 self.auto_outgoing_topic_aliases.remove(&previous_topic);
1770 self.auto_topic_alias_lru.retain(|entry| *entry != alias);
1771 }
1772 self.auto_outgoing_topic_aliases
1773 .retain(|topic, mapped_alias| *mapped_alias != alias || topic == &publish.topic);
1774 }
1775 }
1776}
1777
1778impl Clone for MqttState {
1779 fn clone(&self) -> Self {
1780 Self {
1781 await_pingresp: self.await_pingresp,
1782 collision_ping_count: self.collision_ping_count,
1783 last_incoming: self.last_incoming,
1784 last_outgoing: self.last_outgoing,
1785 last_pkid: self.last_pkid,
1786 last_puback: self.last_puback,
1787 inflight: self.inflight,
1788 outgoing_pub: self.outgoing_pub.clone(),
1789 outgoing_pub_notice: Self::new_notice_slots_with_len(self.outgoing_pub.len()),
1790 outgoing_pub_ack: self.outgoing_pub_ack.clone(),
1791 outgoing_rel: self.outgoing_rel.clone(),
1792 outgoing_rel_notice: Self::new_notice_slots_with_len(self.outgoing_rel_notice.len()),
1793 incoming_pub: self.incoming_pub.clone(),
1794 collision: self.collision.clone(),
1795 collision_notice: None,
1796 tracked_subscribe: BTreeMap::new(),
1797 tracked_unsubscribe: BTreeMap::new(),
1798 events: self.events.clone(),
1799 manual_acks: self.manual_acks,
1800 incoming_topic_aliases: self.incoming_topic_aliases.clone(),
1801 outgoing_topic_aliases: self.outgoing_topic_aliases.clone(),
1802 auto_outgoing_topic_aliases: self.auto_outgoing_topic_aliases.clone(),
1803 next_auto_topic_alias: self.next_auto_topic_alias,
1804 auto_topic_aliases: self.auto_topic_aliases,
1805 auto_topic_alias_policy: self.auto_topic_alias_policy,
1806 auto_topic_alias_lru: self.auto_topic_alias_lru.clone(),
1807 broker_topic_alias_max: self.broker_topic_alias_max,
1808 max_outgoing_inflight: self.max_outgoing_inflight,
1809 max_outgoing_inflight_upper_limit: self.max_outgoing_inflight_upper_limit,
1810 authenticator: self.authenticator.clone(),
1811 auth: self.auth.clone(),
1812 }
1813 }
1814}
1815
1816#[cfg(test)]
1817mod test {
1818 use super::mqttbytes::v5::*;
1819 use super::mqttbytes::*;
1820 use super::{Event, Incoming, Outgoing, Request};
1821 use super::{MqttState, StateError};
1822 use crate::notice::{
1823 AuthNoticeError, AuthNoticeTx, PublishNotice, PublishNoticeError, PublishNoticeTx,
1824 PublishResult, SubscribeNoticeError, SubscribeNoticeTx, UnsubscribeNoticeError,
1825 UnsubscribeNoticeTx,
1826 };
1827 use crate::{NoticeFailureReason, TopicAliasPolicy};
1828 use bytes::Bytes;
1829 use std::collections::{HashMap, VecDeque};
1830 use std::sync::{Arc, Mutex};
1831
1832 const AUTH_METHOD: &str = "test-method";
1833
1834 fn build_outgoing_publish(qos: QoS) -> Publish {
1835 let topic = "hello/world".to_owned();
1836 let payload = vec![1, 2, 3];
1837
1838 let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload, None);
1839 publish.qos = qos;
1840 publish
1841 }
1842
1843 fn publish_properties_with_alias(alias: u16) -> PublishProperties {
1844 PublishProperties {
1845 topic_alias: Some(alias),
1846 ..Default::default()
1847 }
1848 }
1849
1850 fn build_outgoing_publish_with_alias(topic: &str, qos: QoS, alias: u16) -> Publish {
1851 Publish::new(
1852 topic,
1853 qos,
1854 vec![1, 2, 3],
1855 Some(publish_properties_with_alias(alias)),
1856 )
1857 }
1858
1859 fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish {
1860 let topic = "hello/world".to_owned();
1861 let payload = vec![1, 2, 3];
1862
1863 let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload, None);
1864 publish.pkid = pkid;
1865 publish.qos = qos;
1866 publish
1867 }
1868
1869 fn build_mqttstate() -> MqttState {
1870 MqttState::builder(u16::MAX).build()
1871 }
1872
1873 fn build_lru_auto_alias_mqttstate(max_inflight: u16, broker_topic_alias_max: u16) -> MqttState {
1874 let mut mqtt = MqttState::builder(max_inflight)
1875 .auto_topic_aliases(true)
1876 .topic_alias_policy(TopicAliasPolicy::Lru)
1877 .build();
1878 mqtt.broker_topic_alias_max = broker_topic_alias_max;
1879 mqtt
1880 }
1881
1882 fn assert_publish(packet: Packet, topic: &'static [u8], alias: Option<u16>) {
1883 match packet {
1884 Packet::Publish(publish) => {
1885 assert_eq!(publish.topic, Bytes::from_static(topic));
1886 assert_eq!(
1887 publish
1888 .properties
1889 .as_ref()
1890 .and_then(|props| props.topic_alias),
1891 alias
1892 );
1893 }
1894 packet => panic!("expected publish, got {packet:?}"),
1895 }
1896 }
1897
1898 fn build_auth_mqttstate(authentication_method: Option<&str>) -> MqttState {
1899 MqttState::builder(10)
1900 .authentication_method(authentication_method.map(str::to_owned))
1901 .build()
1902 }
1903
1904 fn auth_properties(authentication_method: Option<&str>) -> AuthProperties {
1905 AuthProperties {
1906 method: authentication_method.map(str::to_owned),
1907 data: Some(Bytes::from_static(b"auth-data")),
1908 reason: None,
1909 user_properties: vec![],
1910 }
1911 }
1912
1913 #[derive(Debug)]
1914 struct StaticAuthManager {
1915 response: Result<Option<AuthProperties>, String>,
1916 }
1917
1918 impl crate::Authenticator for StaticAuthManager {
1919 fn start(
1920 &mut self,
1921 _context: crate::AuthContext<'_>,
1922 ) -> Result<Option<AuthProperties>, crate::AuthError> {
1923 Ok(None)
1924 }
1925
1926 fn continue_auth(
1927 &mut self,
1928 _context: crate::AuthContext<'_>,
1929 _auth_prop: Option<AuthProperties>,
1930 ) -> Result<crate::AuthAction, crate::AuthError> {
1931 self.response
1932 .clone()
1933 .map(|props| props.map_or(crate::AuthAction::Complete, crate::AuthAction::Send))
1934 .map_err(crate::AuthError::from)
1935 }
1936
1937 fn success(
1938 &mut self,
1939 _context: crate::AuthContext<'_>,
1940 _incoming: Option<AuthProperties>,
1941 ) -> Result<(), crate::AuthError> {
1942 Ok(())
1943 }
1944
1945 fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
1946 }
1947
1948 #[derive(Debug)]
1949 struct StartAuthManager {
1950 response: Option<AuthProperties>,
1951 }
1952
1953 impl crate::Authenticator for StartAuthManager {
1954 fn start(
1955 &mut self,
1956 _context: crate::AuthContext<'_>,
1957 ) -> Result<Option<AuthProperties>, crate::AuthError> {
1958 Ok(self.response.clone())
1959 }
1960
1961 fn continue_auth(
1962 &mut self,
1963 _context: crate::AuthContext<'_>,
1964 _auth_prop: Option<AuthProperties>,
1965 ) -> Result<crate::AuthAction, crate::AuthError> {
1966 Ok(crate::AuthAction::Complete)
1967 }
1968
1969 fn success(
1970 &mut self,
1971 _context: crate::AuthContext<'_>,
1972 _incoming: Option<AuthProperties>,
1973 ) -> Result<(), crate::AuthError> {
1974 Ok(())
1975 }
1976
1977 fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
1978 }
1979
1980 #[derive(Debug)]
1981 struct FailingStartAuthManager;
1982
1983 impl crate::Authenticator for FailingStartAuthManager {
1984 fn start(
1985 &mut self,
1986 _context: crate::AuthContext<'_>,
1987 ) -> Result<Option<AuthProperties>, crate::AuthError> {
1988 Err(crate::AuthError::from("start failed"))
1989 }
1990
1991 fn continue_auth(
1992 &mut self,
1993 _context: crate::AuthContext<'_>,
1994 _auth_prop: Option<AuthProperties>,
1995 ) -> Result<crate::AuthAction, crate::AuthError> {
1996 Ok(crate::AuthAction::Complete)
1997 }
1998
1999 fn success(
2000 &mut self,
2001 _context: crate::AuthContext<'_>,
2002 _incoming: Option<AuthProperties>,
2003 ) -> Result<(), crate::AuthError> {
2004 Ok(())
2005 }
2006
2007 fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
2008 }
2009
2010 fn queue_publish_with_notice(mqtt: &mut MqttState, publish: Publish) -> PublishNotice {
2011 let (tx, notice) = PublishNoticeTx::new();
2012 let (packet, flush_notice) = mqtt
2013 .outgoing_publish_with_notice(publish, Some(tx))
2014 .unwrap();
2015 assert!(packet.is_some());
2016 assert!(flush_notice.is_none());
2017 notice
2018 }
2019
2020 #[test]
2021 fn new_state_preallocates_event_queue_for_read_batch_bursts() {
2022 let mqtt = MqttState::builder(10).build();
2023 assert!(mqtt.events.capacity() >= MqttState::initial_events_capacity());
2024 }
2025
2026 #[test]
2027 fn clean_pending_capacity_counts_publish_rel_and_tracked_requests() {
2028 let mut mqtt = MqttState::builder(10).build();
2029 mqtt.outgoing_pub[1] = Some(build_outgoing_publish(QoS::AtLeastOnce));
2030 mqtt.outgoing_pub[2] = Some(build_outgoing_publish(QoS::ExactlyOnce));
2031 mqtt.outgoing_rel.insert(3);
2032 mqtt.outgoing_rel.insert(4);
2033
2034 let filter = Filter::new("a/b", QoS::AtMostOnce);
2035 let (sub_notice, _) = SubscribeNoticeTx::new();
2036 mqtt.tracked_subscribe
2037 .insert(5, (Subscribe::new(filter, None), sub_notice));
2038
2039 let (unsub_notice, _) = UnsubscribeNoticeTx::new();
2040 mqtt.tracked_unsubscribe
2041 .insert(6, (Unsubscribe::new("a/b", None), unsub_notice));
2042
2043 assert_eq!(mqtt.clean_pending_capacity(), 6);
2044 }
2045
2046 #[test]
2047 fn tracked_request_len_helpers_report_counts() {
2048 let mut mqtt = MqttState::builder(10).build();
2049 let filter = Filter::new("a/b", QoS::AtMostOnce);
2050 let (sub_notice, _) = SubscribeNoticeTx::new();
2051 mqtt.tracked_subscribe
2052 .insert(5, (Subscribe::new(filter, None), sub_notice));
2053 let (unsub_notice, _) = UnsubscribeNoticeTx::new();
2054 mqtt.tracked_unsubscribe
2055 .insert(6, (Unsubscribe::new("a/b", None), unsub_notice));
2056
2057 assert_eq!(mqtt.tracked_subscribe_len(), 1);
2058 assert_eq!(mqtt.tracked_unsubscribe_len(), 1);
2059 assert!(!mqtt.tracked_requests_is_empty());
2060
2061 mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
2062 assert!(mqtt.tracked_requests_is_empty());
2063 }
2064
2065 #[test]
2066 fn drain_tracked_requests_as_failed_reports_session_reset_and_returns_count() {
2067 let mut mqtt = MqttState::builder(10).build();
2068 let filter = Filter::new("a/b", QoS::AtMostOnce);
2069 let (sub_notice_tx, sub_notice) = SubscribeNoticeTx::new();
2070 mqtt.tracked_subscribe
2071 .insert(5, (Subscribe::new(filter, None), sub_notice_tx));
2072 let (unsub_notice_tx, unsub_notice) = UnsubscribeNoticeTx::new();
2073 mqtt.tracked_unsubscribe
2074 .insert(6, (Unsubscribe::new("a/b", None), unsub_notice_tx));
2075
2076 let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
2077
2078 assert_eq!(drained, 2);
2079 assert!(mqtt.tracked_requests_is_empty());
2080 assert_eq!(
2081 sub_notice.wait().unwrap_err(),
2082 SubscribeNoticeError::SessionReset
2083 );
2084 assert_eq!(
2085 unsub_notice.wait().unwrap_err(),
2086 UnsubscribeNoticeError::SessionReset
2087 );
2088 }
2089
2090 #[test]
2091 fn drain_tracked_requests_as_failed_is_noop_when_empty() {
2092 let mut mqtt = MqttState::builder(10).build();
2093 let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
2094
2095 assert_eq!(drained, 0);
2096 assert!(mqtt.tracked_requests_is_empty());
2097 }
2098
2099 #[test]
2100 fn tracked_puback_returns_ack_and_preserves_incoming_event() {
2101 let mut mqtt = build_mqttstate();
2102 let notice = queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::AtLeastOnce));
2103 mqtt.events.clear();
2104
2105 let mut puback = PubAck::new(1, None);
2106 puback.reason = PubAckReason::NoMatchingSubscribers;
2107 puback.properties = Some(PubAckProperties {
2108 reason_string: Some("accepted without subscribers".to_owned()),
2109 user_properties: vec![("k".to_owned(), "v".to_owned())],
2110 });
2111 assert!(
2112 mqtt.handle_incoming_packet(Incoming::PubAck(puback.clone()))
2113 .unwrap()
2114 .is_none()
2115 );
2116
2117 assert_eq!(notice.wait(), Ok(PublishResult::Qos1(puback.clone())));
2118 assert_eq!(
2119 mqtt.events.pop_front(),
2120 Some(Event::Incoming(Packet::PubAck(puback)))
2121 );
2122 }
2123
2124 #[test]
2125 fn tracked_suback_returns_ack_with_properties_and_preserves_incoming_event() {
2126 let mut mqtt = build_mqttstate();
2127 let (tx, notice) = SubscribeNoticeTx::new();
2128 mqtt.outgoing_subscribe(
2129 Subscribe::new(Filter::new("a/b", QoS::AtMostOnce), None),
2130 Some(tx),
2131 )
2132 .unwrap();
2133 mqtt.events.clear();
2134
2135 let suback = SubAck {
2136 pkid: 1,
2137 return_codes: vec![SubscribeReasonCode::Unspecified],
2138 properties: Some(SubAckProperties {
2139 reason_string: Some("denied".to_owned()),
2140 user_properties: vec![("scope".to_owned(), "missing".to_owned())],
2141 }),
2142 };
2143 assert!(
2144 mqtt.handle_incoming_packet(Incoming::SubAck(suback.clone()))
2145 .unwrap()
2146 .is_none()
2147 );
2148
2149 assert_eq!(notice.wait(), Ok(suback.clone()));
2150 assert_eq!(
2151 mqtt.events.pop_front(),
2152 Some(Event::Incoming(Packet::SubAck(suback)))
2153 );
2154 }
2155
2156 #[test]
2157 fn tracked_unsuback_returns_ack_with_properties_and_preserves_incoming_event() {
2158 let mut mqtt = build_mqttstate();
2159 let (tx, notice) = UnsubscribeNoticeTx::new();
2160 mqtt.outgoing_unsubscribe(Unsubscribe::new("a/b", None), Some(tx))
2161 .unwrap();
2162 mqtt.events.clear();
2163
2164 let unsuback = UnsubAck {
2165 pkid: 1,
2166 reasons: vec![UnsubAckReason::UnspecifiedError],
2167 properties: Some(UnsubAckProperties {
2168 reason_string: Some("failed".to_owned()),
2169 user_properties: vec![("detail".to_owned(), "x".to_owned())],
2170 }),
2171 };
2172 assert!(
2173 mqtt.handle_incoming_packet(Incoming::UnsubAck(unsuback.clone()))
2174 .unwrap()
2175 .is_none()
2176 );
2177
2178 assert_eq!(notice.wait(), Ok(unsuback.clone()));
2179 assert_eq!(
2180 mqtt.events.pop_front(),
2181 Some(Event::Incoming(Packet::UnsubAck(unsuback)))
2182 );
2183 }
2184
2185 fn build_connack_with_receive_max(receive_max: u16) -> ConnAck {
2186 ConnAck {
2187 session_present: false,
2188 code: ConnectReturnCode::Success,
2189 properties: Some(ConnAckProperties {
2190 session_expiry_interval: None,
2191 receive_max: Some(receive_max),
2192 max_qos: None,
2193 retain_available: None,
2194 max_packet_size: None,
2195 assigned_client_identifier: None,
2196 topic_alias_max: None,
2197 reason_string: None,
2198 user_properties: vec![],
2199 wildcard_subscription_available: None,
2200 subscription_identifiers_available: None,
2201 shared_subscription_available: None,
2202 server_keep_alive: None,
2203 response_information: None,
2204 server_reference: None,
2205 authentication_method: None,
2206 authentication_data: None,
2207 }),
2208 }
2209 }
2210
2211 #[test]
2212 fn connack_receive_max_can_grow_tracking_capacity_after_previous_shrink() {
2213 let mut mqtt = MqttState::builder(10).build();
2214 mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(4)))
2215 .unwrap();
2216 mqtt.reconcile_outgoing_tracking_capacity(true);
2217 assert_eq!(mqtt.outgoing_pub.len(), 5);
2218
2219 mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(9)))
2220 .unwrap();
2221 assert_eq!(mqtt.outgoing_pub.len(), 10);
2222 assert_eq!(mqtt.outgoing_pub_notice.len(), 10);
2223 assert_eq!(mqtt.outgoing_rel_notice.len(), 10);
2224 assert_eq!(mqtt.outgoing_pub_ack.len(), 10);
2225 assert_eq!(mqtt.outgoing_rel.len(), 10);
2226 }
2227
2228 #[test]
2229 fn connack_receive_max_shrinks_when_tracking_is_empty_and_pending_is_empty() {
2230 let mut mqtt = MqttState::builder(10).build();
2231 mqtt.last_pkid = 9;
2232 mqtt.last_puback = 8;
2233
2234 mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2235 .unwrap();
2236 assert_eq!(mqtt.outgoing_pub.len(), 11);
2237
2238 mqtt.reconcile_outgoing_tracking_capacity(true);
2239 assert_eq!(mqtt.outgoing_pub.len(), 4);
2240 assert_eq!(mqtt.outgoing_pub_notice.len(), 4);
2241 assert_eq!(mqtt.outgoing_rel_notice.len(), 4);
2242 assert_eq!(mqtt.outgoing_pub_ack.len(), 4);
2243 assert_eq!(mqtt.outgoing_rel.len(), 4);
2244 assert_eq!(mqtt.last_pkid, 0);
2245 assert_eq!(mqtt.last_puback, 0);
2246 }
2247
2248 #[test]
2249 fn connack_resets_connection_scoped_alias_state_when_topic_alias_maximum_is_omitted() {
2250 let mut mqtt = MqttState::builder(10).build();
2251 mqtt.broker_topic_alias_max = 10;
2252 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
2253 "hello/replay",
2254 QoS::AtMostOnce,
2255 2,
2256 ))
2257 .unwrap();
2258
2259 mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(5)))
2260 .unwrap();
2261
2262 assert_eq!(mqtt.broker_topic_alias_max, 0);
2263 let mut replay = build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 2);
2264 let mut replay_topic_aliases = mqtt.replay_topic_aliases();
2265 assert_eq!(
2266 MqttState::prepare_publish_for_replay_with_aliases(
2267 &mut replay,
2268 &mut replay_topic_aliases
2269 )
2270 .unwrap_err(),
2271 PublishNoticeError::TopicAliasReplayUnavailable(2)
2272 );
2273 }
2274
2275 #[test]
2276 fn connack_receive_max_does_not_shrink_when_tracking_is_non_empty() {
2277 let mut mqtt = MqttState::builder(10).build();
2278 let mut publish = build_outgoing_publish(QoS::AtLeastOnce);
2279 publish.pkid = 8;
2280 mqtt.outgoing_pub[8] = Some(publish);
2281 mqtt.inflight = 1;
2282
2283 mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2284 .unwrap();
2285 mqtt.reconcile_outgoing_tracking_capacity(true);
2286
2287 assert_eq!(mqtt.outgoing_pub.len(), 11);
2288 assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
2289 }
2290
2291 #[test]
2292 fn clone_preserves_current_tracking_queue_lengths() {
2293 let mut mqtt = MqttState::builder(10).build();
2294 mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2295 .unwrap();
2296 mqtt.reconcile_outgoing_tracking_capacity(true);
2297
2298 let cloned = mqtt.clone();
2299 assert_eq!(cloned.outgoing_pub.len(), 4);
2300 assert_eq!(cloned.outgoing_pub_notice.len(), 4);
2301 assert_eq!(cloned.outgoing_rel_notice.len(), 4);
2302 }
2303
2304 #[test]
2305 fn next_pkid_increments_as_expected() {
2306 let mut mqtt = build_mqttstate();
2307
2308 for i in 1..=100 {
2309 let pkid = mqtt.next_pkid();
2310
2311 let expected = i % 100;
2313 if expected == 0 {
2314 break;
2315 }
2316
2317 assert_eq!(expected, pkid);
2318 }
2319 }
2320
2321 #[test]
2322 fn can_send_publish_searches_free_pkid_after_control_ids_pass_inflight_limit() {
2323 let mut mqtt = MqttState::builder(4).build();
2324 let mut active_publish = build_outgoing_publish(QoS::AtLeastOnce);
2325 active_publish.pkid = 1;
2326 mqtt.outgoing_pub[1] = Some(active_publish);
2327 mqtt.inflight = 1;
2328 mqtt.last_pkid = 5;
2329
2330 assert!(mqtt.can_send_publish(&build_outgoing_publish(QoS::AtLeastOnce)));
2331
2332 let packet = mqtt
2333 .outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
2334 .unwrap()
2335 .unwrap();
2336 match packet {
2337 Packet::Publish(publish) => assert_eq!(publish.pkid, 2),
2338 packet => panic!("Unexpected packet: {packet:?}"),
2339 }
2340 }
2341
2342 #[test]
2343 fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() {
2344 let mut mqtt = build_mqttstate();
2345
2346 let publish = build_outgoing_publish(QoS::AtMostOnce);
2348
2349 mqtt.outgoing_publish(publish).unwrap();
2351 assert_eq!(mqtt.last_pkid, 0);
2352 assert_eq!(mqtt.inflight, 0);
2353
2354 let publish = build_outgoing_publish(QoS::AtLeastOnce);
2356
2357 mqtt.outgoing_publish(publish.clone()).unwrap();
2359 assert_eq!(mqtt.last_pkid, 1);
2360 assert_eq!(mqtt.inflight, 1);
2361
2362 mqtt.outgoing_publish(publish).unwrap();
2364 assert_eq!(mqtt.last_pkid, 2);
2365 assert_eq!(mqtt.inflight, 2);
2366
2367 let publish = build_outgoing_publish(QoS::ExactlyOnce);
2369
2370 mqtt.outgoing_publish(publish.clone()).unwrap();
2372 assert_eq!(mqtt.last_pkid, 3);
2373 assert_eq!(mqtt.inflight, 3);
2374
2375 mqtt.outgoing_publish(publish).unwrap();
2377 assert_eq!(mqtt.last_pkid, 4);
2378 assert_eq!(mqtt.inflight, 4);
2379 }
2380
2381 #[test]
2382 fn outgoing_publish_with_max_inflight_is_ok() {
2383 let mut mqtt = MqttState::builder(2).build();
2384
2385 let publish = build_outgoing_publish(QoS::ExactlyOnce);
2387
2388 mqtt.outgoing_publish(publish.clone()).unwrap();
2389 assert_eq!(mqtt.last_pkid, 1);
2390 assert_eq!(mqtt.inflight, 1);
2391
2392 mqtt.outgoing_publish(publish.clone()).unwrap();
2394 assert_eq!(mqtt.last_pkid, 0);
2395 assert_eq!(mqtt.inflight, 2);
2396
2397 mqtt.outgoing_publish(publish.clone()).unwrap();
2399 assert_eq!(mqtt.last_pkid, 1);
2400 assert_eq!(mqtt.inflight, 2);
2401 assert!(mqtt.collision.is_some());
2402
2403 mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
2404 mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
2405 assert_eq!(mqtt.inflight, 1);
2406
2407 mqtt.outgoing_publish(publish).unwrap();
2409 assert_eq!(mqtt.last_pkid, 0);
2410 assert_eq!(mqtt.inflight, 2);
2411 }
2412
2413 #[test]
2414 fn clean_is_calculating_pending_correctly() {
2415 fn build_publish_with_pkid(pkid: u16) -> Publish {
2416 let mut publish = Publish::new("test".to_owned(), QoS::AtLeastOnce, vec![], None);
2417 publish.pkid = pkid;
2418 publish
2419 }
2420
2421 fn build_outgoing_pub() -> Vec<Option<Publish>> {
2422 vec![
2423 None,
2424 Some(build_publish_with_pkid(1)),
2425 Some(build_publish_with_pkid(2)),
2426 Some(build_publish_with_pkid(3)),
2427 None,
2428 None,
2429 Some(build_publish_with_pkid(6)),
2430 ]
2431 }
2432
2433 let mut mqtt = build_mqttstate();
2434 mqtt.outgoing_pub = build_outgoing_pub();
2435 mqtt.last_puback = 3;
2436 let requests = mqtt.clean();
2437 let expected = vec![6, 1, 2, 3];
2438 for (req, pkid) in requests.iter().zip(expected) {
2439 if let Request::Publish(publish) = req {
2440 assert_eq!(publish.pkid, pkid);
2441 } else {
2442 unreachable!();
2443 }
2444 }
2445
2446 mqtt.outgoing_pub = build_outgoing_pub();
2447 mqtt.last_puback = 0;
2448 let requests = mqtt.clean();
2449 let expected = vec![1, 2, 3, 6];
2450 for (req, pkid) in requests.iter().zip(expected) {
2451 if let Request::Publish(publish) = req {
2452 assert_eq!(publish.pkid, pkid);
2453 } else {
2454 unreachable!();
2455 }
2456 }
2457
2458 mqtt.outgoing_pub = build_outgoing_pub();
2459 mqtt.last_puback = 6;
2460 let requests = mqtt.clean();
2461 let expected = vec![1, 2, 3, 6];
2462 for (req, pkid) in requests.iter().zip(expected) {
2463 if let Request::Publish(publish) = req {
2464 assert_eq!(publish.pkid, pkid);
2465 } else {
2466 unreachable!();
2467 }
2468 }
2469 }
2470
2471 #[test]
2472 fn incoming_publish_should_be_added_to_queue_correctly() {
2473 let mut mqtt = build_mqttstate();
2474
2475 let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
2477 let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
2478 let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
2479
2480 mqtt.handle_incoming_publish(&mut publish1).unwrap();
2481 mqtt.handle_incoming_publish(&mut publish2).unwrap();
2482 mqtt.handle_incoming_publish(&mut publish3).unwrap();
2483
2484 assert!(mqtt.incoming_pub.contains(3));
2486 }
2487
2488 #[test]
2489 fn incoming_publish_should_be_acked() {
2490 let mut mqtt = build_mqttstate();
2491
2492 let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
2494 let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
2495 let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
2496
2497 mqtt.handle_incoming_publish(&mut publish1).unwrap();
2498 mqtt.handle_incoming_publish(&mut publish2).unwrap();
2499 mqtt.handle_incoming_publish(&mut publish3).unwrap();
2500
2501 if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] {
2502 assert_eq!(pkid, 2);
2503 } else {
2504 panic!("missing puback");
2505 }
2506
2507 if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] {
2508 assert_eq!(pkid, 3);
2509 } else {
2510 panic!("missing PubRec");
2511 }
2512 }
2513
2514 #[test]
2515 fn incoming_publish_should_not_be_acked_with_manual_acks() {
2516 let mut mqtt = build_mqttstate();
2517 mqtt.manual_acks = true;
2518
2519 let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
2521 let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
2522 let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
2523
2524 mqtt.handle_incoming_publish(&mut publish1).unwrap();
2525 mqtt.handle_incoming_publish(&mut publish2).unwrap();
2526 mqtt.handle_incoming_publish(&mut publish3).unwrap();
2527
2528 assert!(mqtt.incoming_pub.contains(3));
2529 assert!(mqtt.events.is_empty());
2530 }
2531
2532 #[test]
2533 fn unknown_incoming_topic_alias_returns_protocol_error_disconnect() {
2534 let mut mqtt = build_mqttstate();
2535 let mut publish = build_incoming_publish(QoS::AtMostOnce, 0);
2536 publish.topic = Bytes::new();
2537 publish.properties = Some(publish_properties_with_alias(1));
2538
2539 let packet = mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap();
2540
2541 assert!(matches!(
2542 packet,
2543 Packet::Disconnect(disconnect)
2544 if disconnect.reason_code == DisconnectReasonCode::ProtocolError
2545 ));
2546 assert!(publish.topic.is_empty());
2547 }
2548
2549 #[test]
2550 fn handle_incoming_packet_does_not_surface_unknown_topic_alias_publish() {
2551 let mut mqtt = build_mqttstate();
2552 let mut publish = build_incoming_publish(QoS::AtMostOnce, 0);
2553 publish.topic = Bytes::new();
2554 publish.properties = Some(publish_properties_with_alias(1));
2555
2556 let packet = mqtt
2557 .handle_incoming_packet(Incoming::Publish(publish))
2558 .unwrap()
2559 .unwrap();
2560
2561 assert!(matches!(
2562 packet,
2563 Packet::Disconnect(disconnect)
2564 if disconnect.reason_code == DisconnectReasonCode::ProtocolError
2565 ));
2566 assert!(
2567 !mqtt
2568 .events
2569 .iter()
2570 .any(|event| matches!(event, Event::Incoming(Incoming::Publish(_))))
2571 );
2572 assert_eq!(
2573 mqtt.events,
2574 VecDeque::from([Event::Outgoing(Outgoing::Disconnect)])
2575 );
2576 }
2577
2578 #[test]
2579 fn outgoing_reauth_without_properties_synthesizes_connect_authentication_method() {
2580 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2581 let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2582
2583 let packet = mqtt
2584 .handle_outgoing_packet(Request::Auth(auth))
2585 .unwrap()
2586 .unwrap();
2587
2588 let Packet::Auth(auth) = packet else {
2589 panic!("expected AUTH packet");
2590 };
2591 let properties = auth.properties.unwrap();
2592 assert_eq!(properties.method.as_deref(), Some(AUTH_METHOD));
2593 assert_eq!(auth.code, AuthReasonCode::ReAuthenticate);
2594 }
2595
2596 #[test]
2597 fn outgoing_reauth_without_connect_authentication_method_fails() {
2598 let mut mqtt = build_auth_mqttstate(None);
2599 let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2600
2601 let err = mqtt
2602 .handle_outgoing_packet(Request::Auth(auth))
2603 .unwrap_err();
2604
2605 assert!(matches!(err, StateError::AuthError(_)));
2606 }
2607
2608 #[test]
2609 fn public_state_authentication_method_setter_enables_outgoing_reauth() {
2610 let mut mqtt = MqttState::builder(10).build();
2611 mqtt.set_authentication_method(Some(AUTH_METHOD.to_owned()));
2612 let auth = Auth::new(
2613 AuthReasonCode::ReAuthenticate,
2614 Some(auth_properties(Some(AUTH_METHOD))),
2615 );
2616
2617 let packet = mqtt
2618 .handle_outgoing_packet(Request::Auth(auth))
2619 .unwrap()
2620 .unwrap();
2621
2622 let Packet::Auth(auth) = packet else {
2623 panic!("expected AUTH packet");
2624 };
2625 assert_eq!(
2626 auth.properties.unwrap().method.as_deref(),
2627 Some(AUTH_METHOD)
2628 );
2629 }
2630
2631 #[test]
2632 fn outgoing_reauth_fills_missing_authentication_method() {
2633 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2634 let auth = Auth::new(AuthReasonCode::ReAuthenticate, Some(auth_properties(None)));
2635
2636 let packet = mqtt
2637 .handle_outgoing_packet(Request::Auth(auth))
2638 .unwrap()
2639 .unwrap();
2640
2641 let Packet::Auth(auth) = packet else {
2642 panic!("expected AUTH packet");
2643 };
2644 assert_eq!(
2645 auth.properties.unwrap().method.as_deref(),
2646 Some(AUTH_METHOD)
2647 );
2648 }
2649
2650 #[test]
2651 fn outgoing_reauth_rejects_mismatched_authentication_method() {
2652 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2653 let auth = Auth::new(
2654 AuthReasonCode::ReAuthenticate,
2655 Some(auth_properties(Some("other-method"))),
2656 );
2657
2658 let err = mqtt
2659 .handle_outgoing_packet(Request::Auth(auth))
2660 .unwrap_err();
2661
2662 assert!(matches!(err, StateError::AuthError(_)));
2663 }
2664
2665 #[test]
2666 fn tracked_reauth_missing_method_notice_fails_with_specific_error() {
2667 let mut mqtt = build_auth_mqttstate(None);
2668 let (notice_tx, notice) = AuthNoticeTx::new();
2669
2670 let err = mqtt
2671 .handle_outgoing_packet_with_notice(
2672 Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2673 Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2674 )
2675 .unwrap_err();
2676
2677 assert!(matches!(err, StateError::AuthError(_)));
2678 assert_eq!(
2679 notice.wait().unwrap_err(),
2680 AuthNoticeError::MissingAuthenticationMethod
2681 );
2682 }
2683
2684 #[test]
2685 fn tracked_reauth_mismatched_method_notice_fails_with_auth_error() {
2686 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2687 let (notice_tx, notice) = AuthNoticeTx::new();
2688
2689 let err = mqtt
2690 .handle_outgoing_packet_with_notice(
2691 Request::Auth(Auth::new(
2692 AuthReasonCode::ReAuthenticate,
2693 Some(auth_properties(Some("other-method"))),
2694 )),
2695 Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2696 )
2697 .unwrap_err();
2698
2699 assert!(matches!(err, StateError::AuthError(_)));
2700 assert!(matches!(
2701 notice.wait().unwrap_err(),
2702 AuthNoticeError::AuthenticationFailed(_)
2703 ));
2704 }
2705
2706 #[test]
2707 fn tracked_reauth_start_failure_notice_fails_with_auth_error() {
2708 let authenticator = Arc::new(Mutex::new(FailingStartAuthManager));
2709 let mut mqtt = MqttState::builder(10)
2710 .authentication_method(Some(AUTH_METHOD.to_owned()))
2711 .authenticator(authenticator)
2712 .build();
2713 let (notice_tx, notice) = AuthNoticeTx::new();
2714
2715 let err = mqtt
2716 .handle_outgoing_packet_with_notice(
2717 Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2718 Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2719 )
2720 .unwrap_err();
2721
2722 assert!(matches!(err, StateError::AuthError(_)));
2723 assert!(matches!(
2724 notice.wait().unwrap_err(),
2725 AuthNoticeError::AuthenticationFailed(_)
2726 ));
2727 }
2728
2729 #[test]
2730 fn initial_auth_start_returns_normalized_auth_properties() {
2731 let authenticator = Arc::new(Mutex::new(StartAuthManager {
2732 response: Some(auth_properties(None)),
2733 }));
2734 let mut mqtt = MqttState::builder(10)
2735 .authentication_method(Some(AUTH_METHOD.to_owned()))
2736 .authenticator(authenticator)
2737 .build();
2738
2739 let properties = mqtt
2740 .begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2741 .unwrap()
2742 .unwrap();
2743
2744 assert_eq!(properties.method.as_deref(), Some(AUTH_METHOD));
2745 assert_eq!(properties.data, Some(Bytes::from_static(b"auth-data")));
2746 }
2747
2748 #[test]
2749 fn outgoing_reauth_rejects_overlapping_attempt() {
2750 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2751 let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2752
2753 mqtt.handle_outgoing_packet(Request::Auth(auth.clone()))
2754 .unwrap();
2755 let err = mqtt
2756 .handle_outgoing_packet(Request::Auth(auth))
2757 .unwrap_err();
2758
2759 assert!(matches!(err, StateError::AuthError(_)));
2760 }
2761
2762 #[test]
2763 fn tracked_reauth_notice_completes_on_matching_auth_success() {
2764 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2765 let (notice_tx, notice) = AuthNoticeTx::new();
2766 let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2767
2768 mqtt.handle_outgoing_packet_with_notice(
2769 Request::Auth(auth),
2770 Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2771 )
2772 .unwrap();
2773 mqtt.handle_incoming_packet(Incoming::Auth(Auth::new(
2774 AuthReasonCode::Success,
2775 Some(auth_properties(Some(AUTH_METHOD))),
2776 )))
2777 .unwrap();
2778
2779 assert_eq!(notice.wait().unwrap(), crate::AuthOutcome::Success);
2780 }
2781
2782 #[test]
2783 fn disconnect_now_fails_active_tracked_reauth_notice() {
2784 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2785 let (notice_tx, notice) = AuthNoticeTx::new();
2786
2787 mqtt.handle_outgoing_packet_with_notice(
2788 Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2789 Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2790 )
2791 .unwrap();
2792 mqtt.handle_outgoing_packet(Request::DisconnectNow(Disconnect::new(
2793 DisconnectReasonCode::NormalDisconnection,
2794 )))
2795 .unwrap();
2796
2797 assert_eq!(
2798 notice.wait().unwrap_err(),
2799 AuthNoticeError::ConnectionClosed
2800 );
2801 assert!(mqtt.events.iter().any(|event| {
2802 matches!(
2803 event,
2804 Event::Auth(crate::AuthEvent::Failed {
2805 kind: crate::AuthExchangeKind::Reauthentication,
2806 reason: crate::AuthFailureReason::ConnectionClosed,
2807 ..
2808 })
2809 )
2810 }));
2811 }
2812
2813 #[test]
2814 fn tracked_overlapping_reauth_notice_fails() {
2815 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2816 mqtt.handle_outgoing_packet(Request::Auth(Auth::new(
2817 AuthReasonCode::ReAuthenticate,
2818 None,
2819 )))
2820 .unwrap();
2821 let (notice_tx, notice) = AuthNoticeTx::new();
2822
2823 let err = mqtt
2824 .handle_outgoing_packet_with_notice(
2825 Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2826 Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2827 )
2828 .unwrap_err();
2829
2830 assert!(matches!(err, StateError::AuthError(_)));
2831 assert_eq!(
2832 notice.wait().unwrap_err(),
2833 AuthNoticeError::OverlappingReauth
2834 );
2835 }
2836
2837 #[test]
2838 fn incoming_auth_success_without_active_exchange_is_protocol_error() {
2839 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2840 let auth = Auth::new(
2841 AuthReasonCode::Success,
2842 Some(auth_properties(Some(AUTH_METHOD))),
2843 );
2844
2845 let err = mqtt
2846 .handle_incoming_packet(Incoming::Auth(auth))
2847 .unwrap_err();
2848
2849 assert!(matches!(
2850 err,
2851 StateError::Deserialization(Error::ProtocolError)
2852 ));
2853 }
2854
2855 #[test]
2856 fn incoming_auth_success_accepts_matching_authentication_method() {
2857 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2858 mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2859 .unwrap();
2860 let auth = Auth::new(
2861 AuthReasonCode::Success,
2862 Some(auth_properties(Some(AUTH_METHOD))),
2863 );
2864
2865 let packet = mqtt.handle_incoming_packet(Incoming::Auth(auth)).unwrap();
2866
2867 assert!(packet.is_none());
2868 }
2869
2870 #[test]
2871 fn initial_auth_survives_fresh_session_pending_notice_cleanup() {
2872 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2873 mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2874 .unwrap();
2875 mqtt.fail_pending_notices();
2876
2877 let mut connack = build_connack_with_receive_max(10);
2878 connack.properties.as_mut().unwrap().authentication_method = Some(AUTH_METHOD.to_owned());
2879 mqtt.handle_incoming_packet(Incoming::ConnAck(connack))
2880 .unwrap();
2881
2882 assert!(mqtt.events.iter().any(|event| {
2883 matches!(
2884 event,
2885 Event::Auth(crate::AuthEvent::Succeeded {
2886 kind: crate::AuthExchangeKind::InitialConnect,
2887 ..
2888 })
2889 )
2890 }));
2891 assert!(!mqtt.events.iter().any(|event| {
2892 matches!(
2893 event,
2894 Event::Auth(crate::AuthEvent::Failed {
2895 kind: crate::AuthExchangeKind::InitialConnect,
2896 ..
2897 })
2898 )
2899 }));
2900 }
2901
2902 #[test]
2903 fn incoming_auth_success_rejects_missing_authentication_method() {
2904 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2905 mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2906 .unwrap();
2907 let auth = Auth::new(AuthReasonCode::Success, None);
2908
2909 let err = mqtt
2910 .handle_incoming_packet(Incoming::Auth(auth))
2911 .unwrap_err();
2912
2913 assert!(matches!(
2914 err,
2915 StateError::Deserialization(Error::ProtocolError)
2916 ));
2917 }
2918
2919 #[test]
2920 fn incoming_auth_success_rejects_mismatched_authentication_method() {
2921 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2922 mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2923 .unwrap();
2924 let auth = Auth::new(
2925 AuthReasonCode::Success,
2926 Some(auth_properties(Some("other-method"))),
2927 );
2928
2929 let err = mqtt
2930 .handle_incoming_packet(Incoming::Auth(auth))
2931 .unwrap_err();
2932
2933 assert!(matches!(
2934 err,
2935 StateError::Deserialization(Error::ProtocolError)
2936 ));
2937 }
2938
2939 #[test]
2940 fn incoming_auth_success_without_connect_authentication_method_is_protocol_error() {
2941 let mut mqtt = build_auth_mqttstate(None);
2942 let auth = Auth::new(
2943 AuthReasonCode::Success,
2944 Some(auth_properties(Some(AUTH_METHOD))),
2945 );
2946
2947 let err = mqtt
2948 .handle_incoming_packet(Incoming::Auth(auth))
2949 .unwrap_err();
2950
2951 assert!(matches!(
2952 err,
2953 StateError::Deserialization(Error::ProtocolError)
2954 ));
2955 }
2956
2957 #[test]
2958 fn incoming_auth_continue_synthesizes_method_when_auth_manager_omits_it() {
2959 let auth_manager = Arc::new(Mutex::new(StaticAuthManager { response: Ok(None) }));
2960 let mut mqtt = MqttState::builder(10)
2961 .authentication_method(Some(AUTH_METHOD.to_owned()))
2962 .auth_manager(auth_manager)
2963 .build();
2964 mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2965 .unwrap();
2966 let auth = Auth::new(
2967 AuthReasonCode::Continue,
2968 Some(auth_properties(Some(AUTH_METHOD))),
2969 );
2970
2971 let packet = mqtt
2972 .handle_incoming_packet(Incoming::Auth(auth))
2973 .unwrap()
2974 .unwrap();
2975
2976 let Packet::Auth(auth) = packet else {
2977 panic!("expected AUTH packet");
2978 };
2979 assert_eq!(auth.code, AuthReasonCode::Continue);
2980 assert_eq!(
2981 auth.properties.unwrap().method.as_deref(),
2982 Some(AUTH_METHOD)
2983 );
2984 }
2985
2986 #[test]
2987 fn incoming_auth_continue_without_connect_authentication_method_is_protocol_error() {
2988 let auth_manager = Arc::new(Mutex::new(StaticAuthManager { response: Ok(None) }));
2989 let mut mqtt = MqttState::builder(10).auth_manager(auth_manager).build();
2990 let auth = Auth::new(
2991 AuthReasonCode::Continue,
2992 Some(auth_properties(Some(AUTH_METHOD))),
2993 );
2994
2995 let err = mqtt
2996 .handle_incoming_packet(Incoming::Auth(auth))
2997 .unwrap_err();
2998
2999 assert!(matches!(
3000 err,
3001 StateError::Deserialization(Error::ProtocolError)
3002 ));
3003 }
3004
3005 #[test]
3006 fn incoming_auth_continue_rejects_mismatched_server_method() {
3007 let auth_manager = Arc::new(Mutex::new(StaticAuthManager { response: Ok(None) }));
3008 let mut mqtt = MqttState::builder(10)
3009 .authentication_method(Some(AUTH_METHOD.to_owned()))
3010 .auth_manager(auth_manager)
3011 .build();
3012 mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
3013 .unwrap();
3014 let auth = Auth::new(
3015 AuthReasonCode::Continue,
3016 Some(auth_properties(Some("other-method"))),
3017 );
3018
3019 let err = mqtt
3020 .handle_incoming_packet(Incoming::Auth(auth))
3021 .unwrap_err();
3022
3023 assert!(matches!(
3024 err,
3025 StateError::Deserialization(Error::ProtocolError)
3026 ));
3027 }
3028
3029 #[test]
3030 fn incoming_auth_reauthenticate_is_protocol_error() {
3031 let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
3032 let auth = Auth::new(
3033 AuthReasonCode::ReAuthenticate,
3034 Some(auth_properties(Some(AUTH_METHOD))),
3035 );
3036
3037 let err = mqtt
3038 .handle_incoming_packet(Incoming::Auth(auth))
3039 .unwrap_err();
3040
3041 assert!(matches!(
3042 err,
3043 StateError::Deserialization(Error::ProtocolError)
3044 ));
3045 }
3046
3047 #[test]
3048 fn connection_scoped_alias_state_resets_incoming_aliases_and_broker_maximum() {
3049 let mut mqtt = build_mqttstate();
3050 mqtt.broker_topic_alias_max = 10;
3051 let mut aliased = build_incoming_publish(QoS::AtMostOnce, 0);
3052 aliased.properties = Some(publish_properties_with_alias(1));
3053 mqtt.handle_incoming_publish(&mut aliased).unwrap();
3054
3055 let mut alias_only = build_incoming_publish(QoS::AtMostOnce, 0);
3056 alias_only.topic = Bytes::new();
3057 alias_only.properties = Some(publish_properties_with_alias(1));
3058 mqtt.handle_incoming_publish(&mut alias_only).unwrap();
3059 assert_eq!(alias_only.topic, Bytes::from_static(b"hello/world"));
3060
3061 mqtt.reset_connection_scoped_state();
3062
3063 assert_eq!(mqtt.broker_topic_alias_max, 0);
3064 let mut stale_alias = build_incoming_publish(QoS::AtMostOnce, 0);
3065 stale_alias.topic = Bytes::new();
3066 stale_alias.properties = Some(publish_properties_with_alias(1));
3067 let packet = mqtt
3068 .handle_incoming_publish(&mut stale_alias)
3069 .unwrap()
3070 .unwrap();
3071 assert!(matches!(packet, Packet::Disconnect(_)));
3072 }
3073
3074 #[test]
3075 fn replay_publish_with_known_outgoing_alias_restores_topic() {
3076 let mut mqtt = build_mqttstate();
3077 mqtt.broker_topic_alias_max = 10;
3078 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3079 "hello/replay",
3080 QoS::AtMostOnce,
3081 2,
3082 ))
3083 .unwrap();
3084
3085 let mut replay = build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 2);
3086 let mut replay_topic_aliases = mqtt.replay_topic_aliases();
3087
3088 MqttState::prepare_publish_for_replay_with_aliases(&mut replay, &mut replay_topic_aliases)
3089 .unwrap();
3090
3091 assert_eq!(replay.topic, Bytes::from_static(b"hello/replay"));
3092 assert_eq!(
3093 replay
3094 .properties
3095 .as_ref()
3096 .and_then(|props| props.topic_alias),
3097 None
3098 );
3099 }
3100
3101 #[test]
3102 fn replay_publish_with_concrete_topic_strips_stale_alias() {
3103 let mut replay = build_outgoing_publish_with_alias("hello/replay", QoS::AtLeastOnce, 2);
3104 let mut replay_topic_aliases = HashMap::new();
3105
3106 MqttState::prepare_publish_for_replay_with_aliases(&mut replay, &mut replay_topic_aliases)
3107 .unwrap();
3108
3109 assert_eq!(replay.topic, Bytes::from_static(b"hello/replay"));
3110 assert_eq!(
3111 replay
3112 .properties
3113 .as_ref()
3114 .and_then(|props| props.topic_alias),
3115 None
3116 );
3117 assert_eq!(
3118 replay_topic_aliases.get(&2),
3119 Some(&Bytes::from_static(b"hello/replay"))
3120 );
3121 }
3122
3123 #[test]
3124 fn replay_publish_with_stripped_alias_is_valid_when_next_broker_allows_no_aliases() {
3125 let mut replay = build_outgoing_publish_with_alias("hello/replay", QoS::AtLeastOnce, 2);
3126 let mut replay_topic_aliases = HashMap::new();
3127 MqttState::prepare_publish_for_replay_with_aliases(&mut replay, &mut replay_topic_aliases)
3128 .unwrap();
3129
3130 let mut next_connection = build_mqttstate();
3131 next_connection.broker_topic_alias_max = 0;
3132
3133 next_connection
3134 .handle_outgoing_packet(Request::Publish(replay))
3135 .unwrap();
3136 }
3137
3138 #[test]
3139 fn replay_publish_with_unknown_outgoing_alias_fails() {
3140 let mqtt = build_mqttstate();
3141 let mut replay = build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 3);
3142 let mut replay_topic_aliases = mqtt.replay_topic_aliases();
3143
3144 let err = MqttState::prepare_publish_for_replay_with_aliases(
3145 &mut replay,
3146 &mut replay_topic_aliases,
3147 )
3148 .unwrap_err();
3149
3150 assert_eq!(err, PublishNoticeError::TopicAliasReplayUnavailable(3));
3151 assert!(replay.topic.is_empty());
3152 }
3153
3154 #[test]
3155 fn auto_topic_aliases_are_disabled_by_default() {
3156 let mut mqtt = build_mqttstate();
3157 mqtt.broker_topic_alias_max = 10;
3158
3159 let packet = mqtt
3160 .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3161 .unwrap()
3162 .unwrap();
3163
3164 match packet {
3165 Packet::Publish(publish) => {
3166 assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3167 assert_eq!(
3168 publish
3169 .properties
3170 .as_ref()
3171 .and_then(|props| props.topic_alias),
3172 None
3173 );
3174 }
3175 packet => panic!("expected publish, got {packet:?}"),
3176 }
3177 }
3178
3179 #[test]
3180 fn auto_topic_aliases_send_topic_and_alias_before_alias_only_publish() {
3181 let mut mqtt = MqttState::builder(u16::MAX)
3182 .auto_topic_aliases(true)
3183 .build();
3184 mqtt.broker_topic_alias_max = 10;
3185
3186 let first = mqtt
3187 .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3188 .unwrap()
3189 .unwrap();
3190 let second = mqtt
3191 .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3192 .unwrap()
3193 .unwrap();
3194
3195 match first {
3196 Packet::Publish(publish) => {
3197 assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3198 assert_eq!(
3199 publish
3200 .properties
3201 .as_ref()
3202 .and_then(|props| props.topic_alias),
3203 Some(1)
3204 );
3205 }
3206 packet => panic!("expected publish, got {packet:?}"),
3207 }
3208 match second {
3209 Packet::Publish(publish) => {
3210 assert!(publish.topic.is_empty());
3211 assert_eq!(
3212 publish
3213 .properties
3214 .as_ref()
3215 .and_then(|props| props.topic_alias),
3216 Some(1)
3217 );
3218 }
3219 packet => panic!("expected publish, got {packet:?}"),
3220 }
3221 }
3222
3223 #[test]
3224 fn auto_topic_aliases_do_nothing_when_broker_allows_no_aliases() {
3225 let mut mqtt = MqttState::builder(u16::MAX)
3226 .auto_topic_aliases(true)
3227 .build();
3228
3229 let packet = mqtt
3230 .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3231 .unwrap()
3232 .unwrap();
3233
3234 match packet {
3235 Packet::Publish(publish) => {
3236 assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3237 assert_eq!(
3238 publish
3239 .properties
3240 .as_ref()
3241 .and_then(|props| props.topic_alias),
3242 None
3243 );
3244 }
3245 packet => panic!("expected publish, got {packet:?}"),
3246 }
3247 }
3248
3249 #[test]
3250 fn auto_topic_aliases_stop_allocating_when_capacity_is_exhausted() {
3251 let mut mqtt = MqttState::builder(u16::MAX)
3252 .auto_topic_aliases(true)
3253 .build();
3254 mqtt.broker_topic_alias_max = 1;
3255 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3256 .unwrap();
3257
3258 let packet = mqtt
3259 .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3260 .unwrap()
3261 .unwrap();
3262
3263 match packet {
3264 Packet::Publish(publish) => {
3265 assert_eq!(publish.topic, Bytes::from_static(b"topic/two"));
3266 assert_eq!(
3267 publish
3268 .properties
3269 .as_ref()
3270 .and_then(|props| props.topic_alias),
3271 None
3272 );
3273 }
3274 packet => panic!("expected publish, got {packet:?}"),
3275 }
3276 }
3277
3278 #[test]
3279 fn auto_topic_aliases_preserve_manual_aliases_and_skip_used_aliases() {
3280 let mut mqtt = MqttState::builder(u16::MAX)
3281 .auto_topic_aliases(true)
3282 .build();
3283 mqtt.broker_topic_alias_max = 2;
3284 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3285 "manual/topic",
3286 QoS::AtMostOnce,
3287 1,
3288 ))
3289 .unwrap();
3290
3291 let packet = mqtt
3292 .outgoing_publish(Publish::new("auto/topic", QoS::AtMostOnce, vec![], None))
3293 .unwrap()
3294 .unwrap();
3295
3296 match packet {
3297 Packet::Publish(publish) => {
3298 assert_eq!(publish.topic, Bytes::from_static(b"auto/topic"));
3299 assert_eq!(
3300 publish
3301 .properties
3302 .as_ref()
3303 .and_then(|props| props.topic_alias),
3304 Some(2)
3305 );
3306 }
3307 packet => panic!("expected publish, got {packet:?}"),
3308 }
3309 }
3310
3311 #[test]
3312 fn manual_rebind_clears_stale_auto_topic_alias_mapping() {
3313 let mut mqtt = MqttState::builder(u16::MAX)
3314 .auto_topic_aliases(true)
3315 .build();
3316 mqtt.broker_topic_alias_max = 2;
3317 mqtt.outgoing_publish(Publish::new("auto/topic", QoS::AtMostOnce, vec![], None))
3318 .unwrap();
3319 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3320 "manual/topic",
3321 QoS::AtMostOnce,
3322 1,
3323 ))
3324 .unwrap();
3325
3326 let packet = mqtt
3327 .outgoing_publish(Publish::new("auto/topic", QoS::AtMostOnce, vec![], None))
3328 .unwrap()
3329 .unwrap();
3330
3331 match packet {
3332 Packet::Publish(publish) => {
3333 assert_eq!(publish.topic, Bytes::from_static(b"auto/topic"));
3334 assert_eq!(
3335 publish
3336 .properties
3337 .as_ref()
3338 .and_then(|props| props.topic_alias),
3339 Some(2)
3340 );
3341 }
3342 packet => panic!("expected publish, got {packet:?}"),
3343 }
3344 }
3345
3346 #[test]
3347 fn auto_topic_alias_qos_replay_uses_full_topic_after_clean() {
3348 let mut mqtt = MqttState::builder(u16::MAX)
3349 .auto_topic_aliases(true)
3350 .build();
3351 mqtt.broker_topic_alias_max = 10;
3352 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3353 .unwrap();
3354 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3355 .unwrap();
3356
3357 let requests = mqtt.clean();
3358
3359 assert_eq!(requests.len(), 1);
3360 match &requests[0] {
3361 Request::Publish(publish) => {
3362 assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3363 assert_eq!(
3364 publish
3365 .properties
3366 .as_ref()
3367 .and_then(|props| props.topic_alias),
3368 None
3369 );
3370 }
3371 request => panic!("expected replay publish, got {request:?}"),
3372 }
3373 }
3374
3375 #[test]
3376 fn auto_topic_alias_collision_does_not_register_unsent_alias() {
3377 let mut mqtt = MqttState::builder(2).auto_topic_aliases(true).build();
3378 mqtt.broker_topic_alias_max = 10;
3379 mqtt.outgoing_publish(Publish::new(
3380 "inflight/topic",
3381 QoS::AtLeastOnce,
3382 vec![],
3383 None,
3384 ))
3385 .unwrap();
3386
3387 let mut collided = Publish::new("collided/topic", QoS::AtLeastOnce, vec![], None);
3388 collided.pkid = 1;
3389 let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3390 assert!(packet.is_none());
3391 assert!(flush_notice.is_none());
3392 assert!(mqtt.collision.is_some());
3393
3394 let packet = mqtt
3395 .outgoing_publish(Publish::new(
3396 "collided/topic",
3397 QoS::AtMostOnce,
3398 vec![],
3399 None,
3400 ))
3401 .unwrap()
3402 .unwrap();
3403
3404 match packet {
3405 Packet::Publish(publish) => {
3406 assert_eq!(publish.topic, Bytes::from_static(b"collided/topic"));
3407 assert_eq!(
3408 publish
3409 .properties
3410 .as_ref()
3411 .and_then(|props| props.topic_alias),
3412 Some(2)
3413 );
3414 }
3415 packet => panic!("expected publish, got {packet:?}"),
3416 }
3417 }
3418
3419 #[test]
3420 fn auto_topic_alias_collision_replay_does_not_send_uncommitted_alias() {
3421 let mut mqtt = MqttState::builder(2).auto_topic_aliases(true).build();
3422 mqtt.broker_topic_alias_max = 10;
3423 let first_notice = queue_publish_with_notice(
3424 &mut mqtt,
3425 Publish::new("inflight/topic", QoS::AtLeastOnce, vec![], None),
3426 );
3427
3428 let mut collided = Publish::new("collided/topic", QoS::AtLeastOnce, vec![], None);
3429 collided.pkid = 1;
3430 let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3431 assert!(packet.is_none());
3432 assert!(flush_notice.is_none());
3433
3434 let puback = PubAck::new(1, None);
3435 let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
3436
3437 match packet {
3438 Packet::Publish(publish) => {
3439 assert_eq!(publish.topic, Bytes::from_static(b"collided/topic"));
3440 assert_eq!(
3441 publish
3442 .properties
3443 .as_ref()
3444 .and_then(|props| props.topic_alias),
3445 None
3446 );
3447 }
3448 packet => panic!("expected publish, got {packet:?}"),
3449 }
3450 assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
3451 }
3452
3453 #[test]
3454 fn auto_topic_alias_collision_replay_restores_reused_alias_topic_after_rebind() {
3455 let mut mqtt = MqttState::builder(2).auto_topic_aliases(true).build();
3456 mqtt.broker_topic_alias_max = 10;
3457 let first_notice = queue_publish_with_notice(
3458 &mut mqtt,
3459 Publish::new("aliased/topic", QoS::AtLeastOnce, vec![], None),
3460 );
3461
3462 let mut collided = Publish::new("aliased/topic", QoS::AtLeastOnce, vec![], None);
3463 collided.pkid = 1;
3464 let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3465 assert!(packet.is_none());
3466 assert!(flush_notice.is_none());
3467
3468 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3469 "manual/rebind",
3470 QoS::AtMostOnce,
3471 1,
3472 ))
3473 .unwrap();
3474
3475 let puback = PubAck::new(1, None);
3476 let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
3477
3478 match packet {
3479 Packet::Publish(publish) => {
3480 assert_eq!(publish.topic, Bytes::from_static(b"aliased/topic"));
3481 assert_eq!(
3482 publish
3483 .properties
3484 .as_ref()
3485 .and_then(|props| props.topic_alias),
3486 None
3487 );
3488 }
3489 packet => panic!("expected publish, got {packet:?}"),
3490 }
3491 assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
3492 }
3493
3494 #[test]
3495 fn auto_topic_aliases_exhaust_without_wrapping_at_u16_max() {
3496 let mut mqtt = MqttState::builder(u16::MAX)
3497 .auto_topic_aliases(true)
3498 .build();
3499 mqtt.broker_topic_alias_max = u16::MAX;
3500 mqtt.next_auto_topic_alias = Some(u16::MAX);
3501
3502 let last_packet = mqtt
3503 .outgoing_publish(Publish::new("last/topic", QoS::AtMostOnce, vec![], None))
3504 .unwrap()
3505 .unwrap();
3506 match last_packet {
3507 Packet::Publish(publish) => {
3508 assert_eq!(publish.topic, Bytes::from_static(b"last/topic"));
3509 assert_eq!(
3510 publish
3511 .properties
3512 .as_ref()
3513 .and_then(|props| props.topic_alias),
3514 Some(u16::MAX)
3515 );
3516 }
3517 packet => panic!("expected publish, got {packet:?}"),
3518 }
3519 assert_eq!(mqtt.next_auto_topic_alias, None);
3520
3521 let exhausted_packet = mqtt
3522 .outgoing_publish(Publish::new(
3523 "exhausted/topic",
3524 QoS::AtMostOnce,
3525 vec![],
3526 None,
3527 ))
3528 .unwrap()
3529 .unwrap();
3530
3531 match exhausted_packet {
3532 Packet::Publish(publish) => {
3533 assert_eq!(publish.topic, Bytes::from_static(b"exhausted/topic"));
3534 assert_eq!(
3535 publish
3536 .properties
3537 .as_ref()
3538 .and_then(|props| props.topic_alias),
3539 None
3540 );
3541 }
3542 packet => panic!("expected publish, got {packet:?}"),
3543 }
3544 }
3545
3546 #[test]
3547 fn lru_auto_topic_aliases_evict_least_recent_topic() {
3548 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 2);
3549
3550 let first = mqtt
3551 .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3552 .unwrap()
3553 .unwrap();
3554 let second = mqtt
3555 .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3556 .unwrap()
3557 .unwrap();
3558 let third = mqtt
3559 .outgoing_publish(Publish::new("topic/three", QoS::AtMostOnce, vec![], None))
3560 .unwrap()
3561 .unwrap();
3562
3563 assert_publish(first, b"topic/one", Some(1));
3564 assert_publish(second, b"topic/two", Some(2));
3565 assert_publish(third, b"topic/three", Some(1));
3566
3567 let packet = mqtt
3568 .outgoing_publish(Publish::new("topic/three", QoS::AtMostOnce, vec![], None))
3569 .unwrap()
3570 .unwrap();
3571 assert_publish(packet, b"", Some(1));
3572 }
3573
3574 #[test]
3575 fn lru_auto_topic_aliases_refresh_existing_topic_recency() {
3576 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 2);
3577
3578 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3579 .unwrap();
3580 mqtt.outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3581 .unwrap();
3582 let refresh = mqtt
3583 .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3584 .unwrap()
3585 .unwrap();
3586 let evict = mqtt
3587 .outgoing_publish(Publish::new("topic/three", QoS::AtMostOnce, vec![], None))
3588 .unwrap()
3589 .unwrap();
3590
3591 assert_publish(refresh, b"", Some(1));
3592 assert_publish(evict, b"topic/three", Some(2));
3593 }
3594
3595 #[test]
3596 fn lru_auto_topic_aliases_rebound_alias_sends_full_topic_then_alias_only() {
3597 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 1);
3598
3599 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3600 .unwrap();
3601 let rebound = mqtt
3602 .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3603 .unwrap()
3604 .unwrap();
3605 let alias_only = mqtt
3606 .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3607 .unwrap()
3608 .unwrap();
3609
3610 assert_publish(rebound, b"topic/two", Some(1));
3611 assert_publish(alias_only, b"", Some(1));
3612 }
3613
3614 #[test]
3615 fn lru_auto_topic_aliases_do_not_evict_manual_aliases() {
3616 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 2);
3617 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3618 "manual/topic",
3619 QoS::AtMostOnce,
3620 1,
3621 ))
3622 .unwrap();
3623
3624 let first_auto = mqtt
3625 .outgoing_publish(Publish::new("auto/one", QoS::AtMostOnce, vec![], None))
3626 .unwrap()
3627 .unwrap();
3628 let second_auto = mqtt
3629 .outgoing_publish(Publish::new("auto/two", QoS::AtMostOnce, vec![], None))
3630 .unwrap()
3631 .unwrap();
3632
3633 assert_publish(first_auto, b"auto/one", Some(2));
3634 assert_publish(second_auto, b"auto/two", Some(2));
3635 assert_eq!(
3636 mqtt.outgoing_topic_aliases.get(&1),
3637 Some(&Bytes::from_static(b"manual/topic"))
3638 );
3639 }
3640
3641 #[test]
3642 fn lru_auto_topic_aliases_reset_on_reconnect() {
3643 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 1);
3644 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3645 .unwrap();
3646
3647 mqtt.reset_connection_scoped_state();
3648 mqtt.broker_topic_alias_max = 1;
3649 let packet = mqtt
3650 .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3651 .unwrap()
3652 .unwrap();
3653
3654 assert_publish(packet, b"topic/one", Some(1));
3655 }
3656
3657 #[test]
3658 fn lru_auto_topic_alias_qos_replay_after_eviction_uses_full_topic() {
3659 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 1);
3660 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3661 .unwrap();
3662 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtLeastOnce, vec![], None))
3663 .unwrap();
3664 mqtt.outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3665 .unwrap();
3666
3667 let requests = mqtt.clean();
3668
3669 assert_eq!(requests.len(), 1);
3670 match &requests[0] {
3671 Request::Publish(publish) => {
3672 assert_eq!(publish.topic, Bytes::from_static(b"topic/one"));
3673 assert_eq!(
3674 publish
3675 .properties
3676 .as_ref()
3677 .and_then(|props| props.topic_alias),
3678 None
3679 );
3680 }
3681 request => panic!("expected replay publish, got {request:?}"),
3682 }
3683 }
3684
3685 #[test]
3686 fn lru_auto_topic_alias_collision_during_rebind_does_not_commit_rebind() {
3687 let mut mqtt = build_lru_auto_alias_mqttstate(2, 1);
3688 mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtLeastOnce, vec![], None))
3689 .unwrap();
3690
3691 let mut collided = Publish::new("topic/two", QoS::AtLeastOnce, vec![], None);
3692 collided.pkid = 1;
3693 let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3694 assert!(packet.is_none());
3695 assert!(flush_notice.is_none());
3696
3697 let packet = mqtt
3698 .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3699 .unwrap()
3700 .unwrap();
3701 assert_publish(packet, b"", Some(1));
3702 }
3703
3704 #[test]
3705 fn lru_auto_topic_alias_collision_replay_after_later_rebind_uses_original_topic() {
3706 let mut mqtt = build_lru_auto_alias_mqttstate(2, 1);
3707 let first_notice = queue_publish_with_notice(
3708 &mut mqtt,
3709 Publish::new("topic/one", QoS::AtLeastOnce, vec![], None),
3710 );
3711
3712 let mut collided = Publish::new("topic/one", QoS::AtLeastOnce, vec![], None);
3713 collided.pkid = 1;
3714 let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3715 assert!(packet.is_none());
3716 assert!(flush_notice.is_none());
3717
3718 mqtt.outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3719 .unwrap();
3720
3721 let puback = PubAck::new(1, None);
3722 let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
3723
3724 assert_publish(packet, b"topic/one", None);
3725 assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
3726 }
3727
3728 #[test]
3729 fn lru_auto_topic_aliases_do_not_wrap_at_u16_max() {
3730 let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, u16::MAX);
3731 mqtt.next_auto_topic_alias = Some(u16::MAX);
3732
3733 let last_packet = mqtt
3734 .outgoing_publish(Publish::new("last/topic", QoS::AtMostOnce, vec![], None))
3735 .unwrap()
3736 .unwrap();
3737 let rebound_packet = mqtt
3738 .outgoing_publish(Publish::new("rebound/topic", QoS::AtMostOnce, vec![], None))
3739 .unwrap()
3740 .unwrap();
3741
3742 assert_publish(last_packet, b"last/topic", Some(u16::MAX));
3743 assert_eq!(mqtt.next_auto_topic_alias, None);
3744 assert_publish(rebound_packet, b"rebound/topic", Some(u16::MAX));
3745 }
3746
3747 #[test]
3748 fn public_clean_repairs_alias_only_publish_when_mapping_is_known() {
3749 let mut mqtt = build_mqttstate();
3750 mqtt.broker_topic_alias_max = 10;
3751 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3752 "hello/replay",
3753 QoS::AtMostOnce,
3754 2,
3755 ))
3756 .unwrap();
3757 mqtt.outgoing_publish(build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 2))
3758 .unwrap();
3759
3760 let requests = mqtt.clean();
3761
3762 assert_eq!(requests.len(), 1);
3763 match &requests[0] {
3764 Request::Publish(publish) => {
3765 assert_eq!(publish.topic, Bytes::from_static(b"hello/replay"));
3766 assert_eq!(
3767 publish
3768 .properties
3769 .as_ref()
3770 .and_then(|props| props.topic_alias),
3771 None
3772 );
3773 }
3774 request => panic!("expected publish replay, got {request:?}"),
3775 }
3776 assert_eq!(mqtt.broker_topic_alias_max, 0);
3777 }
3778
3779 #[test]
3780 fn public_clean_preserves_alias_only_publish_topic_from_send_time_after_rebind() {
3781 let mut mqtt = build_mqttstate();
3782 mqtt.broker_topic_alias_max = 10;
3783 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3784 "topic/a",
3785 QoS::AtMostOnce,
3786 1,
3787 ))
3788 .unwrap();
3789 mqtt.outgoing_publish(build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 1))
3790 .unwrap();
3791 mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3792 "topic/b",
3793 QoS::AtMostOnce,
3794 1,
3795 ))
3796 .unwrap();
3797
3798 let requests = mqtt.clean();
3799
3800 assert_eq!(requests.len(), 1);
3801 match &requests[0] {
3802 Request::Publish(publish) => {
3803 assert_eq!(publish.topic, Bytes::from_static(b"topic/a"));
3804 assert_eq!(
3805 publish
3806 .properties
3807 .as_ref()
3808 .and_then(|props| props.topic_alias),
3809 None
3810 );
3811 }
3812 request => panic!("expected publish replay, got {request:?}"),
3813 }
3814 }
3815
3816 #[test]
3817 fn public_clean_drops_alias_only_publish_when_mapping_is_unknown() {
3818 let mut mqtt = build_mqttstate();
3819 mqtt.broker_topic_alias_max = 10;
3820 mqtt.outgoing_publish(build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 3))
3821 .unwrap();
3822
3823 let requests = mqtt.clean();
3824
3825 assert!(requests.is_empty());
3826 assert_eq!(mqtt.broker_topic_alias_max, 0);
3827 }
3828
3829 #[test]
3830 fn handle_incoming_packet_should_emit_incoming_before_derived_qos1_ack() {
3831 let mut mqtt = build_mqttstate();
3832 let publish = build_incoming_publish(QoS::AtLeastOnce, 42);
3833
3834 mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
3835 .unwrap();
3836
3837 assert_eq!(mqtt.events.len(), 2);
3838 assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
3839 assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubAck(42)));
3840 }
3841
3842 #[test]
3843 fn handle_incoming_packet_should_emit_incoming_before_derived_qos2_ack() {
3844 let mut mqtt = build_mqttstate();
3845 let publish = build_incoming_publish(QoS::ExactlyOnce, 43);
3846
3847 mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
3848 .unwrap();
3849
3850 assert_eq!(mqtt.events.len(), 2);
3851 assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
3852 assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubRec(43)));
3853 }
3854
3855 #[test]
3856 fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
3857 let mut mqtt = build_mqttstate();
3858 let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1);
3859
3860 match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() {
3861 Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
3862 packet => panic!("Invalid network request: {packet:?}"),
3863 }
3864 }
3865
3866 #[test]
3867 fn incoming_puback_should_remove_correct_publish_from_queue() {
3868 let mut mqtt = build_mqttstate();
3869
3870 let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
3871 let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
3872
3873 mqtt.outgoing_publish(publish1).unwrap();
3874 mqtt.outgoing_publish(publish2).unwrap();
3875 assert_eq!(mqtt.inflight, 2);
3876
3877 mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
3878 assert_eq!(mqtt.inflight, 1);
3879
3880 mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3881 assert_eq!(mqtt.inflight, 0);
3882
3883 assert!(mqtt.outgoing_pub[1].is_none());
3884 assert!(mqtt.outgoing_pub[2].is_none());
3885 }
3886
3887 #[test]
3888 fn incoming_puback_updates_last_puback() {
3889 let mut mqtt = build_mqttstate();
3890
3891 let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
3892 let publish2 = build_outgoing_publish(QoS::AtLeastOnce);
3893 mqtt.outgoing_publish(publish1).unwrap();
3894 mqtt.outgoing_publish(publish2).unwrap();
3895 assert_eq!(mqtt.last_puback, 0);
3896
3897 mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
3898 assert_eq!(mqtt.last_puback, 1);
3899
3900 mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3901 assert_eq!(mqtt.last_puback, 2);
3902 }
3903
3904 #[test]
3905 fn incoming_puback_advances_last_puback_only_on_contiguous_boundary() {
3906 let mut mqtt = build_mqttstate();
3907
3908 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3909 .unwrap();
3910 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3911 .unwrap();
3912 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3913 .unwrap();
3914 assert_eq!(mqtt.last_puback, 0);
3915
3916 mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3917 assert_eq!(mqtt.last_puback, 0);
3918
3919 mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
3920 assert_eq!(mqtt.last_puback, 2);
3921
3922 mqtt.handle_incoming_puback(&PubAck::new(3, None)).unwrap();
3923 assert_eq!(mqtt.last_puback, 3);
3924 }
3925
3926 #[test]
3927 fn mixed_qos_completion_clears_outbound_drain_state() {
3928 let mut mqtt = build_mqttstate();
3929
3930 mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
3931 .unwrap();
3932 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3933 .unwrap();
3934 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3935 .unwrap();
3936 mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
3937 .unwrap();
3938
3939 mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
3940 mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3941 mqtt.handle_incoming_puback(&PubAck::new(3, None)).unwrap();
3942 mqtt.handle_incoming_pubcomp(&PubComp::new(1, None))
3943 .unwrap();
3944 mqtt.handle_incoming_pubrec(&PubRec::new(4, None)).unwrap();
3945 mqtt.handle_incoming_pubcomp(&PubComp::new(4, None))
3946 .unwrap();
3947
3948 assert_eq!(mqtt.inflight, 0);
3949 assert!(mqtt.outbound_requests_drained());
3950 assert!(mqtt.outgoing_pub_ack.ones().next().is_none());
3951 assert!(mqtt.outgoing_rel.ones().next().is_none());
3952 }
3953
3954 #[test]
3955 fn clean_keeps_oldest_unacked_publish_first_after_out_of_order_puback() {
3956 let mut mqtt = build_mqttstate();
3957
3958 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3959 .unwrap();
3960 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3961 .unwrap();
3962 mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3963 .unwrap();
3964
3965 mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3966 let requests = mqtt.clean();
3967
3968 let pending_pkids: Vec<u16> = requests
3969 .iter()
3970 .map(|req| match req {
3971 Request::Publish(publish) => publish.pkid,
3972 req => panic!("Unexpected request while cleaning: {req:?}"),
3973 })
3974 .collect();
3975
3976 assert_eq!(pending_pkids, vec![1, 3]);
3977 }
3978
3979 #[test]
3980 fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() {
3981 let mut mqtt = build_mqttstate();
3982
3983 let got = mqtt
3984 .handle_incoming_puback(&PubAck::new(101, None))
3985 .unwrap_err();
3986
3987 match got {
3988 StateError::Unsolicited(pkid) => assert_eq!(pkid, 101),
3989 e => panic!("Unexpected error: {e}"),
3990 }
3991 }
3992
3993 #[test]
3994 fn incoming_puback_failure_collision_replays_blocked_publish() {
3995 let mut mqtt = build_mqttstate();
3996 let first_notice =
3997 queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::AtLeastOnce));
3998
3999 let (collided_tx, _collided_notice) = PublishNoticeTx::new();
4000 let mut collided = build_outgoing_publish(QoS::AtLeastOnce);
4001 collided.pkid = 1;
4002 let (packet, flush_notice) = mqtt
4003 .outgoing_publish_with_notice(collided, Some(collided_tx))
4004 .unwrap();
4005 assert!(packet.is_none());
4006 assert!(flush_notice.is_none());
4007 assert!(mqtt.collision.is_some());
4008
4009 let mut puback = PubAck::new(1, None);
4010 puback.reason = PubAckReason::ImplementationSpecificError;
4011
4012 let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
4013 match packet {
4014 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4015 packet => panic!("Invalid network request: {packet:?}"),
4016 }
4017
4018 assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
4019 assert_eq!(mqtt.inflight, 1);
4020 assert!(mqtt.collision.is_none());
4021 }
4022
4023 #[test]
4024 fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() {
4025 let mut mqtt = build_mqttstate();
4026
4027 let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
4028 let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
4029
4030 let _publish_out = mqtt.outgoing_publish(publish1);
4031 let _publish_out = mqtt.outgoing_publish(publish2);
4032
4033 mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap();
4034 assert_eq!(mqtt.inflight, 2);
4035
4036 let backup = mqtt.outgoing_pub[1].clone();
4038 assert_eq!(backup.unwrap().pkid, 1);
4039
4040 assert!(mqtt.outgoing_rel.contains(2));
4042 }
4043
4044 #[test]
4045 fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() {
4046 let mut mqtt = build_mqttstate();
4047
4048 let publish = build_outgoing_publish(QoS::ExactlyOnce);
4049 match mqtt.outgoing_publish(publish).unwrap().unwrap() {
4050 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4051 packet => panic!("Invalid network request: {packet:?}"),
4052 }
4053
4054 match mqtt
4055 .handle_incoming_pubrec(&PubRec::new(1, None))
4056 .unwrap()
4057 .unwrap()
4058 {
4059 Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
4060 packet => panic!("Invalid network request: {packet:?}"),
4061 }
4062 }
4063
4064 #[test]
4065 fn incoming_pubrec_failure_without_collision_decrements_inflight() {
4066 let mut mqtt = build_mqttstate();
4067 let first_notice =
4068 queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4069
4070 let mut pubrec = PubRec::new(1, None);
4071 pubrec.reason = PubRecReason::ImplementationSpecificError;
4072
4073 assert!(mqtt.handle_incoming_pubrec(&pubrec).unwrap().is_none());
4074 assert_eq!(
4075 first_notice.wait(),
4076 Ok(PublishResult::Qos2PubRecRejected(pubrec))
4077 );
4078 assert_eq!(mqtt.inflight, 0);
4079 assert!(mqtt.outgoing_pub[1].is_none());
4080 assert!(!mqtt.outgoing_rel.contains(1));
4081 }
4082
4083 #[test]
4084 fn incoming_pubrec_failure_releases_inflight_and_replays_collision() {
4085 let mut mqtt = build_mqttstate();
4086 let first_notice =
4087 queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4088
4089 let (collided_tx, _collided_notice) = PublishNoticeTx::new();
4090 let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
4091 collided.pkid = 1;
4092 let (packet, flush_notice) = mqtt
4093 .outgoing_publish_with_notice(collided, Some(collided_tx))
4094 .unwrap();
4095 assert!(packet.is_none());
4096 assert!(flush_notice.is_none());
4097 assert!(mqtt.collision.is_some());
4098
4099 let mut pubrec = PubRec::new(1, None);
4100 pubrec.reason = PubRecReason::ImplementationSpecificError;
4101
4102 let packet = mqtt.handle_incoming_pubrec(&pubrec).unwrap().unwrap();
4103 match packet {
4104 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4105 packet => panic!("Invalid network request: {packet:?}"),
4106 }
4107
4108 assert_eq!(
4109 first_notice.wait(),
4110 Ok(PublishResult::Qos2PubRecRejected(pubrec))
4111 );
4112 assert_eq!(mqtt.inflight, 1);
4113 assert!(mqtt.collision.is_none());
4114 assert!(!mqtt.outgoing_rel.contains(1));
4115
4116 let packet = mqtt
4117 .handle_incoming_pubrec(&PubRec::new(1, None))
4118 .unwrap()
4119 .unwrap();
4120 match packet {
4121 Packet::PubRel(release) => assert_eq!(release.pkid, 1),
4122 packet => panic!("Invalid network request: {packet:?}"),
4123 }
4124 }
4125
4126 #[test]
4127 fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() {
4128 let mut mqtt = build_mqttstate();
4129 let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1);
4130
4131 match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() {
4132 Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
4133 packet => panic!("Invalid network request: {packet:?}"),
4134 }
4135
4136 match mqtt
4137 .handle_incoming_pubrel(&PubRel::new(1, None))
4138 .unwrap()
4139 .unwrap()
4140 {
4141 Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1),
4142 packet => panic!("Invalid network request: {packet:?}"),
4143 }
4144 }
4145
4146 #[test]
4147 fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() {
4148 let mut mqtt = build_mqttstate();
4149 let publish = build_outgoing_publish(QoS::ExactlyOnce);
4150
4151 mqtt.outgoing_publish(publish).unwrap();
4152 mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4153
4154 mqtt.handle_incoming_pubcomp(&PubComp::new(1, None))
4155 .unwrap();
4156 assert_eq!(mqtt.inflight, 0);
4157 }
4158
4159 #[test]
4160 fn incoming_pubcomp_failure_without_collision_decrements_inflight() {
4161 let mut mqtt = build_mqttstate();
4162 let first_notice =
4163 queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4164 mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4165
4166 let mut pubcomp = PubComp::new(1, None);
4167 pubcomp.reason = PubCompReason::PacketIdentifierNotFound;
4168
4169 assert!(mqtt.handle_incoming_pubcomp(&pubcomp).unwrap().is_none());
4170 assert_eq!(
4171 first_notice.wait(),
4172 Ok(PublishResult::Qos2Completed(pubcomp))
4173 );
4174 assert_eq!(mqtt.inflight, 0);
4175 assert!(!mqtt.outgoing_rel.contains(1));
4176 }
4177
4178 #[test]
4179 fn incoming_pubcomp_collision_replay_should_restore_qos2_tracking() {
4180 let mut mqtt = build_mqttstate();
4181 let publish = build_outgoing_publish(QoS::ExactlyOnce);
4182 mqtt.outgoing_publish(publish).unwrap();
4183
4184 let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
4185 collided.pkid = 1;
4186 assert!(mqtt.outgoing_publish(collided).unwrap().is_none());
4187 assert!(mqtt.collision.is_some());
4188
4189 mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4190 let packet = mqtt
4191 .handle_incoming_pubcomp(&PubComp::new(1, None))
4192 .unwrap()
4193 .unwrap();
4194 match packet {
4195 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4196 packet => panic!("Invalid network request: {packet:?}"),
4197 }
4198
4199 assert!(mqtt.outgoing_pub[1].is_some());
4200 assert_eq!(mqtt.inflight, 1);
4201
4202 let packet = mqtt
4203 .handle_incoming_pubrec(&PubRec::new(1, None))
4204 .unwrap()
4205 .unwrap();
4206 match packet {
4207 Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
4208 packet => panic!("Invalid network request: {packet:?}"),
4209 }
4210 }
4211
4212 #[test]
4213 fn incoming_pubcomp_failure_replays_collision_and_preserves_qos2_tracking() {
4214 let mut mqtt = build_mqttstate();
4215 let first_notice =
4216 queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4217
4218 let (collided_tx, _collided_notice) = PublishNoticeTx::new();
4219 let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
4220 collided.pkid = 1;
4221 let (packet, flush_notice) = mqtt
4222 .outgoing_publish_with_notice(collided, Some(collided_tx))
4223 .unwrap();
4224 assert!(packet.is_none());
4225 assert!(flush_notice.is_none());
4226 assert!(mqtt.collision.is_some());
4227
4228 mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4229
4230 let mut pubcomp = PubComp::new(1, None);
4231 pubcomp.reason = PubCompReason::PacketIdentifierNotFound;
4232
4233 let packet = mqtt.handle_incoming_pubcomp(&pubcomp).unwrap().unwrap();
4234 match packet {
4235 Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4236 packet => panic!("Invalid network request: {packet:?}"),
4237 }
4238
4239 assert_eq!(
4240 first_notice.wait(),
4241 Ok(PublishResult::Qos2Completed(pubcomp))
4242 );
4243 assert_eq!(mqtt.inflight, 1);
4244 assert!(mqtt.collision.is_none());
4245 assert!(!mqtt.outgoing_rel.contains(1));
4246 assert!(mqtt.outgoing_pub[1].is_some());
4247
4248 let packet = mqtt
4249 .handle_incoming_pubrec(&PubRec::new(1, None))
4250 .unwrap()
4251 .unwrap();
4252 match packet {
4253 Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
4254 packet => panic!("Invalid network request: {packet:?}"),
4255 }
4256 }
4257
4258 #[test]
4259 fn outgoing_disconnect_should_preserve_reason_and_properties() {
4260 let mut mqtt = build_mqttstate();
4261 let properties = DisconnectProperties {
4262 session_expiry_interval: Some(60),
4263 reason_string: Some("disconnect test".to_string()),
4264 user_properties: vec![("key".to_string(), "value".to_string())],
4265 server_reference: Some("broker-2".to_string()),
4266 };
4267 let disconnect = Disconnect::new_with_properties(
4268 DisconnectReasonCode::ImplementationSpecificError,
4269 properties,
4270 );
4271
4272 let packet = mqtt
4273 .handle_outgoing_packet(Request::DisconnectNow(disconnect.clone()))
4274 .unwrap()
4275 .unwrap();
4276 assert_eq!(packet, Packet::Disconnect(disconnect));
4277 assert!(matches!(
4278 mqtt.events.back(),
4279 Some(Event::Outgoing(Outgoing::Disconnect))
4280 ));
4281 }
4282
4283 #[test]
4284 fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() {
4285 let mut mqtt = build_mqttstate();
4286 mqtt.outgoing_ping().unwrap();
4287
4288 let publish = build_outgoing_publish(QoS::AtLeastOnce);
4290 mqtt.handle_outgoing_packet(Request::Publish(publish))
4291 .unwrap();
4292 mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None)))
4293 .unwrap();
4294
4295 match mqtt.outgoing_ping() {
4297 Ok(_) => panic!("Should throw pingresp await error"),
4298 Err(StateError::AwaitPingResp) => (),
4299 Err(e) => panic!("Should throw pingresp await error. Error = {e:?}"),
4300 }
4301 }
4302
4303 #[test]
4304 fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() {
4305 let mut mqtt = build_mqttstate();
4306
4307 mqtt.outgoing_ping().unwrap();
4309 mqtt.handle_incoming_packet(Incoming::PingResp(PingResp))
4310 .unwrap();
4311
4312 mqtt.outgoing_ping().unwrap();
4314 }
4315}