Skip to main content

rumqttc/
state.rs

1use super::mqttbytes::v5::{
2    Auth, ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck,
3    PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish,
4    PublishProperties, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason,
5    Unsubscribe,
6};
7use super::mqttbytes::{self, Error as MqttError, QoS};
8use crate::auth::{AuthLifecycle, IncomingAuthEffect};
9use crate::notice::{
10    AuthNoticeError, PublishNoticeTx, PublishResult, SubscribeNoticeTx, TrackedNoticeTx,
11    UnsubscribeNoticeTx,
12};
13use crate::{
14    AuthContext, AuthError, AuthExchangeKind, Authenticator, NoticeFailureReason,
15    PublishNoticeError, TopicAliasPolicy,
16};
17
18use super::{Event, Incoming, Outgoing, Request};
19
20use bytes::Bytes;
21use fixedbitset::FixedBitSet;
22use std::collections::{BTreeMap, HashMap, VecDeque};
23use std::sync::{Arc, Mutex};
24use std::{io, time::Instant};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
27struct PendingAutoTopicAlias {
28    topic: Bytes,
29    alias: u16,
30    previous_topic: Option<Bytes>,
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
34enum AutoTopicAliasAction {
35    Existing { original_topic: Bytes, alias: u16 },
36    New(PendingAutoTopicAlias),
37}
38
39impl AutoTopicAliasAction {
40    const fn pending_alias(&self) -> Option<&PendingAutoTopicAlias> {
41        match self {
42            Self::Existing { .. } => None,
43            Self::New(pending) => Some(pending),
44        }
45    }
46
47    const fn existing_alias(&self) -> Option<u16> {
48        match self {
49            Self::Existing { alias, .. } => Some(*alias),
50            Self::New(_) => None,
51        }
52    }
53
54    fn restore_for_replay(self, publish: &mut Publish) {
55        match self {
56            Self::Existing { original_topic, .. } => {
57                publish.topic = original_topic;
58                MqttState::strip_publish_topic_alias(publish);
59            }
60            Self::New(_) => {
61                MqttState::strip_publish_topic_alias(publish);
62            }
63        }
64    }
65}
66
67/// Errors during state handling
68#[derive(Debug, thiserror::Error)]
69pub enum StateError {
70    /// Io Error while state is passed to network
71    #[error("Io error: {0:?}")]
72    Io(#[from] io::Error),
73    #[error("Conversion error {0:?}")]
74    Coversion(#[from] core::num::TryFromIntError),
75    /// Invalid state for a given operation
76    #[error("Invalid state for a given operation")]
77    InvalidState,
78    /// Received a packet (ack) which isn't asked for
79    #[error("Received unsolicited ack pkid: {0}")]
80    Unsolicited(u16),
81    /// Last pingreq isn't acked
82    #[error("Last pingreq isn't acked")]
83    AwaitPingResp,
84    /// Received a wrong packet while waiting for another packet
85    #[error("Received a wrong packet while waiting for another packet")]
86    WrongPacket,
87    #[error("Timeout while waiting to resolve collision")]
88    CollisionTimeout,
89    #[error("A Subscribe packet must contain atleast one filter")]
90    EmptySubscription,
91    #[error("Mqtt serialization/deserialization error: {0}")]
92    Deserialization(MqttError),
93    #[error(
94        "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'."
95    )]
96    InvalidAlias { alias: u16, max: u16 },
97    #[error(
98        "Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'"
99    )]
100    OutgoingPacketTooLarge { pkt_size: u32, max: u32 },
101    #[error(
102        "Cannot receive packet of size '{pkt_size:?}'. It's greater than the client's maximum packet size of: '{max:?}'"
103    )]
104    IncomingPacketTooLarge { pkt_size: usize, max: usize },
105    #[error("Server sent disconnect with reason `{reason_string:?}` and code '{reason_code:?}' ")]
106    ServerDisconnect {
107        reason_code: DisconnectReasonCode,
108        reason_string: Option<String>,
109    },
110    #[error("Connection failed with reason '{reason:?}' ")]
111    ConnFail { reason: ConnectReturnCode },
112    #[error("Connection closed by peer abruptly")]
113    ConnectionAborted,
114    #[error("Authentication error: {0}")]
115    AuthError(String),
116    #[error("Authenticator not set")]
117    AuthenticatorNotSet,
118}
119
120impl From<mqttbytes::Error> for StateError {
121    fn from(value: MqttError) -> Self {
122        match value {
123            MqttError::OutgoingPacketTooLarge { pkt_size, max } => {
124                Self::OutgoingPacketTooLarge { pkt_size, max }
125            }
126            e => Self::Deserialization(e),
127        }
128    }
129}
130
131/// State of the mqtt connection.
132// Design: Methods will just modify the state of the object without doing any network operations
133// Design: All inflight queues are maintained in a pre initialized vec with index as packet id.
134// This is done for 2 reasons
135// Bad acks or out of order acks aren't O(n) causing cpu spikes
136// Any missing acks from the broker are detected during the next recycled use of packet ids
137#[derive(Debug)]
138pub struct MqttState {
139    /// Status of last ping
140    pub await_pingresp: bool,
141    /// Collision ping count. Collisions stop user requests
142    /// which inturn trigger pings. Multiple pings without
143    /// resolving collisions will result in error
144    pub collision_ping_count: usize,
145    /// Last incoming packet time
146    last_incoming: Instant,
147    /// Last outgoing packet time
148    last_outgoing: Instant,
149    /// Packet id of the last outgoing packet
150    pub(crate) last_pkid: u16,
151    /// Packet id of the last acked publish
152    pub(crate) last_puback: u16,
153    /// Number of outgoing inflight publishes
154    pub(crate) inflight: u16,
155    /// Outgoing `QoS` 1, 2 publishes which aren't acked yet
156    pub(crate) outgoing_pub: Vec<Option<Publish>>,
157    /// Notice handles for outgoing `QoS` 1, 2 publishes
158    pub(crate) outgoing_pub_notice: Vec<Option<PublishNoticeTx>>,
159    /// Packet ids acked by broker while waiting to advance last contiguous ack boundary
160    pub(crate) outgoing_pub_ack: FixedBitSet,
161    /// Packet ids of released `QoS` 2 publishes
162    pub(crate) outgoing_rel: FixedBitSet,
163    /// Notice handles for outgoing `QoS` 2 pubrels
164    pub(crate) outgoing_rel_notice: Vec<Option<PublishNoticeTx>>,
165    /// Packet ids on incoming `QoS` 2 publishes
166    pub(crate) incoming_pub: FixedBitSet,
167    /// Last collision due to broker not acking in order
168    pub collision: Option<Publish>,
169    /// Notice handle for the collision publish
170    pub(crate) collision_notice: Option<PublishNoticeTx>,
171    /// Tracked subscribe requests waiting for `SubAck`
172    pub(crate) tracked_subscribe: BTreeMap<u16, (Subscribe, SubscribeNoticeTx)>,
173    /// Tracked unsubscribe requests waiting for `UnsubAck`
174    pub(crate) tracked_unsubscribe: BTreeMap<u16, (Unsubscribe, UnsubscribeNoticeTx)>,
175    /// Buffered incoming packets
176    pub events: VecDeque<Event>,
177    /// Indicates if acknowledgements should be send immediately
178    pub manual_acks: bool,
179    /// Server-to-client topic aliases scoped to the current network connection.
180    incoming_topic_aliases: HashMap<u16, Bytes>,
181    /// Client-to-server topic aliases scoped to the current network connection.
182    outgoing_topic_aliases: HashMap<u16, Bytes>,
183    /// Automatically assigned client-to-server topic aliases for the current network connection.
184    auto_outgoing_topic_aliases: HashMap<Bytes, u16>,
185    next_auto_topic_alias: Option<u16>,
186    auto_topic_aliases: bool,
187    auto_topic_alias_policy: TopicAliasPolicy,
188    auto_topic_alias_lru: VecDeque<u16>,
189    /// `topic_alias_maximum` RECEIVED via connack packet
190    pub broker_topic_alias_max: u16,
191    /// Maximum number of allowed inflight `QoS1` & `QoS2` requests
192    pub(crate) max_outgoing_inflight: u16,
193    /// Upper limit on the maximum number of allowed inflight `QoS1` & `QoS2` requests
194    max_outgoing_inflight_upper_limit: u16,
195    /// Authentication callback
196    authenticator: Option<Arc<Mutex<dyn Authenticator>>>,
197    /// Authentication lifecycle state.
198    auth: AuthLifecycle,
199}
200
201/// Builder for low-level MQTT 5 protocol state.
202///
203/// Most users should configure clients through [`crate::MqttOptions`] and
204/// construct them with [`crate::Client::builder`] or [`crate::AsyncClient::builder`].
205/// This builder is intended for users driving [`MqttState`] directly.
206#[derive(Debug)]
207pub struct MqttStateBuilder {
208    max_inflight: u16,
209    manual_acks: bool,
210    auto_topic_aliases: bool,
211    auto_topic_alias_policy: TopicAliasPolicy,
212    authenticator: Option<Arc<Mutex<dyn Authenticator>>>,
213    authentication_method: Option<String>,
214}
215
216impl MqttStateBuilder {
217    /// Create a new [`MqttState`] builder.
218    #[must_use]
219    pub const fn new(max_inflight: u16) -> Self {
220        Self {
221            max_inflight,
222            manual_acks: false,
223            auto_topic_aliases: false,
224            auto_topic_alias_policy: TopicAliasPolicy::Monotonic,
225            authenticator: None,
226            authentication_method: None,
227        }
228    }
229
230    /// Set whether incoming publish acknowledgements should be sent manually.
231    #[must_use]
232    pub const fn manual_acks(mut self, manual_acks: bool) -> Self {
233        self.manual_acks = manual_acks;
234        self
235    }
236
237    /// Enable or disable automatic outgoing topic alias assignment.
238    #[must_use]
239    pub const fn auto_topic_aliases(mut self, auto_topic_aliases: bool) -> Self {
240        self.auto_topic_aliases = auto_topic_aliases;
241        self
242    }
243
244    /// Set the policy used for automatic outgoing topic alias assignment.
245    #[must_use]
246    pub const fn topic_alias_policy(mut self, auto_topic_alias_policy: TopicAliasPolicy) -> Self {
247        self.auto_topic_alias_policy = auto_topic_alias_policy;
248        self
249    }
250
251    /// Set the Authentication Method used in the CONNECT packet.
252    #[must_use]
253    pub fn authentication_method(mut self, authentication_method: Option<String>) -> Self {
254        self.authentication_method = authentication_method;
255        self
256    }
257
258    /// Set the authentication callback used for MQTT 5 enhanced authentication.
259    #[must_use]
260    pub fn authenticator(mut self, authenticator: Arc<Mutex<dyn Authenticator>>) -> Self {
261        self.authenticator = Some(authenticator);
262        self
263    }
264
265    /// Set the authentication callback used for MQTT 5 enhanced authentication.
266    #[must_use]
267    pub fn auth_manager(mut self, authenticator: Arc<Mutex<dyn Authenticator>>) -> Self {
268        self.authenticator = Some(authenticator);
269        self
270    }
271
272    /// Build the configured [`MqttState`].
273    #[must_use]
274    pub fn build(self) -> MqttState {
275        MqttState::new_internal(
276            self.max_inflight,
277            self.manual_acks,
278            self.auto_topic_aliases,
279            self.auto_topic_alias_policy,
280            self.authentication_method,
281            self.authenticator,
282        )
283    }
284}
285
286impl MqttState {
287    const fn initial_events_capacity() -> usize {
288        128
289    }
290
291    fn outgoing_tracking_len(max_inflight: u16) -> usize {
292        usize::from(max_inflight) + 1
293    }
294
295    fn new_notice_slots_with_len(size: usize) -> Vec<Option<PublishNoticeTx>> {
296        std::iter::repeat_with(|| None).take(size).collect()
297    }
298
299    fn new_notice_slots(max_inflight: u16) -> Vec<Option<PublishNoticeTx>> {
300        Self::new_notice_slots_with_len(Self::outgoing_tracking_len(max_inflight))
301    }
302
303    fn clean_pending_capacity(&self) -> usize {
304        self.outgoing_pub
305            .iter()
306            .filter(|publish| publish.is_some())
307            .count()
308            + self.outgoing_rel.ones().count()
309            + self.tracked_subscribe.len()
310            + self.tracked_unsubscribe.len()
311    }
312
313    const fn next_publish_pkid_after(&self, pkid: u16) -> u16 {
314        if pkid >= self.max_outgoing_inflight {
315            1
316        } else {
317            pkid + 1
318        }
319    }
320
321    fn packet_identifier_in_use(&self, pkid: u16) -> bool {
322        let index = usize::from(pkid);
323        self.outgoing_pub.get(index).is_some_and(Option::is_some)
324            || self.outgoing_rel.contains(index)
325            || self.tracked_subscribe.contains_key(&pkid)
326            || self.tracked_unsubscribe.contains_key(&pkid)
327    }
328
329    pub(crate) fn can_send_publish(&self, publish: &Publish) -> bool {
330        if publish.qos == QoS::AtMostOnce {
331            return true;
332        }
333
334        if self.inflight >= self.max_outgoing_inflight || self.collision.is_some() {
335            return false;
336        }
337
338        if publish.pkid == 0 {
339            return self.next_publish_pkid().is_some();
340        }
341
342        publish.pkid != 0
343            && publish.pkid <= self.max_outgoing_inflight
344            && !self.packet_identifier_in_use(publish.pkid)
345    }
346
347    pub(crate) fn control_packet_identifier_available(&self) -> bool {
348        (1..=u16::MAX).any(|pkid| !self.packet_identifier_in_use(pkid))
349    }
350
351    /// Create a builder for low-level MQTT 5 protocol state.
352    #[must_use]
353    pub const fn builder(max_inflight: u16) -> MqttStateBuilder {
354        MqttStateBuilder::new(max_inflight)
355    }
356
357    /// Creates new mqtt state. Same state should be used during a
358    /// connection for persistent sessions while new state should
359    /// instantiated for clean sessions.
360    #[must_use]
361    pub(crate) fn new_internal(
362        max_inflight: u16,
363        manual_acks: bool,
364        auto_topic_aliases: bool,
365        auto_topic_alias_policy: TopicAliasPolicy,
366        authentication_method: Option<String>,
367        authenticator: Option<Arc<Mutex<dyn Authenticator>>>,
368    ) -> Self {
369        Self {
370            await_pingresp: false,
371            collision_ping_count: 0,
372            last_incoming: Instant::now(),
373            last_outgoing: Instant::now(),
374            last_pkid: 0,
375            last_puback: 0,
376            inflight: 0,
377            // index 0 is wasted as 0 is not a valid packet id
378            outgoing_pub: vec![None; max_inflight as usize + 1],
379            outgoing_pub_notice: Self::new_notice_slots(max_inflight),
380            outgoing_pub_ack: FixedBitSet::with_capacity(max_inflight as usize + 1),
381            outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1),
382            outgoing_rel_notice: Self::new_notice_slots(max_inflight),
383            incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1),
384            collision: None,
385            collision_notice: None,
386            tracked_subscribe: BTreeMap::new(),
387            tracked_unsubscribe: BTreeMap::new(),
388            events: VecDeque::with_capacity(Self::initial_events_capacity()),
389            manual_acks,
390            incoming_topic_aliases: HashMap::new(),
391            outgoing_topic_aliases: HashMap::new(),
392            auto_outgoing_topic_aliases: HashMap::new(),
393            next_auto_topic_alias: Some(1),
394            auto_topic_aliases,
395            auto_topic_alias_policy,
396            auto_topic_alias_lru: VecDeque::new(),
397            // Set via CONNACK
398            broker_topic_alias_max: 0,
399            max_outgoing_inflight: max_inflight,
400            max_outgoing_inflight_upper_limit: max_inflight,
401            authenticator,
402            auth: AuthLifecycle::new(authentication_method),
403        }
404    }
405
406    /// Set the Authentication Method used in the CONNECT packet for this state.
407    ///
408    /// Low-level users that send or process MQTT 5 AUTH packets through
409    /// [`MqttState`] must keep this value in sync with the CONNECT packet's
410    /// Authentication Method property. The event loop updates it automatically
411    /// before each CONNECT attempt.
412    pub fn set_authentication_method(&mut self, authentication_method: Option<String>) {
413        self.auth.set_method(authentication_method);
414    }
415
416    pub(crate) fn begin_authentication_connect(
417        &mut self,
418        authentication_method: Option<String>,
419    ) -> Result<Option<crate::mqttbytes::v5::AuthProperties>, StateError> {
420        self.auth
421            .begin_connect(authentication_method, &mut self.events);
422        let Some(method) = self.auth.method().map(str::to_owned) else {
423            return Ok(None);
424        };
425        let Some(authenticator) = self.authenticator.clone() else {
426            return Ok(None);
427        };
428        let context = AuthContext {
429            kind: AuthExchangeKind::InitialConnect,
430            method: &method,
431        };
432        let start_result = authenticator.lock().unwrap().start(context);
433        let properties = match start_result {
434            Ok(properties) => properties,
435            Err(err) => return Err(self.fail_authenticator(&err)),
436        };
437        properties
438            .map(|properties| crate::auth::normalize_auth_properties(&method, Some(properties)))
439            .transpose()
440    }
441
442    pub(crate) fn validate_successful_connack_authentication_method(
443        &self,
444        connack: &ConnAck,
445    ) -> Result<(), StateError> {
446        self.auth.validate_successful_connack(connack)
447    }
448
449    fn ensure_outgoing_tracking_capacity(&mut self, target_len: usize) {
450        if self.outgoing_pub.len() < target_len {
451            self.outgoing_pub.resize_with(target_len, || None);
452        }
453
454        if self.outgoing_pub_notice.len() < target_len {
455            self.outgoing_pub_notice.resize_with(target_len, || None);
456        }
457
458        if self.outgoing_rel_notice.len() < target_len {
459            self.outgoing_rel_notice.resize_with(target_len, || None);
460        }
461
462        if self.outgoing_pub_ack.len() < target_len {
463            self.outgoing_pub_ack.grow(target_len);
464        }
465
466        if self.outgoing_rel.len() < target_len {
467            self.outgoing_rel.grow(target_len);
468        }
469    }
470
471    pub(crate) fn outbound_requests_drained(&self) -> bool {
472        self.inflight == 0
473            && self.collision.is_none()
474            && self.collision_notice.is_none()
475            && self.tracked_subscribe.is_empty()
476            && self.tracked_unsubscribe.is_empty()
477            && self.outgoing_pub.iter().all(Option::is_none)
478            && self.outgoing_pub_notice.iter().all(Option::is_none)
479            && self.outgoing_rel_notice.iter().all(Option::is_none)
480            && self.outgoing_pub_ack.ones().next().is_none()
481            && self.outgoing_rel.ones().next().is_none()
482    }
483
484    fn maybe_shrink_outgoing_tracking_capacity(&mut self, target_len: usize, pending_empty: bool) {
485        if !pending_empty
486            || self.outgoing_pub.len() <= target_len
487            || !self.outbound_requests_drained()
488        {
489            return;
490        }
491
492        self.outgoing_pub.truncate(target_len);
493        self.outgoing_pub_notice.truncate(target_len);
494        self.outgoing_rel_notice.truncate(target_len);
495        self.outgoing_pub_ack = FixedBitSet::with_capacity(target_len);
496        self.outgoing_rel = FixedBitSet::with_capacity(target_len);
497        // Ensure future packet id reuse starts from the beginning of the new range.
498        self.last_pkid = 0;
499        self.last_puback = 0;
500    }
501
502    pub(crate) fn reconcile_outgoing_tracking_capacity(&mut self, pending_empty: bool) {
503        let target_len = Self::outgoing_tracking_len(self.max_outgoing_inflight);
504        self.ensure_outgoing_tracking_capacity(target_len);
505        self.maybe_shrink_outgoing_tracking_capacity(target_len, pending_empty);
506    }
507
508    pub(crate) fn reset_connection_scoped_state(&mut self) {
509        self.incoming_topic_aliases.clear();
510        self.outgoing_topic_aliases.clear();
511        self.auto_outgoing_topic_aliases.clear();
512        self.next_auto_topic_alias = Some(1);
513        self.auto_topic_alias_lru.clear();
514        self.broker_topic_alias_max = 0;
515    }
516
517    pub(crate) fn replay_topic_aliases(&self) -> HashMap<u16, Bytes> {
518        self.outgoing_topic_aliases.clone()
519    }
520
521    pub(crate) fn prepare_publish_for_replay_with_aliases(
522        publish: &mut Publish,
523        topic_aliases: &mut HashMap<u16, Bytes>,
524    ) -> Result<(), PublishNoticeError> {
525        let Some(alias) = Self::publish_topic_alias(publish) else {
526            return Ok(());
527        };
528
529        if !publish.topic.is_empty() {
530            topic_aliases.insert(alias, publish.topic.clone());
531            Self::strip_publish_topic_alias(publish);
532            return Ok(());
533        }
534
535        if let Some(topic) = topic_aliases.get(&alias) {
536            topic.clone_into(&mut publish.topic);
537            topic_aliases.insert(alias, publish.topic.clone());
538            Self::strip_publish_topic_alias(publish);
539            return Ok(());
540        }
541
542        Err(PublishNoticeError::TopicAliasReplayUnavailable(alias))
543    }
544
545    pub(crate) fn prepare_request_for_replay_with_aliases(
546        request: &mut Request,
547        topic_aliases: &mut HashMap<u16, Bytes>,
548    ) -> Result<(), PublishNoticeError> {
549        if let Request::Publish(publish) = request {
550            Self::prepare_publish_for_replay_with_aliases(publish, topic_aliases)?;
551        }
552
553        Ok(())
554    }
555
556    pub(crate) fn clean_with_notices(&mut self) -> Vec<(Request, Option<TrackedNoticeTx>)> {
557        let mut pending = Vec::with_capacity(self.clean_pending_capacity());
558        let (first_half, second_half) = self
559            .outgoing_pub
560            .split_at_mut(self.last_puback as usize + 1);
561        let (notice_first_half, notice_second_half) = self
562            .outgoing_pub_notice
563            .split_at_mut(self.last_puback as usize + 1);
564
565        for (publish, notice) in second_half
566            .iter_mut()
567            .zip(notice_second_half.iter_mut())
568            .chain(first_half.iter_mut().zip(notice_first_half.iter_mut()))
569        {
570            if let Some(publish) = publish.take() {
571                let request = Request::Publish(publish);
572                pending.push((request, notice.take().map(TrackedNoticeTx::Publish)));
573            } else {
574                _ = notice.take();
575            }
576        }
577
578        // remove and collect pending releases
579        for pkid in self.outgoing_rel.ones() {
580            let pkid = u16::try_from(pkid).expect("fixedbitset index always fits in u16");
581            let request = Request::PubRel(PubRel::new(pkid, None));
582            pending.push((
583                request,
584                self.outgoing_rel_notice[pkid as usize]
585                    .take()
586                    .map(TrackedNoticeTx::Publish),
587            ));
588        }
589        self.outgoing_rel.clear();
590        self.outgoing_pub_ack.clear();
591
592        for (pkid, (mut subscribe, notice)) in std::mem::take(&mut self.tracked_subscribe) {
593            subscribe.pkid = pkid;
594            pending.push((
595                Request::Subscribe(subscribe),
596                Some(TrackedNoticeTx::Subscribe(notice)),
597            ));
598        }
599        for (pkid, (mut unsubscribe, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
600            unsubscribe.pkid = pkid;
601            pending.push((
602                Request::Unsubscribe(unsubscribe),
603                Some(TrackedNoticeTx::Unsubscribe(notice)),
604            ));
605        }
606
607        // remove packed ids of incoming qos2 publishes
608        self.incoming_pub.clear();
609
610        self.await_pingresp = false;
611        self.collision_ping_count = 0;
612        self.inflight = 0;
613        pending
614    }
615
616    /// Returns inflight outgoing packets and clears internal queues.
617    ///
618    /// MQTT 5 topic aliases are scoped to a single network connection. During
619    /// cleanup, replayed publishes that only contain a topic alias are repaired
620    /// with the remembered topic when possible. If the topic cannot be
621    /// recovered, the publish is omitted because replaying it on a new
622    /// connection would be protocol-invalid.
623    pub fn clean(&mut self) -> Vec<Request> {
624        let mut replay_topic_aliases = self.replay_topic_aliases();
625        let mut pending = Vec::with_capacity(self.clean_pending_capacity());
626
627        for (mut request, _) in self.clean_with_notices() {
628            if Self::prepare_request_for_replay_with_aliases(
629                &mut request,
630                &mut replay_topic_aliases,
631            )
632            .is_ok()
633            {
634                pending.push(request);
635            }
636        }
637
638        self.reset_connection_scoped_state();
639        pending
640    }
641
642    pub const fn inflight(&self) -> u16 {
643        self.inflight
644    }
645
646    pub fn tracked_subscribe_len(&self) -> usize {
647        self.tracked_subscribe.len()
648    }
649
650    pub fn tracked_unsubscribe_len(&self) -> usize {
651        self.tracked_unsubscribe.len()
652    }
653
654    pub fn tracked_requests_is_empty(&self) -> bool {
655        self.tracked_subscribe.is_empty() && self.tracked_unsubscribe.is_empty()
656    }
657
658    pub fn drain_tracked_requests_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
659        let mut drained = 0;
660        for (_, (_, notice)) in std::mem::take(&mut self.tracked_subscribe) {
661            drained += 1;
662            notice.error(reason.subscribe_error());
663        }
664        for (_, (_, notice)) in std::mem::take(&mut self.tracked_unsubscribe) {
665            drained += 1;
666            notice.error(reason.unsubscribe_error());
667        }
668
669        drained
670    }
671
672    pub(crate) fn fail_pending_notices(&mut self) {
673        for notice in &mut self.outgoing_pub_notice {
674            if let Some(tx) = notice.take() {
675                tx.error(PublishNoticeError::SessionReset);
676            }
677        }
678
679        for notice in &mut self.outgoing_rel_notice {
680            if let Some(tx) = notice.take() {
681                tx.error(PublishNoticeError::SessionReset);
682            }
683        }
684
685        if let Some(tx) = self.collision_notice.take() {
686            tx.error(PublishNoticeError::SessionReset);
687        }
688        self.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
689        self.clear_collision();
690    }
691
692    pub(crate) fn fail_auth_exchange_due_to_session_reset(&mut self) {
693        self.fail_auth_exchange(
694            AuthNoticeError::SessionReset,
695            AuthError::Failed("authentication exchange was reset with the session".to_owned()),
696        );
697    }
698
699    pub(crate) fn fail_reauth_exchange_due_to_session_reset(&mut self) {
700        let Some((method, notice_error)) = self
701            .auth
702            .reset_reauth(AuthNoticeError::SessionReset, &mut self.events)
703        else {
704            return;
705        };
706
707        if let Some(authenticator) = self.authenticator.clone() {
708            authenticator.lock().unwrap().failure(
709                AuthContext {
710                    kind: AuthExchangeKind::Reauthentication,
711                    method: &method,
712                },
713                AuthError::Failed(notice_error.to_string()),
714            );
715        }
716    }
717
718    pub(crate) fn fail_auth_exchange_due_to_connection_closed(&mut self) {
719        self.fail_auth_exchange(
720            AuthNoticeError::ConnectionClosed,
721            AuthError::Failed("connection closed before authentication completed".to_owned()),
722        );
723    }
724
725    pub(crate) fn fail_auth_exchange_due_to_client_disconnect(&mut self) {
726        self.fail_auth_exchange(
727            AuthNoticeError::ConnectionClosed,
728            AuthError::Failed("authentication aborted by client disconnect".to_owned()),
729        );
730    }
731
732    /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should
733    /// be put on to the network by the eventloop
734    ///
735    /// # Errors
736    ///
737    /// Returns an error if the outgoing request is invalid for the current
738    /// client state.
739    pub fn handle_outgoing_packet(
740        &mut self,
741        request: Request,
742    ) -> Result<Option<Packet>, StateError> {
743        let (packet, flush_notice) = self.handle_outgoing_packet_with_notice(request, None)?;
744        if let Some(tx) = flush_notice {
745            tx.success(PublishResult::Qos0Flushed);
746        }
747        Ok(packet)
748    }
749
750    pub(crate) fn handle_outgoing_packet_with_notice(
751        &mut self,
752        request: Request,
753        notice: Option<TrackedNoticeTx>,
754    ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
755        let result = match request {
756            Request::Publish(publish) => {
757                let publish_notice = match notice {
758                    Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
759                    Some(
760                        TrackedNoticeTx::Subscribe(_)
761                        | TrackedNoticeTx::Unsubscribe(_)
762                        | TrackedNoticeTx::Auth(_),
763                    )
764                    | None => None,
765                };
766                self.outgoing_publish_with_notice(publish, publish_notice)?
767            }
768            Request::PubRel(pubrel) => {
769                let publish_notice = match notice {
770                    Some(TrackedNoticeTx::Publish(notice)) => Some(notice),
771                    Some(
772                        TrackedNoticeTx::Subscribe(_)
773                        | TrackedNoticeTx::Unsubscribe(_)
774                        | TrackedNoticeTx::Auth(_),
775                    )
776                    | None => None,
777                };
778                self.outgoing_pubrel_with_notice(pubrel, publish_notice)
779            }
780            Request::Subscribe(subscribe) => {
781                let request_notice = match notice {
782                    Some(TrackedNoticeTx::Subscribe(notice)) => Some(notice),
783                    Some(
784                        TrackedNoticeTx::Publish(_)
785                        | TrackedNoticeTx::Unsubscribe(_)
786                        | TrackedNoticeTx::Auth(_),
787                    )
788                    | None => None,
789                };
790                (self.outgoing_subscribe(subscribe, request_notice)?, None)
791            }
792            Request::Unsubscribe(unsubscribe) => {
793                let request_notice = match notice {
794                    Some(TrackedNoticeTx::Unsubscribe(notice)) => Some(notice),
795                    Some(
796                        TrackedNoticeTx::Publish(_)
797                        | TrackedNoticeTx::Subscribe(_)
798                        | TrackedNoticeTx::Auth(_),
799                    )
800                    | None => None,
801                };
802                (
803                    Some(self.outgoing_unsubscribe(unsubscribe, request_notice)?),
804                    None,
805                )
806            }
807            Request::PingReq => (self.outgoing_ping()?, None),
808            Request::Disconnect(_) | Request::DisconnectWithTimeout(_, _) => {
809                unreachable!("graceful disconnect requests are handled by the event loop")
810            }
811            Request::DisconnectNow(disconnect) => {
812                (Some(self.outgoing_disconnect(disconnect)), None)
813            }
814            Request::PubAck(puback) => (Some(self.outgoing_puback(puback)), None),
815            Request::PubRec(pubrec) => (Some(self.outgoing_pubrec(pubrec)), None),
816            Request::Auth(auth) => {
817                let auth_notice = match notice {
818                    Some(TrackedNoticeTx::Auth(notice)) => Some(notice),
819                    Some(
820                        TrackedNoticeTx::Publish(_)
821                        | TrackedNoticeTx::Subscribe(_)
822                        | TrackedNoticeTx::Unsubscribe(_),
823                    )
824                    | None => None,
825                };
826                (Some(self.outgoing_auth(auth, auth_notice)?), None)
827            }
828            _ => unimplemented!(),
829        };
830
831        self.last_outgoing = Instant::now();
832        Ok(result)
833    }
834
835    /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the
836    /// user to consume and `Packet` which for the eventloop to put on the network
837    /// E.g For incoming `QoS1` publish packet, this method returns (Publish, Puback). Publish packet will
838    /// be forwarded to user and Pubck packet will be written to network
839    ///
840    /// # Errors
841    ///
842    /// Returns an error if the incoming packet is invalid for the current
843    /// client state.
844    pub fn handle_incoming_packet(
845        &mut self,
846        mut packet: Incoming,
847    ) -> Result<Option<Packet>, StateError> {
848        let events_len_before = self.events.len();
849        let outgoing = match &mut packet {
850            Incoming::PingResp(_) => Ok(self.handle_incoming_pingresp()),
851            Incoming::Publish(publish) => self.handle_incoming_publish(publish),
852            Incoming::SubAck(suback) => Ok(self.handle_incoming_suback(suback)),
853            Incoming::UnsubAck(unsuback) => Ok(self.handle_incoming_unsuback(unsuback)),
854            Incoming::PubAck(puback) => self.handle_incoming_puback(puback),
855            Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec),
856            Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel),
857            Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp),
858            Incoming::ConnAck(connack) => self.handle_incoming_connack(connack),
859            Incoming::Disconnect(disconn) => Self::handle_incoming_disconn(disconn),
860            Incoming::Auth(auth) => self.handle_incoming_auth(auth),
861            _ => {
862                error!("Invalid incoming packet = {packet:?}");
863                Err(StateError::WrongPacket)
864            }
865        };
866
867        let skip_incoming_event = matches!(
868            (&packet, &outgoing),
869            (Incoming::Publish(_), Ok(Some(Packet::Disconnect(_))))
870        );
871
872        // Preserve original event ordering (Incoming first, derived Outgoing next)
873        // without cloning the incoming packet.
874        if !skip_incoming_event {
875            self.events
876                .insert(events_len_before, Event::Incoming(packet));
877        }
878
879        let outgoing = outgoing?;
880        self.last_incoming = Instant::now();
881        Ok(outgoing)
882    }
883
884    /// Builds the protocol-error disconnect response for the current state.
885    ///
886    /// # Errors
887    ///
888    /// Returns an error if the disconnect packet cannot be produced for the
889    /// current state.
890    pub fn handle_protocol_error(&mut self) -> Result<Option<Packet>, StateError> {
891        // send DISCONNECT packet with REASON_CODE 0x82
892        let disconnect = Disconnect::new(DisconnectReasonCode::ProtocolError);
893        Ok(Some(self.outgoing_disconnect(disconnect)))
894    }
895
896    pub fn clear_collision(&mut self) {
897        self.collision = None;
898        self.collision_notice = None;
899        self.collision_ping_count = 0;
900    }
901
902    fn handle_incoming_suback(&mut self, suback: &SubAck) -> Option<Packet> {
903        for reason in &suback.return_codes {
904            match reason {
905                SubscribeReasonCode::Success(qos) => {
906                    debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos);
907                }
908                _ => {
909                    warn!("SubAck Pkid = {:?}, Reason = {:?}", suback.pkid, reason);
910                }
911            }
912        }
913        if let Some((_, notice)) = self.tracked_subscribe.remove(&suback.pkid) {
914            notice.success(suback.clone());
915        }
916        None
917    }
918
919    fn handle_incoming_unsuback(&mut self, unsuback: &UnsubAck) -> Option<Packet> {
920        for reason in &unsuback.reasons {
921            if reason != &UnsubAckReason::Success {
922                warn!("UnsubAck Pkid = {:?}, Reason = {:?}", unsuback.pkid, reason);
923            }
924        }
925        if let Some((_, notice)) = self.tracked_unsubscribe.remove(&unsuback.pkid) {
926            notice.success(unsuback.clone());
927        }
928        None
929    }
930
931    fn handle_incoming_connack(&mut self, connack: &ConnAck) -> Result<Option<Packet>, StateError> {
932        if connack.code != ConnectReturnCode::Success {
933            return Err(StateError::ConnFail {
934                reason: connack.code,
935            });
936        }
937
938        self.auth.validate_successful_connack(connack)?;
939        self.reset_connection_scoped_state();
940        self.auth.complete_initial_connack(&mut self.events);
941
942        if let Some(props) = &connack.properties
943            && let Some(topic_alias_max) = props.topic_alias_max
944        {
945            self.broker_topic_alias_max = topic_alias_max;
946        }
947
948        if let Some(props) = &connack.properties
949            && let Some(max_inflight) = props.receive_max
950        {
951            self.max_outgoing_inflight = max_inflight.min(self.max_outgoing_inflight_upper_limit);
952            // Shrinking depends on pending retransmission state in eventloop.
953            // Grow immediately so incoming/outgoing packet-id indexed tracking stays valid.
954            self.reconcile_outgoing_tracking_capacity(false);
955        }
956        Ok(None)
957    }
958
959    fn handle_incoming_disconn(disconn: &Disconnect) -> Result<Option<Packet>, StateError> {
960        let reason_code = disconn.reason_code;
961        let reason_string = disconn
962            .properties
963            .as_ref()
964            .and_then(|props| props.reason_string.clone());
965        Err(StateError::ServerDisconnect {
966            reason_code,
967            reason_string,
968        })
969    }
970
971    /// Results in a publish notification in all the `QoS` cases. Replys with an ack
972    /// in case of `QoS1` and Replys rec in case of `QoS` while also storing the message
973    fn handle_incoming_publish(
974        &mut self,
975        publish: &mut Publish,
976    ) -> Result<Option<Packet>, StateError> {
977        let qos = publish.qos;
978
979        let topic_alias = publish
980            .properties
981            .as_ref()
982            .and_then(|props| props.topic_alias);
983
984        if !publish.topic.is_empty() {
985            if let Some(alias) = topic_alias {
986                self.incoming_topic_aliases
987                    .insert(alias, publish.topic.clone());
988            }
989        } else if let Some(alias) = topic_alias
990            && let Some(topic) = self.incoming_topic_aliases.get(&alias)
991        {
992            topic.clone_into(&mut publish.topic);
993        } else if topic_alias.is_some() {
994            return self.handle_protocol_error();
995        }
996
997        match qos {
998            QoS::AtMostOnce => Ok(None),
999            QoS::AtLeastOnce => {
1000                if !self.manual_acks {
1001                    let puback = PubAck::new(publish.pkid, None);
1002                    return Ok(Some(self.outgoing_puback(puback)));
1003                }
1004                Ok(None)
1005            }
1006            QoS::ExactlyOnce => {
1007                let pkid = publish.pkid;
1008                self.incoming_pub.insert(pkid as usize);
1009
1010                if !self.manual_acks {
1011                    let pubrec = PubRec::new(pkid, None);
1012                    return Ok(Some(self.outgoing_pubrec(pubrec)));
1013                }
1014                Ok(None)
1015            }
1016        }
1017    }
1018
1019    fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<Option<Packet>, StateError> {
1020        let publish = self
1021            .outgoing_pub
1022            .get_mut(puback.pkid as usize)
1023            .ok_or(StateError::Unsolicited(puback.pkid))?;
1024
1025        if publish.take().is_none() {
1026            error!("Unsolicited puback packet: {:?}", puback.pkid);
1027            return Err(StateError::Unsolicited(puback.pkid));
1028        }
1029        self.mark_outgoing_packet_id_complete(puback.pkid);
1030
1031        let notice = self.outgoing_pub_notice[puback.pkid as usize].take();
1032        self.inflight -= 1;
1033
1034        if puback.reason != PubAckReason::Success
1035            && puback.reason != PubAckReason::NoMatchingSubscribers
1036        {
1037            warn!(
1038                "PubAck Pkid = {:?}, reason: {:?}",
1039                puback.pkid, puback.reason
1040            );
1041        }
1042        if let Some(tx) = notice {
1043            tx.success(PublishResult::Qos1(puback.clone()));
1044        }
1045
1046        Ok(self.replay_collision_publish(puback.pkid))
1047    }
1048
1049    fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<Option<Packet>, StateError> {
1050        let publish = self
1051            .outgoing_pub
1052            .get_mut(pubrec.pkid as usize)
1053            .ok_or(StateError::Unsolicited(pubrec.pkid))?;
1054
1055        if publish.take().is_none() {
1056            error!("Unsolicited pubrec packet: {:?}", pubrec.pkid);
1057            return Err(StateError::Unsolicited(pubrec.pkid));
1058        }
1059
1060        let notice = self.outgoing_pub_notice[pubrec.pkid as usize].take();
1061        if pubrec.reason != PubRecReason::Success
1062            && pubrec.reason != PubRecReason::NoMatchingSubscribers
1063        {
1064            warn!(
1065                "PubRec Pkid = {:?}, reason: {:?}",
1066                pubrec.pkid, pubrec.reason
1067            );
1068            if let Some(tx) = notice {
1069                tx.success(PublishResult::Qos2PubRecRejected(pubrec.clone()));
1070            }
1071            self.mark_outgoing_packet_id_complete(pubrec.pkid);
1072            self.inflight -= 1;
1073            return Ok(self.replay_collision_publish(pubrec.pkid));
1074        }
1075
1076        // NOTE: Inflight - 1 for qos2 in comp
1077        self.outgoing_rel.insert(pubrec.pkid as usize);
1078        self.outgoing_rel_notice[pubrec.pkid as usize] = notice;
1079        let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
1080        self.events.push_back(event);
1081
1082        Ok(Some(Packet::PubRel(PubRel::new(pubrec.pkid, None))))
1083    }
1084
1085    fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<Option<Packet>, StateError> {
1086        if !self.incoming_pub.contains(pubrel.pkid as usize) {
1087            error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
1088            return Err(StateError::Unsolicited(pubrel.pkid));
1089        }
1090        self.incoming_pub.set(pubrel.pkid as usize, false);
1091
1092        if pubrel.reason != PubRelReason::Success {
1093            warn!(
1094                "PubRel Pkid = {:?}, reason: {:?}",
1095                pubrel.pkid, pubrel.reason
1096            );
1097            return Ok(None);
1098        }
1099
1100        let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
1101        self.events.push_back(event);
1102
1103        Ok(Some(Packet::PubComp(PubComp::new(pubrel.pkid, None))))
1104    }
1105
1106    fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<Option<Packet>, StateError> {
1107        if !self.outgoing_rel.contains(pubcomp.pkid as usize) {
1108            error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
1109            return Err(StateError::Unsolicited(pubcomp.pkid));
1110        }
1111        self.outgoing_rel.set(pubcomp.pkid as usize, false);
1112        let notice = self.outgoing_rel_notice[pubcomp.pkid as usize].take();
1113        self.mark_outgoing_packet_id_complete(pubcomp.pkid);
1114        self.inflight -= 1;
1115
1116        if pubcomp.reason != PubCompReason::Success {
1117            warn!(
1118                "PubComp Pkid = {:?}, reason: {:?}",
1119                pubcomp.pkid, pubcomp.reason
1120            );
1121        }
1122        if let Some(tx) = notice {
1123            tx.success(PublishResult::Qos2Completed(pubcomp.clone()));
1124        }
1125
1126        Ok(self.replay_collision_publish(pubcomp.pkid))
1127    }
1128
1129    fn replay_collision_publish(&mut self, pkid: u16) -> Option<Packet> {
1130        self.check_collision(pkid).map(|(publish, notice)| {
1131            let pkid = publish.pkid;
1132            let replay_publish = self.publish_for_replay_tracking(&publish);
1133            self.outgoing_pub[pkid as usize] = Some(replay_publish);
1134            self.outgoing_pub_notice[pkid as usize] = notice;
1135            self.inflight += 1;
1136            self.record_outgoing_topic_alias(&publish);
1137
1138            let event = Event::Outgoing(Outgoing::Publish(pkid));
1139            self.events.push_back(event);
1140            self.collision_ping_count = 0;
1141
1142            Packet::Publish(publish)
1143        })
1144    }
1145
1146    const fn handle_incoming_pingresp(&mut self) -> Option<Packet> {
1147        self.await_pingresp = false;
1148        None
1149    }
1150
1151    fn handle_incoming_auth(&mut self, auth: &Auth) -> Result<Option<Packet>, StateError> {
1152        let effect = match self.auth.incoming_auth(auth, &mut self.events) {
1153            Ok(effect) => effect,
1154            Err(err @ StateError::Deserialization(mqttbytes::Error::ProtocolError)) => {
1155                self.fail_auth_exchange(
1156                    AuthNoticeError::ProtocolError,
1157                    AuthError::Failed("authentication protocol error".to_owned()),
1158                );
1159                return Err(err);
1160            }
1161            Err(err) => return Err(err),
1162        };
1163
1164        match effect {
1165            IncomingAuthEffect::Success { kind, method } => {
1166                if let Some(authenticator) = self.authenticator.clone() {
1167                    let context = AuthContext {
1168                        kind,
1169                        method: &method,
1170                    };
1171                    let auth_result = authenticator
1172                        .lock()
1173                        .unwrap()
1174                        .success(context, auth.properties.clone());
1175                    if let Err(err) = auth_result {
1176                        return Err(self.fail_authenticator(&err));
1177                    }
1178                }
1179                self.auth.complete_success(kind, method, &mut self.events);
1180                Ok(None)
1181            }
1182            IncomingAuthEffect::Continue { kind } => {
1183                let authenticator = self
1184                    .authenticator
1185                    .clone()
1186                    .ok_or(StateError::AuthenticatorNotSet)?;
1187                let method = auth
1188                    .properties
1189                    .as_ref()
1190                    .and_then(|props| props.method.as_deref())
1191                    .unwrap_or_default();
1192                let context = AuthContext { kind, method };
1193                let continue_result = authenticator
1194                    .lock()
1195                    .unwrap()
1196                    .continue_auth(context, auth.properties.clone());
1197                let action = match continue_result {
1198                    Ok(action) => action,
1199                    Err(err) => return Err(self.fail_authenticator(&err)),
1200                };
1201
1202                let out_auth_props = match action.into_continue_properties() {
1203                    Ok(properties) => properties,
1204                    Err(err) => return Err(self.fail_authenticator(&err)),
1205                };
1206                let client_auth = self.auth.outgoing_continue(out_auth_props)?;
1207                Ok(Some(self.outgoing_auth_packet(client_auth)))
1208            }
1209        }
1210    }
1211
1212    fn fail_authenticator(&mut self, error: &AuthError) -> StateError {
1213        self.fail_auth_exchange(
1214            AuthNoticeError::AuthenticationFailed(error.to_string()),
1215            error.clone(),
1216        );
1217        StateError::AuthError(error.to_string())
1218    }
1219
1220    fn fail_auth_exchange(&mut self, notice_error: AuthNoticeError, callback_error: AuthError) {
1221        if let Some((kind, method)) = self.auth.active_exchange()
1222            && let Some(authenticator) = self.authenticator.clone()
1223        {
1224            authenticator.lock().unwrap().failure(
1225                AuthContext {
1226                    kind,
1227                    method: &method,
1228                },
1229                callback_error,
1230            );
1231        }
1232        self.auth.reset(notice_error, &mut self.events);
1233    }
1234
1235    /// Adds next packet identifier to `QoS` 1 and 2 publish packets and returns
1236    /// it by wrapping the publish in a packet.
1237    #[cfg(test)]
1238    fn outgoing_publish(&mut self, publish: Publish) -> Result<Option<Packet>, StateError> {
1239        let (packet, flush_notice) = self.outgoing_publish_with_notice(publish, None)?;
1240        if let Some(tx) = flush_notice {
1241            tx.success(PublishResult::Qos0Flushed);
1242        }
1243        Ok(packet)
1244    }
1245
1246    fn outgoing_publish_with_notice(
1247        &mut self,
1248        mut publish: Publish,
1249        notice: Option<PublishNoticeTx>,
1250    ) -> Result<(Option<Packet>, Option<PublishNoticeTx>), StateError> {
1251        let mut notice = notice;
1252        let auto_topic_alias_action = self.apply_auto_topic_alias(&mut publish);
1253        self.validate_outgoing_topic_alias(&publish)?;
1254
1255        if publish.qos != QoS::AtMostOnce {
1256            if publish.pkid == 0 {
1257                publish.pkid = self.next_pkid();
1258            }
1259
1260            let pkid = publish.pkid;
1261            if self
1262                .outgoing_pub
1263                .get(publish.pkid as usize)
1264                .ok_or(StateError::Unsolicited(publish.pkid))?
1265                .is_some()
1266            {
1267                info!("Collision on packet id = {:?}", publish.pkid);
1268                if let Some(action) = auto_topic_alias_action {
1269                    action.restore_for_replay(&mut publish);
1270                }
1271                self.collision = Some(publish);
1272                self.collision_notice = notice.take();
1273                let event = Event::Outgoing(Outgoing::AwaitAck(pkid));
1274                self.events.push_back(event);
1275                return Ok((None, None));
1276            }
1277
1278            // if there is an existing publish at this pkid, this implies that broker hasn't acked this
1279            // packet yet. This error is possible only when broker isn't acking sequentially
1280            let replay_publish = self.publish_for_replay_tracking(&publish);
1281            self.outgoing_pub[pkid as usize] = Some(replay_publish);
1282            self.outgoing_pub_notice[pkid as usize] = notice.take();
1283            self.outgoing_pub_ack.set(pkid as usize, false);
1284            self.inflight += 1;
1285        }
1286
1287        debug!(
1288            "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}",
1289            String::from_utf8_lossy(&publish.topic),
1290            publish.pkid,
1291            publish.payload.len()
1292        );
1293
1294        let pkid = publish.pkid;
1295        if let Some(pending_auto_topic_alias) = auto_topic_alias_action
1296            .as_ref()
1297            .and_then(AutoTopicAliasAction::pending_alias)
1298        {
1299            self.record_auto_topic_alias(pending_auto_topic_alias.clone());
1300        } else if let Some(alias) = auto_topic_alias_action
1301            .as_ref()
1302            .and_then(AutoTopicAliasAction::existing_alias)
1303        {
1304            self.record_auto_topic_alias_use(alias);
1305        }
1306        self.record_outgoing_topic_alias(&publish);
1307
1308        let event = Event::Outgoing(Outgoing::Publish(pkid));
1309        self.events.push_back(event);
1310
1311        if publish.qos == QoS::AtMostOnce {
1312            Ok((Some(Packet::Publish(publish)), notice.take()))
1313        } else {
1314            Ok((Some(Packet::Publish(publish)), None))
1315        }
1316    }
1317
1318    fn outgoing_pubrel_with_notice(
1319        &mut self,
1320        pubrel: PubRel,
1321        notice: Option<PublishNoticeTx>,
1322    ) -> (Option<Packet>, Option<PublishNoticeTx>) {
1323        let pubrel = self.save_pubrel_with_notice(pubrel, notice);
1324
1325        debug!("Pubrel. Pkid = {}", pubrel.pkid);
1326
1327        let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid));
1328        self.events.push_back(event);
1329
1330        (Some(Packet::PubRel(PubRel::new(pubrel.pkid, None))), None)
1331    }
1332
1333    fn outgoing_puback(&mut self, puback: PubAck) -> Packet {
1334        let pkid = puback.pkid;
1335        let event = Event::Outgoing(Outgoing::PubAck(pkid));
1336        self.events.push_back(event);
1337
1338        Packet::PubAck(puback)
1339    }
1340
1341    fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Packet {
1342        let pkid = pubrec.pkid;
1343        let event = Event::Outgoing(Outgoing::PubRec(pkid));
1344        self.events.push_back(event);
1345
1346        Packet::PubRec(pubrec)
1347    }
1348
1349    /// check when the last control packet/pingreq packet is received and return
1350    /// the status which tells if keep alive time has exceeded
1351    /// NOTE: status will be checked for zero keepalive times also
1352    fn outgoing_ping(&mut self) -> Result<Option<Packet>, StateError> {
1353        let elapsed_in = self.last_incoming.elapsed();
1354        let elapsed_out = self.last_outgoing.elapsed();
1355
1356        if self.collision.is_some() {
1357            self.collision_ping_count += 1;
1358            if self.collision_ping_count >= 2 {
1359                return Err(StateError::CollisionTimeout);
1360            }
1361        }
1362
1363        // raise error if last ping didn't receive ack
1364        if self.await_pingresp {
1365            return Err(StateError::AwaitPingResp);
1366        }
1367
1368        self.await_pingresp = true;
1369
1370        debug!(
1371            "Pingreq, last incoming packet before {elapsed_in:?}, last outgoing request before {elapsed_out:?}",
1372        );
1373
1374        let event = Event::Outgoing(Outgoing::PingReq);
1375        self.events.push_back(event);
1376
1377        Ok(Some(Packet::PingReq(PingReq)))
1378    }
1379
1380    fn outgoing_subscribe(
1381        &mut self,
1382        mut subscription: Subscribe,
1383        notice: Option<SubscribeNoticeTx>,
1384    ) -> Result<Option<Packet>, StateError> {
1385        if subscription.filters.is_empty() {
1386            return Err(StateError::EmptySubscription);
1387        }
1388
1389        let pkid = self.next_control_pkid()?;
1390        subscription.pkid = pkid;
1391
1392        debug!(
1393            "Subscribe. Topics = {:?}, Pkid = {:?}",
1394            subscription.filters, subscription.pkid
1395        );
1396
1397        let pkid = subscription.pkid;
1398        let event = Event::Outgoing(Outgoing::Subscribe(pkid));
1399        self.events.push_back(event);
1400        if let Some(notice) = notice {
1401            self.tracked_subscribe
1402                .insert(subscription.pkid, (subscription.clone(), notice));
1403        }
1404
1405        Ok(Some(Packet::Subscribe(subscription)))
1406    }
1407
1408    fn outgoing_unsubscribe(
1409        &mut self,
1410        mut unsub: Unsubscribe,
1411        notice: Option<UnsubscribeNoticeTx>,
1412    ) -> Result<Packet, StateError> {
1413        let pkid = self.next_control_pkid()?;
1414        unsub.pkid = pkid;
1415
1416        debug!(
1417            "Unsubscribe. Topics = {:?}, Pkid = {:?}",
1418            unsub.filters, unsub.pkid
1419        );
1420
1421        let pkid = unsub.pkid;
1422        let event = Event::Outgoing(Outgoing::Unsubscribe(pkid));
1423        self.events.push_back(event);
1424        if let Some(notice) = notice {
1425            self.tracked_unsubscribe
1426                .insert(unsub.pkid, (unsub.clone(), notice));
1427        }
1428
1429        Ok(Packet::Unsubscribe(unsub))
1430    }
1431
1432    fn outgoing_disconnect(&mut self, disconnect: Disconnect) -> Packet {
1433        self.fail_auth_exchange_due_to_client_disconnect();
1434        let reason = disconnect.reason_code;
1435        debug!("Disconnect with {reason:?}");
1436        let event = Event::Outgoing(Outgoing::Disconnect);
1437        self.events.push_back(event);
1438
1439        Packet::Disconnect(disconnect)
1440    }
1441
1442    fn outgoing_auth(
1443        &mut self,
1444        mut auth: Auth,
1445        mut notice: Option<crate::notice::AuthNoticeTx>,
1446    ) -> Result<Packet, StateError> {
1447        let method = match self.auth.reauth_method() {
1448            Ok(method) => method.to_owned(),
1449            Err(err) => {
1450                if let Some(notice) = notice.take() {
1451                    notice.error(err.clone());
1452                }
1453                return Err(StateError::AuthError(err.to_string()));
1454            }
1455        };
1456
1457        if let Some(authenticator) = self.authenticator.clone() {
1458            let context = AuthContext {
1459                kind: AuthExchangeKind::Reauthentication,
1460                method: &method,
1461            };
1462            let start_result = authenticator.lock().unwrap().start(context);
1463            match start_result {
1464                Ok(Some(properties)) if auth.properties.is_none() => {
1465                    auth.properties = Some(properties);
1466                }
1467                Ok(_) => {}
1468                Err(err) => {
1469                    if let Some(notice) = notice.take() {
1470                        notice.error(AuthNoticeError::AuthenticationFailed(err.to_string()));
1471                    }
1472                    return Err(StateError::AuthError(err.to_string()));
1473                }
1474            }
1475        }
1476        let auth = self
1477            .auth
1478            .begin_reauth(auth.properties, notice, &mut self.events)?;
1479        Ok(self.outgoing_auth_packet(auth))
1480    }
1481
1482    fn outgoing_auth_packet(&mut self, auth: Auth) -> Packet {
1483        let props = auth
1484            .properties
1485            .as_ref()
1486            .expect("AUTH packets created by state always contain properties");
1487        debug!(
1488            "Auth packet sent. Auth Method: {:?}. Auth Data: {:?}",
1489            props.method, props.data
1490        );
1491        let event = Event::Outgoing(Outgoing::Auth);
1492        self.events.push_back(event);
1493        Packet::Auth(auth)
1494    }
1495
1496    fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option<PublishNoticeTx>)> {
1497        if let Some(publish) = &self.collision
1498            && publish.pkid == pkid
1499        {
1500            return self
1501                .collision
1502                .take()
1503                .map(|publish| (publish, self.collision_notice.take()));
1504        }
1505
1506        None
1507    }
1508
1509    fn save_pubrel_with_notice(
1510        &mut self,
1511        mut pubrel: PubRel,
1512        notice: Option<PublishNoticeTx>,
1513    ) -> PubRel {
1514        let pubrel = match pubrel.pkid {
1515            // consider PacketIdentifier(0) as uninitialized packets
1516            0 => {
1517                pubrel.pkid = self.next_pkid();
1518                pubrel
1519            }
1520            _ => pubrel,
1521        };
1522
1523        self.outgoing_rel.insert(pubrel.pkid as usize);
1524        self.outgoing_rel_notice[pubrel.pkid as usize] = notice;
1525        self.inflight += 1;
1526        pubrel
1527    }
1528
1529    fn mark_outgoing_packet_id_complete(&mut self, pkid: u16) {
1530        self.outgoing_pub_ack.set(pkid as usize, true);
1531        self.advance_last_puback_frontier();
1532    }
1533
1534    fn advance_last_puback_frontier(&mut self) {
1535        let mut next = self.next_puback_boundary_pkid(self.last_puback);
1536        while next != 0 && self.outgoing_pub_ack.contains(next as usize) {
1537            self.outgoing_pub_ack.set(next as usize, false);
1538            self.last_puback = next;
1539            next = self.next_puback_boundary_pkid(self.last_puback);
1540        }
1541    }
1542
1543    const fn next_puback_boundary_pkid(&self, pkid: u16) -> u16 {
1544        if self.max_outgoing_inflight == 0 {
1545            return 0;
1546        }
1547
1548        if pkid >= self.max_outgoing_inflight {
1549            1
1550        } else {
1551            pkid + 1
1552        }
1553    }
1554
1555    /// <http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation>
1556    /// Packet ids are incremented till maximum set inflight messages and reset to 1 after that.
1557    ///
1558    fn next_publish_pkid(&self) -> Option<u16> {
1559        let mut pkid = self.next_publish_pkid_after(self.last_pkid);
1560        for _ in 0..usize::from(self.max_outgoing_inflight) {
1561            if !self.packet_identifier_in_use(pkid) {
1562                return Some(pkid);
1563            }
1564            pkid = self.next_publish_pkid_after(pkid);
1565        }
1566
1567        None
1568    }
1569
1570    fn next_pkid(&mut self) -> u16 {
1571        let next_pkid = self
1572            .next_publish_pkid()
1573            .unwrap_or_else(|| self.next_publish_pkid_after(self.last_pkid));
1574
1575        // When next packet id is at the edge of inflight queue,
1576        // set await flag. This instructs eventloop to stop
1577        // processing requests until all the inflight publishes
1578        // are acked
1579        if next_pkid == self.max_outgoing_inflight {
1580            self.last_pkid = 0;
1581            return next_pkid;
1582        }
1583
1584        self.last_pkid = next_pkid;
1585        next_pkid
1586    }
1587
1588    fn next_control_pkid(&mut self) -> Result<u16, StateError> {
1589        for offset in 1..=u16::MAX {
1590            let pkid = self.last_pkid.wrapping_add(offset);
1591            if pkid != 0 && !self.packet_identifier_in_use(pkid) {
1592                self.last_pkid = pkid;
1593                return Ok(pkid);
1594            }
1595        }
1596
1597        Err(StateError::InvalidState)
1598    }
1599
1600    fn publish_topic_alias(publish: &Publish) -> Option<u16> {
1601        publish
1602            .properties
1603            .as_ref()
1604            .and_then(|props| props.topic_alias)
1605    }
1606
1607    fn set_publish_topic_alias(publish: &mut Publish, alias: u16) {
1608        publish
1609            .properties
1610            .get_or_insert_with(PublishProperties::default)
1611            .topic_alias = Some(alias);
1612    }
1613
1614    fn apply_auto_topic_alias(&self, publish: &mut Publish) -> Option<AutoTopicAliasAction> {
1615        if !self.auto_topic_aliases
1616            || self.broker_topic_alias_max == 0
1617            || publish.topic.is_empty()
1618            || Self::publish_topic_alias(publish).is_some()
1619        {
1620            return None;
1621        }
1622
1623        if let Some(alias) = self
1624            .auto_outgoing_topic_aliases
1625            .get(&publish.topic)
1626            .copied()
1627        {
1628            let original_topic = publish.topic.clone();
1629            Self::set_publish_topic_alias(publish, alias);
1630            publish.topic = Bytes::new();
1631            return Some(AutoTopicAliasAction::Existing {
1632                original_topic,
1633                alias,
1634            });
1635        }
1636
1637        let (alias, previous_topic) = self.next_auto_topic_alias_assignment()?;
1638
1639        let pending = PendingAutoTopicAlias {
1640            topic: publish.topic.clone(),
1641            alias,
1642            previous_topic,
1643        };
1644        Self::set_publish_topic_alias(publish, alias);
1645        Some(AutoTopicAliasAction::New(pending))
1646    }
1647
1648    fn next_auto_topic_alias_assignment(&self) -> Option<(u16, Option<Bytes>)> {
1649        if let Some(alias) = self.next_available_auto_topic_alias() {
1650            return Some((alias, None));
1651        }
1652
1653        if self.auto_topic_alias_policy != TopicAliasPolicy::Lru {
1654            return None;
1655        }
1656
1657        self.least_recent_auto_topic_alias()
1658    }
1659
1660    fn next_available_auto_topic_alias(&self) -> Option<u16> {
1661        let next_auto_topic_alias = self.next_auto_topic_alias?;
1662        (next_auto_topic_alias..=self.broker_topic_alias_max)
1663            .find(|&alias| !self.outgoing_topic_aliases.contains_key(&alias))
1664    }
1665
1666    fn record_auto_topic_alias(&mut self, pending: PendingAutoTopicAlias) {
1667        if let Some(previous_topic) = pending.previous_topic {
1668            self.auto_outgoing_topic_aliases.remove(&previous_topic);
1669        }
1670        self.auto_outgoing_topic_aliases
1671            .insert(pending.topic.clone(), pending.alias);
1672        self.outgoing_topic_aliases
1673            .insert(pending.alias, pending.topic.clone());
1674        self.record_auto_topic_alias_use(pending.alias);
1675        self.advance_next_auto_topic_alias();
1676    }
1677
1678    fn record_auto_topic_alias_use(&mut self, alias: u16) {
1679        if self.auto_topic_alias_policy != TopicAliasPolicy::Lru {
1680            return;
1681        }
1682
1683        self.auto_topic_alias_lru.retain(|entry| *entry != alias);
1684        self.auto_topic_alias_lru.push_back(alias);
1685    }
1686
1687    fn least_recent_auto_topic_alias(&self) -> Option<(u16, Option<Bytes>)> {
1688        for &alias in &self.auto_topic_alias_lru {
1689            if alias == 0 || alias > self.broker_topic_alias_max {
1690                continue;
1691            }
1692
1693            let Some(topic) = self.outgoing_topic_aliases.get(&alias) else {
1694                continue;
1695            };
1696
1697            if self
1698                .auto_outgoing_topic_aliases
1699                .get(topic)
1700                .is_some_and(|mapped_alias| *mapped_alias == alias)
1701            {
1702                return Some((alias, Some(topic.clone())));
1703            }
1704        }
1705
1706        None
1707    }
1708
1709    fn advance_next_auto_topic_alias(&mut self) {
1710        let Some(mut alias) = self.next_auto_topic_alias else {
1711            return;
1712        };
1713
1714        while alias <= self.broker_topic_alias_max
1715            && self.outgoing_topic_aliases.contains_key(&alias)
1716        {
1717            let Some(next_alias) = alias.checked_add(1) else {
1718                self.next_auto_topic_alias = None;
1719                return;
1720            };
1721            alias = next_alias;
1722        }
1723
1724        self.next_auto_topic_alias = (alias <= self.broker_topic_alias_max).then_some(alias);
1725    }
1726
1727    const fn strip_publish_topic_alias(publish: &mut Publish) {
1728        if let Some(props) = &mut publish.properties {
1729            props.topic_alias = None;
1730        }
1731    }
1732
1733    fn publish_for_replay_tracking(&self, publish: &Publish) -> Publish {
1734        let mut replay_publish = publish.clone();
1735        if replay_publish.topic.is_empty()
1736            && let Some(alias) = Self::publish_topic_alias(&replay_publish)
1737            && let Some(topic) = self.outgoing_topic_aliases.get(&alias)
1738        {
1739            topic.clone_into(&mut replay_publish.topic);
1740        }
1741
1742        replay_publish
1743    }
1744
1745    fn validate_outgoing_topic_alias(&self, publish: &Publish) -> Result<(), StateError> {
1746        if let Some(alias) = Self::publish_topic_alias(publish)
1747            && alias > self.broker_topic_alias_max
1748        {
1749            // We MUST NOT send a Topic Alias that is greater than the
1750            // broker's Topic Alias Maximum.
1751            return Err(StateError::InvalidAlias {
1752                alias,
1753                max: self.broker_topic_alias_max,
1754            });
1755        }
1756
1757        Ok(())
1758    }
1759
1760    fn record_outgoing_topic_alias(&mut self, publish: &Publish) {
1761        if !publish.topic.is_empty()
1762            && let Some(alias) = Self::publish_topic_alias(publish)
1763        {
1764            if let Some(previous_topic) = self
1765                .outgoing_topic_aliases
1766                .insert(alias, publish.topic.clone())
1767                && previous_topic != publish.topic
1768            {
1769                self.auto_outgoing_topic_aliases.remove(&previous_topic);
1770                self.auto_topic_alias_lru.retain(|entry| *entry != alias);
1771            }
1772            self.auto_outgoing_topic_aliases
1773                .retain(|topic, mapped_alias| *mapped_alias != alias || topic == &publish.topic);
1774        }
1775    }
1776}
1777
1778impl Clone for MqttState {
1779    fn clone(&self) -> Self {
1780        Self {
1781            await_pingresp: self.await_pingresp,
1782            collision_ping_count: self.collision_ping_count,
1783            last_incoming: self.last_incoming,
1784            last_outgoing: self.last_outgoing,
1785            last_pkid: self.last_pkid,
1786            last_puback: self.last_puback,
1787            inflight: self.inflight,
1788            outgoing_pub: self.outgoing_pub.clone(),
1789            outgoing_pub_notice: Self::new_notice_slots_with_len(self.outgoing_pub.len()),
1790            outgoing_pub_ack: self.outgoing_pub_ack.clone(),
1791            outgoing_rel: self.outgoing_rel.clone(),
1792            outgoing_rel_notice: Self::new_notice_slots_with_len(self.outgoing_rel_notice.len()),
1793            incoming_pub: self.incoming_pub.clone(),
1794            collision: self.collision.clone(),
1795            collision_notice: None,
1796            tracked_subscribe: BTreeMap::new(),
1797            tracked_unsubscribe: BTreeMap::new(),
1798            events: self.events.clone(),
1799            manual_acks: self.manual_acks,
1800            incoming_topic_aliases: self.incoming_topic_aliases.clone(),
1801            outgoing_topic_aliases: self.outgoing_topic_aliases.clone(),
1802            auto_outgoing_topic_aliases: self.auto_outgoing_topic_aliases.clone(),
1803            next_auto_topic_alias: self.next_auto_topic_alias,
1804            auto_topic_aliases: self.auto_topic_aliases,
1805            auto_topic_alias_policy: self.auto_topic_alias_policy,
1806            auto_topic_alias_lru: self.auto_topic_alias_lru.clone(),
1807            broker_topic_alias_max: self.broker_topic_alias_max,
1808            max_outgoing_inflight: self.max_outgoing_inflight,
1809            max_outgoing_inflight_upper_limit: self.max_outgoing_inflight_upper_limit,
1810            authenticator: self.authenticator.clone(),
1811            auth: self.auth.clone(),
1812        }
1813    }
1814}
1815
1816#[cfg(test)]
1817mod test {
1818    use super::mqttbytes::v5::*;
1819    use super::mqttbytes::*;
1820    use super::{Event, Incoming, Outgoing, Request};
1821    use super::{MqttState, StateError};
1822    use crate::notice::{
1823        AuthNoticeError, AuthNoticeTx, PublishNotice, PublishNoticeError, PublishNoticeTx,
1824        PublishResult, SubscribeNoticeError, SubscribeNoticeTx, UnsubscribeNoticeError,
1825        UnsubscribeNoticeTx,
1826    };
1827    use crate::{NoticeFailureReason, TopicAliasPolicy};
1828    use bytes::Bytes;
1829    use std::collections::{HashMap, VecDeque};
1830    use std::sync::{Arc, Mutex};
1831
1832    const AUTH_METHOD: &str = "test-method";
1833
1834    fn build_outgoing_publish(qos: QoS) -> Publish {
1835        let topic = "hello/world".to_owned();
1836        let payload = vec![1, 2, 3];
1837
1838        let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload, None);
1839        publish.qos = qos;
1840        publish
1841    }
1842
1843    fn publish_properties_with_alias(alias: u16) -> PublishProperties {
1844        PublishProperties {
1845            topic_alias: Some(alias),
1846            ..Default::default()
1847        }
1848    }
1849
1850    fn build_outgoing_publish_with_alias(topic: &str, qos: QoS, alias: u16) -> Publish {
1851        Publish::new(
1852            topic,
1853            qos,
1854            vec![1, 2, 3],
1855            Some(publish_properties_with_alias(alias)),
1856        )
1857    }
1858
1859    fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish {
1860        let topic = "hello/world".to_owned();
1861        let payload = vec![1, 2, 3];
1862
1863        let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload, None);
1864        publish.pkid = pkid;
1865        publish.qos = qos;
1866        publish
1867    }
1868
1869    fn build_mqttstate() -> MqttState {
1870        MqttState::builder(u16::MAX).build()
1871    }
1872
1873    fn build_lru_auto_alias_mqttstate(max_inflight: u16, broker_topic_alias_max: u16) -> MqttState {
1874        let mut mqtt = MqttState::builder(max_inflight)
1875            .auto_topic_aliases(true)
1876            .topic_alias_policy(TopicAliasPolicy::Lru)
1877            .build();
1878        mqtt.broker_topic_alias_max = broker_topic_alias_max;
1879        mqtt
1880    }
1881
1882    fn assert_publish(packet: Packet, topic: &'static [u8], alias: Option<u16>) {
1883        match packet {
1884            Packet::Publish(publish) => {
1885                assert_eq!(publish.topic, Bytes::from_static(topic));
1886                assert_eq!(
1887                    publish
1888                        .properties
1889                        .as_ref()
1890                        .and_then(|props| props.topic_alias),
1891                    alias
1892                );
1893            }
1894            packet => panic!("expected publish, got {packet:?}"),
1895        }
1896    }
1897
1898    fn build_auth_mqttstate(authentication_method: Option<&str>) -> MqttState {
1899        MqttState::builder(10)
1900            .authentication_method(authentication_method.map(str::to_owned))
1901            .build()
1902    }
1903
1904    fn auth_properties(authentication_method: Option<&str>) -> AuthProperties {
1905        AuthProperties {
1906            method: authentication_method.map(str::to_owned),
1907            data: Some(Bytes::from_static(b"auth-data")),
1908            reason: None,
1909            user_properties: vec![],
1910        }
1911    }
1912
1913    #[derive(Debug)]
1914    struct StaticAuthManager {
1915        response: Result<Option<AuthProperties>, String>,
1916    }
1917
1918    impl crate::Authenticator for StaticAuthManager {
1919        fn start(
1920            &mut self,
1921            _context: crate::AuthContext<'_>,
1922        ) -> Result<Option<AuthProperties>, crate::AuthError> {
1923            Ok(None)
1924        }
1925
1926        fn continue_auth(
1927            &mut self,
1928            _context: crate::AuthContext<'_>,
1929            _auth_prop: Option<AuthProperties>,
1930        ) -> Result<crate::AuthAction, crate::AuthError> {
1931            self.response
1932                .clone()
1933                .map(|props| props.map_or(crate::AuthAction::Complete, crate::AuthAction::Send))
1934                .map_err(crate::AuthError::from)
1935        }
1936
1937        fn success(
1938            &mut self,
1939            _context: crate::AuthContext<'_>,
1940            _incoming: Option<AuthProperties>,
1941        ) -> Result<(), crate::AuthError> {
1942            Ok(())
1943        }
1944
1945        fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
1946    }
1947
1948    #[derive(Debug)]
1949    struct StartAuthManager {
1950        response: Option<AuthProperties>,
1951    }
1952
1953    impl crate::Authenticator for StartAuthManager {
1954        fn start(
1955            &mut self,
1956            _context: crate::AuthContext<'_>,
1957        ) -> Result<Option<AuthProperties>, crate::AuthError> {
1958            Ok(self.response.clone())
1959        }
1960
1961        fn continue_auth(
1962            &mut self,
1963            _context: crate::AuthContext<'_>,
1964            _auth_prop: Option<AuthProperties>,
1965        ) -> Result<crate::AuthAction, crate::AuthError> {
1966            Ok(crate::AuthAction::Complete)
1967        }
1968
1969        fn success(
1970            &mut self,
1971            _context: crate::AuthContext<'_>,
1972            _incoming: Option<AuthProperties>,
1973        ) -> Result<(), crate::AuthError> {
1974            Ok(())
1975        }
1976
1977        fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
1978    }
1979
1980    #[derive(Debug)]
1981    struct FailingStartAuthManager;
1982
1983    impl crate::Authenticator for FailingStartAuthManager {
1984        fn start(
1985            &mut self,
1986            _context: crate::AuthContext<'_>,
1987        ) -> Result<Option<AuthProperties>, crate::AuthError> {
1988            Err(crate::AuthError::from("start failed"))
1989        }
1990
1991        fn continue_auth(
1992            &mut self,
1993            _context: crate::AuthContext<'_>,
1994            _auth_prop: Option<AuthProperties>,
1995        ) -> Result<crate::AuthAction, crate::AuthError> {
1996            Ok(crate::AuthAction::Complete)
1997        }
1998
1999        fn success(
2000            &mut self,
2001            _context: crate::AuthContext<'_>,
2002            _incoming: Option<AuthProperties>,
2003        ) -> Result<(), crate::AuthError> {
2004            Ok(())
2005        }
2006
2007        fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
2008    }
2009
2010    fn queue_publish_with_notice(mqtt: &mut MqttState, publish: Publish) -> PublishNotice {
2011        let (tx, notice) = PublishNoticeTx::new();
2012        let (packet, flush_notice) = mqtt
2013            .outgoing_publish_with_notice(publish, Some(tx))
2014            .unwrap();
2015        assert!(packet.is_some());
2016        assert!(flush_notice.is_none());
2017        notice
2018    }
2019
2020    #[test]
2021    fn new_state_preallocates_event_queue_for_read_batch_bursts() {
2022        let mqtt = MqttState::builder(10).build();
2023        assert!(mqtt.events.capacity() >= MqttState::initial_events_capacity());
2024    }
2025
2026    #[test]
2027    fn clean_pending_capacity_counts_publish_rel_and_tracked_requests() {
2028        let mut mqtt = MqttState::builder(10).build();
2029        mqtt.outgoing_pub[1] = Some(build_outgoing_publish(QoS::AtLeastOnce));
2030        mqtt.outgoing_pub[2] = Some(build_outgoing_publish(QoS::ExactlyOnce));
2031        mqtt.outgoing_rel.insert(3);
2032        mqtt.outgoing_rel.insert(4);
2033
2034        let filter = Filter::new("a/b", QoS::AtMostOnce);
2035        let (sub_notice, _) = SubscribeNoticeTx::new();
2036        mqtt.tracked_subscribe
2037            .insert(5, (Subscribe::new(filter, None), sub_notice));
2038
2039        let (unsub_notice, _) = UnsubscribeNoticeTx::new();
2040        mqtt.tracked_unsubscribe
2041            .insert(6, (Unsubscribe::new("a/b", None), unsub_notice));
2042
2043        assert_eq!(mqtt.clean_pending_capacity(), 6);
2044    }
2045
2046    #[test]
2047    fn tracked_request_len_helpers_report_counts() {
2048        let mut mqtt = MqttState::builder(10).build();
2049        let filter = Filter::new("a/b", QoS::AtMostOnce);
2050        let (sub_notice, _) = SubscribeNoticeTx::new();
2051        mqtt.tracked_subscribe
2052            .insert(5, (Subscribe::new(filter, None), sub_notice));
2053        let (unsub_notice, _) = UnsubscribeNoticeTx::new();
2054        mqtt.tracked_unsubscribe
2055            .insert(6, (Unsubscribe::new("a/b", None), unsub_notice));
2056
2057        assert_eq!(mqtt.tracked_subscribe_len(), 1);
2058        assert_eq!(mqtt.tracked_unsubscribe_len(), 1);
2059        assert!(!mqtt.tracked_requests_is_empty());
2060
2061        mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
2062        assert!(mqtt.tracked_requests_is_empty());
2063    }
2064
2065    #[test]
2066    fn drain_tracked_requests_as_failed_reports_session_reset_and_returns_count() {
2067        let mut mqtt = MqttState::builder(10).build();
2068        let filter = Filter::new("a/b", QoS::AtMostOnce);
2069        let (sub_notice_tx, sub_notice) = SubscribeNoticeTx::new();
2070        mqtt.tracked_subscribe
2071            .insert(5, (Subscribe::new(filter, None), sub_notice_tx));
2072        let (unsub_notice_tx, unsub_notice) = UnsubscribeNoticeTx::new();
2073        mqtt.tracked_unsubscribe
2074            .insert(6, (Unsubscribe::new("a/b", None), unsub_notice_tx));
2075
2076        let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
2077
2078        assert_eq!(drained, 2);
2079        assert!(mqtt.tracked_requests_is_empty());
2080        assert_eq!(
2081            sub_notice.wait().unwrap_err(),
2082            SubscribeNoticeError::SessionReset
2083        );
2084        assert_eq!(
2085            unsub_notice.wait().unwrap_err(),
2086            UnsubscribeNoticeError::SessionReset
2087        );
2088    }
2089
2090    #[test]
2091    fn drain_tracked_requests_as_failed_is_noop_when_empty() {
2092        let mut mqtt = MqttState::builder(10).build();
2093        let drained = mqtt.drain_tracked_requests_as_failed(NoticeFailureReason::SessionReset);
2094
2095        assert_eq!(drained, 0);
2096        assert!(mqtt.tracked_requests_is_empty());
2097    }
2098
2099    #[test]
2100    fn tracked_puback_returns_ack_and_preserves_incoming_event() {
2101        let mut mqtt = build_mqttstate();
2102        let notice = queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::AtLeastOnce));
2103        mqtt.events.clear();
2104
2105        let mut puback = PubAck::new(1, None);
2106        puback.reason = PubAckReason::NoMatchingSubscribers;
2107        puback.properties = Some(PubAckProperties {
2108            reason_string: Some("accepted without subscribers".to_owned()),
2109            user_properties: vec![("k".to_owned(), "v".to_owned())],
2110        });
2111        assert!(
2112            mqtt.handle_incoming_packet(Incoming::PubAck(puback.clone()))
2113                .unwrap()
2114                .is_none()
2115        );
2116
2117        assert_eq!(notice.wait(), Ok(PublishResult::Qos1(puback.clone())));
2118        assert_eq!(
2119            mqtt.events.pop_front(),
2120            Some(Event::Incoming(Packet::PubAck(puback)))
2121        );
2122    }
2123
2124    #[test]
2125    fn tracked_suback_returns_ack_with_properties_and_preserves_incoming_event() {
2126        let mut mqtt = build_mqttstate();
2127        let (tx, notice) = SubscribeNoticeTx::new();
2128        mqtt.outgoing_subscribe(
2129            Subscribe::new(Filter::new("a/b", QoS::AtMostOnce), None),
2130            Some(tx),
2131        )
2132        .unwrap();
2133        mqtt.events.clear();
2134
2135        let suback = SubAck {
2136            pkid: 1,
2137            return_codes: vec![SubscribeReasonCode::Unspecified],
2138            properties: Some(SubAckProperties {
2139                reason_string: Some("denied".to_owned()),
2140                user_properties: vec![("scope".to_owned(), "missing".to_owned())],
2141            }),
2142        };
2143        assert!(
2144            mqtt.handle_incoming_packet(Incoming::SubAck(suback.clone()))
2145                .unwrap()
2146                .is_none()
2147        );
2148
2149        assert_eq!(notice.wait(), Ok(suback.clone()));
2150        assert_eq!(
2151            mqtt.events.pop_front(),
2152            Some(Event::Incoming(Packet::SubAck(suback)))
2153        );
2154    }
2155
2156    #[test]
2157    fn tracked_unsuback_returns_ack_with_properties_and_preserves_incoming_event() {
2158        let mut mqtt = build_mqttstate();
2159        let (tx, notice) = UnsubscribeNoticeTx::new();
2160        mqtt.outgoing_unsubscribe(Unsubscribe::new("a/b", None), Some(tx))
2161            .unwrap();
2162        mqtt.events.clear();
2163
2164        let unsuback = UnsubAck {
2165            pkid: 1,
2166            reasons: vec![UnsubAckReason::UnspecifiedError],
2167            properties: Some(UnsubAckProperties {
2168                reason_string: Some("failed".to_owned()),
2169                user_properties: vec![("detail".to_owned(), "x".to_owned())],
2170            }),
2171        };
2172        assert!(
2173            mqtt.handle_incoming_packet(Incoming::UnsubAck(unsuback.clone()))
2174                .unwrap()
2175                .is_none()
2176        );
2177
2178        assert_eq!(notice.wait(), Ok(unsuback.clone()));
2179        assert_eq!(
2180            mqtt.events.pop_front(),
2181            Some(Event::Incoming(Packet::UnsubAck(unsuback)))
2182        );
2183    }
2184
2185    fn build_connack_with_receive_max(receive_max: u16) -> ConnAck {
2186        ConnAck {
2187            session_present: false,
2188            code: ConnectReturnCode::Success,
2189            properties: Some(ConnAckProperties {
2190                session_expiry_interval: None,
2191                receive_max: Some(receive_max),
2192                max_qos: None,
2193                retain_available: None,
2194                max_packet_size: None,
2195                assigned_client_identifier: None,
2196                topic_alias_max: None,
2197                reason_string: None,
2198                user_properties: vec![],
2199                wildcard_subscription_available: None,
2200                subscription_identifiers_available: None,
2201                shared_subscription_available: None,
2202                server_keep_alive: None,
2203                response_information: None,
2204                server_reference: None,
2205                authentication_method: None,
2206                authentication_data: None,
2207            }),
2208        }
2209    }
2210
2211    #[test]
2212    fn connack_receive_max_can_grow_tracking_capacity_after_previous_shrink() {
2213        let mut mqtt = MqttState::builder(10).build();
2214        mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(4)))
2215            .unwrap();
2216        mqtt.reconcile_outgoing_tracking_capacity(true);
2217        assert_eq!(mqtt.outgoing_pub.len(), 5);
2218
2219        mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(9)))
2220            .unwrap();
2221        assert_eq!(mqtt.outgoing_pub.len(), 10);
2222        assert_eq!(mqtt.outgoing_pub_notice.len(), 10);
2223        assert_eq!(mqtt.outgoing_rel_notice.len(), 10);
2224        assert_eq!(mqtt.outgoing_pub_ack.len(), 10);
2225        assert_eq!(mqtt.outgoing_rel.len(), 10);
2226    }
2227
2228    #[test]
2229    fn connack_receive_max_shrinks_when_tracking_is_empty_and_pending_is_empty() {
2230        let mut mqtt = MqttState::builder(10).build();
2231        mqtt.last_pkid = 9;
2232        mqtt.last_puback = 8;
2233
2234        mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2235            .unwrap();
2236        assert_eq!(mqtt.outgoing_pub.len(), 11);
2237
2238        mqtt.reconcile_outgoing_tracking_capacity(true);
2239        assert_eq!(mqtt.outgoing_pub.len(), 4);
2240        assert_eq!(mqtt.outgoing_pub_notice.len(), 4);
2241        assert_eq!(mqtt.outgoing_rel_notice.len(), 4);
2242        assert_eq!(mqtt.outgoing_pub_ack.len(), 4);
2243        assert_eq!(mqtt.outgoing_rel.len(), 4);
2244        assert_eq!(mqtt.last_pkid, 0);
2245        assert_eq!(mqtt.last_puback, 0);
2246    }
2247
2248    #[test]
2249    fn connack_resets_connection_scoped_alias_state_when_topic_alias_maximum_is_omitted() {
2250        let mut mqtt = MqttState::builder(10).build();
2251        mqtt.broker_topic_alias_max = 10;
2252        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
2253            "hello/replay",
2254            QoS::AtMostOnce,
2255            2,
2256        ))
2257        .unwrap();
2258
2259        mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(5)))
2260            .unwrap();
2261
2262        assert_eq!(mqtt.broker_topic_alias_max, 0);
2263        let mut replay = build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 2);
2264        let mut replay_topic_aliases = mqtt.replay_topic_aliases();
2265        assert_eq!(
2266            MqttState::prepare_publish_for_replay_with_aliases(
2267                &mut replay,
2268                &mut replay_topic_aliases
2269            )
2270            .unwrap_err(),
2271            PublishNoticeError::TopicAliasReplayUnavailable(2)
2272        );
2273    }
2274
2275    #[test]
2276    fn connack_receive_max_does_not_shrink_when_tracking_is_non_empty() {
2277        let mut mqtt = MqttState::builder(10).build();
2278        let mut publish = build_outgoing_publish(QoS::AtLeastOnce);
2279        publish.pkid = 8;
2280        mqtt.outgoing_pub[8] = Some(publish);
2281        mqtt.inflight = 1;
2282
2283        mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2284            .unwrap();
2285        mqtt.reconcile_outgoing_tracking_capacity(true);
2286
2287        assert_eq!(mqtt.outgoing_pub.len(), 11);
2288        assert_eq!(mqtt.outgoing_rel_notice.len(), 11);
2289    }
2290
2291    #[test]
2292    fn clone_preserves_current_tracking_queue_lengths() {
2293        let mut mqtt = MqttState::builder(10).build();
2294        mqtt.handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2295            .unwrap();
2296        mqtt.reconcile_outgoing_tracking_capacity(true);
2297
2298        let cloned = mqtt.clone();
2299        assert_eq!(cloned.outgoing_pub.len(), 4);
2300        assert_eq!(cloned.outgoing_pub_notice.len(), 4);
2301        assert_eq!(cloned.outgoing_rel_notice.len(), 4);
2302    }
2303
2304    #[test]
2305    fn next_pkid_increments_as_expected() {
2306        let mut mqtt = build_mqttstate();
2307
2308        for i in 1..=100 {
2309            let pkid = mqtt.next_pkid();
2310
2311            // loops between 0-99. % 100 == 0 implies border
2312            let expected = i % 100;
2313            if expected == 0 {
2314                break;
2315            }
2316
2317            assert_eq!(expected, pkid);
2318        }
2319    }
2320
2321    #[test]
2322    fn can_send_publish_searches_free_pkid_after_control_ids_pass_inflight_limit() {
2323        let mut mqtt = MqttState::builder(4).build();
2324        let mut active_publish = build_outgoing_publish(QoS::AtLeastOnce);
2325        active_publish.pkid = 1;
2326        mqtt.outgoing_pub[1] = Some(active_publish);
2327        mqtt.inflight = 1;
2328        mqtt.last_pkid = 5;
2329
2330        assert!(mqtt.can_send_publish(&build_outgoing_publish(QoS::AtLeastOnce)));
2331
2332        let packet = mqtt
2333            .outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
2334            .unwrap()
2335            .unwrap();
2336        match packet {
2337            Packet::Publish(publish) => assert_eq!(publish.pkid, 2),
2338            packet => panic!("Unexpected packet: {packet:?}"),
2339        }
2340    }
2341
2342    #[test]
2343    fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() {
2344        let mut mqtt = build_mqttstate();
2345
2346        // QoS0 Publish
2347        let publish = build_outgoing_publish(QoS::AtMostOnce);
2348
2349        // QoS 0 publish shouldn't be saved in queue
2350        mqtt.outgoing_publish(publish).unwrap();
2351        assert_eq!(mqtt.last_pkid, 0);
2352        assert_eq!(mqtt.inflight, 0);
2353
2354        // QoS1 Publish
2355        let publish = build_outgoing_publish(QoS::AtLeastOnce);
2356
2357        // Packet id should be set and publish should be saved in queue
2358        mqtt.outgoing_publish(publish.clone()).unwrap();
2359        assert_eq!(mqtt.last_pkid, 1);
2360        assert_eq!(mqtt.inflight, 1);
2361
2362        // Packet id should be incremented and publish should be saved in queue
2363        mqtt.outgoing_publish(publish).unwrap();
2364        assert_eq!(mqtt.last_pkid, 2);
2365        assert_eq!(mqtt.inflight, 2);
2366
2367        // QoS1 Publish
2368        let publish = build_outgoing_publish(QoS::ExactlyOnce);
2369
2370        // Packet id should be set and publish should be saved in queue
2371        mqtt.outgoing_publish(publish.clone()).unwrap();
2372        assert_eq!(mqtt.last_pkid, 3);
2373        assert_eq!(mqtt.inflight, 3);
2374
2375        // Packet id should be incremented and publish should be saved in queue
2376        mqtt.outgoing_publish(publish).unwrap();
2377        assert_eq!(mqtt.last_pkid, 4);
2378        assert_eq!(mqtt.inflight, 4);
2379    }
2380
2381    #[test]
2382    fn outgoing_publish_with_max_inflight_is_ok() {
2383        let mut mqtt = MqttState::builder(2).build();
2384
2385        // QoS2 publish
2386        let publish = build_outgoing_publish(QoS::ExactlyOnce);
2387
2388        mqtt.outgoing_publish(publish.clone()).unwrap();
2389        assert_eq!(mqtt.last_pkid, 1);
2390        assert_eq!(mqtt.inflight, 1);
2391
2392        // Packet id should be set back down to 0, since we hit the limit
2393        mqtt.outgoing_publish(publish.clone()).unwrap();
2394        assert_eq!(mqtt.last_pkid, 0);
2395        assert_eq!(mqtt.inflight, 2);
2396
2397        // This should cause a collition
2398        mqtt.outgoing_publish(publish.clone()).unwrap();
2399        assert_eq!(mqtt.last_pkid, 1);
2400        assert_eq!(mqtt.inflight, 2);
2401        assert!(mqtt.collision.is_some());
2402
2403        mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
2404        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
2405        assert_eq!(mqtt.inflight, 1);
2406
2407        // Now there should be space in the outgoing queue
2408        mqtt.outgoing_publish(publish).unwrap();
2409        assert_eq!(mqtt.last_pkid, 0);
2410        assert_eq!(mqtt.inflight, 2);
2411    }
2412
2413    #[test]
2414    fn clean_is_calculating_pending_correctly() {
2415        fn build_publish_with_pkid(pkid: u16) -> Publish {
2416            let mut publish = Publish::new("test".to_owned(), QoS::AtLeastOnce, vec![], None);
2417            publish.pkid = pkid;
2418            publish
2419        }
2420
2421        fn build_outgoing_pub() -> Vec<Option<Publish>> {
2422            vec![
2423                None,
2424                Some(build_publish_with_pkid(1)),
2425                Some(build_publish_with_pkid(2)),
2426                Some(build_publish_with_pkid(3)),
2427                None,
2428                None,
2429                Some(build_publish_with_pkid(6)),
2430            ]
2431        }
2432
2433        let mut mqtt = build_mqttstate();
2434        mqtt.outgoing_pub = build_outgoing_pub();
2435        mqtt.last_puback = 3;
2436        let requests = mqtt.clean();
2437        let expected = vec![6, 1, 2, 3];
2438        for (req, pkid) in requests.iter().zip(expected) {
2439            if let Request::Publish(publish) = req {
2440                assert_eq!(publish.pkid, pkid);
2441            } else {
2442                unreachable!();
2443            }
2444        }
2445
2446        mqtt.outgoing_pub = build_outgoing_pub();
2447        mqtt.last_puback = 0;
2448        let requests = mqtt.clean();
2449        let expected = vec![1, 2, 3, 6];
2450        for (req, pkid) in requests.iter().zip(expected) {
2451            if let Request::Publish(publish) = req {
2452                assert_eq!(publish.pkid, pkid);
2453            } else {
2454                unreachable!();
2455            }
2456        }
2457
2458        mqtt.outgoing_pub = build_outgoing_pub();
2459        mqtt.last_puback = 6;
2460        let requests = mqtt.clean();
2461        let expected = vec![1, 2, 3, 6];
2462        for (req, pkid) in requests.iter().zip(expected) {
2463            if let Request::Publish(publish) = req {
2464                assert_eq!(publish.pkid, pkid);
2465            } else {
2466                unreachable!();
2467            }
2468        }
2469    }
2470
2471    #[test]
2472    fn incoming_publish_should_be_added_to_queue_correctly() {
2473        let mut mqtt = build_mqttstate();
2474
2475        // QoS0, 1, 2 Publishes
2476        let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
2477        let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
2478        let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
2479
2480        mqtt.handle_incoming_publish(&mut publish1).unwrap();
2481        mqtt.handle_incoming_publish(&mut publish2).unwrap();
2482        mqtt.handle_incoming_publish(&mut publish3).unwrap();
2483
2484        // only qos2 publish should be add to queue
2485        assert!(mqtt.incoming_pub.contains(3));
2486    }
2487
2488    #[test]
2489    fn incoming_publish_should_be_acked() {
2490        let mut mqtt = build_mqttstate();
2491
2492        // QoS0, 1, 2 Publishes
2493        let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
2494        let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
2495        let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
2496
2497        mqtt.handle_incoming_publish(&mut publish1).unwrap();
2498        mqtt.handle_incoming_publish(&mut publish2).unwrap();
2499        mqtt.handle_incoming_publish(&mut publish3).unwrap();
2500
2501        if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] {
2502            assert_eq!(pkid, 2);
2503        } else {
2504            panic!("missing puback");
2505        }
2506
2507        if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] {
2508            assert_eq!(pkid, 3);
2509        } else {
2510            panic!("missing PubRec");
2511        }
2512    }
2513
2514    #[test]
2515    fn incoming_publish_should_not_be_acked_with_manual_acks() {
2516        let mut mqtt = build_mqttstate();
2517        mqtt.manual_acks = true;
2518
2519        // QoS0, 1, 2 Publishes
2520        let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
2521        let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
2522        let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
2523
2524        mqtt.handle_incoming_publish(&mut publish1).unwrap();
2525        mqtt.handle_incoming_publish(&mut publish2).unwrap();
2526        mqtt.handle_incoming_publish(&mut publish3).unwrap();
2527
2528        assert!(mqtt.incoming_pub.contains(3));
2529        assert!(mqtt.events.is_empty());
2530    }
2531
2532    #[test]
2533    fn unknown_incoming_topic_alias_returns_protocol_error_disconnect() {
2534        let mut mqtt = build_mqttstate();
2535        let mut publish = build_incoming_publish(QoS::AtMostOnce, 0);
2536        publish.topic = Bytes::new();
2537        publish.properties = Some(publish_properties_with_alias(1));
2538
2539        let packet = mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap();
2540
2541        assert!(matches!(
2542            packet,
2543            Packet::Disconnect(disconnect)
2544                if disconnect.reason_code == DisconnectReasonCode::ProtocolError
2545        ));
2546        assert!(publish.topic.is_empty());
2547    }
2548
2549    #[test]
2550    fn handle_incoming_packet_does_not_surface_unknown_topic_alias_publish() {
2551        let mut mqtt = build_mqttstate();
2552        let mut publish = build_incoming_publish(QoS::AtMostOnce, 0);
2553        publish.topic = Bytes::new();
2554        publish.properties = Some(publish_properties_with_alias(1));
2555
2556        let packet = mqtt
2557            .handle_incoming_packet(Incoming::Publish(publish))
2558            .unwrap()
2559            .unwrap();
2560
2561        assert!(matches!(
2562            packet,
2563            Packet::Disconnect(disconnect)
2564                if disconnect.reason_code == DisconnectReasonCode::ProtocolError
2565        ));
2566        assert!(
2567            !mqtt
2568                .events
2569                .iter()
2570                .any(|event| matches!(event, Event::Incoming(Incoming::Publish(_))))
2571        );
2572        assert_eq!(
2573            mqtt.events,
2574            VecDeque::from([Event::Outgoing(Outgoing::Disconnect)])
2575        );
2576    }
2577
2578    #[test]
2579    fn outgoing_reauth_without_properties_synthesizes_connect_authentication_method() {
2580        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2581        let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2582
2583        let packet = mqtt
2584            .handle_outgoing_packet(Request::Auth(auth))
2585            .unwrap()
2586            .unwrap();
2587
2588        let Packet::Auth(auth) = packet else {
2589            panic!("expected AUTH packet");
2590        };
2591        let properties = auth.properties.unwrap();
2592        assert_eq!(properties.method.as_deref(), Some(AUTH_METHOD));
2593        assert_eq!(auth.code, AuthReasonCode::ReAuthenticate);
2594    }
2595
2596    #[test]
2597    fn outgoing_reauth_without_connect_authentication_method_fails() {
2598        let mut mqtt = build_auth_mqttstate(None);
2599        let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2600
2601        let err = mqtt
2602            .handle_outgoing_packet(Request::Auth(auth))
2603            .unwrap_err();
2604
2605        assert!(matches!(err, StateError::AuthError(_)));
2606    }
2607
2608    #[test]
2609    fn public_state_authentication_method_setter_enables_outgoing_reauth() {
2610        let mut mqtt = MqttState::builder(10).build();
2611        mqtt.set_authentication_method(Some(AUTH_METHOD.to_owned()));
2612        let auth = Auth::new(
2613            AuthReasonCode::ReAuthenticate,
2614            Some(auth_properties(Some(AUTH_METHOD))),
2615        );
2616
2617        let packet = mqtt
2618            .handle_outgoing_packet(Request::Auth(auth))
2619            .unwrap()
2620            .unwrap();
2621
2622        let Packet::Auth(auth) = packet else {
2623            panic!("expected AUTH packet");
2624        };
2625        assert_eq!(
2626            auth.properties.unwrap().method.as_deref(),
2627            Some(AUTH_METHOD)
2628        );
2629    }
2630
2631    #[test]
2632    fn outgoing_reauth_fills_missing_authentication_method() {
2633        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2634        let auth = Auth::new(AuthReasonCode::ReAuthenticate, Some(auth_properties(None)));
2635
2636        let packet = mqtt
2637            .handle_outgoing_packet(Request::Auth(auth))
2638            .unwrap()
2639            .unwrap();
2640
2641        let Packet::Auth(auth) = packet else {
2642            panic!("expected AUTH packet");
2643        };
2644        assert_eq!(
2645            auth.properties.unwrap().method.as_deref(),
2646            Some(AUTH_METHOD)
2647        );
2648    }
2649
2650    #[test]
2651    fn outgoing_reauth_rejects_mismatched_authentication_method() {
2652        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2653        let auth = Auth::new(
2654            AuthReasonCode::ReAuthenticate,
2655            Some(auth_properties(Some("other-method"))),
2656        );
2657
2658        let err = mqtt
2659            .handle_outgoing_packet(Request::Auth(auth))
2660            .unwrap_err();
2661
2662        assert!(matches!(err, StateError::AuthError(_)));
2663    }
2664
2665    #[test]
2666    fn tracked_reauth_missing_method_notice_fails_with_specific_error() {
2667        let mut mqtt = build_auth_mqttstate(None);
2668        let (notice_tx, notice) = AuthNoticeTx::new();
2669
2670        let err = mqtt
2671            .handle_outgoing_packet_with_notice(
2672                Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2673                Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2674            )
2675            .unwrap_err();
2676
2677        assert!(matches!(err, StateError::AuthError(_)));
2678        assert_eq!(
2679            notice.wait().unwrap_err(),
2680            AuthNoticeError::MissingAuthenticationMethod
2681        );
2682    }
2683
2684    #[test]
2685    fn tracked_reauth_mismatched_method_notice_fails_with_auth_error() {
2686        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2687        let (notice_tx, notice) = AuthNoticeTx::new();
2688
2689        let err = mqtt
2690            .handle_outgoing_packet_with_notice(
2691                Request::Auth(Auth::new(
2692                    AuthReasonCode::ReAuthenticate,
2693                    Some(auth_properties(Some("other-method"))),
2694                )),
2695                Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2696            )
2697            .unwrap_err();
2698
2699        assert!(matches!(err, StateError::AuthError(_)));
2700        assert!(matches!(
2701            notice.wait().unwrap_err(),
2702            AuthNoticeError::AuthenticationFailed(_)
2703        ));
2704    }
2705
2706    #[test]
2707    fn tracked_reauth_start_failure_notice_fails_with_auth_error() {
2708        let authenticator = Arc::new(Mutex::new(FailingStartAuthManager));
2709        let mut mqtt = MqttState::builder(10)
2710            .authentication_method(Some(AUTH_METHOD.to_owned()))
2711            .authenticator(authenticator)
2712            .build();
2713        let (notice_tx, notice) = AuthNoticeTx::new();
2714
2715        let err = mqtt
2716            .handle_outgoing_packet_with_notice(
2717                Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2718                Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2719            )
2720            .unwrap_err();
2721
2722        assert!(matches!(err, StateError::AuthError(_)));
2723        assert!(matches!(
2724            notice.wait().unwrap_err(),
2725            AuthNoticeError::AuthenticationFailed(_)
2726        ));
2727    }
2728
2729    #[test]
2730    fn initial_auth_start_returns_normalized_auth_properties() {
2731        let authenticator = Arc::new(Mutex::new(StartAuthManager {
2732            response: Some(auth_properties(None)),
2733        }));
2734        let mut mqtt = MqttState::builder(10)
2735            .authentication_method(Some(AUTH_METHOD.to_owned()))
2736            .authenticator(authenticator)
2737            .build();
2738
2739        let properties = mqtt
2740            .begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2741            .unwrap()
2742            .unwrap();
2743
2744        assert_eq!(properties.method.as_deref(), Some(AUTH_METHOD));
2745        assert_eq!(properties.data, Some(Bytes::from_static(b"auth-data")));
2746    }
2747
2748    #[test]
2749    fn outgoing_reauth_rejects_overlapping_attempt() {
2750        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2751        let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2752
2753        mqtt.handle_outgoing_packet(Request::Auth(auth.clone()))
2754            .unwrap();
2755        let err = mqtt
2756            .handle_outgoing_packet(Request::Auth(auth))
2757            .unwrap_err();
2758
2759        assert!(matches!(err, StateError::AuthError(_)));
2760    }
2761
2762    #[test]
2763    fn tracked_reauth_notice_completes_on_matching_auth_success() {
2764        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2765        let (notice_tx, notice) = AuthNoticeTx::new();
2766        let auth = Auth::new(AuthReasonCode::ReAuthenticate, None);
2767
2768        mqtt.handle_outgoing_packet_with_notice(
2769            Request::Auth(auth),
2770            Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2771        )
2772        .unwrap();
2773        mqtt.handle_incoming_packet(Incoming::Auth(Auth::new(
2774            AuthReasonCode::Success,
2775            Some(auth_properties(Some(AUTH_METHOD))),
2776        )))
2777        .unwrap();
2778
2779        assert_eq!(notice.wait().unwrap(), crate::AuthOutcome::Success);
2780    }
2781
2782    #[test]
2783    fn disconnect_now_fails_active_tracked_reauth_notice() {
2784        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2785        let (notice_tx, notice) = AuthNoticeTx::new();
2786
2787        mqtt.handle_outgoing_packet_with_notice(
2788            Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2789            Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2790        )
2791        .unwrap();
2792        mqtt.handle_outgoing_packet(Request::DisconnectNow(Disconnect::new(
2793            DisconnectReasonCode::NormalDisconnection,
2794        )))
2795        .unwrap();
2796
2797        assert_eq!(
2798            notice.wait().unwrap_err(),
2799            AuthNoticeError::ConnectionClosed
2800        );
2801        assert!(mqtt.events.iter().any(|event| {
2802            matches!(
2803                event,
2804                Event::Auth(crate::AuthEvent::Failed {
2805                    kind: crate::AuthExchangeKind::Reauthentication,
2806                    reason: crate::AuthFailureReason::ConnectionClosed,
2807                    ..
2808                })
2809            )
2810        }));
2811    }
2812
2813    #[test]
2814    fn tracked_overlapping_reauth_notice_fails() {
2815        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2816        mqtt.handle_outgoing_packet(Request::Auth(Auth::new(
2817            AuthReasonCode::ReAuthenticate,
2818            None,
2819        )))
2820        .unwrap();
2821        let (notice_tx, notice) = AuthNoticeTx::new();
2822
2823        let err = mqtt
2824            .handle_outgoing_packet_with_notice(
2825                Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2826                Some(crate::notice::TrackedNoticeTx::Auth(notice_tx)),
2827            )
2828            .unwrap_err();
2829
2830        assert!(matches!(err, StateError::AuthError(_)));
2831        assert_eq!(
2832            notice.wait().unwrap_err(),
2833            AuthNoticeError::OverlappingReauth
2834        );
2835    }
2836
2837    #[test]
2838    fn incoming_auth_success_without_active_exchange_is_protocol_error() {
2839        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2840        let auth = Auth::new(
2841            AuthReasonCode::Success,
2842            Some(auth_properties(Some(AUTH_METHOD))),
2843        );
2844
2845        let err = mqtt
2846            .handle_incoming_packet(Incoming::Auth(auth))
2847            .unwrap_err();
2848
2849        assert!(matches!(
2850            err,
2851            StateError::Deserialization(Error::ProtocolError)
2852        ));
2853    }
2854
2855    #[test]
2856    fn incoming_auth_success_accepts_matching_authentication_method() {
2857        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2858        mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2859            .unwrap();
2860        let auth = Auth::new(
2861            AuthReasonCode::Success,
2862            Some(auth_properties(Some(AUTH_METHOD))),
2863        );
2864
2865        let packet = mqtt.handle_incoming_packet(Incoming::Auth(auth)).unwrap();
2866
2867        assert!(packet.is_none());
2868    }
2869
2870    #[test]
2871    fn initial_auth_survives_fresh_session_pending_notice_cleanup() {
2872        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2873        mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2874            .unwrap();
2875        mqtt.fail_pending_notices();
2876
2877        let mut connack = build_connack_with_receive_max(10);
2878        connack.properties.as_mut().unwrap().authentication_method = Some(AUTH_METHOD.to_owned());
2879        mqtt.handle_incoming_packet(Incoming::ConnAck(connack))
2880            .unwrap();
2881
2882        assert!(mqtt.events.iter().any(|event| {
2883            matches!(
2884                event,
2885                Event::Auth(crate::AuthEvent::Succeeded {
2886                    kind: crate::AuthExchangeKind::InitialConnect,
2887                    ..
2888                })
2889            )
2890        }));
2891        assert!(!mqtt.events.iter().any(|event| {
2892            matches!(
2893                event,
2894                Event::Auth(crate::AuthEvent::Failed {
2895                    kind: crate::AuthExchangeKind::InitialConnect,
2896                    ..
2897                })
2898            )
2899        }));
2900    }
2901
2902    #[test]
2903    fn incoming_auth_success_rejects_missing_authentication_method() {
2904        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2905        mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2906            .unwrap();
2907        let auth = Auth::new(AuthReasonCode::Success, None);
2908
2909        let err = mqtt
2910            .handle_incoming_packet(Incoming::Auth(auth))
2911            .unwrap_err();
2912
2913        assert!(matches!(
2914            err,
2915            StateError::Deserialization(Error::ProtocolError)
2916        ));
2917    }
2918
2919    #[test]
2920    fn incoming_auth_success_rejects_mismatched_authentication_method() {
2921        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
2922        mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2923            .unwrap();
2924        let auth = Auth::new(
2925            AuthReasonCode::Success,
2926            Some(auth_properties(Some("other-method"))),
2927        );
2928
2929        let err = mqtt
2930            .handle_incoming_packet(Incoming::Auth(auth))
2931            .unwrap_err();
2932
2933        assert!(matches!(
2934            err,
2935            StateError::Deserialization(Error::ProtocolError)
2936        ));
2937    }
2938
2939    #[test]
2940    fn incoming_auth_success_without_connect_authentication_method_is_protocol_error() {
2941        let mut mqtt = build_auth_mqttstate(None);
2942        let auth = Auth::new(
2943            AuthReasonCode::Success,
2944            Some(auth_properties(Some(AUTH_METHOD))),
2945        );
2946
2947        let err = mqtt
2948            .handle_incoming_packet(Incoming::Auth(auth))
2949            .unwrap_err();
2950
2951        assert!(matches!(
2952            err,
2953            StateError::Deserialization(Error::ProtocolError)
2954        ));
2955    }
2956
2957    #[test]
2958    fn incoming_auth_continue_synthesizes_method_when_auth_manager_omits_it() {
2959        let auth_manager = Arc::new(Mutex::new(StaticAuthManager { response: Ok(None) }));
2960        let mut mqtt = MqttState::builder(10)
2961            .authentication_method(Some(AUTH_METHOD.to_owned()))
2962            .auth_manager(auth_manager)
2963            .build();
2964        mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
2965            .unwrap();
2966        let auth = Auth::new(
2967            AuthReasonCode::Continue,
2968            Some(auth_properties(Some(AUTH_METHOD))),
2969        );
2970
2971        let packet = mqtt
2972            .handle_incoming_packet(Incoming::Auth(auth))
2973            .unwrap()
2974            .unwrap();
2975
2976        let Packet::Auth(auth) = packet else {
2977            panic!("expected AUTH packet");
2978        };
2979        assert_eq!(auth.code, AuthReasonCode::Continue);
2980        assert_eq!(
2981            auth.properties.unwrap().method.as_deref(),
2982            Some(AUTH_METHOD)
2983        );
2984    }
2985
2986    #[test]
2987    fn incoming_auth_continue_without_connect_authentication_method_is_protocol_error() {
2988        let auth_manager = Arc::new(Mutex::new(StaticAuthManager { response: Ok(None) }));
2989        let mut mqtt = MqttState::builder(10).auth_manager(auth_manager).build();
2990        let auth = Auth::new(
2991            AuthReasonCode::Continue,
2992            Some(auth_properties(Some(AUTH_METHOD))),
2993        );
2994
2995        let err = mqtt
2996            .handle_incoming_packet(Incoming::Auth(auth))
2997            .unwrap_err();
2998
2999        assert!(matches!(
3000            err,
3001            StateError::Deserialization(Error::ProtocolError)
3002        ));
3003    }
3004
3005    #[test]
3006    fn incoming_auth_continue_rejects_mismatched_server_method() {
3007        let auth_manager = Arc::new(Mutex::new(StaticAuthManager { response: Ok(None) }));
3008        let mut mqtt = MqttState::builder(10)
3009            .authentication_method(Some(AUTH_METHOD.to_owned()))
3010            .auth_manager(auth_manager)
3011            .build();
3012        mqtt.begin_authentication_connect(Some(AUTH_METHOD.to_owned()))
3013            .unwrap();
3014        let auth = Auth::new(
3015            AuthReasonCode::Continue,
3016            Some(auth_properties(Some("other-method"))),
3017        );
3018
3019        let err = mqtt
3020            .handle_incoming_packet(Incoming::Auth(auth))
3021            .unwrap_err();
3022
3023        assert!(matches!(
3024            err,
3025            StateError::Deserialization(Error::ProtocolError)
3026        ));
3027    }
3028
3029    #[test]
3030    fn incoming_auth_reauthenticate_is_protocol_error() {
3031        let mut mqtt = build_auth_mqttstate(Some(AUTH_METHOD));
3032        let auth = Auth::new(
3033            AuthReasonCode::ReAuthenticate,
3034            Some(auth_properties(Some(AUTH_METHOD))),
3035        );
3036
3037        let err = mqtt
3038            .handle_incoming_packet(Incoming::Auth(auth))
3039            .unwrap_err();
3040
3041        assert!(matches!(
3042            err,
3043            StateError::Deserialization(Error::ProtocolError)
3044        ));
3045    }
3046
3047    #[test]
3048    fn connection_scoped_alias_state_resets_incoming_aliases_and_broker_maximum() {
3049        let mut mqtt = build_mqttstate();
3050        mqtt.broker_topic_alias_max = 10;
3051        let mut aliased = build_incoming_publish(QoS::AtMostOnce, 0);
3052        aliased.properties = Some(publish_properties_with_alias(1));
3053        mqtt.handle_incoming_publish(&mut aliased).unwrap();
3054
3055        let mut alias_only = build_incoming_publish(QoS::AtMostOnce, 0);
3056        alias_only.topic = Bytes::new();
3057        alias_only.properties = Some(publish_properties_with_alias(1));
3058        mqtt.handle_incoming_publish(&mut alias_only).unwrap();
3059        assert_eq!(alias_only.topic, Bytes::from_static(b"hello/world"));
3060
3061        mqtt.reset_connection_scoped_state();
3062
3063        assert_eq!(mqtt.broker_topic_alias_max, 0);
3064        let mut stale_alias = build_incoming_publish(QoS::AtMostOnce, 0);
3065        stale_alias.topic = Bytes::new();
3066        stale_alias.properties = Some(publish_properties_with_alias(1));
3067        let packet = mqtt
3068            .handle_incoming_publish(&mut stale_alias)
3069            .unwrap()
3070            .unwrap();
3071        assert!(matches!(packet, Packet::Disconnect(_)));
3072    }
3073
3074    #[test]
3075    fn replay_publish_with_known_outgoing_alias_restores_topic() {
3076        let mut mqtt = build_mqttstate();
3077        mqtt.broker_topic_alias_max = 10;
3078        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3079            "hello/replay",
3080            QoS::AtMostOnce,
3081            2,
3082        ))
3083        .unwrap();
3084
3085        let mut replay = build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 2);
3086        let mut replay_topic_aliases = mqtt.replay_topic_aliases();
3087
3088        MqttState::prepare_publish_for_replay_with_aliases(&mut replay, &mut replay_topic_aliases)
3089            .unwrap();
3090
3091        assert_eq!(replay.topic, Bytes::from_static(b"hello/replay"));
3092        assert_eq!(
3093            replay
3094                .properties
3095                .as_ref()
3096                .and_then(|props| props.topic_alias),
3097            None
3098        );
3099    }
3100
3101    #[test]
3102    fn replay_publish_with_concrete_topic_strips_stale_alias() {
3103        let mut replay = build_outgoing_publish_with_alias("hello/replay", QoS::AtLeastOnce, 2);
3104        let mut replay_topic_aliases = HashMap::new();
3105
3106        MqttState::prepare_publish_for_replay_with_aliases(&mut replay, &mut replay_topic_aliases)
3107            .unwrap();
3108
3109        assert_eq!(replay.topic, Bytes::from_static(b"hello/replay"));
3110        assert_eq!(
3111            replay
3112                .properties
3113                .as_ref()
3114                .and_then(|props| props.topic_alias),
3115            None
3116        );
3117        assert_eq!(
3118            replay_topic_aliases.get(&2),
3119            Some(&Bytes::from_static(b"hello/replay"))
3120        );
3121    }
3122
3123    #[test]
3124    fn replay_publish_with_stripped_alias_is_valid_when_next_broker_allows_no_aliases() {
3125        let mut replay = build_outgoing_publish_with_alias("hello/replay", QoS::AtLeastOnce, 2);
3126        let mut replay_topic_aliases = HashMap::new();
3127        MqttState::prepare_publish_for_replay_with_aliases(&mut replay, &mut replay_topic_aliases)
3128            .unwrap();
3129
3130        let mut next_connection = build_mqttstate();
3131        next_connection.broker_topic_alias_max = 0;
3132
3133        next_connection
3134            .handle_outgoing_packet(Request::Publish(replay))
3135            .unwrap();
3136    }
3137
3138    #[test]
3139    fn replay_publish_with_unknown_outgoing_alias_fails() {
3140        let mqtt = build_mqttstate();
3141        let mut replay = build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 3);
3142        let mut replay_topic_aliases = mqtt.replay_topic_aliases();
3143
3144        let err = MqttState::prepare_publish_for_replay_with_aliases(
3145            &mut replay,
3146            &mut replay_topic_aliases,
3147        )
3148        .unwrap_err();
3149
3150        assert_eq!(err, PublishNoticeError::TopicAliasReplayUnavailable(3));
3151        assert!(replay.topic.is_empty());
3152    }
3153
3154    #[test]
3155    fn auto_topic_aliases_are_disabled_by_default() {
3156        let mut mqtt = build_mqttstate();
3157        mqtt.broker_topic_alias_max = 10;
3158
3159        let packet = mqtt
3160            .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3161            .unwrap()
3162            .unwrap();
3163
3164        match packet {
3165            Packet::Publish(publish) => {
3166                assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3167                assert_eq!(
3168                    publish
3169                        .properties
3170                        .as_ref()
3171                        .and_then(|props| props.topic_alias),
3172                    None
3173                );
3174            }
3175            packet => panic!("expected publish, got {packet:?}"),
3176        }
3177    }
3178
3179    #[test]
3180    fn auto_topic_aliases_send_topic_and_alias_before_alias_only_publish() {
3181        let mut mqtt = MqttState::builder(u16::MAX)
3182            .auto_topic_aliases(true)
3183            .build();
3184        mqtt.broker_topic_alias_max = 10;
3185
3186        let first = mqtt
3187            .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3188            .unwrap()
3189            .unwrap();
3190        let second = mqtt
3191            .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3192            .unwrap()
3193            .unwrap();
3194
3195        match first {
3196            Packet::Publish(publish) => {
3197                assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3198                assert_eq!(
3199                    publish
3200                        .properties
3201                        .as_ref()
3202                        .and_then(|props| props.topic_alias),
3203                    Some(1)
3204                );
3205            }
3206            packet => panic!("expected publish, got {packet:?}"),
3207        }
3208        match second {
3209            Packet::Publish(publish) => {
3210                assert!(publish.topic.is_empty());
3211                assert_eq!(
3212                    publish
3213                        .properties
3214                        .as_ref()
3215                        .and_then(|props| props.topic_alias),
3216                    Some(1)
3217                );
3218            }
3219            packet => panic!("expected publish, got {packet:?}"),
3220        }
3221    }
3222
3223    #[test]
3224    fn auto_topic_aliases_do_nothing_when_broker_allows_no_aliases() {
3225        let mut mqtt = MqttState::builder(u16::MAX)
3226            .auto_topic_aliases(true)
3227            .build();
3228
3229        let packet = mqtt
3230            .outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3231            .unwrap()
3232            .unwrap();
3233
3234        match packet {
3235            Packet::Publish(publish) => {
3236                assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3237                assert_eq!(
3238                    publish
3239                        .properties
3240                        .as_ref()
3241                        .and_then(|props| props.topic_alias),
3242                    None
3243                );
3244            }
3245            packet => panic!("expected publish, got {packet:?}"),
3246        }
3247    }
3248
3249    #[test]
3250    fn auto_topic_aliases_stop_allocating_when_capacity_is_exhausted() {
3251        let mut mqtt = MqttState::builder(u16::MAX)
3252            .auto_topic_aliases(true)
3253            .build();
3254        mqtt.broker_topic_alias_max = 1;
3255        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3256            .unwrap();
3257
3258        let packet = mqtt
3259            .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3260            .unwrap()
3261            .unwrap();
3262
3263        match packet {
3264            Packet::Publish(publish) => {
3265                assert_eq!(publish.topic, Bytes::from_static(b"topic/two"));
3266                assert_eq!(
3267                    publish
3268                        .properties
3269                        .as_ref()
3270                        .and_then(|props| props.topic_alias),
3271                    None
3272                );
3273            }
3274            packet => panic!("expected publish, got {packet:?}"),
3275        }
3276    }
3277
3278    #[test]
3279    fn auto_topic_aliases_preserve_manual_aliases_and_skip_used_aliases() {
3280        let mut mqtt = MqttState::builder(u16::MAX)
3281            .auto_topic_aliases(true)
3282            .build();
3283        mqtt.broker_topic_alias_max = 2;
3284        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3285            "manual/topic",
3286            QoS::AtMostOnce,
3287            1,
3288        ))
3289        .unwrap();
3290
3291        let packet = mqtt
3292            .outgoing_publish(Publish::new("auto/topic", QoS::AtMostOnce, vec![], None))
3293            .unwrap()
3294            .unwrap();
3295
3296        match packet {
3297            Packet::Publish(publish) => {
3298                assert_eq!(publish.topic, Bytes::from_static(b"auto/topic"));
3299                assert_eq!(
3300                    publish
3301                        .properties
3302                        .as_ref()
3303                        .and_then(|props| props.topic_alias),
3304                    Some(2)
3305                );
3306            }
3307            packet => panic!("expected publish, got {packet:?}"),
3308        }
3309    }
3310
3311    #[test]
3312    fn manual_rebind_clears_stale_auto_topic_alias_mapping() {
3313        let mut mqtt = MqttState::builder(u16::MAX)
3314            .auto_topic_aliases(true)
3315            .build();
3316        mqtt.broker_topic_alias_max = 2;
3317        mqtt.outgoing_publish(Publish::new("auto/topic", QoS::AtMostOnce, vec![], None))
3318            .unwrap();
3319        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3320            "manual/topic",
3321            QoS::AtMostOnce,
3322            1,
3323        ))
3324        .unwrap();
3325
3326        let packet = mqtt
3327            .outgoing_publish(Publish::new("auto/topic", QoS::AtMostOnce, vec![], None))
3328            .unwrap()
3329            .unwrap();
3330
3331        match packet {
3332            Packet::Publish(publish) => {
3333                assert_eq!(publish.topic, Bytes::from_static(b"auto/topic"));
3334                assert_eq!(
3335                    publish
3336                        .properties
3337                        .as_ref()
3338                        .and_then(|props| props.topic_alias),
3339                    Some(2)
3340                );
3341            }
3342            packet => panic!("expected publish, got {packet:?}"),
3343        }
3344    }
3345
3346    #[test]
3347    fn auto_topic_alias_qos_replay_uses_full_topic_after_clean() {
3348        let mut mqtt = MqttState::builder(u16::MAX)
3349            .auto_topic_aliases(true)
3350            .build();
3351        mqtt.broker_topic_alias_max = 10;
3352        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtMostOnce))
3353            .unwrap();
3354        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3355            .unwrap();
3356
3357        let requests = mqtt.clean();
3358
3359        assert_eq!(requests.len(), 1);
3360        match &requests[0] {
3361            Request::Publish(publish) => {
3362                assert_eq!(publish.topic, Bytes::from_static(b"hello/world"));
3363                assert_eq!(
3364                    publish
3365                        .properties
3366                        .as_ref()
3367                        .and_then(|props| props.topic_alias),
3368                    None
3369                );
3370            }
3371            request => panic!("expected replay publish, got {request:?}"),
3372        }
3373    }
3374
3375    #[test]
3376    fn auto_topic_alias_collision_does_not_register_unsent_alias() {
3377        let mut mqtt = MqttState::builder(2).auto_topic_aliases(true).build();
3378        mqtt.broker_topic_alias_max = 10;
3379        mqtt.outgoing_publish(Publish::new(
3380            "inflight/topic",
3381            QoS::AtLeastOnce,
3382            vec![],
3383            None,
3384        ))
3385        .unwrap();
3386
3387        let mut collided = Publish::new("collided/topic", QoS::AtLeastOnce, vec![], None);
3388        collided.pkid = 1;
3389        let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3390        assert!(packet.is_none());
3391        assert!(flush_notice.is_none());
3392        assert!(mqtt.collision.is_some());
3393
3394        let packet = mqtt
3395            .outgoing_publish(Publish::new(
3396                "collided/topic",
3397                QoS::AtMostOnce,
3398                vec![],
3399                None,
3400            ))
3401            .unwrap()
3402            .unwrap();
3403
3404        match packet {
3405            Packet::Publish(publish) => {
3406                assert_eq!(publish.topic, Bytes::from_static(b"collided/topic"));
3407                assert_eq!(
3408                    publish
3409                        .properties
3410                        .as_ref()
3411                        .and_then(|props| props.topic_alias),
3412                    Some(2)
3413                );
3414            }
3415            packet => panic!("expected publish, got {packet:?}"),
3416        }
3417    }
3418
3419    #[test]
3420    fn auto_topic_alias_collision_replay_does_not_send_uncommitted_alias() {
3421        let mut mqtt = MqttState::builder(2).auto_topic_aliases(true).build();
3422        mqtt.broker_topic_alias_max = 10;
3423        let first_notice = queue_publish_with_notice(
3424            &mut mqtt,
3425            Publish::new("inflight/topic", QoS::AtLeastOnce, vec![], None),
3426        );
3427
3428        let mut collided = Publish::new("collided/topic", QoS::AtLeastOnce, vec![], None);
3429        collided.pkid = 1;
3430        let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3431        assert!(packet.is_none());
3432        assert!(flush_notice.is_none());
3433
3434        let puback = PubAck::new(1, None);
3435        let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
3436
3437        match packet {
3438            Packet::Publish(publish) => {
3439                assert_eq!(publish.topic, Bytes::from_static(b"collided/topic"));
3440                assert_eq!(
3441                    publish
3442                        .properties
3443                        .as_ref()
3444                        .and_then(|props| props.topic_alias),
3445                    None
3446                );
3447            }
3448            packet => panic!("expected publish, got {packet:?}"),
3449        }
3450        assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
3451    }
3452
3453    #[test]
3454    fn auto_topic_alias_collision_replay_restores_reused_alias_topic_after_rebind() {
3455        let mut mqtt = MqttState::builder(2).auto_topic_aliases(true).build();
3456        mqtt.broker_topic_alias_max = 10;
3457        let first_notice = queue_publish_with_notice(
3458            &mut mqtt,
3459            Publish::new("aliased/topic", QoS::AtLeastOnce, vec![], None),
3460        );
3461
3462        let mut collided = Publish::new("aliased/topic", QoS::AtLeastOnce, vec![], None);
3463        collided.pkid = 1;
3464        let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3465        assert!(packet.is_none());
3466        assert!(flush_notice.is_none());
3467
3468        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3469            "manual/rebind",
3470            QoS::AtMostOnce,
3471            1,
3472        ))
3473        .unwrap();
3474
3475        let puback = PubAck::new(1, None);
3476        let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
3477
3478        match packet {
3479            Packet::Publish(publish) => {
3480                assert_eq!(publish.topic, Bytes::from_static(b"aliased/topic"));
3481                assert_eq!(
3482                    publish
3483                        .properties
3484                        .as_ref()
3485                        .and_then(|props| props.topic_alias),
3486                    None
3487                );
3488            }
3489            packet => panic!("expected publish, got {packet:?}"),
3490        }
3491        assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
3492    }
3493
3494    #[test]
3495    fn auto_topic_aliases_exhaust_without_wrapping_at_u16_max() {
3496        let mut mqtt = MqttState::builder(u16::MAX)
3497            .auto_topic_aliases(true)
3498            .build();
3499        mqtt.broker_topic_alias_max = u16::MAX;
3500        mqtt.next_auto_topic_alias = Some(u16::MAX);
3501
3502        let last_packet = mqtt
3503            .outgoing_publish(Publish::new("last/topic", QoS::AtMostOnce, vec![], None))
3504            .unwrap()
3505            .unwrap();
3506        match last_packet {
3507            Packet::Publish(publish) => {
3508                assert_eq!(publish.topic, Bytes::from_static(b"last/topic"));
3509                assert_eq!(
3510                    publish
3511                        .properties
3512                        .as_ref()
3513                        .and_then(|props| props.topic_alias),
3514                    Some(u16::MAX)
3515                );
3516            }
3517            packet => panic!("expected publish, got {packet:?}"),
3518        }
3519        assert_eq!(mqtt.next_auto_topic_alias, None);
3520
3521        let exhausted_packet = mqtt
3522            .outgoing_publish(Publish::new(
3523                "exhausted/topic",
3524                QoS::AtMostOnce,
3525                vec![],
3526                None,
3527            ))
3528            .unwrap()
3529            .unwrap();
3530
3531        match exhausted_packet {
3532            Packet::Publish(publish) => {
3533                assert_eq!(publish.topic, Bytes::from_static(b"exhausted/topic"));
3534                assert_eq!(
3535                    publish
3536                        .properties
3537                        .as_ref()
3538                        .and_then(|props| props.topic_alias),
3539                    None
3540                );
3541            }
3542            packet => panic!("expected publish, got {packet:?}"),
3543        }
3544    }
3545
3546    #[test]
3547    fn lru_auto_topic_aliases_evict_least_recent_topic() {
3548        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 2);
3549
3550        let first = mqtt
3551            .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3552            .unwrap()
3553            .unwrap();
3554        let second = mqtt
3555            .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3556            .unwrap()
3557            .unwrap();
3558        let third = mqtt
3559            .outgoing_publish(Publish::new("topic/three", QoS::AtMostOnce, vec![], None))
3560            .unwrap()
3561            .unwrap();
3562
3563        assert_publish(first, b"topic/one", Some(1));
3564        assert_publish(second, b"topic/two", Some(2));
3565        assert_publish(third, b"topic/three", Some(1));
3566
3567        let packet = mqtt
3568            .outgoing_publish(Publish::new("topic/three", QoS::AtMostOnce, vec![], None))
3569            .unwrap()
3570            .unwrap();
3571        assert_publish(packet, b"", Some(1));
3572    }
3573
3574    #[test]
3575    fn lru_auto_topic_aliases_refresh_existing_topic_recency() {
3576        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 2);
3577
3578        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3579            .unwrap();
3580        mqtt.outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3581            .unwrap();
3582        let refresh = mqtt
3583            .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3584            .unwrap()
3585            .unwrap();
3586        let evict = mqtt
3587            .outgoing_publish(Publish::new("topic/three", QoS::AtMostOnce, vec![], None))
3588            .unwrap()
3589            .unwrap();
3590
3591        assert_publish(refresh, b"", Some(1));
3592        assert_publish(evict, b"topic/three", Some(2));
3593    }
3594
3595    #[test]
3596    fn lru_auto_topic_aliases_rebound_alias_sends_full_topic_then_alias_only() {
3597        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 1);
3598
3599        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3600            .unwrap();
3601        let rebound = mqtt
3602            .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3603            .unwrap()
3604            .unwrap();
3605        let alias_only = mqtt
3606            .outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3607            .unwrap()
3608            .unwrap();
3609
3610        assert_publish(rebound, b"topic/two", Some(1));
3611        assert_publish(alias_only, b"", Some(1));
3612    }
3613
3614    #[test]
3615    fn lru_auto_topic_aliases_do_not_evict_manual_aliases() {
3616        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 2);
3617        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3618            "manual/topic",
3619            QoS::AtMostOnce,
3620            1,
3621        ))
3622        .unwrap();
3623
3624        let first_auto = mqtt
3625            .outgoing_publish(Publish::new("auto/one", QoS::AtMostOnce, vec![], None))
3626            .unwrap()
3627            .unwrap();
3628        let second_auto = mqtt
3629            .outgoing_publish(Publish::new("auto/two", QoS::AtMostOnce, vec![], None))
3630            .unwrap()
3631            .unwrap();
3632
3633        assert_publish(first_auto, b"auto/one", Some(2));
3634        assert_publish(second_auto, b"auto/two", Some(2));
3635        assert_eq!(
3636            mqtt.outgoing_topic_aliases.get(&1),
3637            Some(&Bytes::from_static(b"manual/topic"))
3638        );
3639    }
3640
3641    #[test]
3642    fn lru_auto_topic_aliases_reset_on_reconnect() {
3643        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 1);
3644        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3645            .unwrap();
3646
3647        mqtt.reset_connection_scoped_state();
3648        mqtt.broker_topic_alias_max = 1;
3649        let packet = mqtt
3650            .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3651            .unwrap()
3652            .unwrap();
3653
3654        assert_publish(packet, b"topic/one", Some(1));
3655    }
3656
3657    #[test]
3658    fn lru_auto_topic_alias_qos_replay_after_eviction_uses_full_topic() {
3659        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, 1);
3660        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3661            .unwrap();
3662        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtLeastOnce, vec![], None))
3663            .unwrap();
3664        mqtt.outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3665            .unwrap();
3666
3667        let requests = mqtt.clean();
3668
3669        assert_eq!(requests.len(), 1);
3670        match &requests[0] {
3671            Request::Publish(publish) => {
3672                assert_eq!(publish.topic, Bytes::from_static(b"topic/one"));
3673                assert_eq!(
3674                    publish
3675                        .properties
3676                        .as_ref()
3677                        .and_then(|props| props.topic_alias),
3678                    None
3679                );
3680            }
3681            request => panic!("expected replay publish, got {request:?}"),
3682        }
3683    }
3684
3685    #[test]
3686    fn lru_auto_topic_alias_collision_during_rebind_does_not_commit_rebind() {
3687        let mut mqtt = build_lru_auto_alias_mqttstate(2, 1);
3688        mqtt.outgoing_publish(Publish::new("topic/one", QoS::AtLeastOnce, vec![], None))
3689            .unwrap();
3690
3691        let mut collided = Publish::new("topic/two", QoS::AtLeastOnce, vec![], None);
3692        collided.pkid = 1;
3693        let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3694        assert!(packet.is_none());
3695        assert!(flush_notice.is_none());
3696
3697        let packet = mqtt
3698            .outgoing_publish(Publish::new("topic/one", QoS::AtMostOnce, vec![], None))
3699            .unwrap()
3700            .unwrap();
3701        assert_publish(packet, b"", Some(1));
3702    }
3703
3704    #[test]
3705    fn lru_auto_topic_alias_collision_replay_after_later_rebind_uses_original_topic() {
3706        let mut mqtt = build_lru_auto_alias_mqttstate(2, 1);
3707        let first_notice = queue_publish_with_notice(
3708            &mut mqtt,
3709            Publish::new("topic/one", QoS::AtLeastOnce, vec![], None),
3710        );
3711
3712        let mut collided = Publish::new("topic/one", QoS::AtLeastOnce, vec![], None);
3713        collided.pkid = 1;
3714        let (packet, flush_notice) = mqtt.outgoing_publish_with_notice(collided, None).unwrap();
3715        assert!(packet.is_none());
3716        assert!(flush_notice.is_none());
3717
3718        mqtt.outgoing_publish(Publish::new("topic/two", QoS::AtMostOnce, vec![], None))
3719            .unwrap();
3720
3721        let puback = PubAck::new(1, None);
3722        let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
3723
3724        assert_publish(packet, b"topic/one", None);
3725        assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
3726    }
3727
3728    #[test]
3729    fn lru_auto_topic_aliases_do_not_wrap_at_u16_max() {
3730        let mut mqtt = build_lru_auto_alias_mqttstate(u16::MAX, u16::MAX);
3731        mqtt.next_auto_topic_alias = Some(u16::MAX);
3732
3733        let last_packet = mqtt
3734            .outgoing_publish(Publish::new("last/topic", QoS::AtMostOnce, vec![], None))
3735            .unwrap()
3736            .unwrap();
3737        let rebound_packet = mqtt
3738            .outgoing_publish(Publish::new("rebound/topic", QoS::AtMostOnce, vec![], None))
3739            .unwrap()
3740            .unwrap();
3741
3742        assert_publish(last_packet, b"last/topic", Some(u16::MAX));
3743        assert_eq!(mqtt.next_auto_topic_alias, None);
3744        assert_publish(rebound_packet, b"rebound/topic", Some(u16::MAX));
3745    }
3746
3747    #[test]
3748    fn public_clean_repairs_alias_only_publish_when_mapping_is_known() {
3749        let mut mqtt = build_mqttstate();
3750        mqtt.broker_topic_alias_max = 10;
3751        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3752            "hello/replay",
3753            QoS::AtMostOnce,
3754            2,
3755        ))
3756        .unwrap();
3757        mqtt.outgoing_publish(build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 2))
3758            .unwrap();
3759
3760        let requests = mqtt.clean();
3761
3762        assert_eq!(requests.len(), 1);
3763        match &requests[0] {
3764            Request::Publish(publish) => {
3765                assert_eq!(publish.topic, Bytes::from_static(b"hello/replay"));
3766                assert_eq!(
3767                    publish
3768                        .properties
3769                        .as_ref()
3770                        .and_then(|props| props.topic_alias),
3771                    None
3772                );
3773            }
3774            request => panic!("expected publish replay, got {request:?}"),
3775        }
3776        assert_eq!(mqtt.broker_topic_alias_max, 0);
3777    }
3778
3779    #[test]
3780    fn public_clean_preserves_alias_only_publish_topic_from_send_time_after_rebind() {
3781        let mut mqtt = build_mqttstate();
3782        mqtt.broker_topic_alias_max = 10;
3783        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3784            "topic/a",
3785            QoS::AtMostOnce,
3786            1,
3787        ))
3788        .unwrap();
3789        mqtt.outgoing_publish(build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 1))
3790            .unwrap();
3791        mqtt.outgoing_publish(build_outgoing_publish_with_alias(
3792            "topic/b",
3793            QoS::AtMostOnce,
3794            1,
3795        ))
3796        .unwrap();
3797
3798        let requests = mqtt.clean();
3799
3800        assert_eq!(requests.len(), 1);
3801        match &requests[0] {
3802            Request::Publish(publish) => {
3803                assert_eq!(publish.topic, Bytes::from_static(b"topic/a"));
3804                assert_eq!(
3805                    publish
3806                        .properties
3807                        .as_ref()
3808                        .and_then(|props| props.topic_alias),
3809                    None
3810                );
3811            }
3812            request => panic!("expected publish replay, got {request:?}"),
3813        }
3814    }
3815
3816    #[test]
3817    fn public_clean_drops_alias_only_publish_when_mapping_is_unknown() {
3818        let mut mqtt = build_mqttstate();
3819        mqtt.broker_topic_alias_max = 10;
3820        mqtt.outgoing_publish(build_outgoing_publish_with_alias("", QoS::AtLeastOnce, 3))
3821            .unwrap();
3822
3823        let requests = mqtt.clean();
3824
3825        assert!(requests.is_empty());
3826        assert_eq!(mqtt.broker_topic_alias_max, 0);
3827    }
3828
3829    #[test]
3830    fn handle_incoming_packet_should_emit_incoming_before_derived_qos1_ack() {
3831        let mut mqtt = build_mqttstate();
3832        let publish = build_incoming_publish(QoS::AtLeastOnce, 42);
3833
3834        mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
3835            .unwrap();
3836
3837        assert_eq!(mqtt.events.len(), 2);
3838        assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
3839        assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubAck(42)));
3840    }
3841
3842    #[test]
3843    fn handle_incoming_packet_should_emit_incoming_before_derived_qos2_ack() {
3844        let mut mqtt = build_mqttstate();
3845        let publish = build_incoming_publish(QoS::ExactlyOnce, 43);
3846
3847        mqtt.handle_incoming_packet(Incoming::Publish(publish.clone()))
3848            .unwrap();
3849
3850        assert_eq!(mqtt.events.len(), 2);
3851        assert_eq!(mqtt.events[0], Event::Incoming(Incoming::Publish(publish)));
3852        assert_eq!(mqtt.events[1], Event::Outgoing(Outgoing::PubRec(43)));
3853    }
3854
3855    #[test]
3856    fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
3857        let mut mqtt = build_mqttstate();
3858        let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1);
3859
3860        match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() {
3861            Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
3862            packet => panic!("Invalid network request: {packet:?}"),
3863        }
3864    }
3865
3866    #[test]
3867    fn incoming_puback_should_remove_correct_publish_from_queue() {
3868        let mut mqtt = build_mqttstate();
3869
3870        let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
3871        let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
3872
3873        mqtt.outgoing_publish(publish1).unwrap();
3874        mqtt.outgoing_publish(publish2).unwrap();
3875        assert_eq!(mqtt.inflight, 2);
3876
3877        mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
3878        assert_eq!(mqtt.inflight, 1);
3879
3880        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3881        assert_eq!(mqtt.inflight, 0);
3882
3883        assert!(mqtt.outgoing_pub[1].is_none());
3884        assert!(mqtt.outgoing_pub[2].is_none());
3885    }
3886
3887    #[test]
3888    fn incoming_puback_updates_last_puback() {
3889        let mut mqtt = build_mqttstate();
3890
3891        let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
3892        let publish2 = build_outgoing_publish(QoS::AtLeastOnce);
3893        mqtt.outgoing_publish(publish1).unwrap();
3894        mqtt.outgoing_publish(publish2).unwrap();
3895        assert_eq!(mqtt.last_puback, 0);
3896
3897        mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
3898        assert_eq!(mqtt.last_puback, 1);
3899
3900        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3901        assert_eq!(mqtt.last_puback, 2);
3902    }
3903
3904    #[test]
3905    fn incoming_puback_advances_last_puback_only_on_contiguous_boundary() {
3906        let mut mqtt = build_mqttstate();
3907
3908        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3909            .unwrap();
3910        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3911            .unwrap();
3912        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3913            .unwrap();
3914        assert_eq!(mqtt.last_puback, 0);
3915
3916        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3917        assert_eq!(mqtt.last_puback, 0);
3918
3919        mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
3920        assert_eq!(mqtt.last_puback, 2);
3921
3922        mqtt.handle_incoming_puback(&PubAck::new(3, None)).unwrap();
3923        assert_eq!(mqtt.last_puback, 3);
3924    }
3925
3926    #[test]
3927    fn mixed_qos_completion_clears_outbound_drain_state() {
3928        let mut mqtt = build_mqttstate();
3929
3930        mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
3931            .unwrap();
3932        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3933            .unwrap();
3934        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3935            .unwrap();
3936        mqtt.outgoing_publish(build_outgoing_publish(QoS::ExactlyOnce))
3937            .unwrap();
3938
3939        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
3940        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3941        mqtt.handle_incoming_puback(&PubAck::new(3, None)).unwrap();
3942        mqtt.handle_incoming_pubcomp(&PubComp::new(1, None))
3943            .unwrap();
3944        mqtt.handle_incoming_pubrec(&PubRec::new(4, None)).unwrap();
3945        mqtt.handle_incoming_pubcomp(&PubComp::new(4, None))
3946            .unwrap();
3947
3948        assert_eq!(mqtt.inflight, 0);
3949        assert!(mqtt.outbound_requests_drained());
3950        assert!(mqtt.outgoing_pub_ack.ones().next().is_none());
3951        assert!(mqtt.outgoing_rel.ones().next().is_none());
3952    }
3953
3954    #[test]
3955    fn clean_keeps_oldest_unacked_publish_first_after_out_of_order_puback() {
3956        let mut mqtt = build_mqttstate();
3957
3958        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3959            .unwrap();
3960        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3961            .unwrap();
3962        mqtt.outgoing_publish(build_outgoing_publish(QoS::AtLeastOnce))
3963            .unwrap();
3964
3965        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
3966        let requests = mqtt.clean();
3967
3968        let pending_pkids: Vec<u16> = requests
3969            .iter()
3970            .map(|req| match req {
3971                Request::Publish(publish) => publish.pkid,
3972                req => panic!("Unexpected request while cleaning: {req:?}"),
3973            })
3974            .collect();
3975
3976        assert_eq!(pending_pkids, vec![1, 3]);
3977    }
3978
3979    #[test]
3980    fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() {
3981        let mut mqtt = build_mqttstate();
3982
3983        let got = mqtt
3984            .handle_incoming_puback(&PubAck::new(101, None))
3985            .unwrap_err();
3986
3987        match got {
3988            StateError::Unsolicited(pkid) => assert_eq!(pkid, 101),
3989            e => panic!("Unexpected error: {e}"),
3990        }
3991    }
3992
3993    #[test]
3994    fn incoming_puback_failure_collision_replays_blocked_publish() {
3995        let mut mqtt = build_mqttstate();
3996        let first_notice =
3997            queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::AtLeastOnce));
3998
3999        let (collided_tx, _collided_notice) = PublishNoticeTx::new();
4000        let mut collided = build_outgoing_publish(QoS::AtLeastOnce);
4001        collided.pkid = 1;
4002        let (packet, flush_notice) = mqtt
4003            .outgoing_publish_with_notice(collided, Some(collided_tx))
4004            .unwrap();
4005        assert!(packet.is_none());
4006        assert!(flush_notice.is_none());
4007        assert!(mqtt.collision.is_some());
4008
4009        let mut puback = PubAck::new(1, None);
4010        puback.reason = PubAckReason::ImplementationSpecificError;
4011
4012        let packet = mqtt.handle_incoming_puback(&puback).unwrap().unwrap();
4013        match packet {
4014            Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4015            packet => panic!("Invalid network request: {packet:?}"),
4016        }
4017
4018        assert_eq!(first_notice.wait(), Ok(PublishResult::Qos1(puback)));
4019        assert_eq!(mqtt.inflight, 1);
4020        assert!(mqtt.collision.is_none());
4021    }
4022
4023    #[test]
4024    fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() {
4025        let mut mqtt = build_mqttstate();
4026
4027        let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
4028        let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
4029
4030        let _publish_out = mqtt.outgoing_publish(publish1);
4031        let _publish_out = mqtt.outgoing_publish(publish2);
4032
4033        mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap();
4034        assert_eq!(mqtt.inflight, 2);
4035
4036        // check if the remaining element's pkid is 1
4037        let backup = mqtt.outgoing_pub[1].clone();
4038        assert_eq!(backup.unwrap().pkid, 1);
4039
4040        // check if the qos2 element's release pkid is 2
4041        assert!(mqtt.outgoing_rel.contains(2));
4042    }
4043
4044    #[test]
4045    fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() {
4046        let mut mqtt = build_mqttstate();
4047
4048        let publish = build_outgoing_publish(QoS::ExactlyOnce);
4049        match mqtt.outgoing_publish(publish).unwrap().unwrap() {
4050            Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4051            packet => panic!("Invalid network request: {packet:?}"),
4052        }
4053
4054        match mqtt
4055            .handle_incoming_pubrec(&PubRec::new(1, None))
4056            .unwrap()
4057            .unwrap()
4058        {
4059            Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
4060            packet => panic!("Invalid network request: {packet:?}"),
4061        }
4062    }
4063
4064    #[test]
4065    fn incoming_pubrec_failure_without_collision_decrements_inflight() {
4066        let mut mqtt = build_mqttstate();
4067        let first_notice =
4068            queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4069
4070        let mut pubrec = PubRec::new(1, None);
4071        pubrec.reason = PubRecReason::ImplementationSpecificError;
4072
4073        assert!(mqtt.handle_incoming_pubrec(&pubrec).unwrap().is_none());
4074        assert_eq!(
4075            first_notice.wait(),
4076            Ok(PublishResult::Qos2PubRecRejected(pubrec))
4077        );
4078        assert_eq!(mqtt.inflight, 0);
4079        assert!(mqtt.outgoing_pub[1].is_none());
4080        assert!(!mqtt.outgoing_rel.contains(1));
4081    }
4082
4083    #[test]
4084    fn incoming_pubrec_failure_releases_inflight_and_replays_collision() {
4085        let mut mqtt = build_mqttstate();
4086        let first_notice =
4087            queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4088
4089        let (collided_tx, _collided_notice) = PublishNoticeTx::new();
4090        let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
4091        collided.pkid = 1;
4092        let (packet, flush_notice) = mqtt
4093            .outgoing_publish_with_notice(collided, Some(collided_tx))
4094            .unwrap();
4095        assert!(packet.is_none());
4096        assert!(flush_notice.is_none());
4097        assert!(mqtt.collision.is_some());
4098
4099        let mut pubrec = PubRec::new(1, None);
4100        pubrec.reason = PubRecReason::ImplementationSpecificError;
4101
4102        let packet = mqtt.handle_incoming_pubrec(&pubrec).unwrap().unwrap();
4103        match packet {
4104            Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4105            packet => panic!("Invalid network request: {packet:?}"),
4106        }
4107
4108        assert_eq!(
4109            first_notice.wait(),
4110            Ok(PublishResult::Qos2PubRecRejected(pubrec))
4111        );
4112        assert_eq!(mqtt.inflight, 1);
4113        assert!(mqtt.collision.is_none());
4114        assert!(!mqtt.outgoing_rel.contains(1));
4115
4116        let packet = mqtt
4117            .handle_incoming_pubrec(&PubRec::new(1, None))
4118            .unwrap()
4119            .unwrap();
4120        match packet {
4121            Packet::PubRel(release) => assert_eq!(release.pkid, 1),
4122            packet => panic!("Invalid network request: {packet:?}"),
4123        }
4124    }
4125
4126    #[test]
4127    fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() {
4128        let mut mqtt = build_mqttstate();
4129        let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1);
4130
4131        match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() {
4132            Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
4133            packet => panic!("Invalid network request: {packet:?}"),
4134        }
4135
4136        match mqtt
4137            .handle_incoming_pubrel(&PubRel::new(1, None))
4138            .unwrap()
4139            .unwrap()
4140        {
4141            Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1),
4142            packet => panic!("Invalid network request: {packet:?}"),
4143        }
4144    }
4145
4146    #[test]
4147    fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() {
4148        let mut mqtt = build_mqttstate();
4149        let publish = build_outgoing_publish(QoS::ExactlyOnce);
4150
4151        mqtt.outgoing_publish(publish).unwrap();
4152        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4153
4154        mqtt.handle_incoming_pubcomp(&PubComp::new(1, None))
4155            .unwrap();
4156        assert_eq!(mqtt.inflight, 0);
4157    }
4158
4159    #[test]
4160    fn incoming_pubcomp_failure_without_collision_decrements_inflight() {
4161        let mut mqtt = build_mqttstate();
4162        let first_notice =
4163            queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4164        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4165
4166        let mut pubcomp = PubComp::new(1, None);
4167        pubcomp.reason = PubCompReason::PacketIdentifierNotFound;
4168
4169        assert!(mqtt.handle_incoming_pubcomp(&pubcomp).unwrap().is_none());
4170        assert_eq!(
4171            first_notice.wait(),
4172            Ok(PublishResult::Qos2Completed(pubcomp))
4173        );
4174        assert_eq!(mqtt.inflight, 0);
4175        assert!(!mqtt.outgoing_rel.contains(1));
4176    }
4177
4178    #[test]
4179    fn incoming_pubcomp_collision_replay_should_restore_qos2_tracking() {
4180        let mut mqtt = build_mqttstate();
4181        let publish = build_outgoing_publish(QoS::ExactlyOnce);
4182        mqtt.outgoing_publish(publish).unwrap();
4183
4184        let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
4185        collided.pkid = 1;
4186        assert!(mqtt.outgoing_publish(collided).unwrap().is_none());
4187        assert!(mqtt.collision.is_some());
4188
4189        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4190        let packet = mqtt
4191            .handle_incoming_pubcomp(&PubComp::new(1, None))
4192            .unwrap()
4193            .unwrap();
4194        match packet {
4195            Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4196            packet => panic!("Invalid network request: {packet:?}"),
4197        }
4198
4199        assert!(mqtt.outgoing_pub[1].is_some());
4200        assert_eq!(mqtt.inflight, 1);
4201
4202        let packet = mqtt
4203            .handle_incoming_pubrec(&PubRec::new(1, None))
4204            .unwrap()
4205            .unwrap();
4206        match packet {
4207            Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
4208            packet => panic!("Invalid network request: {packet:?}"),
4209        }
4210    }
4211
4212    #[test]
4213    fn incoming_pubcomp_failure_replays_collision_and_preserves_qos2_tracking() {
4214        let mut mqtt = build_mqttstate();
4215        let first_notice =
4216            queue_publish_with_notice(&mut mqtt, build_outgoing_publish(QoS::ExactlyOnce));
4217
4218        let (collided_tx, _collided_notice) = PublishNoticeTx::new();
4219        let mut collided = build_outgoing_publish(QoS::ExactlyOnce);
4220        collided.pkid = 1;
4221        let (packet, flush_notice) = mqtt
4222            .outgoing_publish_with_notice(collided, Some(collided_tx))
4223            .unwrap();
4224        assert!(packet.is_none());
4225        assert!(flush_notice.is_none());
4226        assert!(mqtt.collision.is_some());
4227
4228        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
4229
4230        let mut pubcomp = PubComp::new(1, None);
4231        pubcomp.reason = PubCompReason::PacketIdentifierNotFound;
4232
4233        let packet = mqtt.handle_incoming_pubcomp(&pubcomp).unwrap().unwrap();
4234        match packet {
4235            Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
4236            packet => panic!("Invalid network request: {packet:?}"),
4237        }
4238
4239        assert_eq!(
4240            first_notice.wait(),
4241            Ok(PublishResult::Qos2Completed(pubcomp))
4242        );
4243        assert_eq!(mqtt.inflight, 1);
4244        assert!(mqtt.collision.is_none());
4245        assert!(!mqtt.outgoing_rel.contains(1));
4246        assert!(mqtt.outgoing_pub[1].is_some());
4247
4248        let packet = mqtt
4249            .handle_incoming_pubrec(&PubRec::new(1, None))
4250            .unwrap()
4251            .unwrap();
4252        match packet {
4253            Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
4254            packet => panic!("Invalid network request: {packet:?}"),
4255        }
4256    }
4257
4258    #[test]
4259    fn outgoing_disconnect_should_preserve_reason_and_properties() {
4260        let mut mqtt = build_mqttstate();
4261        let properties = DisconnectProperties {
4262            session_expiry_interval: Some(60),
4263            reason_string: Some("disconnect test".to_string()),
4264            user_properties: vec![("key".to_string(), "value".to_string())],
4265            server_reference: Some("broker-2".to_string()),
4266        };
4267        let disconnect = Disconnect::new_with_properties(
4268            DisconnectReasonCode::ImplementationSpecificError,
4269            properties,
4270        );
4271
4272        let packet = mqtt
4273            .handle_outgoing_packet(Request::DisconnectNow(disconnect.clone()))
4274            .unwrap()
4275            .unwrap();
4276        assert_eq!(packet, Packet::Disconnect(disconnect));
4277        assert!(matches!(
4278            mqtt.events.back(),
4279            Some(Event::Outgoing(Outgoing::Disconnect))
4280        ));
4281    }
4282
4283    #[test]
4284    fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() {
4285        let mut mqtt = build_mqttstate();
4286        mqtt.outgoing_ping().unwrap();
4287
4288        // network activity other than pingresp
4289        let publish = build_outgoing_publish(QoS::AtLeastOnce);
4290        mqtt.handle_outgoing_packet(Request::Publish(publish))
4291            .unwrap();
4292        mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None)))
4293            .unwrap();
4294
4295        // should throw error because we didn't get pingresp for previous ping
4296        match mqtt.outgoing_ping() {
4297            Ok(_) => panic!("Should throw pingresp await error"),
4298            Err(StateError::AwaitPingResp) => (),
4299            Err(e) => panic!("Should throw pingresp await error. Error = {e:?}"),
4300        }
4301    }
4302
4303    #[test]
4304    fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() {
4305        let mut mqtt = build_mqttstate();
4306
4307        // should ping
4308        mqtt.outgoing_ping().unwrap();
4309        mqtt.handle_incoming_packet(Incoming::PingResp(PingResp))
4310            .unwrap();
4311
4312        // should ping
4313        mqtt.outgoing_ping().unwrap();
4314    }
4315}