1use super::framed::Network;
2use super::mqttbytes::v5::{
3 ConnAck, Connect, ConnectProperties, Disconnect, DisconnectReasonCode, Packet, Publish,
4 Subscribe, Unsubscribe,
5};
6use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
7use crate::framed::AsyncReadWrite;
8use crate::notice::{
9 AuthNoticeTx, PublishNoticeTx, PublishResult, SubscribeNoticeTx, TrackedNoticeTx,
10 UnsubscribeNoticeTx,
11};
12use crate::{AuthEvent, NoticeFailureReason, PublishNoticeError};
13
14use flume::{Receiver, Sender, TryRecvError, bounded, unbounded};
15use rumqttc_core::{OutboundScheduler, RequestClass, RequestReadiness, ScheduledRequest};
16use tokio::select;
17use tokio::time::{self, Instant, Sleep, error::Elapsed};
18
19use std::collections::VecDeque;
20use std::io;
21use std::pin::Pin;
22use std::time::Duration;
23
24use super::mqttbytes::v5::ConnectReturnCode;
25
26#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
27use crate::tls;
28
29#[cfg(unix)]
30use {std::path::Path, tokio::net::UnixStream};
31
32#[cfg(feature = "websocket")]
33use {
34 crate::websockets::WsAdapter,
35 crate::websockets::{UrlError, split_url, validate_response_headers},
36 async_tungstenite::tungstenite::client::IntoClientRequest,
37};
38
39#[cfg(feature = "proxy")]
40use crate::proxy::ProxyError;
41
42#[derive(Debug)]
43pub struct RequestEnvelope {
44 request: Request,
45 notice: Option<TrackedNoticeTx>,
46}
47
48impl RequestEnvelope {
49 pub(crate) const fn from_parts(request: Request, notice: Option<TrackedNoticeTx>) -> Self {
50 Self { request, notice }
51 }
52
53 pub(crate) const fn plain(request: Request) -> Self {
54 Self {
55 request,
56 notice: None,
57 }
58 }
59
60 pub(crate) const fn tracked_publish(publish: Publish, notice: PublishNoticeTx) -> Self {
61 Self {
62 request: Request::Publish(publish),
63 notice: Some(TrackedNoticeTx::Publish(notice)),
64 }
65 }
66
67 pub(crate) const fn tracked_subscribe(subscribe: Subscribe, notice: SubscribeNoticeTx) -> Self {
68 Self {
69 request: Request::Subscribe(subscribe),
70 notice: Some(TrackedNoticeTx::Subscribe(notice)),
71 }
72 }
73
74 pub(crate) const fn tracked_unsubscribe(
75 unsubscribe: Unsubscribe,
76 notice: UnsubscribeNoticeTx,
77 ) -> Self {
78 Self {
79 request: Request::Unsubscribe(unsubscribe),
80 notice: Some(TrackedNoticeTx::Unsubscribe(notice)),
81 }
82 }
83
84 pub(crate) const fn tracked_auth(
85 auth: super::mqttbytes::v5::Auth,
86 notice: AuthNoticeTx,
87 ) -> Self {
88 Self {
89 request: Request::Auth(auth),
90 notice: Some(TrackedNoticeTx::Auth(notice)),
91 }
92 }
93
94 pub(crate) fn into_parts(self) -> (Request, Option<TrackedNoticeTx>) {
95 (self.request, self.notice)
96 }
97}
98
99#[derive(Clone, Copy, Debug, Eq, PartialEq)]
100pub enum RequestChannelCapacity {
101 Bounded(usize),
102 Unbounded,
103}
104
105struct PendingDisconnect {
106 disconnect: super::mqttbytes::v5::Disconnect,
107 deadline: Option<Instant>,
108}
109
110impl PendingDisconnect {
111 fn new(disconnect: super::mqttbytes::v5::Disconnect, timeout: Option<Duration>) -> Self {
112 Self {
113 disconnect,
114 deadline: timeout.map(|timeout| Instant::now() + timeout),
115 }
116 }
117}
118
119#[derive(Debug, thiserror::Error)]
121pub enum ConnectionError {
122 #[error("Mqtt state: {0}")]
123 MqttState(#[from] StateError),
124 #[error("Timeout")]
125 Timeout(#[from] Elapsed),
126 #[error("Graceful disconnect timed out before outbound protocol state drained")]
127 DisconnectTimeout,
128 #[cfg(feature = "websocket")]
129 #[error("Websocket: {0}")]
130 Websocket(#[from] async_tungstenite::tungstenite::error::Error),
131 #[cfg(feature = "websocket")]
132 #[error("Websocket Connect: {0}")]
133 WsConnect(#[from] http::Error),
134 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
135 #[error("TLS: {0}")]
136 Tls(#[from] tls::Error),
137 #[error("I/O: {0}")]
138 Io(#[from] io::Error),
139 #[error("Connection refused, return code: `{0:?}`")]
140 ConnectionRefused(ConnectReturnCode),
141 #[error("Expected ConnAck packet, received: {0:?}")]
142 NotConnAck(Box<Packet>),
143 #[error("Broker replied with session_present={session_present} for clean_start={clean_start}")]
144 SessionStateMismatch {
145 clean_start: bool,
146 session_present: bool,
147 },
148 #[error("Broker target is incompatible with the selected transport")]
149 BrokerTransportMismatch,
150 #[error("Requests done")]
153 RequestsDone,
154 #[error("Auth processing error")]
155 AuthProcessingError,
156 #[cfg(feature = "websocket")]
157 #[error("Invalid Url: {0}")]
158 InvalidUrl(#[from] UrlError),
159 #[cfg(feature = "proxy")]
160 #[error("Proxy Connect: {0}")]
161 Proxy(#[from] ProxyError),
162 #[cfg(feature = "websocket")]
163 #[error("Websocket response validation error: ")]
164 ResponseValidation(#[from] crate::websockets::ValidationError),
165 #[cfg(feature = "websocket")]
166 #[error("Websocket request modifier failed: {0}")]
167 RequestModifier(#[source] Box<dyn std::error::Error + Send + Sync>),
168}
169
170pub struct EventLoop {
172 pub options: MqttOptions,
174 pub state: MqttState,
176 requests_rx: Receiver<RequestEnvelope>,
178 control_requests_rx: Receiver<RequestEnvelope>,
180 immediate_disconnect_rx: Receiver<RequestEnvelope>,
182 _requests_tx: Option<Sender<RequestEnvelope>>,
185 _control_requests_tx: Option<Sender<RequestEnvelope>>,
186 _immediate_disconnect_tx: Option<Sender<RequestEnvelope>>,
187 pending: VecDeque<RequestEnvelope>,
189 queued: OutboundScheduler<RequestEnvelope>,
191 network: Option<Network>,
193 keepalive_timeout: Option<Pin<Box<Sleep>>>,
195 no_sleep: Option<Pin<Box<Sleep>>>,
198 pending_disconnect: Option<PendingDisconnect>,
199 disconnect_complete: bool,
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
204#[allow(clippy::large_enum_variant)]
205pub enum Event {
206 Incoming(Incoming),
207 Outgoing(Outgoing),
208 Auth(AuthEvent),
209}
210
211impl EventLoop {
212 fn reconcile_connack_session(&mut self, session_present: bool) -> Result<(), ConnectionError> {
213 let clean_start = self.options.clean_start();
214 if clean_start && session_present {
215 return Err(ConnectionError::SessionStateMismatch {
216 clean_start,
217 session_present,
218 });
219 }
220
221 if !session_present {
222 self.reset_session_state();
223 }
224
225 Ok(())
226 }
227
228 pub fn new(options: MqttOptions, cap: usize) -> Self {
233 let (requests_tx, requests_rx) = bounded(cap);
234 let (control_requests_tx, control_requests_rx) = bounded(cap);
235 let (immediate_disconnect_tx, immediate_disconnect_rx) = unbounded();
236 Self::with_channel(
237 options,
238 requests_rx,
239 control_requests_rx,
240 immediate_disconnect_rx,
241 Some(requests_tx),
242 Some(control_requests_tx),
243 Some(immediate_disconnect_tx),
244 )
245 }
246
247 pub(crate) fn new_for_async_client_with_capacity(
252 options: MqttOptions,
253 capacity: RequestChannelCapacity,
254 ) -> (
255 Self,
256 Sender<RequestEnvelope>,
257 Sender<RequestEnvelope>,
258 Sender<RequestEnvelope>,
259 ) {
260 let (requests_tx, requests_rx) = match capacity {
261 RequestChannelCapacity::Bounded(cap) => bounded(cap),
262 RequestChannelCapacity::Unbounded => unbounded(),
263 };
264 let (control_requests_tx, control_requests_rx) = match capacity {
265 RequestChannelCapacity::Bounded(cap) => bounded(cap),
266 RequestChannelCapacity::Unbounded => unbounded(),
267 };
268 let (immediate_disconnect_tx, immediate_disconnect_rx) = unbounded();
269 let eventloop = Self::with_channel(
270 options,
271 requests_rx,
272 control_requests_rx,
273 immediate_disconnect_rx,
274 None,
275 None,
276 None,
277 );
278 (
279 eventloop,
280 requests_tx,
281 control_requests_tx,
282 immediate_disconnect_tx,
283 )
284 }
285
286 #[cfg(test)]
288 pub(crate) fn new_for_async_client(
289 options: MqttOptions,
290 cap: usize,
291 ) -> (Self, Sender<RequestEnvelope>) {
292 let (eventloop, request_tx, _control_request_tx, _immediate_disconnect_tx) =
293 Self::new_for_async_client_with_capacity(options, RequestChannelCapacity::Bounded(cap));
294 (eventloop, request_tx)
295 }
296
297 fn with_channel(
298 options: MqttOptions,
299 requests_rx: Receiver<RequestEnvelope>,
300 control_requests_rx: Receiver<RequestEnvelope>,
301 immediate_disconnect_rx: Receiver<RequestEnvelope>,
302 requests_tx: Option<Sender<RequestEnvelope>>,
303 control_requests_tx: Option<Sender<RequestEnvelope>>,
304 immediate_disconnect_tx: Option<Sender<RequestEnvelope>>,
305 ) -> Self {
306 let pending = VecDeque::new();
307 let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX);
308 let manual_acks = options.manual_acks;
309 let auto_topic_aliases = options.auto_topic_aliases();
310 let topic_alias_policy = options.topic_alias_policy();
311
312 let authenticator = options.authenticator();
313 let authentication_method = options.authentication_method();
314
315 Self {
316 options,
317 state: MqttState::new_internal(
318 inflight_limit,
319 manual_acks,
320 auto_topic_aliases,
321 topic_alias_policy,
322 authentication_method,
323 authenticator,
324 ),
325 requests_rx,
326 control_requests_rx,
327 immediate_disconnect_rx,
328 _requests_tx: requests_tx,
329 _control_requests_tx: control_requests_tx,
330 _immediate_disconnect_tx: immediate_disconnect_tx,
331 pending,
332 queued: OutboundScheduler::default(),
333 network: None,
334 keepalive_timeout: None,
335 no_sleep: None,
336 pending_disconnect: None,
337 disconnect_complete: false,
338 }
339 }
340
341 pub fn clean(&mut self) {
350 self.network = None;
351 self.keepalive_timeout = None;
352 self.pending_disconnect = None;
353 let mut replay_topic_aliases = self.state.replay_topic_aliases();
354
355 for (request, notice) in self.state.clean_with_notices() {
356 self.push_replay_envelope(
357 RequestEnvelope::from_parts(request, notice),
358 &mut replay_topic_aliases,
359 );
360 }
361
362 let queued: Vec<_> = self.queued.drain().collect();
363 for envelope in queued {
364 if should_replay_after_reconnect(&envelope.request) {
365 self.push_replay_envelope(envelope, &mut replay_topic_aliases);
366 }
367 }
368
369 let drained_requests: Vec<_> = self.requests_rx.drain().collect();
371 for envelope in drained_requests {
372 if should_replay_after_reconnect(&envelope.request) {
375 self.push_replay_envelope(envelope, &mut replay_topic_aliases);
376 }
377 }
378
379 let drained_control_requests: Vec<_> = self.control_requests_rx.drain().collect();
380 for envelope in drained_control_requests {
381 if should_replay_after_reconnect(&envelope.request) {
384 self.push_replay_envelope(envelope, &mut replay_topic_aliases);
385 }
386 }
387
388 self.state.fail_auth_exchange_due_to_session_reset();
389 self.state.reset_connection_scoped_state();
390 }
391
392 fn push_replay_envelope(
393 &mut self,
394 mut envelope: RequestEnvelope,
395 replay_topic_aliases: &mut std::collections::HashMap<u16, bytes::Bytes>,
396 ) {
397 if let Err(err) = MqttState::prepare_request_for_replay_with_aliases(
398 &mut envelope.request,
399 replay_topic_aliases,
400 ) {
401 if let Some(TrackedNoticeTx::Publish(notice)) = envelope.notice {
402 notice.error(err);
403 }
404 return;
405 }
406
407 self.pending.push_back(envelope);
408 }
409
410 pub fn pending_len(&self) -> usize {
412 self.pending.len() + self.queued.len()
413 }
414
415 pub fn pending_is_empty(&self) -> bool {
417 self.pending.is_empty() && self.queued.is_empty()
418 }
419
420 pub fn drain_pending_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
424 let mut drained = 0;
425 for envelope in self.pending.drain(..) {
426 drained += 1;
427 if let Some(notice) = envelope.notice {
428 Self::fail_tracked_notice(notice, reason);
429 }
430 }
431 for envelope in self.queued.drain() {
432 drained += 1;
433 if let Some(notice) = envelope.notice {
434 Self::fail_tracked_notice(notice, reason);
435 }
436 }
437
438 drained
439 }
440
441 fn fail_tracked_notice(notice: TrackedNoticeTx, reason: NoticeFailureReason) {
442 match notice {
443 TrackedNoticeTx::Publish(notice) => {
444 notice.error(reason.publish_error());
445 }
446 TrackedNoticeTx::Subscribe(notice) => {
447 notice.error(reason.subscribe_error());
448 }
449 TrackedNoticeTx::Unsubscribe(notice) => {
450 notice.error(reason.unsubscribe_error());
451 }
452 TrackedNoticeTx::Auth(notice) => {
453 notice.error(reason.auth_error());
454 }
455 }
456 }
457
458 fn drop_unprocessed_requests(&mut self) {
459 self.pending.clear();
462 self.queued.clear();
463 self.requests_rx.drain().for_each(drop);
464 self.control_requests_rx.drain().for_each(drop);
465 }
466
467 pub fn reset_session_state(&mut self) {
469 self.drain_pending_as_failed(NoticeFailureReason::SessionReset);
470 self.state.fail_pending_notices();
471 self.state.fail_reauth_exchange_due_to_session_reset();
472 self.state.reset_connection_scoped_state();
473 }
474
475 fn reconcile_outgoing_tracking_after_connack(&mut self) {
476 self.state
477 .reconcile_outgoing_tracking_capacity(self.pending.is_empty());
478 }
479
480 pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
490 if self.disconnect_complete {
491 return Err(ConnectionError::RequestsDone);
492 }
493
494 if self.network.is_none() {
495 if let Ok(envelope) = self.immediate_disconnect_rx.try_recv() {
496 self.disconnect_complete = true;
497 if let Some(notice) = envelope.notice {
498 drop(notice);
499 }
500 return Err(ConnectionError::RequestsDone);
501 }
502
503 let (network, connack) = time::timeout(
504 self.options.connect_timeout(),
505 connect(&mut self.options, &mut self.state),
506 )
507 .await??;
508 self.reconcile_connack_session(connack.session_present)?;
509 self.network = Some(network);
510
511 if self.keepalive_timeout.is_none() && !self.options.keep_alive.is_zero() {
512 self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
513 }
514
515 self.state
516 .handle_incoming_packet(Incoming::ConnAck(connack))?;
517 self.reconcile_outgoing_tracking_after_connack();
518 }
519
520 match self.select().await {
521 Ok(v) => Ok(v),
522 Err(ConnectionError::DisconnectTimeout) => {
523 self.network = None;
524 self.keepalive_timeout = None;
525 self.pending_disconnect = None;
526 self.drop_unprocessed_requests();
527 self.disconnect_complete = true;
528 Err(ConnectionError::DisconnectTimeout)
529 }
530 Err(e) => {
531 self.clean();
534 Err(e)
535 }
536 }
537 }
538
539 #[allow(clippy::too_many_lines)]
541 async fn select(&mut self) -> Result<Event, ConnectionError> {
542 loop {
543 if let Some(event) = self.state.events.pop_front() {
544 return Ok(event);
545 }
546
547 if let Ok(envelope) = self.immediate_disconnect_rx.try_recv() {
548 return self.handle_immediate_disconnect(envelope).await;
549 }
550
551 if self.queued.is_empty()
552 && self.pending.is_empty()
553 && self.pending_disconnect.is_none()
554 && self.requests_rx.is_disconnected()
555 && self.requests_rx.is_empty()
556 && self.control_requests_rx.is_disconnected()
557 && self.control_requests_rx.is_empty()
558 && self.state.outbound_requests_drained()
559 {
560 return Err(ConnectionError::RequestsDone);
561 }
562
563 if self.pending_disconnect.is_none() && self.handle_ready_requests().await? {
564 if let Some(event) = self.state.events.pop_front() {
565 return Ok(event);
566 }
567 continue;
568 }
569
570 if self.pending_disconnect.is_some() {
571 if self.state.outbound_requests_drained() {
572 return self.send_pending_disconnect().await;
573 }
574
575 if let Some(event) = self.poll_disconnect_drain().await? {
576 return Ok(event);
577 }
578 continue;
579 }
580
581 let read_batch_size = self.effective_read_batch_size();
582 let normal_request_admission_allowed =
583 self.normal_request_admission_allowed() || !self.pending.is_empty();
584 let no_sleep = self
585 .no_sleep
586 .get_or_insert_with(|| Box::pin(time::sleep(Duration::MAX)));
587
588 return select! {
589 biased;
590 o = self.immediate_disconnect_rx.recv_async(), if !self.immediate_disconnect_rx.is_disconnected() => match o {
591 Ok(envelope) => self.handle_immediate_disconnect(envelope).await,
592 Err(_) => continue,
593 },
594 o = self.control_requests_rx.recv_async(),
595 if self.pending_disconnect.is_none()
596 && (!self.control_requests_rx.is_empty()
597 || !self.control_requests_rx.is_disconnected()) => match o {
598 Ok(envelope) => {
599 self.try_admit_existing_normal_requests().await;
600 self.queued.push_back(envelope);
601 continue;
602 }
603 Err(_) => continue,
604 },
605 o = Self::next_request(
606 &mut self.pending,
607 &self.requests_rx,
608 self.options.pending_throttle
609 ), if self.pending_disconnect.is_none()
610 && normal_request_admission_allowed
611 && (!self.pending.is_empty()
612 || !self.requests_rx.is_empty()
613 || !self.requests_rx.is_disconnected()) => match o {
614 Ok((request, notice)) => {
615 self.admit_normal_request_batch((request, notice)).await;
616 continue;
617 }
618 Err(_) => continue,
619 },
620 o = self.network.as_mut().unwrap().readb(&mut self.state, read_batch_size) => {
621 o?;
622 self.network.as_mut().unwrap().flush().await?;
623 Ok(self.state.events.pop_front().unwrap())
624 },
625 () = self.keepalive_timeout.as_mut().unwrap_or(no_sleep),
626 if self.keepalive_timeout.is_some() && !self.options.keep_alive.is_zero() => {
627 let timeout = self.keepalive_timeout.as_mut().unwrap();
628 timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
629
630 let (outgoing, _flush_notice) = self
631 .state
632 .handle_outgoing_packet_with_notice(Request::PingReq, None)?;
633 if let Some(outgoing) = outgoing {
634 self.network.as_mut().unwrap().write(outgoing).await?;
635 }
636 self.network.as_mut().unwrap().flush().await?;
637 Ok(self.state.events.pop_front().unwrap())
638 }
639 };
640 }
641 }
642
643 async fn handle_immediate_disconnect(
644 &mut self,
645 envelope: RequestEnvelope,
646 ) -> Result<Event, ConnectionError> {
647 let (request, notice) = envelope.into_parts();
648 let mut should_flush = false;
649 let mut qos0_notices = Vec::new();
650 self.handle_request(request, notice, &mut should_flush, &mut qos0_notices)
651 .await?;
652 self.flush_request_batch(should_flush, qos0_notices).await?;
653 Ok(self.state.events.pop_front().unwrap())
654 }
655
656 async fn handle_ready_requests(&mut self) -> Result<bool, ConnectionError> {
657 let Some((request, notice)) = self.next_scheduled_request() else {
658 return Ok(false);
659 };
660
661 let mut should_flush = false;
662 let mut qos0_notices = Vec::new();
663
664 if self
665 .handle_request(request, notice, &mut should_flush, &mut qos0_notices)
666 .await?
667 {
668 for _ in 1..self.options.max_request_batch.max(1) {
669 let Some((next_request, next_notice)) = self.next_scheduled_request() else {
670 break;
671 };
672
673 if !self
674 .handle_request(
675 next_request,
676 next_notice,
677 &mut should_flush,
678 &mut qos0_notices,
679 )
680 .await?
681 {
682 break;
683 }
684 }
685 }
686
687 self.flush_request_batch(should_flush, qos0_notices).await?;
688 Ok(true)
689 }
690
691 fn next_scheduled_request(&mut self) -> Option<(Request, Option<TrackedNoticeTx>)> {
692 let state = &self.state;
693 self.queued
694 .pop_next(|envelope| classify_request(state, &envelope.request))
695 .map(RequestEnvelope::into_parts)
696 }
697
698 fn normal_request_admission_allowed(&self) -> bool {
699 self.queued.is_empty()
700 || self
701 .queued
702 .has_ready(|envelope| classify_request(&self.state, &envelope.request))
703 }
704
705 async fn try_admit_existing_normal_requests(&mut self) {
706 let ready = self.pending.len() + self.requests_rx.len();
707 for _ in 0..ready {
708 if !(self.normal_request_admission_allowed() || !self.pending.is_empty()) {
709 break;
710 }
711
712 let Some((request, notice)) = Self::try_next_request(
713 &mut self.pending,
714 &self.requests_rx,
715 self.options.pending_throttle,
716 )
717 .await
718 else {
719 break;
720 };
721
722 let stop_batch = is_disconnect_request(&request);
723 self.queued
724 .push_back(RequestEnvelope::from_parts(request, notice));
725 if stop_batch {
726 break;
727 }
728 }
729 }
730
731 async fn admit_normal_request_batch(&mut self, first: (Request, Option<TrackedNoticeTx>)) {
732 let (request, notice) = first;
733 let stop_batch = is_disconnect_request(&request);
734 self.queued
735 .push_back(RequestEnvelope::from_parts(request, notice));
736 if stop_batch || !(self.normal_request_admission_allowed() || !self.pending.is_empty()) {
737 return;
738 }
739
740 for _ in 1..self.options.max_request_batch.max(1) {
741 let Some((request, notice)) = Self::try_next_request(
742 &mut self.pending,
743 &self.requests_rx,
744 self.options.pending_throttle,
745 )
746 .await
747 else {
748 break;
749 };
750 let stop_batch = is_disconnect_request(&request);
751 self.queued
752 .push_back(RequestEnvelope::from_parts(request, notice));
753 if stop_batch {
754 break;
755 }
756 }
757 }
758
759 async fn handle_request(
760 &mut self,
761 request: Request,
762 notice: Option<TrackedNoticeTx>,
763 should_flush: &mut bool,
764 qos0_notices: &mut Vec<PublishNoticeTx>,
765 ) -> Result<bool, ConnectionError> {
766 match request {
767 Request::Disconnect(disconnect) => {
768 self.state.fail_auth_exchange_due_to_client_disconnect();
769 self.pending_disconnect = Some(PendingDisconnect::new(disconnect, None));
770 Ok(false)
771 }
772 Request::DisconnectWithTimeout(disconnect, timeout) => {
773 self.state.fail_auth_exchange_due_to_client_disconnect();
774 self.pending_disconnect = Some(PendingDisconnect::new(disconnect, Some(timeout)));
775 Ok(false)
776 }
777 Request::DisconnectNow(_) => {
778 self.state.fail_auth_exchange_due_to_client_disconnect();
779 let (outgoing, _) = self
780 .state
781 .handle_outgoing_packet_with_notice(request, notice)?;
782 if let Some(outgoing) = outgoing {
783 if let Err(err) = self.network.as_mut().unwrap().write(outgoing).await {
784 return Err(ConnectionError::MqttState(err));
785 }
786 *should_flush = true;
787 }
788 self.disconnect_complete = true;
789 Ok(false)
790 }
791 request => {
792 let (outgoing, flush_notice) = self
793 .state
794 .handle_outgoing_packet_with_notice(request, notice)?;
795 if let Some(notice) = flush_notice {
796 qos0_notices.push(notice);
797 }
798 if let Some(outgoing) = outgoing {
799 if let Err(err) = self.network.as_mut().unwrap().write(outgoing).await {
800 for notice in qos0_notices.drain(..) {
801 notice.error(PublishNoticeError::Qos0NotFlushed);
802 }
803 return Err(ConnectionError::MqttState(err));
804 }
805 *should_flush = true;
806 }
807 Ok(true)
808 }
809 }
810 }
811
812 async fn flush_request_batch(
813 &mut self,
814 should_flush: bool,
815 qos0_notices: Vec<PublishNoticeTx>,
816 ) -> Result<(), ConnectionError> {
817 if !should_flush {
818 return Ok(());
819 }
820
821 match self.network.as_mut().unwrap().flush().await {
822 Ok(()) => {
823 for notice in qos0_notices {
824 notice.success(PublishResult::Qos0Flushed);
825 }
826 Ok(())
827 }
828 Err(err) => {
829 for notice in qos0_notices {
830 notice.error(PublishNoticeError::Qos0NotFlushed);
831 }
832 Err(ConnectionError::MqttState(err))
833 }
834 }
835 }
836
837 async fn poll_disconnect_drain(&mut self) -> Result<Option<Event>, ConnectionError> {
838 let read_batch_size = self.effective_read_batch_size();
839 let read = self
840 .network
841 .as_mut()
842 .unwrap()
843 .readb(&mut self.state, read_batch_size);
844
845 if let Some(deadline) = self
846 .pending_disconnect
847 .as_ref()
848 .and_then(|pending| pending.deadline)
849 {
850 select! {
851 o = self.immediate_disconnect_rx.recv_async(), if !self.immediate_disconnect_rx.is_disconnected() => match o {
852 Ok(envelope) => return self.handle_immediate_disconnect(envelope).await.map(Some),
853 Err(_) => return Ok(None),
854 },
855 result = read => result?,
856 () = time::sleep_until(deadline) => return Err(ConnectionError::DisconnectTimeout),
857 }
858 } else {
859 select! {
860 o = self.immediate_disconnect_rx.recv_async(), if !self.immediate_disconnect_rx.is_disconnected() => match o {
861 Ok(envelope) => return self.handle_immediate_disconnect(envelope).await.map(Some),
862 Err(_) => return Ok(None),
863 },
864 result = read => result?,
865 }
866 }
867
868 self.network.as_mut().unwrap().flush().await?;
869 Ok(None)
870 }
871
872 async fn send_pending_disconnect(&mut self) -> Result<Event, ConnectionError> {
873 let disconnect = self
874 .pending_disconnect
875 .take()
876 .expect("pending disconnect checked by caller")
877 .disconnect;
878 let (outgoing, _) = self
879 .state
880 .handle_outgoing_packet_with_notice(Request::DisconnectNow(disconnect), None)?;
881
882 if let Some(outgoing) = outgoing {
883 self.network.as_mut().unwrap().write(outgoing).await?;
884 self.network.as_mut().unwrap().flush().await?;
885 }
886
887 self.drop_unprocessed_requests();
888 self.disconnect_complete = true;
889 Ok(self.state.events.pop_front().unwrap())
890 }
891
892 async fn try_next_request(
893 pending: &mut VecDeque<RequestEnvelope>,
894 rx: &Receiver<RequestEnvelope>,
895 pending_throttle: Duration,
896 ) -> Option<(Request, Option<TrackedNoticeTx>)> {
897 if !pending.is_empty() {
898 if pending_throttle.is_zero() {
899 tokio::task::yield_now().await;
900 } else {
901 time::sleep(pending_throttle).await;
902 }
903 return pending.pop_front().map(RequestEnvelope::into_parts);
906 }
907
908 match rx.try_recv() {
909 Ok(envelope) => return Some(envelope.into_parts()),
910 Err(TryRecvError::Disconnected) => return None,
911 Err(TryRecvError::Empty) => {}
912 }
913
914 None
915 }
916
917 async fn next_request(
918 pending: &mut VecDeque<RequestEnvelope>,
919 rx: &Receiver<RequestEnvelope>,
920 pending_throttle: Duration,
921 ) -> Result<(Request, Option<TrackedNoticeTx>), ConnectionError> {
922 if pending.is_empty() {
923 rx.recv_async()
924 .await
925 .map(RequestEnvelope::into_parts)
926 .map_err(|_| ConnectionError::RequestsDone)
927 } else {
928 if pending_throttle.is_zero() {
929 tokio::task::yield_now().await;
930 } else {
931 time::sleep(pending_throttle).await;
932 }
933 Ok(pending.pop_front().unwrap().into_parts())
936 }
937 }
938
939 fn effective_read_batch_size(&self) -> usize {
940 const MAX_READ_BATCH_SIZE: usize = 128;
941 const PENDING_FAIRNESS_CAP: usize = 16;
942
943 let configured = self.options.read_batch_size();
944 if configured > 0 {
945 return configured.clamp(1, MAX_READ_BATCH_SIZE);
946 }
947
948 let request_batch = self.options.max_request_batch().max(1);
949 let inflight = usize::from(self.state.max_outgoing_inflight);
950 let mut adaptive = request_batch.max(inflight / 2).max(8);
951
952 if !self.pending.is_empty()
953 || !self.queued.is_empty()
954 || !self.requests_rx.is_empty()
955 || !self.control_requests_rx.is_empty()
956 {
957 adaptive = adaptive.min(PENDING_FAIRNESS_CAP);
958 }
959
960 adaptive.clamp(1, MAX_READ_BATCH_SIZE)
961 }
962}
963
964fn classify_request(state: &MqttState, request: &Request) -> ScheduledRequest {
965 match request {
966 Request::Publish(publish) if publish.qos != crate::mqttbytes::QoS::AtMostOnce => {
967 ScheduledRequest {
968 class: RequestClass::FlowControlledPublish,
969 readiness: if state.can_send_publish(publish) {
970 RequestReadiness::Ready
971 } else {
972 RequestReadiness::Blocked
973 },
974 }
975 }
976 Request::Publish(_) => ScheduledRequest {
977 class: RequestClass::Publish,
978 readiness: RequestReadiness::Ready,
979 },
980 Request::Subscribe(_) | Request::Unsubscribe(_) => ScheduledRequest {
981 class: RequestClass::Control,
982 readiness: if state.control_packet_identifier_available() {
983 RequestReadiness::Ready
984 } else {
985 RequestReadiness::Blocked
986 },
987 },
988 _ => ScheduledRequest {
989 class: RequestClass::Control,
990 readiness: RequestReadiness::Ready,
991 },
992 }
993}
994
995const fn is_disconnect_request(request: &Request) -> bool {
996 matches!(
997 request,
998 Request::Disconnect(_) | Request::DisconnectWithTimeout(_, _) | Request::DisconnectNow(_)
999 )
1000}
1001
1002async fn connect(
1008 options: &mut MqttOptions,
1009 state: &mut MqttState,
1010) -> Result<(Network, ConnAck), ConnectionError> {
1011 let mut network = network_connect(options).await?;
1013
1014 let connack = mqtt_connect(options, &mut network, state).await?;
1016
1017 Ok((network, connack))
1018}
1019
1020#[allow(clippy::too_many_lines)]
1021async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
1022 let max_incoming_pkt_size = options.max_incoming_packet_size();
1023 let transport = options.transport();
1024
1025 #[cfg(unix)]
1027 if matches!(&transport, Transport::Unix) {
1028 let file = options
1029 .broker()
1030 .unix_path()
1031 .ok_or(ConnectionError::BrokerTransportMismatch)?;
1032 let socket = UnixStream::connect(Path::new(file)).await?;
1033 let network = Network::new(socket, max_incoming_pkt_size);
1034 return Ok(network);
1035 }
1036
1037 let (domain, port) = match &transport {
1039 #[cfg(feature = "websocket")]
1040 Transport::Ws => split_url(
1041 options
1042 .broker()
1043 .websocket_url()
1044 .ok_or(ConnectionError::BrokerTransportMismatch)?,
1045 )?,
1046 #[cfg(all(
1047 any(feature = "use-rustls-no-provider", feature = "use-native-tls"),
1048 feature = "websocket"
1049 ))]
1050 Transport::Wss(_) => split_url(
1051 options
1052 .broker()
1053 .websocket_url()
1054 .ok_or(ConnectionError::BrokerTransportMismatch)?,
1055 )?,
1056 _ => options
1057 .broker()
1058 .tcp_address()
1059 .map(|(host, port)| (host.to_owned(), port))
1060 .ok_or(ConnectionError::BrokerTransportMismatch)?,
1061 };
1062
1063 let tcp_stream: Box<dyn AsyncReadWrite> = {
1064 #[cfg(feature = "proxy")]
1065 if let Some(proxy) = options.proxy() {
1066 proxy
1067 .connect(
1068 &domain,
1069 port,
1070 options.network_options(),
1071 Some(options.effective_socket_connector()),
1072 )
1073 .await?
1074 } else {
1075 let addr = format!("{domain}:{port}");
1076 options
1077 .socket_connect(addr, options.network_options())
1078 .await?
1079 }
1080 #[cfg(not(feature = "proxy"))]
1081 {
1082 let addr = format!("{domain}:{port}");
1083 options
1084 .socket_connect(addr, options.network_options())
1085 .await?
1086 }
1087 };
1088
1089 let network = match transport {
1090 Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
1091 #[cfg(any(feature = "use-native-tls", feature = "use-rustls-no-provider"))]
1092 Transport::Tls(tls_config) => {
1093 let (host, port) = options
1094 .broker()
1095 .tcp_address()
1096 .expect("tls transport requires a tcp broker");
1097 let socket = tls::tls_connect(host, port, &tls_config, tcp_stream).await?;
1098 Network::new(socket, max_incoming_pkt_size)
1099 }
1100 #[cfg(unix)]
1101 Transport::Unix => unreachable!(),
1102 #[cfg(feature = "websocket")]
1103 Transport::Ws => {
1104 let mut request = options
1105 .broker()
1106 .websocket_url()
1107 .expect("ws transport requires a websocket broker")
1108 .into_client_request()?;
1109 request
1110 .headers_mut()
1111 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
1112
1113 if let Some(request_modifier) = options.fallible_request_modifier() {
1114 request = request_modifier(request)
1115 .await
1116 .map_err(ConnectionError::RequestModifier)?;
1117 } else if let Some(request_modifier) = options.request_modifier() {
1118 request = request_modifier(request).await;
1119 }
1120
1121 let (socket, response) =
1122 async_tungstenite::tokio::client_async(request, tcp_stream).await?;
1123 validate_response_headers(response)?;
1124
1125 Network::new(WsAdapter::new(socket), max_incoming_pkt_size)
1126 }
1127 #[cfg(all(
1128 any(feature = "use-rustls-no-provider", feature = "use-native-tls"),
1129 feature = "websocket"
1130 ))]
1131 Transport::Wss(tls_config) => {
1132 let mut request = options
1133 .broker()
1134 .websocket_url()
1135 .expect("wss transport requires a websocket broker")
1136 .into_client_request()?;
1137 request
1138 .headers_mut()
1139 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
1140
1141 if let Some(request_modifier) = options.fallible_request_modifier() {
1142 request = request_modifier(request)
1143 .await
1144 .map_err(ConnectionError::RequestModifier)?;
1145 } else if let Some(request_modifier) = options.request_modifier() {
1146 request = request_modifier(request).await;
1147 }
1148
1149 let tls_stream = tls::tls_connect(&domain, port, &tls_config, tcp_stream).await?;
1150 let (socket, response) =
1151 async_tungstenite::tokio::client_async(request, tls_stream).await?;
1152 validate_response_headers(response)?;
1153
1154 Network::new(WsAdapter::new(socket), max_incoming_pkt_size)
1155 }
1156 };
1157
1158 Ok(network)
1159}
1160
1161async fn mqtt_connect(
1162 options: &mut MqttOptions,
1163 network: &mut Network,
1164 state: &mut MqttState,
1165) -> Result<ConnAck, ConnectionError> {
1166 let authentication_method = options.authentication_method();
1167 let auth_exchange_started = authentication_method.is_some();
1168 let start_auth_properties = state.begin_authentication_connect(authentication_method)?;
1169 let mut connect_properties = options.connect_properties();
1170 if let Some(auth_properties) = start_auth_properties {
1171 let properties = connect_properties.get_or_insert_with(ConnectProperties::new);
1172 properties.authentication_method = auth_properties.method;
1173 properties.authentication_data = auth_properties.data;
1174 }
1175
1176 let result = mqtt_connect_inner(options, network, state, connect_properties).await;
1177 if result.is_err() && auth_exchange_started {
1178 state.fail_auth_exchange_due_to_connection_closed();
1179 }
1180 result
1181}
1182
1183async fn mqtt_connect_inner(
1184 options: &mut MqttOptions,
1185 network: &mut Network,
1186 state: &mut MqttState,
1187 connect_properties: Option<ConnectProperties>,
1188) -> Result<ConnAck, ConnectionError> {
1189 network
1191 .write(Packet::Connect(
1192 Connect {
1193 client_id: options.client_id(),
1194 keep_alive: u16::try_from(options.keep_alive().as_secs()).unwrap_or(u16::MAX),
1195 clean_start: options.clean_start(),
1196 properties: connect_properties,
1197 },
1198 options.last_will(),
1199 options.auth().clone(),
1200 ))
1201 .await?;
1202 network.flush().await?;
1203
1204 loop {
1206 match network.read().await? {
1207 Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
1208 if let Err(err) = state.validate_successful_connack_authentication_method(&connack)
1209 {
1210 send_protocol_error_disconnect(network).await?;
1211 return Err(err.into());
1212 }
1213
1214 if let Some(props) = &connack.properties
1215 && let Some(keep_alive) = props.server_keep_alive
1216 {
1217 options.keep_alive = Duration::from_secs(u64::from(keep_alive));
1218 }
1219
1220 if let Some(props) = &connack.properties {
1221 network.set_max_outgoing_size(props.max_packet_size);
1222
1223 if props.session_expiry_interval.is_some() {
1225 options.set_session_expiry_interval(props.session_expiry_interval);
1226 }
1227 }
1228 return Ok(connack);
1229 }
1230 Incoming::ConnAck(connack) => {
1231 return Err(ConnectionError::ConnectionRefused(connack.code));
1232 }
1233 Incoming::Auth(auth) => match state.handle_incoming_packet(Incoming::Auth(auth)) {
1234 Ok(Some(outgoing)) => {
1235 network.write(outgoing).await?;
1236 network.flush().await?;
1237 }
1238 Ok(None) => return Err(ConnectionError::AuthProcessingError),
1239 Err(err @ StateError::Deserialization(super::mqttbytes::Error::ProtocolError)) => {
1240 send_protocol_error_disconnect(network).await?;
1241 return Err(err.into());
1242 }
1243 Err(err) => return Err(err.into()),
1244 },
1245 packet => return Err(ConnectionError::NotConnAck(Box::new(packet))),
1246 }
1247 }
1248}
1249
1250async fn send_protocol_error_disconnect(network: &mut Network) -> Result<(), ConnectionError> {
1251 network
1252 .write(Packet::Disconnect(Disconnect::new(
1253 DisconnectReasonCode::ProtocolError,
1254 )))
1255 .await?;
1256 network.flush().await?;
1257 Ok(())
1258}
1259
1260const fn should_replay_after_reconnect(request: &Request) -> bool {
1261 !matches!(request, Request::PubAck(_) | Request::PubRec(_))
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266 use super::*;
1267 use crate::mqttbytes::{Error as MqttError, QoS};
1268 use crate::{Auth, AuthProperties, AuthReasonCode};
1269 use crate::{ConnAckProperties, Filter, PubAck, PubComp, PubRec, PubRel, PublishProperties};
1270 use bytes::{Bytes, BytesMut};
1271 use flume::TryRecvError;
1272 use std::sync::{Arc, Mutex};
1273 use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
1274
1275 #[derive(Debug)]
1276 struct StaticAuthManager {
1277 response: Result<Option<AuthProperties>, String>,
1278 }
1279
1280 impl crate::Authenticator for StaticAuthManager {
1281 fn start(
1282 &mut self,
1283 _context: crate::AuthContext<'_>,
1284 ) -> Result<Option<AuthProperties>, crate::AuthError> {
1285 Ok(None)
1286 }
1287
1288 fn continue_auth(
1289 &mut self,
1290 _context: crate::AuthContext<'_>,
1291 _auth_prop: Option<AuthProperties>,
1292 ) -> Result<crate::AuthAction, crate::AuthError> {
1293 self.response
1294 .clone()
1295 .map(|props| props.map_or(crate::AuthAction::Complete, crate::AuthAction::Send))
1296 .map_err(crate::AuthError::from)
1297 }
1298
1299 fn success(
1300 &mut self,
1301 _context: crate::AuthContext<'_>,
1302 _incoming: Option<AuthProperties>,
1303 ) -> Result<(), crate::AuthError> {
1304 Ok(())
1305 }
1306
1307 fn failure(&mut self, _context: crate::AuthContext<'_>, _error: crate::AuthError) {}
1308 }
1309
1310 fn build_connack_with_receive_max(receive_max: u16) -> ConnAck {
1311 ConnAck {
1312 session_present: false,
1313 code: ConnectReturnCode::Success,
1314 properties: Some(ConnAckProperties {
1315 session_expiry_interval: None,
1316 receive_max: Some(receive_max),
1317 max_qos: None,
1318 retain_available: None,
1319 max_packet_size: None,
1320 assigned_client_identifier: None,
1321 topic_alias_max: None,
1322 reason_string: None,
1323 user_properties: vec![],
1324 wildcard_subscription_available: None,
1325 subscription_identifiers_available: None,
1326 shared_subscription_available: None,
1327 server_keep_alive: None,
1328 response_information: None,
1329 server_reference: None,
1330 authentication_method: None,
1331 authentication_data: None,
1332 }),
1333 }
1334 }
1335
1336 fn build_connack_with_authentication_method(authentication_method: Option<&str>) -> ConnAck {
1337 let mut connack = build_connack_with_receive_max(10);
1338 connack.properties.as_mut().unwrap().authentication_method =
1339 authentication_method.map(str::to_owned);
1340 connack
1341 }
1342
1343 async fn read_packet_bytes(peer: &mut DuplexStream) -> Vec<u8> {
1344 let byte1 = peer.read_u8().await.unwrap();
1345 let mut multiplier = 1usize;
1346 let mut remaining_len = 0usize;
1347 let mut remaining_len_bytes = Vec::new();
1348
1349 loop {
1350 let encoded_byte = peer.read_u8().await.unwrap();
1351 remaining_len_bytes.push(encoded_byte);
1352 remaining_len += usize::from(encoded_byte & 0x7F) * multiplier;
1353 if encoded_byte & 0x80 == 0 {
1354 break;
1355 }
1356 multiplier *= 128;
1357 }
1358
1359 let mut payload = vec![0; remaining_len];
1360 peer.read_exact(&mut payload).await.unwrap();
1361
1362 let mut packet = Vec::with_capacity(1 + remaining_len_bytes.len() + payload.len());
1363 packet.push(byte1);
1364 packet.extend(remaining_len_bytes);
1365 packet.extend(payload);
1366 packet
1367 }
1368
1369 async fn run_mqtt_connect_with_connack(
1370 mut options: MqttOptions,
1371 connack: ConnAck,
1372 ) -> (Result<ConnAck, ConnectionError>, Vec<u8>) {
1373 let (client, mut peer) = tokio::io::duplex(1024);
1374 let mut network = Network::new(client, Some(1024));
1375 let mut state = MqttState::new_internal(
1376 10,
1377 false,
1378 options.auto_topic_aliases(),
1379 options.topic_alias_policy(),
1380 options.authentication_method(),
1381 options.auth_manager(),
1382 );
1383
1384 let broker = async {
1385 let _connect = read_packet_bytes(&mut peer).await;
1386 let mut encoded_connack = BytesMut::new();
1387 connack.write(&mut encoded_connack).unwrap();
1388 peer.write_all(&encoded_connack).await.unwrap();
1389 read_packet_bytes(&mut peer).await
1390 };
1391
1392 tokio::join!(mqtt_connect(&mut options, &mut network, &mut state), broker)
1393 }
1394
1395 async fn run_mqtt_connect_with_connack_and_return_state(
1396 mut options: MqttOptions,
1397 connack: ConnAck,
1398 ) -> (Result<ConnAck, ConnectionError>, MqttState) {
1399 let (client, mut peer) = tokio::io::duplex(1024);
1400 let mut network = Network::new(client, Some(1024));
1401 let mut state = MqttState::new_internal(
1402 10,
1403 false,
1404 options.auto_topic_aliases(),
1405 options.topic_alias_policy(),
1406 options.authentication_method(),
1407 options.auth_manager(),
1408 );
1409
1410 let broker = async {
1411 let _connect = read_packet_bytes(&mut peer).await;
1412 let mut encoded_connack = BytesMut::new();
1413 connack.write(&mut encoded_connack).unwrap();
1414 peer.write_all(&encoded_connack).await.unwrap();
1415 };
1416
1417 let (result, ()) =
1418 tokio::join!(mqtt_connect(&mut options, &mut network, &mut state), broker);
1419 (result, state)
1420 }
1421
1422 async fn run_mqtt_connect_with_stale_state_auth_method(
1423 mut options: MqttOptions,
1424 stale_authentication_method: Option<&str>,
1425 incoming: Vec<Packet>,
1426 ) -> Result<ConnAck, ConnectionError> {
1427 let (client, mut peer) = tokio::io::duplex(1024);
1428 let mut network = Network::new(client, Some(1024));
1429 let mut state = MqttState::new_internal(
1430 10,
1431 false,
1432 options.auto_topic_aliases(),
1433 options.topic_alias_policy(),
1434 stale_authentication_method.map(str::to_owned),
1435 options.auth_manager(),
1436 );
1437
1438 let broker = async {
1439 let _connect = read_packet_bytes(&mut peer).await;
1440 for packet in incoming {
1441 let mut encoded = BytesMut::new();
1442 packet.write(&mut encoded, None).unwrap();
1443 peer.write_all(&encoded).await.unwrap();
1444 }
1445 };
1446
1447 let (result, ()) =
1448 tokio::join!(mqtt_connect(&mut options, &mut network, &mut state), broker);
1449 result
1450 }
1451
1452 async fn run_successful_mqtt_connect_with_connack(
1453 mut options: MqttOptions,
1454 connack: ConnAck,
1455 ) -> Result<ConnAck, ConnectionError> {
1456 let (client, mut peer) = tokio::io::duplex(1024);
1457 let mut network = Network::new(client, Some(1024));
1458 let mut state = MqttState::new_internal(
1459 10,
1460 false,
1461 options.auto_topic_aliases(),
1462 options.topic_alias_policy(),
1463 options.authentication_method(),
1464 options.auth_manager(),
1465 );
1466
1467 let broker = async {
1468 let _connect = read_packet_bytes(&mut peer).await;
1469 let mut encoded_connack = BytesMut::new();
1470 connack.write(&mut encoded_connack).unwrap();
1471 peer.write_all(&encoded_connack).await.unwrap();
1472 };
1473
1474 let (result, ()) =
1475 tokio::join!(mqtt_connect(&mut options, &mut network, &mut state), broker);
1476 result
1477 }
1478
1479 fn push_pending(eventloop: &mut EventLoop, request: Request) {
1480 eventloop.pending.push_back(RequestEnvelope::plain(request));
1481 }
1482
1483 fn pending_front_request(eventloop: &EventLoop) -> Option<&Request> {
1484 eventloop.pending.front().map(|envelope| &envelope.request)
1485 }
1486
1487 #[tokio::test]
1488 async fn graceful_disconnect_fails_active_tracked_reauth_notice() {
1489 let mut options = MqttOptions::new("test-client", "localhost");
1490 options.set_authentication_method(Some("test-method".to_owned()));
1491 let (mut eventloop, _) = EventLoop::new_for_async_client(options, 10);
1492 let (client, _peer) = tokio::io::duplex(1024);
1493 eventloop.network = Some(Network::new(client, Some(1024)));
1494 let (notice_tx, notice) = AuthNoticeTx::new();
1495 let mut should_flush = false;
1496 let mut qos0_notices = Vec::new();
1497
1498 assert!(
1499 eventloop
1500 .handle_request(
1501 Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
1502 Some(TrackedNoticeTx::Auth(notice_tx)),
1503 &mut should_flush,
1504 &mut qos0_notices,
1505 )
1506 .await
1507 .unwrap()
1508 );
1509 assert!(eventloop.state.outbound_requests_drained());
1510
1511 let handled = eventloop
1512 .handle_request(
1513 Request::Disconnect(Disconnect::new(DisconnectReasonCode::NormalDisconnection)),
1514 None,
1515 &mut should_flush,
1516 &mut qos0_notices,
1517 )
1518 .await
1519 .unwrap();
1520
1521 assert!(!handled);
1522 assert!(eventloop.pending_disconnect.is_some());
1523 assert_eq!(
1524 notice.wait_async().await.unwrap_err(),
1525 crate::notice::AuthNoticeError::ConnectionClosed
1526 );
1527 assert!(eventloop.state.events.iter().any(|event| {
1528 matches!(
1529 event,
1530 Event::Auth(crate::AuthEvent::Failed {
1531 kind: crate::AuthExchangeKind::Reauthentication,
1532 reason: crate::AuthFailureReason::ConnectionClosed,
1533 ..
1534 })
1535 )
1536 }));
1537 }
1538
1539 #[tokio::test]
1540 async fn mqtt_connect_accepts_matching_connack_authentication_method() {
1541 let mut options = MqttOptions::new("test-client", "localhost");
1542 options.set_authentication_method(Some("test-method".to_owned()));
1543 let connack = build_connack_with_authentication_method(Some("test-method"));
1544
1545 let result = run_successful_mqtt_connect_with_connack(options, connack)
1546 .await
1547 .unwrap();
1548
1549 assert_eq!(
1550 result.properties.unwrap().authentication_method.as_deref(),
1551 Some("test-method")
1552 );
1553 }
1554
1555 #[tokio::test]
1556 async fn mqtt_connect_accepts_connack_without_authentication_method_when_connect_omits_it() {
1557 let options = MqttOptions::new("test-client", "localhost");
1558 let connack = ConnAck {
1559 session_present: false,
1560 code: ConnectReturnCode::Success,
1561 properties: None,
1562 };
1563
1564 let result = run_successful_mqtt_connect_with_connack(options, connack)
1565 .await
1566 .unwrap();
1567
1568 assert!(result.properties.is_none());
1569 }
1570
1571 #[tokio::test]
1572 async fn mqtt_connect_refreshes_state_authentication_method_from_mutated_options() {
1573 let mut options = MqttOptions::new("test-client", "localhost");
1574 options.set_authentication_method(Some("new-method".to_owned()));
1575 options.set_auth_manager(Arc::new(Mutex::new(StaticAuthManager {
1576 response: Ok(Some(AuthProperties {
1577 method: Some("new-method".to_owned()),
1578 data: None,
1579 reason: None,
1580 user_properties: vec![],
1581 })),
1582 })));
1583 let auth = Auth::new(
1584 AuthReasonCode::Continue,
1585 Some(AuthProperties {
1586 method: Some("new-method".to_owned()),
1587 data: None,
1588 reason: None,
1589 user_properties: vec![],
1590 }),
1591 );
1592 let connack = build_connack_with_authentication_method(Some("new-method"));
1593
1594 let result = run_mqtt_connect_with_stale_state_auth_method(
1595 options,
1596 Some("old-method"),
1597 vec![Packet::Auth(auth), Packet::ConnAck(connack)],
1598 )
1599 .await
1600 .unwrap();
1601
1602 assert_eq!(
1603 result.properties.unwrap().authentication_method.as_deref(),
1604 Some("new-method")
1605 );
1606 }
1607
1608 #[tokio::test]
1609 async fn mqtt_connect_rejects_missing_connack_authentication_method() {
1610 let mut options = MqttOptions::new("test-client", "localhost");
1611 options.set_authentication_method(Some("test-method".to_owned()));
1612 let connack = ConnAck {
1613 session_present: false,
1614 code: ConnectReturnCode::Success,
1615 properties: None,
1616 };
1617
1618 let (result, disconnect) = run_mqtt_connect_with_connack(options, connack).await;
1619
1620 assert!(matches!(
1621 result,
1622 Err(ConnectionError::MqttState(StateError::Deserialization(
1623 MqttError::ProtocolError
1624 )))
1625 ));
1626 assert_eq!(
1627 disconnect,
1628 vec![0xE0, 0x01, DisconnectReasonCode::ProtocolError as u8]
1629 );
1630 }
1631
1632 #[tokio::test]
1633 async fn mqtt_connect_failure_resets_initial_auth_exchange() {
1634 let mut options = MqttOptions::new("test-client", "localhost");
1635 options.set_authentication_method(Some("test-method".to_owned()));
1636 let connack = ConnAck {
1637 session_present: false,
1638 code: ConnectReturnCode::BadAuthenticationMethod,
1639 properties: None,
1640 };
1641
1642 let (result, state) =
1643 run_mqtt_connect_with_connack_and_return_state(options, connack).await;
1644
1645 assert!(matches!(
1646 result,
1647 Err(ConnectionError::ConnectionRefused(
1648 ConnectReturnCode::BadAuthenticationMethod
1649 ))
1650 ));
1651 assert!(state.events.iter().any(|event| {
1652 matches!(
1653 event,
1654 Event::Auth(crate::AuthEvent::Failed {
1655 kind: crate::AuthExchangeKind::InitialConnect,
1656 reason: crate::AuthFailureReason::ConnectionClosed,
1657 ..
1658 })
1659 )
1660 }));
1661 }
1662
1663 #[tokio::test]
1664 async fn mqtt_connect_protocol_error_resets_initial_auth_exchange() {
1665 let mut options = MqttOptions::new("test-client", "localhost");
1666 options.set_authentication_method(Some("test-method".to_owned()));
1667 let connack = ConnAck {
1668 session_present: false,
1669 code: ConnectReturnCode::Success,
1670 properties: None,
1671 };
1672
1673 let (result, state) =
1674 run_mqtt_connect_with_connack_and_return_state(options, connack).await;
1675
1676 assert!(matches!(
1677 result,
1678 Err(ConnectionError::MqttState(StateError::Deserialization(
1679 MqttError::ProtocolError
1680 )))
1681 ));
1682 assert!(state.events.iter().any(|event| {
1683 matches!(
1684 event,
1685 Event::Auth(crate::AuthEvent::Failed {
1686 kind: crate::AuthExchangeKind::InitialConnect,
1687 reason: crate::AuthFailureReason::ConnectionClosed,
1688 ..
1689 })
1690 )
1691 }));
1692 }
1693
1694 #[tokio::test]
1695 async fn mqtt_connect_rejects_mismatched_connack_authentication_method() {
1696 let mut options = MqttOptions::new("test-client", "localhost");
1697 options.set_authentication_method(Some("test-method".to_owned()));
1698 let connack = build_connack_with_authentication_method(Some("other-method"));
1699
1700 let (result, disconnect) = run_mqtt_connect_with_connack(options, connack).await;
1701
1702 assert!(matches!(
1703 result,
1704 Err(ConnectionError::MqttState(StateError::Deserialization(
1705 MqttError::ProtocolError
1706 )))
1707 ));
1708 assert_eq!(
1709 disconnect,
1710 vec![0xE0, 0x01, DisconnectReasonCode::ProtocolError as u8]
1711 );
1712 }
1713
1714 #[tokio::test]
1715 async fn mqtt_connect_rejects_connack_authentication_method_when_connect_omits_it() {
1716 let options = MqttOptions::new("test-client", "localhost");
1717 let connack = build_connack_with_authentication_method(Some("test-method"));
1718
1719 let (result, disconnect) = run_mqtt_connect_with_connack(options, connack).await;
1720
1721 assert!(matches!(
1722 result,
1723 Err(ConnectionError::MqttState(StateError::Deserialization(
1724 MqttError::ProtocolError
1725 )))
1726 ));
1727 assert_eq!(
1728 disconnect,
1729 vec![0xE0, 0x01, DisconnectReasonCode::ProtocolError as u8]
1730 );
1731 }
1732
1733 fn build_eventloop_with_pending(clean_start: bool) -> EventLoop {
1734 let mut options = MqttOptions::new("test-client", "localhost");
1735 options.set_clean_start(clean_start);
1736
1737 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1738 push_pending(&mut eventloop, Request::PingReq);
1739 eventloop
1740 }
1741
1742 fn publish(qos: QoS) -> Publish {
1743 Publish::new("hello/world", qos, "payload", None)
1744 }
1745
1746 fn subscribe() -> Subscribe {
1747 Subscribe::new(Filter::new("hello/world", QoS::AtMostOnce), None)
1748 }
1749
1750 fn fill_publish_window(eventloop: &mut EventLoop) {
1751 let mut active = publish(QoS::AtLeastOnce);
1752 active.pkid = 1;
1753 eventloop.state.outgoing_pub[1] = Some(active);
1754 eventloop.state.inflight = 1;
1755 }
1756
1757 fn next_after_blocked_publish(request: Request) -> Option<Request> {
1758 let mut options = MqttOptions::new("test-client", "localhost");
1759 options.set_outgoing_inflight_upper_limit(1);
1760 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1761 fill_publish_window(&mut eventloop);
1762 eventloop
1763 .queued
1764 .push_back(RequestEnvelope::plain(Request::Publish(publish(
1765 QoS::AtLeastOnce,
1766 ))));
1767 eventloop.queued.push_back(RequestEnvelope::plain(request));
1768
1769 eventloop
1770 .next_scheduled_request()
1771 .map(|(request, _notice)| request)
1772 }
1773
1774 #[test]
1775 fn scheduler_sends_control_after_receive_max_blocked_publish() {
1776 let mut options = MqttOptions::new("test-client", "localhost");
1777 options.set_outgoing_inflight_upper_limit(1);
1778 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1779 fill_publish_window(&mut eventloop);
1780 eventloop
1781 .queued
1782 .push_back(RequestEnvelope::plain(Request::Publish(publish(
1783 QoS::AtLeastOnce,
1784 ))));
1785 eventloop
1786 .queued
1787 .push_back(RequestEnvelope::plain(Request::Subscribe(subscribe())));
1788
1789 let (request, _) = eventloop.next_scheduled_request().unwrap();
1790
1791 assert!(matches!(request, Request::Subscribe(_)));
1792 match eventloop.state.handle_outgoing_packet(request).unwrap() {
1793 Some(Packet::Subscribe(subscribe)) => assert_ne!(subscribe.pkid, 1),
1794 packet => panic!("expected subscribe packet, got {packet:?}"),
1795 }
1796 assert!(matches!(
1797 eventloop
1798 .queued
1799 .drain()
1800 .next()
1801 .map(|envelope| envelope.request),
1802 Some(Request::Publish(_))
1803 ));
1804 }
1805
1806 #[test]
1807 fn scheduler_does_not_let_qos0_publish_pass_blocked_publish() {
1808 let mut options = MqttOptions::new("test-client", "localhost");
1809 options.set_outgoing_inflight_upper_limit(1);
1810 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1811 fill_publish_window(&mut eventloop);
1812 eventloop
1813 .queued
1814 .push_back(RequestEnvelope::plain(Request::Publish(publish(
1815 QoS::AtLeastOnce,
1816 ))));
1817 eventloop
1818 .queued
1819 .push_back(RequestEnvelope::plain(Request::Publish(publish(
1820 QoS::AtMostOnce,
1821 ))));
1822
1823 assert!(eventloop.next_scheduled_request().is_none());
1824 }
1825
1826 #[test]
1827 fn scheduler_allows_each_progress_and_control_class_after_blocked_publish() {
1828 assert!(matches!(
1829 next_after_blocked_publish(Request::PubAck(PubAck::new(7, None))),
1830 Some(Request::PubAck(_))
1831 ));
1832 assert!(matches!(
1833 next_after_blocked_publish(Request::PubRec(PubRec::new(8, None))),
1834 Some(Request::PubRec(_))
1835 ));
1836 assert!(matches!(
1837 next_after_blocked_publish(Request::PubRel(PubRel::new(9, None))),
1838 Some(Request::PubRel(_))
1839 ));
1840 assert!(matches!(
1841 next_after_blocked_publish(Request::PubComp(PubComp::new(10, None))),
1842 Some(Request::PubComp(_))
1843 ));
1844 assert!(matches!(
1845 next_after_blocked_publish(Request::PingReq),
1846 Some(Request::PingReq)
1847 ));
1848 assert!(matches!(
1849 next_after_blocked_publish(Request::Subscribe(subscribe())),
1850 Some(Request::Subscribe(_))
1851 ));
1852 assert!(matches!(
1853 next_after_blocked_publish(Request::Unsubscribe(Unsubscribe::new("hello/world", None))),
1854 Some(Request::Unsubscribe(_))
1855 ));
1856 assert!(matches!(
1857 next_after_blocked_publish(Request::Auth(Auth::new(
1858 AuthReasonCode::ReAuthenticate,
1859 None
1860 ))),
1861 Some(Request::Auth(_))
1862 ));
1863 assert!(matches!(
1864 next_after_blocked_publish(Request::Disconnect(Disconnect::new(
1865 DisconnectReasonCode::NormalDisconnection
1866 ))),
1867 Some(Request::Disconnect(_))
1868 ));
1869 }
1870
1871 #[test]
1872 fn scheduler_unsubscribe_after_blocked_publish_uses_non_conflicting_packet_id() {
1873 let mut options = MqttOptions::new("test-client", "localhost");
1874 options.set_outgoing_inflight_upper_limit(1);
1875 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1876 fill_publish_window(&mut eventloop);
1877 eventloop
1878 .queued
1879 .push_back(RequestEnvelope::plain(Request::Publish(publish(
1880 QoS::AtLeastOnce,
1881 ))));
1882 eventloop
1883 .queued
1884 .push_back(RequestEnvelope::plain(Request::Unsubscribe(
1885 Unsubscribe::new("hello/world", None),
1886 )));
1887
1888 let (request, _) = eventloop.next_scheduled_request().unwrap();
1889
1890 assert!(matches!(request, Request::Unsubscribe(_)));
1891 match eventloop.state.handle_outgoing_packet(request).unwrap() {
1892 Some(Packet::Unsubscribe(unsubscribe)) => assert_ne!(unsubscribe.pkid, 1),
1893 packet => panic!("expected unsubscribe packet, got {packet:?}"),
1894 }
1895 }
1896
1897 #[test]
1898 fn scheduler_preserves_sendable_publish_before_later_control() {
1899 let (mut eventloop, _request_tx) =
1900 EventLoop::new_for_async_client(MqttOptions::new("test-client", "localhost"), 1);
1901 eventloop
1902 .queued
1903 .push_back(RequestEnvelope::plain(Request::Publish(publish(
1904 QoS::AtLeastOnce,
1905 ))));
1906 eventloop
1907 .queued
1908 .push_back(RequestEnvelope::plain(Request::Subscribe(subscribe())));
1909
1910 let (request, _) = eventloop.next_scheduled_request().unwrap();
1911
1912 assert!(matches!(request, Request::Publish(_)));
1913 }
1914
1915 #[tokio::test]
1916 async fn select_admits_control_request_after_ready_publish_backlog_snapshot() {
1917 let mut options = MqttOptions::new("test-client", "localhost");
1918 options.set_max_request_batch(1);
1919 let (mut eventloop, request_tx, control_request_tx, _immediate_disconnect_tx) =
1920 EventLoop::new_for_async_client_with_capacity(
1921 options,
1922 RequestChannelCapacity::Unbounded,
1923 );
1924 let (client, _peer) = tokio::io::duplex(1024);
1925 let mut network = Network::new(client, Some(1024));
1926 network.set_max_outgoing_size(Some(1024));
1927 eventloop.network = Some(network);
1928
1929 request_tx
1930 .send_async(RequestEnvelope::plain(Request::Publish(publish(
1931 QoS::AtMostOnce,
1932 ))))
1933 .await
1934 .unwrap();
1935 request_tx
1936 .send_async(RequestEnvelope::plain(Request::Publish(publish(
1937 QoS::AtMostOnce,
1938 ))))
1939 .await
1940 .unwrap();
1941 control_request_tx
1942 .send_async(RequestEnvelope::plain(Request::Disconnect(
1943 Disconnect::new(DisconnectReasonCode::NormalDisconnection),
1944 )))
1945 .await
1946 .unwrap();
1947
1948 let first = time::timeout(Duration::from_secs(1), eventloop.select())
1949 .await
1950 .expect("timed out waiting for first request")
1951 .expect("select should not fail");
1952 request_tx
1953 .send_async(RequestEnvelope::plain(Request::Publish(publish(
1954 QoS::AtMostOnce,
1955 ))))
1956 .await
1957 .unwrap();
1958 let second = time::timeout(Duration::from_secs(1), eventloop.select())
1959 .await
1960 .expect("timed out waiting for second request")
1961 .expect("select should not fail");
1962 let third = time::timeout(Duration::from_secs(1), eventloop.select())
1963 .await
1964 .expect("timed out waiting for control request")
1965 .expect("select should not fail");
1966
1967 assert!(matches!(first, Event::Outgoing(Outgoing::Publish(_))));
1968 assert!(matches!(second, Event::Outgoing(Outgoing::Publish(_))));
1969 assert!(matches!(third, Event::Outgoing(Outgoing::Disconnect)));
1970 }
1971
1972 fn publish_properties_with_alias(alias: u16) -> PublishProperties {
1973 PublishProperties {
1974 topic_alias: Some(alias),
1975 ..Default::default()
1976 }
1977 }
1978
1979 fn publish_with_alias(topic: &str, qos: QoS, alias: u16) -> Publish {
1980 Publish::new(
1981 topic,
1982 qos,
1983 "payload",
1984 Some(publish_properties_with_alias(alias)),
1985 )
1986 }
1987
1988 #[test]
1989 fn eventloop_new_keeps_internal_sender_alive() {
1990 let options = MqttOptions::new("test-client", "localhost");
1991 let eventloop = EventLoop::new(options, 1);
1992
1993 assert!(matches!(
1994 eventloop.requests_rx.try_recv(),
1995 Err(TryRecvError::Empty)
1996 ));
1997 }
1998
1999 #[test]
2000 fn async_client_constructor_path_allows_channel_shutdown() {
2001 let options = MqttOptions::new("test-client", "localhost");
2002 let (eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
2003 drop(request_tx);
2004
2005 assert!(matches!(
2006 eventloop.requests_rx.try_recv(),
2007 Err(TryRecvError::Disconnected)
2008 ));
2009 }
2010
2011 #[test]
2012 fn clean_drops_ack_requests_drained_from_channel() {
2013 let options = MqttOptions::new("test-client", "localhost");
2014 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 3);
2015 request_tx
2016 .send(RequestEnvelope::plain(Request::PubAck(PubAck::new(
2017 7, None,
2018 ))))
2019 .unwrap();
2020 request_tx
2021 .send(RequestEnvelope::plain(Request::PubRec(PubRec::new(
2022 8, None,
2023 ))))
2024 .unwrap();
2025 request_tx
2026 .send(RequestEnvelope::plain(Request::PingReq))
2027 .unwrap();
2028
2029 eventloop.clean();
2030
2031 assert_eq!(eventloop.pending_len(), 1);
2032 assert!(matches!(
2033 pending_front_request(&eventloop),
2034 Some(Request::PingReq)
2035 ));
2036 }
2037
2038 #[test]
2039 fn clean_drops_ack_requests_drained_from_queued_scheduler() {
2040 let options = MqttOptions::new("test-client", "localhost");
2041 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 3);
2042 eventloop
2043 .queued
2044 .push_back(RequestEnvelope::plain(Request::PubAck(PubAck::new(
2045 7, None,
2046 ))));
2047 eventloop
2048 .queued
2049 .push_back(RequestEnvelope::plain(Request::PubRec(PubRec::new(
2050 8, None,
2051 ))));
2052 eventloop
2053 .queued
2054 .push_back(RequestEnvelope::plain(Request::PingReq));
2055
2056 eventloop.clean();
2057
2058 assert_eq!(eventloop.pending_len(), 1);
2059 assert!(matches!(
2060 pending_front_request(&eventloop),
2061 Some(Request::PingReq)
2062 ));
2063 }
2064
2065 #[test]
2066 fn clean_rewrites_alias_only_pending_publish_when_mapping_is_known() {
2067 let options = MqttOptions::new("test-client", "localhost");
2068 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
2069 eventloop.state.broker_topic_alias_max = 10;
2070 eventloop
2071 .state
2072 .handle_outgoing_packet(Request::Publish(publish_with_alias(
2073 "hello/replay",
2074 QoS::AtMostOnce,
2075 4,
2076 )))
2077 .unwrap();
2078
2079 request_tx
2080 .send(RequestEnvelope::plain(Request::Publish(
2081 publish_with_alias("", QoS::AtLeastOnce, 4),
2082 )))
2083 .unwrap();
2084
2085 eventloop.clean();
2086
2087 assert_eq!(eventloop.pending_len(), 1);
2088 match pending_front_request(&eventloop) {
2089 Some(Request::Publish(publish)) => {
2090 assert_eq!(publish.topic, Bytes::from_static(b"hello/replay"));
2091 assert_eq!(
2092 publish
2093 .properties
2094 .as_ref()
2095 .and_then(|props| props.topic_alias),
2096 None
2097 );
2098 }
2099 request => panic!("expected replay publish, got {request:?}"),
2100 }
2101 assert_eq!(eventloop.state.broker_topic_alias_max, 0);
2102 }
2103
2104 #[test]
2105 fn clean_uses_earlier_drained_publish_to_rewrite_later_alias_only_publish() {
2106 let options = MqttOptions::new("test-client", "localhost");
2107 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 2);
2108
2109 request_tx
2110 .send(RequestEnvelope::plain(Request::Publish(
2111 publish_with_alias("fresh/topic", QoS::AtLeastOnce, 1),
2112 )))
2113 .unwrap();
2114 request_tx
2115 .send(RequestEnvelope::plain(Request::Publish(
2116 publish_with_alias("", QoS::AtLeastOnce, 1),
2117 )))
2118 .unwrap();
2119
2120 eventloop.clean();
2121
2122 assert_eq!(eventloop.pending_len(), 2);
2123 assert!(matches!(
2124 eventloop.pending.pop_front().map(|envelope| envelope.request),
2125 Some(Request::Publish(publish)) if publish.topic == Bytes::from_static(b"fresh/topic")
2126 ));
2127 assert!(matches!(
2128 eventloop.pending.pop_front().map(|envelope| envelope.request),
2129 Some(Request::Publish(publish)) if publish.topic == Bytes::from_static(b"fresh/topic")
2130 ));
2131 }
2132
2133 #[test]
2134 fn clean_prefers_earlier_replay_alias_mapping_over_stale_previous_mapping() {
2135 let options = MqttOptions::new("test-client", "localhost");
2136 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 2);
2137 eventloop.state.broker_topic_alias_max = 10;
2138 eventloop
2139 .state
2140 .handle_outgoing_packet(Request::Publish(publish_with_alias(
2141 "stale/topic",
2142 QoS::AtMostOnce,
2143 1,
2144 )))
2145 .unwrap();
2146
2147 request_tx
2148 .send(RequestEnvelope::plain(Request::Publish(
2149 publish_with_alias("fresh/topic", QoS::AtLeastOnce, 1),
2150 )))
2151 .unwrap();
2152 request_tx
2153 .send(RequestEnvelope::plain(Request::Publish(
2154 publish_with_alias("", QoS::AtLeastOnce, 1),
2155 )))
2156 .unwrap();
2157
2158 eventloop.clean();
2159
2160 assert_eq!(eventloop.pending_len(), 2);
2161 _ = eventloop.pending.pop_front();
2162 assert!(matches!(
2163 eventloop.pending.pop_front().map(|envelope| envelope.request),
2164 Some(Request::Publish(publish)) if publish.topic == Bytes::from_static(b"fresh/topic")
2165 ));
2166 }
2167
2168 #[test]
2169 fn clean_fails_tracked_alias_only_publish_when_mapping_is_unknown() {
2170 let options = MqttOptions::new("test-client", "localhost");
2171 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
2172 let (notice_tx, notice) = PublishNoticeTx::new();
2173
2174 request_tx
2175 .send(RequestEnvelope::tracked_publish(
2176 publish_with_alias("", QoS::AtLeastOnce, 5),
2177 notice_tx,
2178 ))
2179 .unwrap();
2180
2181 eventloop.clean();
2182
2183 assert!(eventloop.pending_is_empty());
2184 assert_eq!(
2185 notice.wait().unwrap_err(),
2186 PublishNoticeError::TopicAliasReplayUnavailable(5)
2187 );
2188 }
2189
2190 #[test]
2191 fn clean_preserves_surviving_pending_order_when_alias_replay_is_filtered() {
2192 let options = MqttOptions::new("test-client", "localhost");
2193 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 2);
2194 push_pending(&mut eventloop, Request::PingReq);
2195 let (notice_tx, notice) = PublishNoticeTx::new();
2196 request_tx
2197 .send(RequestEnvelope::tracked_publish(
2198 publish_with_alias("", QoS::AtLeastOnce, 6),
2199 notice_tx,
2200 ))
2201 .unwrap();
2202 request_tx
2203 .send(RequestEnvelope::plain(Request::PingResp))
2204 .unwrap();
2205
2206 eventloop.clean();
2207
2208 assert_eq!(
2209 notice.wait().unwrap_err(),
2210 PublishNoticeError::TopicAliasReplayUnavailable(6)
2211 );
2212 assert_eq!(eventloop.pending_len(), 2);
2213 assert!(matches!(
2214 eventloop
2215 .pending
2216 .pop_front()
2217 .map(|envelope| envelope.request),
2218 Some(Request::PingReq)
2219 ));
2220 assert!(matches!(
2221 eventloop
2222 .pending
2223 .pop_front()
2224 .map(|envelope| envelope.request),
2225 Some(Request::PingResp)
2226 ));
2227 }
2228
2229 #[tokio::test]
2230 #[cfg(unix)]
2231 async fn network_connect_rejects_unix_broker_with_tcp_transport() {
2232 let mut options = MqttOptions::new("test-client", crate::Broker::unix("/tmp/mqtt.sock"));
2233 options.set_transport(Transport::tcp());
2234
2235 match network_connect(&options).await {
2236 Err(ConnectionError::BrokerTransportMismatch) => {}
2237 Err(err) => panic!("unexpected error: {err:?}"),
2238 Ok(_) => panic!("mismatched broker and transport should fail"),
2239 }
2240 }
2241
2242 #[tokio::test]
2243 #[cfg(feature = "websocket")]
2244 async fn network_connect_rejects_tcp_broker_with_websocket_transport() {
2245 let mut options = MqttOptions::new("test-client", "localhost");
2246 options.set_transport(Transport::Ws);
2247
2248 match network_connect(&options).await {
2249 Err(ConnectionError::BrokerTransportMismatch) => {}
2250 Err(err) => panic!("unexpected error: {err:?}"),
2251 Ok(_) => panic!("mismatched broker and transport should fail"),
2252 }
2253 }
2254
2255 #[tokio::test]
2256 #[cfg(feature = "websocket")]
2257 async fn network_connect_rejects_websocket_broker_with_tcp_transport() {
2258 let broker = crate::Broker::websocket("ws://localhost:9001/mqtt").unwrap();
2259 let mut options = MqttOptions::new("test-client", broker);
2260 options.set_transport(Transport::tcp());
2261
2262 match network_connect(&options).await {
2263 Err(ConnectionError::BrokerTransportMismatch) => {}
2264 Err(err) => panic!("unexpected error: {err:?}"),
2265 Ok(_) => panic!("mismatched broker and transport should fail"),
2266 }
2267 }
2268
2269 #[test]
2270 fn connack_resize_skips_shrink_until_pending_retransmit_queue_is_empty() {
2271 let mut options = MqttOptions::new("test-client", "localhost");
2272 options.set_outgoing_inflight_upper_limit(10);
2273 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2274 let mut publish = Publish::new(
2275 "hello/world",
2276 crate::mqttbytes::QoS::AtLeastOnce,
2277 "payload",
2278 None,
2279 );
2280 publish.pkid = 8;
2281 push_pending(&mut eventloop, Request::Publish(publish));
2282
2283 eventloop
2284 .state
2285 .handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
2286 .unwrap();
2287
2288 eventloop.reconcile_outgoing_tracking_after_connack();
2289 assert_eq!(eventloop.state.outgoing_pub.len(), 11);
2290
2291 eventloop.pending.clear();
2292 eventloop.reconcile_outgoing_tracking_after_connack();
2293 assert_eq!(eventloop.state.outgoing_pub.len(), 4);
2294 assert_eq!(eventloop.state.outgoing_pub_notice.len(), 4);
2295 assert_eq!(eventloop.state.outgoing_rel_notice.len(), 4);
2296 }
2297
2298 #[tokio::test]
2299 async fn async_client_path_reports_requests_done_after_pending_drain() {
2300 let options = MqttOptions::new("test-client", "localhost");
2301 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
2302 push_pending(&mut eventloop, Request::PingReq);
2303 drop(request_tx);
2304
2305 let request = EventLoop::next_request(
2306 &mut eventloop.pending,
2307 &eventloop.requests_rx,
2308 Duration::ZERO,
2309 )
2310 .await
2311 .unwrap();
2312 assert!(matches!(request, (Request::PingReq, None)));
2313
2314 let err = EventLoop::next_request(
2315 &mut eventloop.pending,
2316 &eventloop.requests_rx,
2317 Duration::ZERO,
2318 )
2319 .await
2320 .unwrap_err();
2321 assert!(matches!(err, ConnectionError::RequestsDone));
2322 }
2323
2324 #[tokio::test]
2325 async fn next_request_is_cancellation_safe_for_pending_queue() {
2326 let options = MqttOptions::new("test-client", "localhost");
2327 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2328 push_pending(&mut eventloop, Request::PingReq);
2329
2330 let delayed = EventLoop::next_request(
2331 &mut eventloop.pending,
2332 &eventloop.requests_rx,
2333 Duration::from_millis(50),
2334 );
2335 let timed_out = time::timeout(Duration::from_millis(5), delayed).await;
2336
2337 assert!(timed_out.is_err());
2338 assert!(matches!(
2339 pending_front_request(&eventloop),
2340 Some(Request::PingReq)
2341 ));
2342 }
2343
2344 #[tokio::test]
2345 async fn try_next_request_applies_pending_throttle_for_followup_pending_item() {
2346 let options = MqttOptions::new("test-client", "localhost");
2347 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 2);
2348 push_pending(&mut eventloop, Request::PingReq);
2349 push_pending(&mut eventloop, Request::PingResp);
2350
2351 let first = EventLoop::next_request(
2352 &mut eventloop.pending,
2353 &eventloop.requests_rx,
2354 Duration::ZERO,
2355 )
2356 .await
2357 .unwrap();
2358 assert!(matches!(first, (Request::PingReq, None)));
2359
2360 let delayed = EventLoop::try_next_request(
2361 &mut eventloop.pending,
2362 &eventloop.requests_rx,
2363 Duration::from_millis(50),
2364 );
2365 let timed_out = time::timeout(Duration::from_millis(5), delayed).await;
2366
2367 assert!(timed_out.is_err());
2368 assert!(matches!(
2369 pending_front_request(&eventloop),
2370 Some(Request::PingResp)
2371 ));
2372 }
2373
2374 #[tokio::test]
2375 async fn try_next_request_does_not_throttle_when_pending_queue_is_empty() {
2376 let options = MqttOptions::new("test-client", "localhost");
2377 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
2378 request_tx
2379 .send_async(RequestEnvelope::plain(Request::PingReq))
2380 .await
2381 .unwrap();
2382
2383 let received = time::timeout(
2384 Duration::from_millis(20),
2385 EventLoop::try_next_request(
2386 &mut eventloop.pending,
2387 &eventloop.requests_rx,
2388 Duration::from_secs(1),
2389 ),
2390 )
2391 .await
2392 .unwrap();
2393
2394 assert!(matches!(received, Some((Request::PingReq, None))));
2395 }
2396
2397 #[tokio::test]
2398 async fn next_request_prioritizes_pending_over_channel_messages() {
2399 let options = MqttOptions::new("test-client", "localhost");
2400 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 2);
2401 push_pending(&mut eventloop, Request::PingReq);
2402 request_tx
2403 .send_async(RequestEnvelope::plain(Request::PingReq))
2404 .await
2405 .unwrap();
2406
2407 let first = EventLoop::next_request(
2408 &mut eventloop.pending,
2409 &eventloop.requests_rx,
2410 Duration::ZERO,
2411 )
2412 .await
2413 .unwrap();
2414 assert!(matches!(first, (Request::PingReq, None)));
2415 assert!(eventloop.pending.is_empty());
2416
2417 let second = EventLoop::next_request(
2418 &mut eventloop.pending,
2419 &eventloop.requests_rx,
2420 Duration::ZERO,
2421 )
2422 .await
2423 .unwrap();
2424 assert!(matches!(second, (Request::PingReq, None)));
2425 }
2426
2427 #[tokio::test]
2428 async fn next_request_preserves_fifo_order_for_plain_and_tracked_requests() {
2429 let options = MqttOptions::new("test-client", "localhost");
2430 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
2431 let (notice_tx, _notice) = PublishNoticeTx::new();
2432 let tracked_publish = Publish::new(
2433 "hello/world",
2434 crate::mqttbytes::QoS::AtLeastOnce,
2435 "payload",
2436 None,
2437 );
2438
2439 request_tx
2440 .send_async(RequestEnvelope::plain(Request::PingReq))
2441 .await
2442 .unwrap();
2443 request_tx
2444 .send_async(RequestEnvelope::tracked_publish(
2445 tracked_publish.clone(),
2446 notice_tx,
2447 ))
2448 .await
2449 .unwrap();
2450 request_tx
2451 .send_async(RequestEnvelope::plain(Request::PingResp))
2452 .await
2453 .unwrap();
2454
2455 let first = EventLoop::next_request(
2456 &mut eventloop.pending,
2457 &eventloop.requests_rx,
2458 Duration::ZERO,
2459 )
2460 .await
2461 .unwrap();
2462 assert!(matches!(first, (Request::PingReq, None)));
2463
2464 let second = EventLoop::next_request(
2465 &mut eventloop.pending,
2466 &eventloop.requests_rx,
2467 Duration::ZERO,
2468 )
2469 .await
2470 .unwrap();
2471 assert!(matches!(
2472 second,
2473 (Request::Publish(publish), Some(_)) if publish == tracked_publish
2474 ));
2475
2476 let third = EventLoop::next_request(
2477 &mut eventloop.pending,
2478 &eventloop.requests_rx,
2479 Duration::ZERO,
2480 )
2481 .await
2482 .unwrap();
2483 assert!(matches!(third, (Request::PingResp, None)));
2484 }
2485
2486 #[tokio::test]
2487 async fn tracked_qos0_notice_reports_not_flushed_on_first_write_failure() {
2488 let options = MqttOptions::new("test-client", "localhost");
2489 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
2490 let (client, _peer) = tokio::io::duplex(1024);
2491 let mut network = Network::new(client, Some(1024));
2492 network.set_max_outgoing_size(Some(16));
2493 eventloop.network = Some(network);
2494
2495 let (notice_tx, notice) = PublishNoticeTx::new();
2496 let publish = Publish::new(
2497 "hello/world",
2498 crate::mqttbytes::QoS::AtMostOnce,
2499 vec![1; 128],
2500 None,
2501 );
2502 request_tx
2503 .send_async(RequestEnvelope::tracked_publish(publish, notice_tx))
2504 .await
2505 .unwrap();
2506
2507 let err = eventloop.select().await.unwrap_err();
2508 assert!(matches!(err, ConnectionError::MqttState(_)));
2509 assert_eq!(
2510 notice.wait_async().await.unwrap_err(),
2511 PublishNoticeError::Qos0NotFlushed
2512 );
2513 }
2514
2515 #[tokio::test]
2516 async fn tracked_qos0_notices_report_not_flushed_on_batched_write_failure() {
2517 let mut options = MqttOptions::new("test-client", "localhost");
2518 options.set_max_request_batch(2);
2519 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
2520 let (client, _peer) = tokio::io::duplex(1024);
2521 let mut network = Network::new(client, Some(1024));
2522 network.set_max_outgoing_size(Some(80));
2523 eventloop.network = Some(network);
2524
2525 let small_publish = Publish::new(
2526 "hello/world",
2527 crate::mqttbytes::QoS::AtMostOnce,
2528 vec![1],
2529 None,
2530 );
2531 let large_publish = Publish::new(
2532 "hello/world",
2533 crate::mqttbytes::QoS::AtMostOnce,
2534 vec![2; 256],
2535 None,
2536 );
2537
2538 let (first_notice_tx, first_notice) = PublishNoticeTx::new();
2539 request_tx
2540 .send_async(RequestEnvelope::tracked_publish(
2541 small_publish,
2542 first_notice_tx,
2543 ))
2544 .await
2545 .unwrap();
2546
2547 let (second_notice_tx, second_notice) = PublishNoticeTx::new();
2548 request_tx
2549 .send_async(RequestEnvelope::tracked_publish(
2550 large_publish,
2551 second_notice_tx,
2552 ))
2553 .await
2554 .unwrap();
2555
2556 let err = eventloop.select().await.unwrap_err();
2557 assert!(matches!(err, ConnectionError::MqttState(_)));
2558 assert_eq!(
2559 first_notice.wait_async().await.unwrap_err(),
2560 PublishNoticeError::Qos0NotFlushed
2561 );
2562 assert_eq!(
2563 second_notice.wait_async().await.unwrap_err(),
2564 PublishNoticeError::Qos0NotFlushed
2565 );
2566 }
2567
2568 #[tokio::test]
2569 async fn drain_pending_as_failed_drains_all_and_returns_count() {
2570 let options = MqttOptions::new("test-client", "localhost");
2571 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2572 let (notice_tx, notice) = PublishNoticeTx::new();
2573 let publish = Publish::new(
2574 "hello/world",
2575 crate::mqttbytes::QoS::AtLeastOnce,
2576 "payload",
2577 None,
2578 );
2579 eventloop
2580 .pending
2581 .push_back(RequestEnvelope::tracked_publish(publish, notice_tx));
2582 eventloop
2583 .pending
2584 .push_back(RequestEnvelope::plain(Request::PingReq));
2585
2586 let drained = eventloop.drain_pending_as_failed(NoticeFailureReason::SessionReset);
2587
2588 assert_eq!(drained, 2);
2589 assert!(eventloop.pending.is_empty());
2590 assert_eq!(
2591 notice.wait_async().await.unwrap_err(),
2592 PublishNoticeError::SessionReset
2593 );
2594 }
2595
2596 #[tokio::test]
2597 async fn drain_pending_as_failed_reports_session_reset_for_tracked_notices() {
2598 let options = MqttOptions::new("test-client", "localhost");
2599 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2600 let (publish_notice_tx, publish_notice) = PublishNoticeTx::new();
2601 let publish = Publish::new(
2602 "hello/world",
2603 crate::mqttbytes::QoS::AtLeastOnce,
2604 "payload",
2605 None,
2606 );
2607 eventloop
2608 .pending
2609 .push_back(RequestEnvelope::tracked_publish(publish, publish_notice_tx));
2610
2611 let (request_notice_tx, request_notice) = SubscribeNoticeTx::new();
2612 let subscribe = Subscribe::new(
2613 Filter::new("hello/world", crate::mqttbytes::QoS::AtMostOnce),
2614 None,
2615 );
2616 eventloop
2617 .pending
2618 .push_back(RequestEnvelope::tracked_subscribe(
2619 subscribe,
2620 request_notice_tx,
2621 ));
2622
2623 eventloop.drain_pending_as_failed(NoticeFailureReason::SessionReset);
2624
2625 assert_eq!(
2626 publish_notice.wait_async().await.unwrap_err(),
2627 PublishNoticeError::SessionReset
2628 );
2629 assert_eq!(
2630 request_notice.wait_async().await.unwrap_err(),
2631 crate::SubscribeNoticeError::SessionReset
2632 );
2633 }
2634
2635 #[tokio::test]
2636 async fn reset_session_state_reports_session_reset_for_pending_tracked_notice() {
2637 let options = MqttOptions::new("test-client", "localhost");
2638 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2639 let (notice_tx, notice) = PublishNoticeTx::new();
2640 let publish = Publish::new(
2641 "hello/world",
2642 crate::mqttbytes::QoS::AtLeastOnce,
2643 "payload",
2644 None,
2645 );
2646 eventloop
2647 .pending
2648 .push_back(RequestEnvelope::tracked_publish(publish, notice_tx));
2649
2650 eventloop.reset_session_state();
2651
2652 assert!(eventloop.pending.is_empty());
2653 assert_eq!(
2654 notice.wait_async().await.unwrap_err(),
2655 PublishNoticeError::SessionReset
2656 );
2657 }
2658
2659 #[tokio::test]
2660 async fn reset_session_state_fails_active_tracked_reauth_notice() {
2661 let mut options = MqttOptions::new("test-client", "localhost");
2662 options.set_authentication_method(Some("test-method".to_owned()));
2663 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2664 let (notice_tx, notice) = AuthNoticeTx::new();
2665
2666 eventloop
2667 .state
2668 .handle_outgoing_packet_with_notice(
2669 Request::Auth(Auth::new(AuthReasonCode::ReAuthenticate, None)),
2670 Some(TrackedNoticeTx::Auth(notice_tx)),
2671 )
2672 .unwrap();
2673
2674 eventloop.reset_session_state();
2675
2676 assert_eq!(
2677 notice.wait_async().await.unwrap_err(),
2678 crate::notice::AuthNoticeError::SessionReset
2679 );
2680 assert!(eventloop.state.events.iter().any(|event| {
2681 matches!(
2682 event,
2683 Event::Auth(crate::AuthEvent::Failed {
2684 kind: crate::AuthExchangeKind::Reauthentication,
2685 reason: crate::AuthFailureReason::SessionReset,
2686 ..
2687 })
2688 )
2689 }));
2690 eventloop
2691 .state
2692 .handle_outgoing_packet(Request::Auth(Auth::new(
2693 AuthReasonCode::ReAuthenticate,
2694 None,
2695 )))
2696 .unwrap();
2697 }
2698
2699 #[tokio::test]
2700 async fn reset_session_state_does_not_fail_initial_auth_exchange() {
2701 let mut options = MqttOptions::new("test-client", "localhost");
2702 options.set_authentication_method(Some("test-method".to_owned()));
2703 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
2704
2705 eventloop
2706 .state
2707 .begin_authentication_connect(Some("test-method".to_owned()))
2708 .unwrap();
2709 eventloop.reset_session_state();
2710
2711 assert!(!eventloop.state.events.iter().any(|event| {
2712 matches!(
2713 event,
2714 Event::Auth(crate::AuthEvent::Failed {
2715 kind: crate::AuthExchangeKind::InitialConnect,
2716 ..
2717 })
2718 )
2719 }));
2720 eventloop
2721 .state
2722 .handle_incoming_packet(Incoming::ConnAck(build_connack_with_authentication_method(
2723 Some("test-method"),
2724 )))
2725 .unwrap();
2726 assert!(eventloop.state.events.iter().any(|event| {
2727 matches!(
2728 event,
2729 Event::Auth(crate::AuthEvent::Succeeded {
2730 kind: crate::AuthExchangeKind::InitialConnect,
2731 ..
2732 })
2733 )
2734 }));
2735 }
2736
2737 #[test]
2738 fn connack_reconcile_rejects_clean_start_with_session_present() {
2739 let mut eventloop = build_eventloop_with_pending(true);
2740
2741 let err = eventloop.reconcile_connack_session(true).unwrap_err();
2742
2743 assert!(matches!(
2744 err,
2745 ConnectionError::SessionStateMismatch {
2746 clean_start: true,
2747 session_present: true
2748 }
2749 ));
2750 assert_eq!(eventloop.pending_len(), 1);
2751 }
2752
2753 #[test]
2754 fn connack_reconcile_resets_pending_when_clean_start_gets_new_session() {
2755 let mut eventloop = build_eventloop_with_pending(true);
2756
2757 eventloop.reconcile_connack_session(false).unwrap();
2758
2759 assert!(eventloop.pending_is_empty());
2760 }
2761
2762 #[test]
2763 fn connack_reconcile_resets_pending_when_resumed_session_is_missing() {
2764 let mut eventloop = build_eventloop_with_pending(false);
2765
2766 eventloop.reconcile_connack_session(false).unwrap();
2767
2768 assert!(eventloop.pending_is_empty());
2769 }
2770
2771 #[test]
2772 fn connack_reconcile_keeps_pending_when_resumed_session_exists() {
2773 let mut eventloop = build_eventloop_with_pending(false);
2774
2775 eventloop.reconcile_connack_session(true).unwrap();
2776
2777 assert_eq!(eventloop.pending_len(), 1);
2778 }
2779}