1use super::framed::Network;
2use super::mqttbytes::v5::{ConnAck, Connect, Packet, Publish, Subscribe, Unsubscribe};
3use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
4use crate::framed::AsyncReadWrite;
5use crate::notice::{PublishNoticeTx, RequestNoticeTx, TrackedNoticeTx};
6use crate::{NoticeFailureReason, PublishNoticeError};
7
8use flume::{Receiver, Sender, TryRecvError, bounded};
9use tokio::select;
10use tokio::time::{self, Instant, Sleep, error::Elapsed};
11
12use std::collections::VecDeque;
13use std::io;
14use std::pin::Pin;
15use std::time::Duration;
16
17use super::mqttbytes::v5::ConnectReturnCode;
18
19#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
20use crate::tls;
21
22#[cfg(unix)]
23use {std::path::Path, tokio::net::UnixStream};
24
25#[cfg(feature = "websocket")]
26use {
27 crate::websockets::WsAdapter,
28 crate::websockets::{UrlError, split_url, validate_response_headers},
29 async_tungstenite::tungstenite::client::IntoClientRequest,
30};
31
32#[cfg(feature = "proxy")]
33use crate::proxy::ProxyError;
34
35#[derive(Debug)]
36pub struct RequestEnvelope {
37 request: Request,
38 notice: Option<TrackedNoticeTx>,
39}
40
41impl RequestEnvelope {
42 pub(crate) const fn from_parts(request: Request, notice: Option<TrackedNoticeTx>) -> Self {
43 Self { request, notice }
44 }
45
46 pub(crate) const fn plain(request: Request) -> Self {
47 Self {
48 request,
49 notice: None,
50 }
51 }
52
53 pub(crate) const fn tracked_publish(publish: Publish, notice: PublishNoticeTx) -> Self {
54 Self {
55 request: Request::Publish(publish),
56 notice: Some(TrackedNoticeTx::Publish(notice)),
57 }
58 }
59
60 pub(crate) const fn tracked_subscribe(subscribe: Subscribe, notice: RequestNoticeTx) -> Self {
61 Self {
62 request: Request::Subscribe(subscribe),
63 notice: Some(TrackedNoticeTx::Request(notice)),
64 }
65 }
66
67 pub(crate) const fn tracked_unsubscribe(
68 unsubscribe: Unsubscribe,
69 notice: RequestNoticeTx,
70 ) -> Self {
71 Self {
72 request: Request::Unsubscribe(unsubscribe),
73 notice: Some(TrackedNoticeTx::Request(notice)),
74 }
75 }
76
77 pub(crate) fn into_parts(self) -> (Request, Option<TrackedNoticeTx>) {
78 (self.request, self.notice)
79 }
80}
81
82#[derive(Debug, thiserror::Error)]
84pub enum ConnectionError {
85 #[error("Mqtt state: {0}")]
86 MqttState(#[from] StateError),
87 #[error("Timeout")]
88 Timeout(#[from] Elapsed),
89 #[cfg(feature = "websocket")]
90 #[error("Websocket: {0}")]
91 Websocket(#[from] async_tungstenite::tungstenite::error::Error),
92 #[cfg(feature = "websocket")]
93 #[error("Websocket Connect: {0}")]
94 WsConnect(#[from] http::Error),
95 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
96 #[error("TLS: {0}")]
97 Tls(#[from] tls::Error),
98 #[error("I/O: {0}")]
99 Io(#[from] io::Error),
100 #[error("Connection refused, return code: `{0:?}`")]
101 ConnectionRefused(ConnectReturnCode),
102 #[error("Expected ConnAck packet, received: {0:?}")]
103 NotConnAck(Box<Packet>),
104 #[error("Broker replied with session_present={session_present} for clean_start={clean_start}")]
105 SessionStateMismatch {
106 clean_start: bool,
107 session_present: bool,
108 },
109 #[error("Broker target is incompatible with the selected transport")]
110 BrokerTransportMismatch,
111 #[error("Requests done")]
114 RequestsDone,
115 #[error("Auth processing error")]
116 AuthProcessingError,
117 #[cfg(feature = "websocket")]
118 #[error("Invalid Url: {0}")]
119 InvalidUrl(#[from] UrlError),
120 #[cfg(feature = "proxy")]
121 #[error("Proxy Connect: {0}")]
122 Proxy(#[from] ProxyError),
123 #[cfg(feature = "websocket")]
124 #[error("Websocket response validation error: ")]
125 ResponseValidation(#[from] crate::websockets::ValidationError),
126 #[cfg(feature = "websocket")]
127 #[error("Websocket request modifier failed: {0}")]
128 RequestModifier(#[source] Box<dyn std::error::Error + Send + Sync>),
129}
130
131pub struct EventLoop {
133 pub options: MqttOptions,
135 pub state: MqttState,
137 requests_rx: Receiver<RequestEnvelope>,
139 _requests_tx: Option<Sender<RequestEnvelope>>,
142 pending: VecDeque<RequestEnvelope>,
144 network: Option<Network>,
146 keepalive_timeout: Option<Pin<Box<Sleep>>>,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq)]
152#[allow(clippy::large_enum_variant)]
153pub enum Event {
154 Incoming(Incoming),
155 Outgoing(Outgoing),
156}
157
158impl EventLoop {
159 fn reconcile_connack_session(&mut self, session_present: bool) -> Result<(), ConnectionError> {
160 let clean_start = self.options.clean_start();
161 if clean_start && session_present {
162 return Err(ConnectionError::SessionStateMismatch {
163 clean_start,
164 session_present,
165 });
166 }
167
168 if !session_present {
169 self.reset_session_state();
170 }
171
172 Ok(())
173 }
174
175 pub fn new(options: MqttOptions, cap: usize) -> Self {
180 let (requests_tx, requests_rx) = bounded(cap);
181 Self::with_channel(options, requests_rx, Some(requests_tx))
182 }
183
184 pub(crate) fn new_for_async_client(
189 options: MqttOptions,
190 cap: usize,
191 ) -> (Self, Sender<RequestEnvelope>) {
192 let (requests_tx, requests_rx) = bounded(cap);
193 let eventloop = Self::with_channel(options, requests_rx, None);
194 (eventloop, requests_tx)
195 }
196
197 fn with_channel(
198 options: MqttOptions,
199 requests_rx: Receiver<RequestEnvelope>,
200 requests_tx: Option<Sender<RequestEnvelope>>,
201 ) -> Self {
202 let pending = VecDeque::new();
203 let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX);
204 let manual_acks = options.manual_acks;
205
206 let auth_manager = options.auth_manager();
207
208 Self {
209 options,
210 state: MqttState::new(inflight_limit, manual_acks, auth_manager),
211 requests_rx,
212 _requests_tx: requests_tx,
213 pending,
214 network: None,
215 keepalive_timeout: None,
216 }
217 }
218
219 pub fn clean(&mut self) {
228 self.network = None;
229 self.keepalive_timeout = None;
230 for (request, notice) in self.state.clean_with_notices() {
231 self.pending
232 .push_back(RequestEnvelope::from_parts(request, notice));
233 }
234
235 for envelope in self.requests_rx.drain() {
237 if !matches!(&envelope.request, Request::PubAck(_) | Request::PubRec(_)) {
240 self.pending.push_back(envelope);
241 }
242 }
243 }
244
245 pub fn pending_len(&self) -> usize {
247 self.pending.len()
248 }
249
250 pub fn pending_is_empty(&self) -> bool {
252 self.pending.is_empty()
253 }
254
255 pub fn drain_pending_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
259 let mut drained = 0;
260 for envelope in self.pending.drain(..) {
261 drained += 1;
262 if let Some(notice) = envelope.notice {
263 match notice {
264 TrackedNoticeTx::Publish(notice) => {
265 notice.error(reason.publish_error());
266 }
267 TrackedNoticeTx::Request(notice) => {
268 notice.error(reason.request_error());
269 }
270 }
271 }
272 }
273
274 drained
275 }
276
277 pub fn reset_session_state(&mut self) {
279 self.drain_pending_as_failed(NoticeFailureReason::SessionReset);
280 self.state.fail_pending_notices();
281 }
282
283 fn reconcile_outgoing_tracking_after_connack(&mut self) {
284 self.state
285 .reconcile_outgoing_tracking_capacity(self.pending.is_empty());
286 }
287
288 pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
298 if self.network.is_none() {
299 let (network, connack) = time::timeout(
300 self.options.connect_timeout(),
301 connect(&mut self.options, &mut self.state),
302 )
303 .await??;
304 self.reconcile_connack_session(connack.session_present)?;
305 self.network = Some(network);
306
307 if self.keepalive_timeout.is_none() && !self.options.keep_alive.is_zero() {
308 self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
309 }
310
311 self.state
312 .handle_incoming_packet(Incoming::ConnAck(connack))?;
313 self.reconcile_outgoing_tracking_after_connack();
314 }
315
316 match self.select().await {
317 Ok(v) => Ok(v),
318 Err(e) => {
319 self.clean();
322 Err(e)
323 }
324 }
325 }
326
327 #[allow(clippy::too_many_lines)]
329 async fn select(&mut self) -> Result<Event, ConnectionError> {
330 let read_batch_size = self.effective_read_batch_size();
331 let network = self.network.as_mut().unwrap();
332 let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
335 let collision = self.state.collision.is_some();
336
337 if let Some(event) = self.state.events.pop_front() {
339 return Ok(event);
340 }
341
342 let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
343 select! {
346 o = Self::next_request(
375 &mut self.pending,
376 &self.requests_rx,
377 self.options.pending_throttle
378 ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
379 Ok((request, notice)) => {
380 let max_request_batch = self.options.max_request_batch.max(1);
381 let mut should_flush = false;
382 let mut qos0_notices = Vec::new();
383
384 let (outgoing, flush_notice) =
385 self.state.handle_outgoing_packet_with_notice(request, notice)?;
386 if let Some(notice) = flush_notice {
387 qos0_notices.push(notice);
388 }
389 if let Some(outgoing) = outgoing {
390 if let Err(err) = network.write(outgoing).await {
391 for notice in qos0_notices {
392 notice.error(PublishNoticeError::Qos0NotFlushed);
393 }
394 return Err(ConnectionError::MqttState(err));
395 }
396 should_flush = true;
397 }
398
399 for _ in 1..max_request_batch {
400 let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
401 let collision = self.state.collision.is_some();
402
403 if self.pending.is_empty() && (inflight_full || collision) {
404 break;
405 }
406
407 let Some((next_request, next_notice)) = Self::try_next_request(
408 &mut self.pending,
409 &self.requests_rx,
410 self.options.pending_throttle,
411 ).await else {
412 break;
413 };
414
415 let (outgoing, flush_notice) = self
416 .state
417 .handle_outgoing_packet_with_notice(next_request, next_notice)?;
418 if let Some(notice) = flush_notice {
419 qos0_notices.push(notice);
420 }
421 if let Some(outgoing) = outgoing {
422 if let Err(err) = network.write(outgoing).await {
423 for notice in qos0_notices {
424 notice.error(PublishNoticeError::Qos0NotFlushed);
425 }
426 return Err(ConnectionError::MqttState(err));
427 }
428 should_flush = true;
429 }
430 }
431
432 if should_flush {
433 match network.flush().await {
434 Ok(()) => {
435 for notice in qos0_notices {
436 notice.success();
437 }
438 }
439 Err(err) => {
440 for notice in qos0_notices {
441 notice.error(PublishNoticeError::Qos0NotFlushed);
442 }
443 return Err(ConnectionError::MqttState(err));
444 }
445 }
446 }
447 Ok(self.state.events.pop_front().unwrap())
448 }
449 Err(_) => Err(ConnectionError::RequestsDone),
450 },
451 o = network.readb(&mut self.state, read_batch_size) => {
453 o?;
454 network.flush().await?;
456 Ok(self.state.events.pop_front().unwrap())
457 },
458 () = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
461 if self.keepalive_timeout.is_some() && !self.options.keep_alive.is_zero() => {
462 let timeout = self.keepalive_timeout.as_mut().unwrap();
463 timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
464
465 let (outgoing, _flush_notice) = self
466 .state
467 .handle_outgoing_packet_with_notice(Request::PingReq, None)?;
468 if let Some(outgoing) = outgoing {
469 network.write(outgoing).await?;
470 }
471 network.flush().await?;
472 Ok(self.state.events.pop_front().unwrap())
473 }
474 }
475 }
476
477 async fn try_next_request(
478 pending: &mut VecDeque<RequestEnvelope>,
479 rx: &Receiver<RequestEnvelope>,
480 pending_throttle: Duration,
481 ) -> Option<(Request, Option<TrackedNoticeTx>)> {
482 if !pending.is_empty() {
483 if pending_throttle.is_zero() {
484 tokio::task::yield_now().await;
485 } else {
486 time::sleep(pending_throttle).await;
487 }
488 return pending.pop_front().map(RequestEnvelope::into_parts);
491 }
492
493 match rx.try_recv() {
494 Ok(envelope) => return Some(envelope.into_parts()),
495 Err(TryRecvError::Disconnected) => return None,
496 Err(TryRecvError::Empty) => {}
497 }
498
499 None
500 }
501
502 async fn next_request(
503 pending: &mut VecDeque<RequestEnvelope>,
504 rx: &Receiver<RequestEnvelope>,
505 pending_throttle: Duration,
506 ) -> Result<(Request, Option<TrackedNoticeTx>), ConnectionError> {
507 if pending.is_empty() {
508 rx.recv_async()
509 .await
510 .map(RequestEnvelope::into_parts)
511 .map_err(|_| ConnectionError::RequestsDone)
512 } else {
513 if pending_throttle.is_zero() {
514 tokio::task::yield_now().await;
515 } else {
516 time::sleep(pending_throttle).await;
517 }
518 Ok(pending.pop_front().unwrap().into_parts())
521 }
522 }
523
524 fn effective_read_batch_size(&self) -> usize {
525 const MAX_READ_BATCH_SIZE: usize = 128;
526 const PENDING_FAIRNESS_CAP: usize = 16;
527
528 let configured = self.options.read_batch_size();
529 if configured > 0 {
530 return configured.clamp(1, MAX_READ_BATCH_SIZE);
531 }
532
533 let request_batch = self.options.max_request_batch().max(1);
534 let inflight = usize::from(self.state.max_outgoing_inflight);
535 let mut adaptive = request_batch.max(inflight / 2).max(8);
536
537 if !self.pending.is_empty() || !self.requests_rx.is_empty() {
538 adaptive = adaptive.min(PENDING_FAIRNESS_CAP);
539 }
540
541 adaptive.clamp(1, MAX_READ_BATCH_SIZE)
542 }
543}
544
545async fn connect(
551 options: &mut MqttOptions,
552 state: &mut MqttState,
553) -> Result<(Network, ConnAck), ConnectionError> {
554 let mut network = network_connect(options).await?;
556
557 let connack = mqtt_connect(options, &mut network, state).await?;
559
560 Ok((network, connack))
561}
562
563#[allow(clippy::too_many_lines)]
564async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
565 let max_incoming_pkt_size = options.max_incoming_packet_size();
566 let transport = options.transport();
567
568 #[cfg(unix)]
570 if matches!(&transport, Transport::Unix) {
571 let file = options
572 .broker()
573 .unix_path()
574 .ok_or(ConnectionError::BrokerTransportMismatch)?;
575 let socket = UnixStream::connect(Path::new(file)).await?;
576 let network = Network::new(socket, max_incoming_pkt_size);
577 return Ok(network);
578 }
579
580 let (domain, port) = match &transport {
582 #[cfg(feature = "websocket")]
583 Transport::Ws => split_url(
584 options
585 .broker()
586 .websocket_url()
587 .ok_or(ConnectionError::BrokerTransportMismatch)?,
588 )?,
589 #[cfg(all(
590 any(feature = "use-rustls-no-provider", feature = "use-native-tls"),
591 feature = "websocket"
592 ))]
593 Transport::Wss(_) => split_url(
594 options
595 .broker()
596 .websocket_url()
597 .ok_or(ConnectionError::BrokerTransportMismatch)?,
598 )?,
599 _ => options
600 .broker()
601 .tcp_address()
602 .map(|(host, port)| (host.to_owned(), port))
603 .ok_or(ConnectionError::BrokerTransportMismatch)?,
604 };
605
606 let tcp_stream: Box<dyn AsyncReadWrite> = {
607 #[cfg(feature = "proxy")]
608 if let Some(proxy) = options.proxy() {
609 proxy
610 .connect(
611 &domain,
612 port,
613 options.network_options(),
614 Some(options.effective_socket_connector()),
615 )
616 .await?
617 } else {
618 let addr = format!("{domain}:{port}");
619 options
620 .socket_connect(addr, options.network_options())
621 .await?
622 }
623 #[cfg(not(feature = "proxy"))]
624 {
625 let addr = format!("{domain}:{port}");
626 options
627 .socket_connect(addr, options.network_options())
628 .await?
629 }
630 };
631
632 let network = match transport {
633 Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
634 #[cfg(any(feature = "use-native-tls", feature = "use-rustls-no-provider"))]
635 Transport::Tls(tls_config) => {
636 let (host, port) = options
637 .broker()
638 .tcp_address()
639 .expect("tls transport requires a tcp broker");
640 let socket = tls::tls_connect(host, port, &tls_config, tcp_stream).await?;
641 Network::new(socket, max_incoming_pkt_size)
642 }
643 #[cfg(unix)]
644 Transport::Unix => unreachable!(),
645 #[cfg(feature = "websocket")]
646 Transport::Ws => {
647 let mut request = options
648 .broker()
649 .websocket_url()
650 .expect("ws transport requires a websocket broker")
651 .into_client_request()?;
652 request
653 .headers_mut()
654 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
655
656 if let Some(request_modifier) = options.fallible_request_modifier() {
657 request = request_modifier(request)
658 .await
659 .map_err(ConnectionError::RequestModifier)?;
660 } else if let Some(request_modifier) = options.request_modifier() {
661 request = request_modifier(request).await;
662 }
663
664 let (socket, response) =
665 async_tungstenite::tokio::client_async(request, tcp_stream).await?;
666 validate_response_headers(response)?;
667
668 Network::new(WsAdapter::new(socket), max_incoming_pkt_size)
669 }
670 #[cfg(all(
671 any(feature = "use-rustls-no-provider", feature = "use-native-tls"),
672 feature = "websocket"
673 ))]
674 Transport::Wss(tls_config) => {
675 let mut request = options
676 .broker()
677 .websocket_url()
678 .expect("wss transport requires a websocket broker")
679 .into_client_request()?;
680 request
681 .headers_mut()
682 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
683
684 if let Some(request_modifier) = options.fallible_request_modifier() {
685 request = request_modifier(request)
686 .await
687 .map_err(ConnectionError::RequestModifier)?;
688 } else if let Some(request_modifier) = options.request_modifier() {
689 request = request_modifier(request).await;
690 }
691
692 let tls_stream = tls::tls_connect(&domain, port, &tls_config, tcp_stream).await?;
693 let (socket, response) =
694 async_tungstenite::tokio::client_async(request, tls_stream).await?;
695 validate_response_headers(response)?;
696
697 Network::new(WsAdapter::new(socket), max_incoming_pkt_size)
698 }
699 };
700
701 Ok(network)
702}
703
704async fn mqtt_connect(
705 options: &mut MqttOptions,
706 network: &mut Network,
707 state: &mut MqttState,
708) -> Result<ConnAck, ConnectionError> {
709 let packet = Packet::Connect(
710 Connect {
711 client_id: options.client_id(),
712 keep_alive: u16::try_from(options.keep_alive().as_secs()).unwrap_or(u16::MAX),
713 clean_start: options.clean_start(),
714 properties: options.connect_properties(),
715 },
716 options.last_will(),
717 options.auth().clone(),
718 );
719
720 network.write(packet).await?;
722 network.flush().await?;
723
724 loop {
726 match network.read().await? {
727 Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
728 if let Some(props) = &connack.properties
729 && let Some(keep_alive) = props.server_keep_alive
730 {
731 options.keep_alive = Duration::from_secs(u64::from(keep_alive));
732 }
733
734 if let Some(props) = &connack.properties {
735 network.set_max_outgoing_size(props.max_packet_size);
736
737 if props.session_expiry_interval.is_some() {
739 options.set_session_expiry_interval(props.session_expiry_interval);
740 }
741 }
742 return Ok(connack);
743 }
744 Incoming::ConnAck(connack) => {
745 return Err(ConnectionError::ConnectionRefused(connack.code));
746 }
747 Incoming::Auth(auth) => {
748 if let Some(outgoing) = state.handle_incoming_packet(Incoming::Auth(auth))? {
749 network.write(outgoing).await?;
750 network.flush().await?;
751 } else {
752 return Err(ConnectionError::AuthProcessingError);
753 }
754 }
755 packet => return Err(ConnectionError::NotConnAck(Box::new(packet))),
756 }
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use crate::{ConnAckProperties, Filter, PubAck, PubRec};
764 use flume::TryRecvError;
765
766 fn build_connack_with_receive_max(receive_max: u16) -> ConnAck {
767 ConnAck {
768 session_present: false,
769 code: ConnectReturnCode::Success,
770 properties: Some(ConnAckProperties {
771 session_expiry_interval: None,
772 receive_max: Some(receive_max),
773 max_qos: None,
774 retain_available: None,
775 max_packet_size: None,
776 assigned_client_identifier: None,
777 topic_alias_max: None,
778 reason_string: None,
779 user_properties: vec![],
780 wildcard_subscription_available: None,
781 subscription_identifiers_available: None,
782 shared_subscription_available: None,
783 server_keep_alive: None,
784 response_information: None,
785 server_reference: None,
786 authentication_method: None,
787 authentication_data: None,
788 }),
789 }
790 }
791
792 fn push_pending(eventloop: &mut EventLoop, request: Request) {
793 eventloop.pending.push_back(RequestEnvelope::plain(request));
794 }
795
796 fn pending_front_request(eventloop: &EventLoop) -> Option<&Request> {
797 eventloop.pending.front().map(|envelope| &envelope.request)
798 }
799
800 fn build_eventloop_with_pending(clean_start: bool) -> EventLoop {
801 let mut options = MqttOptions::new("test-client", "localhost");
802 options.set_clean_start(clean_start);
803
804 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
805 push_pending(&mut eventloop, Request::PingReq);
806 eventloop
807 }
808
809 #[test]
810 fn eventloop_new_keeps_internal_sender_alive() {
811 let options = MqttOptions::new("test-client", "localhost");
812 let eventloop = EventLoop::new(options, 1);
813
814 assert!(matches!(
815 eventloop.requests_rx.try_recv(),
816 Err(TryRecvError::Empty)
817 ));
818 }
819
820 #[test]
821 fn async_client_constructor_path_allows_channel_shutdown() {
822 let options = MqttOptions::new("test-client", "localhost");
823 let (eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
824 drop(request_tx);
825
826 assert!(matches!(
827 eventloop.requests_rx.try_recv(),
828 Err(TryRecvError::Disconnected)
829 ));
830 }
831
832 #[test]
833 fn clean_drops_ack_requests_drained_from_channel() {
834 let options = MqttOptions::new("test-client", "localhost");
835 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 3);
836 request_tx
837 .send(RequestEnvelope::plain(Request::PubAck(PubAck::new(
838 7, None,
839 ))))
840 .unwrap();
841 request_tx
842 .send(RequestEnvelope::plain(Request::PubRec(PubRec::new(
843 8, None,
844 ))))
845 .unwrap();
846 request_tx
847 .send(RequestEnvelope::plain(Request::PingReq))
848 .unwrap();
849
850 eventloop.clean();
851
852 assert_eq!(eventloop.pending_len(), 1);
853 assert!(matches!(
854 pending_front_request(&eventloop),
855 Some(Request::PingReq)
856 ));
857 }
858
859 #[tokio::test]
860 #[cfg(unix)]
861 async fn network_connect_rejects_unix_broker_with_tcp_transport() {
862 let mut options = MqttOptions::new("test-client", crate::Broker::unix("/tmp/mqtt.sock"));
863 options.set_transport(Transport::tcp());
864
865 match network_connect(&options).await {
866 Err(ConnectionError::BrokerTransportMismatch) => {}
867 Err(err) => panic!("unexpected error: {err:?}"),
868 Ok(_) => panic!("mismatched broker and transport should fail"),
869 }
870 }
871
872 #[tokio::test]
873 #[cfg(feature = "websocket")]
874 async fn network_connect_rejects_tcp_broker_with_websocket_transport() {
875 let mut options = MqttOptions::new("test-client", "localhost");
876 options.set_transport(Transport::Ws);
877
878 match network_connect(&options).await {
879 Err(ConnectionError::BrokerTransportMismatch) => {}
880 Err(err) => panic!("unexpected error: {err:?}"),
881 Ok(_) => panic!("mismatched broker and transport should fail"),
882 }
883 }
884
885 #[tokio::test]
886 #[cfg(feature = "websocket")]
887 async fn network_connect_rejects_websocket_broker_with_tcp_transport() {
888 let broker = crate::Broker::websocket("ws://localhost:9001/mqtt").unwrap();
889 let mut options = MqttOptions::new("test-client", broker);
890 options.set_transport(Transport::tcp());
891
892 match network_connect(&options).await {
893 Err(ConnectionError::BrokerTransportMismatch) => {}
894 Err(err) => panic!("unexpected error: {err:?}"),
895 Ok(_) => panic!("mismatched broker and transport should fail"),
896 }
897 }
898
899 #[test]
900 fn connack_resize_skips_shrink_until_pending_retransmit_queue_is_empty() {
901 let mut options = MqttOptions::new("test-client", "localhost");
902 options.set_outgoing_inflight_upper_limit(10);
903 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
904 let mut publish = Publish::new(
905 "hello/world",
906 crate::mqttbytes::QoS::AtLeastOnce,
907 "payload",
908 None,
909 );
910 publish.pkid = 8;
911 push_pending(&mut eventloop, Request::Publish(publish));
912
913 eventloop
914 .state
915 .handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
916 .unwrap();
917
918 eventloop.reconcile_outgoing_tracking_after_connack();
919 assert_eq!(eventloop.state.outgoing_pub.len(), 11);
920
921 eventloop.pending.clear();
922 eventloop.reconcile_outgoing_tracking_after_connack();
923 assert_eq!(eventloop.state.outgoing_pub.len(), 4);
924 assert_eq!(eventloop.state.outgoing_pub_notice.len(), 4);
925 assert_eq!(eventloop.state.outgoing_rel_notice.len(), 4);
926 }
927
928 #[tokio::test]
929 async fn async_client_path_reports_requests_done_after_pending_drain() {
930 let options = MqttOptions::new("test-client", "localhost");
931 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
932 push_pending(&mut eventloop, Request::PingReq);
933 drop(request_tx);
934
935 let request = EventLoop::next_request(
936 &mut eventloop.pending,
937 &eventloop.requests_rx,
938 Duration::ZERO,
939 )
940 .await
941 .unwrap();
942 assert!(matches!(request, (Request::PingReq, None)));
943
944 let err = EventLoop::next_request(
945 &mut eventloop.pending,
946 &eventloop.requests_rx,
947 Duration::ZERO,
948 )
949 .await
950 .unwrap_err();
951 assert!(matches!(err, ConnectionError::RequestsDone));
952 }
953
954 #[tokio::test]
955 async fn next_request_is_cancellation_safe_for_pending_queue() {
956 let options = MqttOptions::new("test-client", "localhost");
957 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
958 push_pending(&mut eventloop, Request::PingReq);
959
960 let delayed = EventLoop::next_request(
961 &mut eventloop.pending,
962 &eventloop.requests_rx,
963 Duration::from_millis(50),
964 );
965 let timed_out = time::timeout(Duration::from_millis(5), delayed).await;
966
967 assert!(timed_out.is_err());
968 assert!(matches!(
969 pending_front_request(&eventloop),
970 Some(Request::PingReq)
971 ));
972 }
973
974 #[tokio::test]
975 async fn try_next_request_applies_pending_throttle_for_followup_pending_item() {
976 let options = MqttOptions::new("test-client", "localhost");
977 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 2);
978 push_pending(&mut eventloop, Request::PingReq);
979 push_pending(&mut eventloop, Request::PingResp);
980
981 let first = EventLoop::next_request(
982 &mut eventloop.pending,
983 &eventloop.requests_rx,
984 Duration::ZERO,
985 )
986 .await
987 .unwrap();
988 assert!(matches!(first, (Request::PingReq, None)));
989
990 let delayed = EventLoop::try_next_request(
991 &mut eventloop.pending,
992 &eventloop.requests_rx,
993 Duration::from_millis(50),
994 );
995 let timed_out = time::timeout(Duration::from_millis(5), delayed).await;
996
997 assert!(timed_out.is_err());
998 assert!(matches!(
999 pending_front_request(&eventloop),
1000 Some(Request::PingResp)
1001 ));
1002 }
1003
1004 #[tokio::test]
1005 async fn try_next_request_does_not_throttle_when_pending_queue_is_empty() {
1006 let options = MqttOptions::new("test-client", "localhost");
1007 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
1008 request_tx
1009 .send_async(RequestEnvelope::plain(Request::PingReq))
1010 .await
1011 .unwrap();
1012
1013 let received = time::timeout(
1014 Duration::from_millis(20),
1015 EventLoop::try_next_request(
1016 &mut eventloop.pending,
1017 &eventloop.requests_rx,
1018 Duration::from_secs(1),
1019 ),
1020 )
1021 .await
1022 .unwrap();
1023
1024 assert!(matches!(received, Some((Request::PingReq, None))));
1025 }
1026
1027 #[tokio::test]
1028 async fn next_request_prioritizes_pending_over_channel_messages() {
1029 let options = MqttOptions::new("test-client", "localhost");
1030 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 2);
1031 push_pending(&mut eventloop, Request::PingReq);
1032 request_tx
1033 .send_async(RequestEnvelope::plain(Request::PingReq))
1034 .await
1035 .unwrap();
1036
1037 let first = EventLoop::next_request(
1038 &mut eventloop.pending,
1039 &eventloop.requests_rx,
1040 Duration::ZERO,
1041 )
1042 .await
1043 .unwrap();
1044 assert!(matches!(first, (Request::PingReq, None)));
1045 assert!(eventloop.pending.is_empty());
1046
1047 let second = EventLoop::next_request(
1048 &mut eventloop.pending,
1049 &eventloop.requests_rx,
1050 Duration::ZERO,
1051 )
1052 .await
1053 .unwrap();
1054 assert!(matches!(second, (Request::PingReq, None)));
1055 }
1056
1057 #[tokio::test]
1058 async fn next_request_preserves_fifo_order_for_plain_and_tracked_requests() {
1059 let options = MqttOptions::new("test-client", "localhost");
1060 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
1061 let (notice_tx, _notice) = PublishNoticeTx::new();
1062 let tracked_publish = Publish::new(
1063 "hello/world",
1064 crate::mqttbytes::QoS::AtLeastOnce,
1065 "payload",
1066 None,
1067 );
1068
1069 request_tx
1070 .send_async(RequestEnvelope::plain(Request::PingReq))
1071 .await
1072 .unwrap();
1073 request_tx
1074 .send_async(RequestEnvelope::tracked_publish(
1075 tracked_publish.clone(),
1076 notice_tx,
1077 ))
1078 .await
1079 .unwrap();
1080 request_tx
1081 .send_async(RequestEnvelope::plain(Request::PingResp))
1082 .await
1083 .unwrap();
1084
1085 let first = EventLoop::next_request(
1086 &mut eventloop.pending,
1087 &eventloop.requests_rx,
1088 Duration::ZERO,
1089 )
1090 .await
1091 .unwrap();
1092 assert!(matches!(first, (Request::PingReq, None)));
1093
1094 let second = EventLoop::next_request(
1095 &mut eventloop.pending,
1096 &eventloop.requests_rx,
1097 Duration::ZERO,
1098 )
1099 .await
1100 .unwrap();
1101 assert!(matches!(
1102 second,
1103 (Request::Publish(publish), Some(_)) if publish == tracked_publish
1104 ));
1105
1106 let third = EventLoop::next_request(
1107 &mut eventloop.pending,
1108 &eventloop.requests_rx,
1109 Duration::ZERO,
1110 )
1111 .await
1112 .unwrap();
1113 assert!(matches!(third, (Request::PingResp, None)));
1114 }
1115
1116 #[tokio::test]
1117 async fn tracked_qos0_notice_reports_not_flushed_on_first_write_failure() {
1118 let options = MqttOptions::new("test-client", "localhost");
1119 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
1120 let (client, _peer) = tokio::io::duplex(1024);
1121 let mut network = Network::new(client, Some(1024));
1122 network.set_max_outgoing_size(Some(16));
1123 eventloop.network = Some(network);
1124
1125 let (notice_tx, notice) = PublishNoticeTx::new();
1126 let publish = Publish::new(
1127 "hello/world",
1128 crate::mqttbytes::QoS::AtMostOnce,
1129 vec![1; 128],
1130 None,
1131 );
1132 request_tx
1133 .send_async(RequestEnvelope::tracked_publish(publish, notice_tx))
1134 .await
1135 .unwrap();
1136
1137 let err = eventloop.select().await.unwrap_err();
1138 assert!(matches!(err, ConnectionError::MqttState(_)));
1139 assert_eq!(
1140 notice.wait_async().await.unwrap_err(),
1141 PublishNoticeError::Qos0NotFlushed
1142 );
1143 }
1144
1145 #[tokio::test]
1146 async fn tracked_qos0_notices_report_not_flushed_on_batched_write_failure() {
1147 let mut options = MqttOptions::new("test-client", "localhost");
1148 options.set_max_request_batch(2);
1149 let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
1150 let (client, _peer) = tokio::io::duplex(1024);
1151 let mut network = Network::new(client, Some(1024));
1152 network.set_max_outgoing_size(Some(80));
1153 eventloop.network = Some(network);
1154
1155 let small_publish = Publish::new(
1156 "hello/world",
1157 crate::mqttbytes::QoS::AtMostOnce,
1158 vec![1],
1159 None,
1160 );
1161 let large_publish = Publish::new(
1162 "hello/world",
1163 crate::mqttbytes::QoS::AtMostOnce,
1164 vec![2; 256],
1165 None,
1166 );
1167
1168 let (first_notice_tx, first_notice) = PublishNoticeTx::new();
1169 request_tx
1170 .send_async(RequestEnvelope::tracked_publish(
1171 small_publish,
1172 first_notice_tx,
1173 ))
1174 .await
1175 .unwrap();
1176
1177 let (second_notice_tx, second_notice) = PublishNoticeTx::new();
1178 request_tx
1179 .send_async(RequestEnvelope::tracked_publish(
1180 large_publish,
1181 second_notice_tx,
1182 ))
1183 .await
1184 .unwrap();
1185
1186 let err = eventloop.select().await.unwrap_err();
1187 assert!(matches!(err, ConnectionError::MqttState(_)));
1188 assert_eq!(
1189 first_notice.wait_async().await.unwrap_err(),
1190 PublishNoticeError::Qos0NotFlushed
1191 );
1192 assert_eq!(
1193 second_notice.wait_async().await.unwrap_err(),
1194 PublishNoticeError::Qos0NotFlushed
1195 );
1196 }
1197
1198 #[tokio::test]
1199 async fn drain_pending_as_failed_drains_all_and_returns_count() {
1200 let options = MqttOptions::new("test-client", "localhost");
1201 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1202 let (notice_tx, notice) = PublishNoticeTx::new();
1203 let publish = Publish::new(
1204 "hello/world",
1205 crate::mqttbytes::QoS::AtLeastOnce,
1206 "payload",
1207 None,
1208 );
1209 eventloop
1210 .pending
1211 .push_back(RequestEnvelope::tracked_publish(publish, notice_tx));
1212 eventloop
1213 .pending
1214 .push_back(RequestEnvelope::plain(Request::PingReq));
1215
1216 let drained = eventloop.drain_pending_as_failed(NoticeFailureReason::SessionReset);
1217
1218 assert_eq!(drained, 2);
1219 assert!(eventloop.pending.is_empty());
1220 assert_eq!(
1221 notice.wait_async().await.unwrap_err(),
1222 PublishNoticeError::SessionReset
1223 );
1224 }
1225
1226 #[tokio::test]
1227 async fn drain_pending_as_failed_reports_session_reset_for_tracked_notices() {
1228 let options = MqttOptions::new("test-client", "localhost");
1229 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1230 let (publish_notice_tx, publish_notice) = PublishNoticeTx::new();
1231 let publish = Publish::new(
1232 "hello/world",
1233 crate::mqttbytes::QoS::AtLeastOnce,
1234 "payload",
1235 None,
1236 );
1237 eventloop
1238 .pending
1239 .push_back(RequestEnvelope::tracked_publish(publish, publish_notice_tx));
1240
1241 let (request_notice_tx, request_notice) = RequestNoticeTx::new();
1242 let subscribe = Subscribe::new(
1243 Filter::new("hello/world", crate::mqttbytes::QoS::AtMostOnce),
1244 None,
1245 );
1246 eventloop
1247 .pending
1248 .push_back(RequestEnvelope::tracked_subscribe(
1249 subscribe,
1250 request_notice_tx,
1251 ));
1252
1253 eventloop.drain_pending_as_failed(NoticeFailureReason::SessionReset);
1254
1255 assert_eq!(
1256 publish_notice.wait_async().await.unwrap_err(),
1257 PublishNoticeError::SessionReset
1258 );
1259 assert_eq!(
1260 request_notice.wait_async().await.unwrap_err(),
1261 crate::RequestNoticeError::SessionReset
1262 );
1263 }
1264
1265 #[tokio::test]
1266 async fn reset_session_state_reports_session_reset_for_pending_tracked_notice() {
1267 let options = MqttOptions::new("test-client", "localhost");
1268 let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1269 let (notice_tx, notice) = PublishNoticeTx::new();
1270 let publish = Publish::new(
1271 "hello/world",
1272 crate::mqttbytes::QoS::AtLeastOnce,
1273 "payload",
1274 None,
1275 );
1276 eventloop
1277 .pending
1278 .push_back(RequestEnvelope::tracked_publish(publish, notice_tx));
1279
1280 eventloop.reset_session_state();
1281
1282 assert!(eventloop.pending.is_empty());
1283 assert_eq!(
1284 notice.wait_async().await.unwrap_err(),
1285 PublishNoticeError::SessionReset
1286 );
1287 }
1288
1289 #[test]
1290 fn connack_reconcile_rejects_clean_start_with_session_present() {
1291 let mut eventloop = build_eventloop_with_pending(true);
1292
1293 let err = eventloop.reconcile_connack_session(true).unwrap_err();
1294
1295 assert!(matches!(
1296 err,
1297 ConnectionError::SessionStateMismatch {
1298 clean_start: true,
1299 session_present: true
1300 }
1301 ));
1302 assert_eq!(eventloop.pending_len(), 1);
1303 }
1304
1305 #[test]
1306 fn connack_reconcile_resets_pending_when_clean_start_gets_new_session() {
1307 let mut eventloop = build_eventloop_with_pending(true);
1308
1309 eventloop.reconcile_connack_session(false).unwrap();
1310
1311 assert!(eventloop.pending_is_empty());
1312 }
1313
1314 #[test]
1315 fn connack_reconcile_resets_pending_when_resumed_session_is_missing() {
1316 let mut eventloop = build_eventloop_with_pending(false);
1317
1318 eventloop.reconcile_connack_session(false).unwrap();
1319
1320 assert!(eventloop.pending_is_empty());
1321 }
1322
1323 #[test]
1324 fn connack_reconcile_keeps_pending_when_resumed_session_exists() {
1325 let mut eventloop = build_eventloop_with_pending(false);
1326
1327 eventloop.reconcile_connack_session(true).unwrap();
1328
1329 assert_eq!(eventloop.pending_len(), 1);
1330 }
1331}