Skip to main content

rumqttc/
eventloop.rs

1use super::framed::Network;
2use super::mqttbytes::v5::{ConnAck, Connect, Packet, Publish, Subscribe, Unsubscribe};
3use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
4use crate::framed::AsyncReadWrite;
5use crate::notice::{PublishNoticeTx, RequestNoticeTx, TrackedNoticeTx};
6use crate::{NoticeFailureReason, PublishNoticeError};
7
8use flume::{Receiver, Sender, TryRecvError, bounded};
9use tokio::select;
10use tokio::time::{self, Instant, Sleep, error::Elapsed};
11
12use std::collections::VecDeque;
13use std::io;
14use std::pin::Pin;
15use std::time::Duration;
16
17use super::mqttbytes::v5::ConnectReturnCode;
18
19#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
20use crate::tls;
21
22#[cfg(unix)]
23use {std::path::Path, tokio::net::UnixStream};
24
25#[cfg(feature = "websocket")]
26use {
27    crate::websockets::WsAdapter,
28    crate::websockets::{UrlError, split_url, validate_response_headers},
29    async_tungstenite::tungstenite::client::IntoClientRequest,
30};
31
32#[cfg(feature = "proxy")]
33use crate::proxy::ProxyError;
34
35#[derive(Debug)]
36pub struct RequestEnvelope {
37    request: Request,
38    notice: Option<TrackedNoticeTx>,
39}
40
41impl RequestEnvelope {
42    pub(crate) const fn from_parts(request: Request, notice: Option<TrackedNoticeTx>) -> Self {
43        Self { request, notice }
44    }
45
46    pub(crate) const fn plain(request: Request) -> Self {
47        Self {
48            request,
49            notice: None,
50        }
51    }
52
53    pub(crate) const fn tracked_publish(publish: Publish, notice: PublishNoticeTx) -> Self {
54        Self {
55            request: Request::Publish(publish),
56            notice: Some(TrackedNoticeTx::Publish(notice)),
57        }
58    }
59
60    pub(crate) const fn tracked_subscribe(subscribe: Subscribe, notice: RequestNoticeTx) -> Self {
61        Self {
62            request: Request::Subscribe(subscribe),
63            notice: Some(TrackedNoticeTx::Request(notice)),
64        }
65    }
66
67    pub(crate) const fn tracked_unsubscribe(
68        unsubscribe: Unsubscribe,
69        notice: RequestNoticeTx,
70    ) -> Self {
71        Self {
72            request: Request::Unsubscribe(unsubscribe),
73            notice: Some(TrackedNoticeTx::Request(notice)),
74        }
75    }
76
77    pub(crate) fn into_parts(self) -> (Request, Option<TrackedNoticeTx>) {
78        (self.request, self.notice)
79    }
80}
81
82/// Critical errors during eventloop polling
83#[derive(Debug, thiserror::Error)]
84pub enum ConnectionError {
85    #[error("Mqtt state: {0}")]
86    MqttState(#[from] StateError),
87    #[error("Timeout")]
88    Timeout(#[from] Elapsed),
89    #[cfg(feature = "websocket")]
90    #[error("Websocket: {0}")]
91    Websocket(#[from] async_tungstenite::tungstenite::error::Error),
92    #[cfg(feature = "websocket")]
93    #[error("Websocket Connect: {0}")]
94    WsConnect(#[from] http::Error),
95    #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
96    #[error("TLS: {0}")]
97    Tls(#[from] tls::Error),
98    #[error("I/O: {0}")]
99    Io(#[from] io::Error),
100    #[error("Connection refused, return code: `{0:?}`")]
101    ConnectionRefused(ConnectReturnCode),
102    #[error("Expected ConnAck packet, received: {0:?}")]
103    NotConnAck(Box<Packet>),
104    #[error("Broker replied with session_present={session_present} for clean_start={clean_start}")]
105    SessionStateMismatch {
106        clean_start: bool,
107        session_present: bool,
108    },
109    #[error("Broker target is incompatible with the selected transport")]
110    BrokerTransportMismatch,
111    /// All request senders have been dropped. Use `AsyncClient::disconnect` for MQTT-level
112    /// graceful shutdown with a DISCONNECT packet.
113    #[error("Requests done")]
114    RequestsDone,
115    #[error("Auth processing error")]
116    AuthProcessingError,
117    #[cfg(feature = "websocket")]
118    #[error("Invalid Url: {0}")]
119    InvalidUrl(#[from] UrlError),
120    #[cfg(feature = "proxy")]
121    #[error("Proxy Connect: {0}")]
122    Proxy(#[from] ProxyError),
123    #[cfg(feature = "websocket")]
124    #[error("Websocket response validation error: ")]
125    ResponseValidation(#[from] crate::websockets::ValidationError),
126    #[cfg(feature = "websocket")]
127    #[error("Websocket request modifier failed: {0}")]
128    RequestModifier(#[source] Box<dyn std::error::Error + Send + Sync>),
129}
130
131/// Eventloop with all the state of a connection
132pub struct EventLoop {
133    /// Options of the current mqtt connection
134    pub options: MqttOptions,
135    /// Current state of the connection
136    pub state: MqttState,
137    /// Request stream
138    requests_rx: Receiver<RequestEnvelope>,
139    /// Internal request sender retained for compatibility in `EventLoop::new`.
140    /// This is intentionally dropped in the `AsyncClient::new` constructor path.
141    _requests_tx: Option<Sender<RequestEnvelope>>,
142    /// Pending packets from last session
143    pending: VecDeque<RequestEnvelope>,
144    /// Network connection to the broker
145    network: Option<Network>,
146    /// Keep alive time
147    keepalive_timeout: Option<Pin<Box<Sleep>>>,
148}
149
150/// Events which can be yielded by the event loop
151#[derive(Debug, Clone, PartialEq, Eq)]
152#[allow(clippy::large_enum_variant)]
153pub enum Event {
154    Incoming(Incoming),
155    Outgoing(Outgoing),
156}
157
158impl EventLoop {
159    fn reconcile_connack_session(&mut self, session_present: bool) -> Result<(), ConnectionError> {
160        let clean_start = self.options.clean_start();
161        if clean_start && session_present {
162            return Err(ConnectionError::SessionStateMismatch {
163                clean_start,
164                session_present,
165            });
166        }
167
168        if !session_present {
169            self.reset_session_state();
170        }
171
172        Ok(())
173    }
174
175    /// New MQTT `EventLoop`
176    ///
177    /// When connection encounters critical errors (like auth failure), user has a choice to
178    /// access and update `options`, `state` and `requests`.
179    pub fn new(options: MqttOptions, cap: usize) -> Self {
180        let (requests_tx, requests_rx) = bounded(cap);
181        Self::with_channel(options, requests_rx, Some(requests_tx))
182    }
183
184    /// Internal constructor used by `AsyncClient::new`.
185    ///
186    /// Unlike `EventLoop::new`, this does not keep an internal sender handle, so dropping all
187    /// `AsyncClient` handles can terminate polling with `ConnectionError::RequestsDone`.
188    pub(crate) fn new_for_async_client(
189        options: MqttOptions,
190        cap: usize,
191    ) -> (Self, Sender<RequestEnvelope>) {
192        let (requests_tx, requests_rx) = bounded(cap);
193        let eventloop = Self::with_channel(options, requests_rx, None);
194        (eventloop, requests_tx)
195    }
196
197    fn with_channel(
198        options: MqttOptions,
199        requests_rx: Receiver<RequestEnvelope>,
200        requests_tx: Option<Sender<RequestEnvelope>>,
201    ) -> Self {
202        let pending = VecDeque::new();
203        let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX);
204        let manual_acks = options.manual_acks;
205
206        let auth_manager = options.auth_manager();
207
208        Self {
209            options,
210            state: MqttState::new(inflight_limit, manual_acks, auth_manager),
211            requests_rx,
212            _requests_tx: requests_tx,
213            pending,
214            network: None,
215            keepalive_timeout: None,
216        }
217    }
218
219    /// Last session might contain packets which aren't acked. MQTT says these packets should be
220    /// republished in the next session. Move pending messages from state to eventloop, drops the
221    /// underlying network connection and clears the keepalive timeout if any.
222    ///
223    /// > NOTE: Use only when EventLoop is blocked on network and unable to immediately handle disconnect.
224    /// > Pending requests are managed internally by the event loop.
225    /// > Use [`pending_len`](Self::pending_len) or [`pending_is_empty`](Self::pending_is_empty)
226    /// > for observation-only checks.
227    pub fn clean(&mut self) {
228        self.network = None;
229        self.keepalive_timeout = None;
230        for (request, notice) in self.state.clean_with_notices() {
231            self.pending
232                .push_back(RequestEnvelope::from_parts(request, notice));
233        }
234
235        // drain requests from channel which weren't yet received
236        for envelope in self.requests_rx.drain() {
237            // Wait for publish retransmission, else the broker could be confused by an unexpected
238            // inbound acknowledgment replayed from a previous connection.
239            if !matches!(&envelope.request, Request::PubAck(_) | Request::PubRec(_)) {
240                self.pending.push_back(envelope);
241            }
242        }
243    }
244
245    /// Number of pending requests queued for retransmission.
246    pub fn pending_len(&self) -> usize {
247        self.pending.len()
248    }
249
250    /// Returns true when there are no pending requests queued for retransmission.
251    pub fn pending_is_empty(&self) -> bool {
252        self.pending.is_empty()
253    }
254
255    /// Drains pending retransmission queue and fails tracked notices with the given reason.
256    ///
257    /// Returns the number of pending requests removed from the queue.
258    pub fn drain_pending_as_failed(&mut self, reason: NoticeFailureReason) -> usize {
259        let mut drained = 0;
260        for envelope in self.pending.drain(..) {
261            drained += 1;
262            if let Some(notice) = envelope.notice {
263                match notice {
264                    TrackedNoticeTx::Publish(notice) => {
265                        notice.error(reason.publish_error());
266                    }
267                    TrackedNoticeTx::Request(notice) => {
268                        notice.error(reason.request_error());
269                    }
270                }
271            }
272        }
273
274        drained
275    }
276
277    /// Clears eventloop and state tracking bound to a previous session.
278    pub fn reset_session_state(&mut self) {
279        self.drain_pending_as_failed(NoticeFailureReason::SessionReset);
280        self.state.fail_pending_notices();
281    }
282
283    fn reconcile_outgoing_tracking_after_connack(&mut self) {
284        self.state
285            .reconcile_outgoing_tracking_capacity(self.pending.is_empty());
286    }
287
288    /// Yields Next notification or outgoing request and periodically pings
289    /// the broker. Continuing to poll will reconnect to the broker if there is
290    /// a disconnection.
291    /// **NOTE** Don't block this while iterating
292    ///
293    /// # Errors
294    ///
295    /// Returns a [`ConnectionError`] if connecting, reading, writing, or
296    /// protocol handling fails.
297    pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
298        if self.network.is_none() {
299            let (network, connack) = time::timeout(
300                self.options.connect_timeout(),
301                connect(&mut self.options, &mut self.state),
302            )
303            .await??;
304            self.reconcile_connack_session(connack.session_present)?;
305            self.network = Some(network);
306
307            if self.keepalive_timeout.is_none() && !self.options.keep_alive.is_zero() {
308                self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
309            }
310
311            self.state
312                .handle_incoming_packet(Incoming::ConnAck(connack))?;
313            self.reconcile_outgoing_tracking_after_connack();
314        }
315
316        match self.select().await {
317            Ok(v) => Ok(v),
318            Err(e) => {
319                // MQTT requires that packets pending acknowledgement should be republished on session resume.
320                // Move pending messages from state to eventloop.
321                self.clean();
322                Err(e)
323            }
324        }
325    }
326
327    /// Select on network and requests and generate keepalive pings when necessary
328    #[allow(clippy::too_many_lines)]
329    async fn select(&mut self) -> Result<Event, ConnectionError> {
330        let read_batch_size = self.effective_read_batch_size();
331        let network = self.network.as_mut().unwrap();
332        // let await_acks = self.state.await_acks;
333
334        let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
335        let collision = self.state.collision.is_some();
336
337        // Read buffered events from previous polls before calling a new poll
338        if let Some(event) = self.state.events.pop_front() {
339            return Ok(event);
340        }
341
342        let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
343        // this loop is necessary since self.incoming.pop_front() might return None. In that case,
344        // instead of returning a None event, we try again.
345        select! {
346            // Handles pending and new requests.
347            // If available, prioritises pending requests from previous session.
348            // Else, pulls next request from user requests channel.
349            // If conditions in the below branch are for flow control.
350            // The branch is disabled if there's no pending messages and new user requests
351            // cannot be serviced due flow control.
352            // We read next user user request only when inflight messages are < configured inflight
353            // and there are no collisions while handling previous outgoing requests.
354            //
355            // Flow control is based on ack count. If inflight packet count in the buffer is
356            // less than max_inflight setting, next outgoing request will progress. For this
357            // to work correctly, broker should ack in sequence (a lot of brokers won't)
358            //
359            // E.g If max inflight = 5, user requests will be blocked when inflight queue
360            // looks like this                 -> [1, 2, 3, 4, 5].
361            // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5].
362            // This pulls next user request. But because max packet id = max_inflight, next
363            // user request's packet id will roll to 1. This replaces existing packet id 1.
364            // Resulting in a collision
365            //
366            // Eventloop can stop receiving outgoing user requests when previous outgoing
367            // request collided. I.e collision state. Collision state will be cleared only
368            // when correct ack is received
369            // Full inflight queue will look like -> [1a, 2, 3, 4, 5].
370            // If 3 is acked instead of 1 first   -> [1a, 2, x, 4, 5].
371            // After collision with pkid 1        -> [1b ,2, x, 4, 5].
372            // 1a is saved to state and event loop is set to collision mode stopping new
373            // outgoing requests (along with 1b).
374            o = Self::next_request(
375                &mut self.pending,
376                &self.requests_rx,
377                self.options.pending_throttle
378            ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
379                Ok((request, notice)) => {
380                    let max_request_batch = self.options.max_request_batch.max(1);
381                    let mut should_flush = false;
382                    let mut qos0_notices = Vec::new();
383
384                    let (outgoing, flush_notice) =
385                        self.state.handle_outgoing_packet_with_notice(request, notice)?;
386                    if let Some(notice) = flush_notice {
387                        qos0_notices.push(notice);
388                    }
389                    if let Some(outgoing) = outgoing {
390                        if let Err(err) = network.write(outgoing).await {
391                            for notice in qos0_notices {
392                                notice.error(PublishNoticeError::Qos0NotFlushed);
393                            }
394                            return Err(ConnectionError::MqttState(err));
395                        }
396                        should_flush = true;
397                    }
398
399                    for _ in 1..max_request_batch {
400                        let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
401                        let collision = self.state.collision.is_some();
402
403                        if self.pending.is_empty() && (inflight_full || collision) {
404                            break;
405                        }
406
407                        let Some((next_request, next_notice)) = Self::try_next_request(
408                            &mut self.pending,
409                            &self.requests_rx,
410                            self.options.pending_throttle,
411                        ).await else {
412                            break;
413                        };
414
415                        let (outgoing, flush_notice) = self
416                            .state
417                            .handle_outgoing_packet_with_notice(next_request, next_notice)?;
418                        if let Some(notice) = flush_notice {
419                            qos0_notices.push(notice);
420                        }
421                        if let Some(outgoing) = outgoing {
422                            if let Err(err) = network.write(outgoing).await {
423                                for notice in qos0_notices {
424                                    notice.error(PublishNoticeError::Qos0NotFlushed);
425                                }
426                                return Err(ConnectionError::MqttState(err));
427                            }
428                            should_flush = true;
429                        }
430                    }
431
432                    if should_flush {
433                        match network.flush().await {
434                            Ok(()) => {
435                                for notice in qos0_notices {
436                                    notice.success();
437                                }
438                            }
439                            Err(err) => {
440                                for notice in qos0_notices {
441                                    notice.error(PublishNoticeError::Qos0NotFlushed);
442                                }
443                                return Err(ConnectionError::MqttState(err));
444                            }
445                        }
446                    }
447                    Ok(self.state.events.pop_front().unwrap())
448                }
449                Err(_) => Err(ConnectionError::RequestsDone),
450            },
451            // Pull a bunch of packets from network, reply in bunch and yield the first item
452            o = network.readb(&mut self.state, read_batch_size) => {
453                o?;
454                // flush all the acks and return first incoming packet
455                network.flush().await?;
456                Ok(self.state.events.pop_front().unwrap())
457            },
458            // We generate pings irrespective of network activity. This keeps the ping logic
459            // simple. We can change this behavior in future if necessary (to prevent extra pings)
460            () = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
461                if self.keepalive_timeout.is_some() && !self.options.keep_alive.is_zero() => {
462                    let timeout = self.keepalive_timeout.as_mut().unwrap();
463                    timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
464
465                    let (outgoing, _flush_notice) = self
466                        .state
467                        .handle_outgoing_packet_with_notice(Request::PingReq, None)?;
468                    if let Some(outgoing) = outgoing {
469                        network.write(outgoing).await?;
470                    }
471                    network.flush().await?;
472                    Ok(self.state.events.pop_front().unwrap())
473            }
474        }
475    }
476
477    async fn try_next_request(
478        pending: &mut VecDeque<RequestEnvelope>,
479        rx: &Receiver<RequestEnvelope>,
480        pending_throttle: Duration,
481    ) -> Option<(Request, Option<TrackedNoticeTx>)> {
482        if !pending.is_empty() {
483            if pending_throttle.is_zero() {
484                tokio::task::yield_now().await;
485            } else {
486                time::sleep(pending_throttle).await;
487            }
488            // We must call .next() AFTER sleep() otherwise .next() would
489            // advance the iterator but the future might be canceled before return
490            return pending.pop_front().map(RequestEnvelope::into_parts);
491        }
492
493        match rx.try_recv() {
494            Ok(envelope) => return Some(envelope.into_parts()),
495            Err(TryRecvError::Disconnected) => return None,
496            Err(TryRecvError::Empty) => {}
497        }
498
499        None
500    }
501
502    async fn next_request(
503        pending: &mut VecDeque<RequestEnvelope>,
504        rx: &Receiver<RequestEnvelope>,
505        pending_throttle: Duration,
506    ) -> Result<(Request, Option<TrackedNoticeTx>), ConnectionError> {
507        if pending.is_empty() {
508            rx.recv_async()
509                .await
510                .map(RequestEnvelope::into_parts)
511                .map_err(|_| ConnectionError::RequestsDone)
512        } else {
513            if pending_throttle.is_zero() {
514                tokio::task::yield_now().await;
515            } else {
516                time::sleep(pending_throttle).await;
517            }
518            // We must call .next() AFTER sleep() otherwise .next() would
519            // advance the iterator but the future might be canceled before return
520            Ok(pending.pop_front().unwrap().into_parts())
521        }
522    }
523
524    fn effective_read_batch_size(&self) -> usize {
525        const MAX_READ_BATCH_SIZE: usize = 128;
526        const PENDING_FAIRNESS_CAP: usize = 16;
527
528        let configured = self.options.read_batch_size();
529        if configured > 0 {
530            return configured.clamp(1, MAX_READ_BATCH_SIZE);
531        }
532
533        let request_batch = self.options.max_request_batch().max(1);
534        let inflight = usize::from(self.state.max_outgoing_inflight);
535        let mut adaptive = request_batch.max(inflight / 2).max(8);
536
537        if !self.pending.is_empty() || !self.requests_rx.is_empty() {
538            adaptive = adaptive.min(PENDING_FAIRNESS_CAP);
539        }
540
541        adaptive.clamp(1, MAX_READ_BATCH_SIZE)
542    }
543}
544
545/// This stream internally processes requests from the request stream provided to the eventloop
546/// while also consuming byte stream from the network and yielding mqtt packets as the output of
547/// the stream.
548/// This function (for convenience) includes internal delays for users to perform internal sleeps
549/// between re-connections so that cancel semantics can be used during this sleep
550async fn connect(
551    options: &mut MqttOptions,
552    state: &mut MqttState,
553) -> Result<(Network, ConnAck), ConnectionError> {
554    // connect to the broker
555    let mut network = network_connect(options).await?;
556
557    // make MQTT connection request (which internally awaits for ack)
558    let connack = mqtt_connect(options, &mut network, state).await?;
559
560    Ok((network, connack))
561}
562
563#[allow(clippy::too_many_lines)]
564async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
565    let max_incoming_pkt_size = options.max_incoming_packet_size();
566    let transport = options.transport();
567
568    // Process Unix files early, as proxy is not supported for them.
569    #[cfg(unix)]
570    if matches!(&transport, Transport::Unix) {
571        let file = options
572            .broker()
573            .unix_path()
574            .ok_or(ConnectionError::BrokerTransportMismatch)?;
575        let socket = UnixStream::connect(Path::new(file)).await?;
576        let network = Network::new(socket, max_incoming_pkt_size);
577        return Ok(network);
578    }
579
580    // For websockets domain and port are taken directly from the broker URL.
581    let (domain, port) = match &transport {
582        #[cfg(feature = "websocket")]
583        Transport::Ws => split_url(
584            options
585                .broker()
586                .websocket_url()
587                .ok_or(ConnectionError::BrokerTransportMismatch)?,
588        )?,
589        #[cfg(all(
590            any(feature = "use-rustls-no-provider", feature = "use-native-tls"),
591            feature = "websocket"
592        ))]
593        Transport::Wss(_) => split_url(
594            options
595                .broker()
596                .websocket_url()
597                .ok_or(ConnectionError::BrokerTransportMismatch)?,
598        )?,
599        _ => options
600            .broker()
601            .tcp_address()
602            .map(|(host, port)| (host.to_owned(), port))
603            .ok_or(ConnectionError::BrokerTransportMismatch)?,
604    };
605
606    let tcp_stream: Box<dyn AsyncReadWrite> = {
607        #[cfg(feature = "proxy")]
608        if let Some(proxy) = options.proxy() {
609            proxy
610                .connect(
611                    &domain,
612                    port,
613                    options.network_options(),
614                    Some(options.effective_socket_connector()),
615                )
616                .await?
617        } else {
618            let addr = format!("{domain}:{port}");
619            options
620                .socket_connect(addr, options.network_options())
621                .await?
622        }
623        #[cfg(not(feature = "proxy"))]
624        {
625            let addr = format!("{domain}:{port}");
626            options
627                .socket_connect(addr, options.network_options())
628                .await?
629        }
630    };
631
632    let network = match transport {
633        Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
634        #[cfg(any(feature = "use-native-tls", feature = "use-rustls-no-provider"))]
635        Transport::Tls(tls_config) => {
636            let (host, port) = options
637                .broker()
638                .tcp_address()
639                .expect("tls transport requires a tcp broker");
640            let socket = tls::tls_connect(host, port, &tls_config, tcp_stream).await?;
641            Network::new(socket, max_incoming_pkt_size)
642        }
643        #[cfg(unix)]
644        Transport::Unix => unreachable!(),
645        #[cfg(feature = "websocket")]
646        Transport::Ws => {
647            let mut request = options
648                .broker()
649                .websocket_url()
650                .expect("ws transport requires a websocket broker")
651                .into_client_request()?;
652            request
653                .headers_mut()
654                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
655
656            if let Some(request_modifier) = options.fallible_request_modifier() {
657                request = request_modifier(request)
658                    .await
659                    .map_err(ConnectionError::RequestModifier)?;
660            } else if let Some(request_modifier) = options.request_modifier() {
661                request = request_modifier(request).await;
662            }
663
664            let (socket, response) =
665                async_tungstenite::tokio::client_async(request, tcp_stream).await?;
666            validate_response_headers(response)?;
667
668            Network::new(WsAdapter::new(socket), max_incoming_pkt_size)
669        }
670        #[cfg(all(
671            any(feature = "use-rustls-no-provider", feature = "use-native-tls"),
672            feature = "websocket"
673        ))]
674        Transport::Wss(tls_config) => {
675            let mut request = options
676                .broker()
677                .websocket_url()
678                .expect("wss transport requires a websocket broker")
679                .into_client_request()?;
680            request
681                .headers_mut()
682                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
683
684            if let Some(request_modifier) = options.fallible_request_modifier() {
685                request = request_modifier(request)
686                    .await
687                    .map_err(ConnectionError::RequestModifier)?;
688            } else if let Some(request_modifier) = options.request_modifier() {
689                request = request_modifier(request).await;
690            }
691
692            let tls_stream = tls::tls_connect(&domain, port, &tls_config, tcp_stream).await?;
693            let (socket, response) =
694                async_tungstenite::tokio::client_async(request, tls_stream).await?;
695            validate_response_headers(response)?;
696
697            Network::new(WsAdapter::new(socket), max_incoming_pkt_size)
698        }
699    };
700
701    Ok(network)
702}
703
704async fn mqtt_connect(
705    options: &mut MqttOptions,
706    network: &mut Network,
707    state: &mut MqttState,
708) -> Result<ConnAck, ConnectionError> {
709    let packet = Packet::Connect(
710        Connect {
711            client_id: options.client_id(),
712            keep_alive: u16::try_from(options.keep_alive().as_secs()).unwrap_or(u16::MAX),
713            clean_start: options.clean_start(),
714            properties: options.connect_properties(),
715        },
716        options.last_will(),
717        options.auth().clone(),
718    );
719
720    // send mqtt connect packet
721    network.write(packet).await?;
722    network.flush().await?;
723
724    // validate connack
725    loop {
726        match network.read().await? {
727            Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
728                if let Some(props) = &connack.properties
729                    && let Some(keep_alive) = props.server_keep_alive
730                {
731                    options.keep_alive = Duration::from_secs(u64::from(keep_alive));
732                }
733
734                if let Some(props) = &connack.properties {
735                    network.set_max_outgoing_size(props.max_packet_size);
736
737                    // Override local session_expiry_interval value if set by server.
738                    if props.session_expiry_interval.is_some() {
739                        options.set_session_expiry_interval(props.session_expiry_interval);
740                    }
741                }
742                return Ok(connack);
743            }
744            Incoming::ConnAck(connack) => {
745                return Err(ConnectionError::ConnectionRefused(connack.code));
746            }
747            Incoming::Auth(auth) => {
748                if let Some(outgoing) = state.handle_incoming_packet(Incoming::Auth(auth))? {
749                    network.write(outgoing).await?;
750                    network.flush().await?;
751                } else {
752                    return Err(ConnectionError::AuthProcessingError);
753                }
754            }
755            packet => return Err(ConnectionError::NotConnAck(Box::new(packet))),
756        }
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use crate::{ConnAckProperties, Filter, PubAck, PubRec};
764    use flume::TryRecvError;
765
766    fn build_connack_with_receive_max(receive_max: u16) -> ConnAck {
767        ConnAck {
768            session_present: false,
769            code: ConnectReturnCode::Success,
770            properties: Some(ConnAckProperties {
771                session_expiry_interval: None,
772                receive_max: Some(receive_max),
773                max_qos: None,
774                retain_available: None,
775                max_packet_size: None,
776                assigned_client_identifier: None,
777                topic_alias_max: None,
778                reason_string: None,
779                user_properties: vec![],
780                wildcard_subscription_available: None,
781                subscription_identifiers_available: None,
782                shared_subscription_available: None,
783                server_keep_alive: None,
784                response_information: None,
785                server_reference: None,
786                authentication_method: None,
787                authentication_data: None,
788            }),
789        }
790    }
791
792    fn push_pending(eventloop: &mut EventLoop, request: Request) {
793        eventloop.pending.push_back(RequestEnvelope::plain(request));
794    }
795
796    fn pending_front_request(eventloop: &EventLoop) -> Option<&Request> {
797        eventloop.pending.front().map(|envelope| &envelope.request)
798    }
799
800    fn build_eventloop_with_pending(clean_start: bool) -> EventLoop {
801        let mut options = MqttOptions::new("test-client", "localhost");
802        options.set_clean_start(clean_start);
803
804        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
805        push_pending(&mut eventloop, Request::PingReq);
806        eventloop
807    }
808
809    #[test]
810    fn eventloop_new_keeps_internal_sender_alive() {
811        let options = MqttOptions::new("test-client", "localhost");
812        let eventloop = EventLoop::new(options, 1);
813
814        assert!(matches!(
815            eventloop.requests_rx.try_recv(),
816            Err(TryRecvError::Empty)
817        ));
818    }
819
820    #[test]
821    fn async_client_constructor_path_allows_channel_shutdown() {
822        let options = MqttOptions::new("test-client", "localhost");
823        let (eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
824        drop(request_tx);
825
826        assert!(matches!(
827            eventloop.requests_rx.try_recv(),
828            Err(TryRecvError::Disconnected)
829        ));
830    }
831
832    #[test]
833    fn clean_drops_ack_requests_drained_from_channel() {
834        let options = MqttOptions::new("test-client", "localhost");
835        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 3);
836        request_tx
837            .send(RequestEnvelope::plain(Request::PubAck(PubAck::new(
838                7, None,
839            ))))
840            .unwrap();
841        request_tx
842            .send(RequestEnvelope::plain(Request::PubRec(PubRec::new(
843                8, None,
844            ))))
845            .unwrap();
846        request_tx
847            .send(RequestEnvelope::plain(Request::PingReq))
848            .unwrap();
849
850        eventloop.clean();
851
852        assert_eq!(eventloop.pending_len(), 1);
853        assert!(matches!(
854            pending_front_request(&eventloop),
855            Some(Request::PingReq)
856        ));
857    }
858
859    #[tokio::test]
860    #[cfg(unix)]
861    async fn network_connect_rejects_unix_broker_with_tcp_transport() {
862        let mut options = MqttOptions::new("test-client", crate::Broker::unix("/tmp/mqtt.sock"));
863        options.set_transport(Transport::tcp());
864
865        match network_connect(&options).await {
866            Err(ConnectionError::BrokerTransportMismatch) => {}
867            Err(err) => panic!("unexpected error: {err:?}"),
868            Ok(_) => panic!("mismatched broker and transport should fail"),
869        }
870    }
871
872    #[tokio::test]
873    #[cfg(feature = "websocket")]
874    async fn network_connect_rejects_tcp_broker_with_websocket_transport() {
875        let mut options = MqttOptions::new("test-client", "localhost");
876        options.set_transport(Transport::Ws);
877
878        match network_connect(&options).await {
879            Err(ConnectionError::BrokerTransportMismatch) => {}
880            Err(err) => panic!("unexpected error: {err:?}"),
881            Ok(_) => panic!("mismatched broker and transport should fail"),
882        }
883    }
884
885    #[tokio::test]
886    #[cfg(feature = "websocket")]
887    async fn network_connect_rejects_websocket_broker_with_tcp_transport() {
888        let broker = crate::Broker::websocket("ws://localhost:9001/mqtt").unwrap();
889        let mut options = MqttOptions::new("test-client", broker);
890        options.set_transport(Transport::tcp());
891
892        match network_connect(&options).await {
893            Err(ConnectionError::BrokerTransportMismatch) => {}
894            Err(err) => panic!("unexpected error: {err:?}"),
895            Ok(_) => panic!("mismatched broker and transport should fail"),
896        }
897    }
898
899    #[test]
900    fn connack_resize_skips_shrink_until_pending_retransmit_queue_is_empty() {
901        let mut options = MqttOptions::new("test-client", "localhost");
902        options.set_outgoing_inflight_upper_limit(10);
903        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
904        let mut publish = Publish::new(
905            "hello/world",
906            crate::mqttbytes::QoS::AtLeastOnce,
907            "payload",
908            None,
909        );
910        publish.pkid = 8;
911        push_pending(&mut eventloop, Request::Publish(publish));
912
913        eventloop
914            .state
915            .handle_incoming_packet(Incoming::ConnAck(build_connack_with_receive_max(3)))
916            .unwrap();
917
918        eventloop.reconcile_outgoing_tracking_after_connack();
919        assert_eq!(eventloop.state.outgoing_pub.len(), 11);
920
921        eventloop.pending.clear();
922        eventloop.reconcile_outgoing_tracking_after_connack();
923        assert_eq!(eventloop.state.outgoing_pub.len(), 4);
924        assert_eq!(eventloop.state.outgoing_pub_notice.len(), 4);
925        assert_eq!(eventloop.state.outgoing_rel_notice.len(), 4);
926    }
927
928    #[tokio::test]
929    async fn async_client_path_reports_requests_done_after_pending_drain() {
930        let options = MqttOptions::new("test-client", "localhost");
931        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
932        push_pending(&mut eventloop, Request::PingReq);
933        drop(request_tx);
934
935        let request = EventLoop::next_request(
936            &mut eventloop.pending,
937            &eventloop.requests_rx,
938            Duration::ZERO,
939        )
940        .await
941        .unwrap();
942        assert!(matches!(request, (Request::PingReq, None)));
943
944        let err = EventLoop::next_request(
945            &mut eventloop.pending,
946            &eventloop.requests_rx,
947            Duration::ZERO,
948        )
949        .await
950        .unwrap_err();
951        assert!(matches!(err, ConnectionError::RequestsDone));
952    }
953
954    #[tokio::test]
955    async fn next_request_is_cancellation_safe_for_pending_queue() {
956        let options = MqttOptions::new("test-client", "localhost");
957        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
958        push_pending(&mut eventloop, Request::PingReq);
959
960        let delayed = EventLoop::next_request(
961            &mut eventloop.pending,
962            &eventloop.requests_rx,
963            Duration::from_millis(50),
964        );
965        let timed_out = time::timeout(Duration::from_millis(5), delayed).await;
966
967        assert!(timed_out.is_err());
968        assert!(matches!(
969            pending_front_request(&eventloop),
970            Some(Request::PingReq)
971        ));
972    }
973
974    #[tokio::test]
975    async fn try_next_request_applies_pending_throttle_for_followup_pending_item() {
976        let options = MqttOptions::new("test-client", "localhost");
977        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 2);
978        push_pending(&mut eventloop, Request::PingReq);
979        push_pending(&mut eventloop, Request::PingResp);
980
981        let first = EventLoop::next_request(
982            &mut eventloop.pending,
983            &eventloop.requests_rx,
984            Duration::ZERO,
985        )
986        .await
987        .unwrap();
988        assert!(matches!(first, (Request::PingReq, None)));
989
990        let delayed = EventLoop::try_next_request(
991            &mut eventloop.pending,
992            &eventloop.requests_rx,
993            Duration::from_millis(50),
994        );
995        let timed_out = time::timeout(Duration::from_millis(5), delayed).await;
996
997        assert!(timed_out.is_err());
998        assert!(matches!(
999            pending_front_request(&eventloop),
1000            Some(Request::PingResp)
1001        ));
1002    }
1003
1004    #[tokio::test]
1005    async fn try_next_request_does_not_throttle_when_pending_queue_is_empty() {
1006        let options = MqttOptions::new("test-client", "localhost");
1007        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
1008        request_tx
1009            .send_async(RequestEnvelope::plain(Request::PingReq))
1010            .await
1011            .unwrap();
1012
1013        let received = time::timeout(
1014            Duration::from_millis(20),
1015            EventLoop::try_next_request(
1016                &mut eventloop.pending,
1017                &eventloop.requests_rx,
1018                Duration::from_secs(1),
1019            ),
1020        )
1021        .await
1022        .unwrap();
1023
1024        assert!(matches!(received, Some((Request::PingReq, None))));
1025    }
1026
1027    #[tokio::test]
1028    async fn next_request_prioritizes_pending_over_channel_messages() {
1029        let options = MqttOptions::new("test-client", "localhost");
1030        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 2);
1031        push_pending(&mut eventloop, Request::PingReq);
1032        request_tx
1033            .send_async(RequestEnvelope::plain(Request::PingReq))
1034            .await
1035            .unwrap();
1036
1037        let first = EventLoop::next_request(
1038            &mut eventloop.pending,
1039            &eventloop.requests_rx,
1040            Duration::ZERO,
1041        )
1042        .await
1043        .unwrap();
1044        assert!(matches!(first, (Request::PingReq, None)));
1045        assert!(eventloop.pending.is_empty());
1046
1047        let second = EventLoop::next_request(
1048            &mut eventloop.pending,
1049            &eventloop.requests_rx,
1050            Duration::ZERO,
1051        )
1052        .await
1053        .unwrap();
1054        assert!(matches!(second, (Request::PingReq, None)));
1055    }
1056
1057    #[tokio::test]
1058    async fn next_request_preserves_fifo_order_for_plain_and_tracked_requests() {
1059        let options = MqttOptions::new("test-client", "localhost");
1060        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
1061        let (notice_tx, _notice) = PublishNoticeTx::new();
1062        let tracked_publish = Publish::new(
1063            "hello/world",
1064            crate::mqttbytes::QoS::AtLeastOnce,
1065            "payload",
1066            None,
1067        );
1068
1069        request_tx
1070            .send_async(RequestEnvelope::plain(Request::PingReq))
1071            .await
1072            .unwrap();
1073        request_tx
1074            .send_async(RequestEnvelope::tracked_publish(
1075                tracked_publish.clone(),
1076                notice_tx,
1077            ))
1078            .await
1079            .unwrap();
1080        request_tx
1081            .send_async(RequestEnvelope::plain(Request::PingResp))
1082            .await
1083            .unwrap();
1084
1085        let first = EventLoop::next_request(
1086            &mut eventloop.pending,
1087            &eventloop.requests_rx,
1088            Duration::ZERO,
1089        )
1090        .await
1091        .unwrap();
1092        assert!(matches!(first, (Request::PingReq, None)));
1093
1094        let second = EventLoop::next_request(
1095            &mut eventloop.pending,
1096            &eventloop.requests_rx,
1097            Duration::ZERO,
1098        )
1099        .await
1100        .unwrap();
1101        assert!(matches!(
1102            second,
1103            (Request::Publish(publish), Some(_)) if publish == tracked_publish
1104        ));
1105
1106        let third = EventLoop::next_request(
1107            &mut eventloop.pending,
1108            &eventloop.requests_rx,
1109            Duration::ZERO,
1110        )
1111        .await
1112        .unwrap();
1113        assert!(matches!(third, (Request::PingResp, None)));
1114    }
1115
1116    #[tokio::test]
1117    async fn tracked_qos0_notice_reports_not_flushed_on_first_write_failure() {
1118        let options = MqttOptions::new("test-client", "localhost");
1119        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
1120        let (client, _peer) = tokio::io::duplex(1024);
1121        let mut network = Network::new(client, Some(1024));
1122        network.set_max_outgoing_size(Some(16));
1123        eventloop.network = Some(network);
1124
1125        let (notice_tx, notice) = PublishNoticeTx::new();
1126        let publish = Publish::new(
1127            "hello/world",
1128            crate::mqttbytes::QoS::AtMostOnce,
1129            vec![1; 128],
1130            None,
1131        );
1132        request_tx
1133            .send_async(RequestEnvelope::tracked_publish(publish, notice_tx))
1134            .await
1135            .unwrap();
1136
1137        let err = eventloop.select().await.unwrap_err();
1138        assert!(matches!(err, ConnectionError::MqttState(_)));
1139        assert_eq!(
1140            notice.wait_async().await.unwrap_err(),
1141            PublishNoticeError::Qos0NotFlushed
1142        );
1143    }
1144
1145    #[tokio::test]
1146    async fn tracked_qos0_notices_report_not_flushed_on_batched_write_failure() {
1147        let mut options = MqttOptions::new("test-client", "localhost");
1148        options.set_max_request_batch(2);
1149        let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 4);
1150        let (client, _peer) = tokio::io::duplex(1024);
1151        let mut network = Network::new(client, Some(1024));
1152        network.set_max_outgoing_size(Some(80));
1153        eventloop.network = Some(network);
1154
1155        let small_publish = Publish::new(
1156            "hello/world",
1157            crate::mqttbytes::QoS::AtMostOnce,
1158            vec![1],
1159            None,
1160        );
1161        let large_publish = Publish::new(
1162            "hello/world",
1163            crate::mqttbytes::QoS::AtMostOnce,
1164            vec![2; 256],
1165            None,
1166        );
1167
1168        let (first_notice_tx, first_notice) = PublishNoticeTx::new();
1169        request_tx
1170            .send_async(RequestEnvelope::tracked_publish(
1171                small_publish,
1172                first_notice_tx,
1173            ))
1174            .await
1175            .unwrap();
1176
1177        let (second_notice_tx, second_notice) = PublishNoticeTx::new();
1178        request_tx
1179            .send_async(RequestEnvelope::tracked_publish(
1180                large_publish,
1181                second_notice_tx,
1182            ))
1183            .await
1184            .unwrap();
1185
1186        let err = eventloop.select().await.unwrap_err();
1187        assert!(matches!(err, ConnectionError::MqttState(_)));
1188        assert_eq!(
1189            first_notice.wait_async().await.unwrap_err(),
1190            PublishNoticeError::Qos0NotFlushed
1191        );
1192        assert_eq!(
1193            second_notice.wait_async().await.unwrap_err(),
1194            PublishNoticeError::Qos0NotFlushed
1195        );
1196    }
1197
1198    #[tokio::test]
1199    async fn drain_pending_as_failed_drains_all_and_returns_count() {
1200        let options = MqttOptions::new("test-client", "localhost");
1201        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1202        let (notice_tx, notice) = PublishNoticeTx::new();
1203        let publish = Publish::new(
1204            "hello/world",
1205            crate::mqttbytes::QoS::AtLeastOnce,
1206            "payload",
1207            None,
1208        );
1209        eventloop
1210            .pending
1211            .push_back(RequestEnvelope::tracked_publish(publish, notice_tx));
1212        eventloop
1213            .pending
1214            .push_back(RequestEnvelope::plain(Request::PingReq));
1215
1216        let drained = eventloop.drain_pending_as_failed(NoticeFailureReason::SessionReset);
1217
1218        assert_eq!(drained, 2);
1219        assert!(eventloop.pending.is_empty());
1220        assert_eq!(
1221            notice.wait_async().await.unwrap_err(),
1222            PublishNoticeError::SessionReset
1223        );
1224    }
1225
1226    #[tokio::test]
1227    async fn drain_pending_as_failed_reports_session_reset_for_tracked_notices() {
1228        let options = MqttOptions::new("test-client", "localhost");
1229        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1230        let (publish_notice_tx, publish_notice) = PublishNoticeTx::new();
1231        let publish = Publish::new(
1232            "hello/world",
1233            crate::mqttbytes::QoS::AtLeastOnce,
1234            "payload",
1235            None,
1236        );
1237        eventloop
1238            .pending
1239            .push_back(RequestEnvelope::tracked_publish(publish, publish_notice_tx));
1240
1241        let (request_notice_tx, request_notice) = RequestNoticeTx::new();
1242        let subscribe = Subscribe::new(
1243            Filter::new("hello/world", crate::mqttbytes::QoS::AtMostOnce),
1244            None,
1245        );
1246        eventloop
1247            .pending
1248            .push_back(RequestEnvelope::tracked_subscribe(
1249                subscribe,
1250                request_notice_tx,
1251            ));
1252
1253        eventloop.drain_pending_as_failed(NoticeFailureReason::SessionReset);
1254
1255        assert_eq!(
1256            publish_notice.wait_async().await.unwrap_err(),
1257            PublishNoticeError::SessionReset
1258        );
1259        assert_eq!(
1260            request_notice.wait_async().await.unwrap_err(),
1261            crate::RequestNoticeError::SessionReset
1262        );
1263    }
1264
1265    #[tokio::test]
1266    async fn reset_session_state_reports_session_reset_for_pending_tracked_notice() {
1267        let options = MqttOptions::new("test-client", "localhost");
1268        let (mut eventloop, _request_tx) = EventLoop::new_for_async_client(options, 1);
1269        let (notice_tx, notice) = PublishNoticeTx::new();
1270        let publish = Publish::new(
1271            "hello/world",
1272            crate::mqttbytes::QoS::AtLeastOnce,
1273            "payload",
1274            None,
1275        );
1276        eventloop
1277            .pending
1278            .push_back(RequestEnvelope::tracked_publish(publish, notice_tx));
1279
1280        eventloop.reset_session_state();
1281
1282        assert!(eventloop.pending.is_empty());
1283        assert_eq!(
1284            notice.wait_async().await.unwrap_err(),
1285            PublishNoticeError::SessionReset
1286        );
1287    }
1288
1289    #[test]
1290    fn connack_reconcile_rejects_clean_start_with_session_present() {
1291        let mut eventloop = build_eventloop_with_pending(true);
1292
1293        let err = eventloop.reconcile_connack_session(true).unwrap_err();
1294
1295        assert!(matches!(
1296            err,
1297            ConnectionError::SessionStateMismatch {
1298                clean_start: true,
1299                session_present: true
1300            }
1301        ));
1302        assert_eq!(eventloop.pending_len(), 1);
1303    }
1304
1305    #[test]
1306    fn connack_reconcile_resets_pending_when_clean_start_gets_new_session() {
1307        let mut eventloop = build_eventloop_with_pending(true);
1308
1309        eventloop.reconcile_connack_session(false).unwrap();
1310
1311        assert!(eventloop.pending_is_empty());
1312    }
1313
1314    #[test]
1315    fn connack_reconcile_resets_pending_when_resumed_session_is_missing() {
1316        let mut eventloop = build_eventloop_with_pending(false);
1317
1318        eventloop.reconcile_connack_session(false).unwrap();
1319
1320        assert!(eventloop.pending_is_empty());
1321    }
1322
1323    #[test]
1324    fn connack_reconcile_keeps_pending_when_resumed_session_exists() {
1325        let mut eventloop = build_eventloop_with_pending(false);
1326
1327        eventloop.reconcile_connack_session(true).unwrap();
1328
1329        assert_eq!(eventloop.pending_len(), 1);
1330    }
1331}