rumqttc_dev_patched/
eventloop.rs

1use crate::{framed::Network, Transport};
2use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError};
3use crate::{MqttOptions, Outgoing};
4
5use crate::framed::AsyncReadWrite;
6use crate::mqttbytes::v4::*;
7use flume::{bounded, Receiver, Sender};
8use tokio::net::{lookup_host, TcpSocket, TcpStream};
9use tokio::select;
10use tokio::time::{self, Instant, Sleep};
11
12use std::collections::VecDeque;
13use std::io;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::time::Duration;
17
18#[cfg(unix)]
19use {std::path::Path, tokio::net::UnixStream};
20
21#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
22use crate::tls;
23
24#[cfg(feature = "websocket")]
25use {
26    crate::websockets::{split_url, validate_response_headers, UrlError},
27    async_tungstenite::tungstenite::client::IntoClientRequest,
28    ws_stream_tungstenite::WsStream,
29};
30
31#[cfg(feature = "proxy")]
32use crate::proxy::ProxyError;
33
34/// Critical errors during eventloop polling
35#[derive(Debug, thiserror::Error)]
36pub enum ConnectionError {
37    #[error("Mqtt state: {0}")]
38    MqttState(#[from] StateError),
39    #[error("Network timeout")]
40    NetworkTimeout,
41    #[error("Flush timeout")]
42    FlushTimeout,
43    #[cfg(feature = "websocket")]
44    #[error("Websocket: {0}")]
45    Websocket(#[from] async_tungstenite::tungstenite::error::Error),
46    #[cfg(feature = "websocket")]
47    #[error("Websocket Connect: {0}")]
48    WsConnect(#[from] http::Error),
49    #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
50    #[error("TLS: {0}")]
51    Tls(#[from] tls::Error),
52    #[error("I/O: {0}")]
53    Io(#[from] io::Error),
54    #[error("Connection refused, return code: `{0:?}`")]
55    ConnectionRefused(ConnectReturnCode),
56    #[error("Expected ConnAck packet, received: {0:?}")]
57    NotConnAck(Packet),
58    #[error("Requests done")]
59    RequestsDone,
60    #[cfg(feature = "websocket")]
61    #[error("Invalid Url: {0}")]
62    InvalidUrl(#[from] UrlError),
63    #[cfg(feature = "proxy")]
64    #[error("Proxy Connect: {0}")]
65    Proxy(#[from] ProxyError),
66    #[cfg(feature = "websocket")]
67    #[error("Websocket response validation error: ")]
68    ResponseValidation(#[from] crate::websockets::ValidationError),
69}
70
71/// Eventloop with all the state of a connection
72pub struct EventLoop {
73    /// Options of the current mqtt connection
74    pub mqtt_options: MqttOptions,
75    /// Current state of the connection
76    pub state: MqttState,
77    /// Request stream
78    requests_rx: Receiver<Request>,
79    /// Requests handle to send requests
80    pub(crate) requests_tx: Sender<Request>,
81    /// Pending packets from last session
82    pub pending: VecDeque<Request>,
83    /// Network connection to the broker
84    pub network: Option<Network>,
85    /// Keep alive time
86    keepalive_timeout: Option<Pin<Box<Sleep>>>,
87    pub network_options: NetworkOptions,
88}
89
90/// Events which can be yielded by the event loop
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum Event {
93    Incoming(Incoming),
94    Outgoing(Outgoing),
95}
96
97impl EventLoop {
98    /// New MQTT `EventLoop`
99    ///
100    /// When connection encounters critical errors (like auth failure), user has a choice to
101    /// access and update `options`, `state` and `requests`.
102    pub fn new(mqtt_options: MqttOptions, cap: usize) -> EventLoop {
103        let (requests_tx, requests_rx) = bounded(cap);
104        let pending = VecDeque::new();
105        let max_inflight = mqtt_options.inflight;
106        let manual_acks = mqtt_options.manual_acks;
107
108        EventLoop {
109            mqtt_options,
110            state: MqttState::new(max_inflight, manual_acks),
111            requests_tx,
112            requests_rx,
113            pending,
114            network: None,
115            keepalive_timeout: None,
116            network_options: NetworkOptions::new(),
117        }
118    }
119
120    /// Last session might contain packets which aren't acked. MQTT says these packets should be
121    /// republished in the next session. Move pending messages from state to eventloop, drops the
122    /// underlying network connection and clears the keepalive timeout if any.
123    ///
124    /// > NOTE: Use only when EventLoop is blocked on network and unable to immediately handle disconnect.
125    /// > Also, while this helps prevent data loss, the pending list length should be managed properly.
126    /// > For this reason we recommend setting [`AsycClient`](crate::AsyncClient)'s channel capacity to `0`.
127    pub fn clean(&mut self) {
128        self.network = None;
129        self.keepalive_timeout = None;
130        self.pending.extend(self.state.clean());
131
132        // drain requests from channel which weren't yet received
133        let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect();
134
135        requests_in_channel.retain(|request| {
136            match request {
137                Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack
138                _ => true,
139            }
140        });
141
142        self.pending.extend(requests_in_channel);
143    }
144
145    /// Yields Next notification or outgoing request and periodically pings
146    /// the broker. Continuing to poll will reconnect to the broker if there is
147    /// a disconnection.
148    /// **NOTE** Don't block this while iterating
149    pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
150        if self.network.is_none() {
151            let (network, connack) = match time::timeout(
152                Duration::from_secs(self.network_options.connection_timeout()),
153                connect(&self.mqtt_options, self.network_options.clone()),
154            )
155            .await
156            {
157                Ok(inner) => inner?,
158                Err(_) => return Err(ConnectionError::NetworkTimeout),
159            };
160            // Last session might contain packets which aren't acked. If it's a new session, clear the pending packets.
161            if !connack.session_present {
162                self.pending.clear();
163            }
164            self.network = Some(network);
165
166            if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() {
167                self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive)));
168            }
169
170            return Ok(Event::Incoming(Packet::ConnAck(connack)));
171        }
172
173        match self.select().await {
174            Ok(v) => Ok(v),
175            Err(e) => {
176                // MQTT requires that packets pending acknowledgement should be republished on session resume.
177                // Move pending messages from state to eventloop.
178                self.clean();
179                Err(e)
180            }
181        }
182    }
183
184    /// Select on network and requests and generate keepalive pings when necessary
185    async fn select(&mut self) -> Result<Event, ConnectionError> {
186        let network = self.network.as_mut().unwrap();
187        // let await_acks = self.state.await_acks;
188        let inflight_full = self.state.inflight >= self.mqtt_options.inflight;
189        let collision = self.state.collision.is_some();
190        let network_timeout = Duration::from_secs(self.network_options.connection_timeout());
191
192        // Read buffered events from previous polls before calling a new poll
193        if let Some(event) = self.state.events.pop_front() {
194            return Ok(event);
195        }
196
197        let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
198        // this loop is necessary since self.incoming.pop_front() might return None. In that case,
199        // instead of returning a None event, we try again.
200        select! {
201            // Pull a bunch of packets from network, reply in bunch and yield the first item
202            o = network.readb(&mut self.state) => {
203                o?;
204                // flush all the acks and return first incoming packet
205                match time::timeout(network_timeout, network.flush()).await {
206                    Ok(inner) => inner?,
207                    Err(_)=> return Err(ConnectionError::FlushTimeout),
208                };
209                Ok(self.state.events.pop_front().unwrap())
210            },
211             // Handles pending and new requests.
212            // If available, prioritises pending requests from previous session.
213            // Else, pulls next request from user requests channel.
214            // If conditions in the below branch are for flow control.
215            // The branch is disabled if there's no pending messages and new user requests
216            // cannot be serviced due flow control.
217            // We read next user user request only when inflight messages are < configured inflight
218            // and there are no collisions while handling previous outgoing requests.
219            //
220            // Flow control is based on ack count. If inflight packet count in the buffer is
221            // less than max_inflight setting, next outgoing request will progress. For this
222            // to work correctly, broker should ack in sequence (a lot of brokers won't)
223            //
224            // E.g If max inflight = 5, user requests will be blocked when inflight queue
225            // looks like this                 -> [1, 2, 3, 4, 5].
226            // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5].
227            // This pulls next user request. But because max packet id = max_inflight, next
228            // user request's packet id will roll to 1. This replaces existing packet id 1.
229            // Resulting in a collision
230            //
231            // Eventloop can stop receiving outgoing user requests when previous outgoing
232            // request collided. I.e collision state. Collision state will be cleared only
233            // when correct ack is received
234            // Full inflight queue will look like -> [1a, 2, 3, 4, 5].
235            // If 3 is acked instead of 1 first   -> [1a, 2, x, 4, 5].
236            // After collision with pkid 1        -> [1b ,2, x, 4, 5].
237            // 1a is saved to state and event loop is set to collision mode stopping new
238            // outgoing requests (along with 1b).
239            o = Self::next_request(
240                &mut self.pending,
241                &self.requests_rx,
242                self.mqtt_options.pending_throttle
243            ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
244                Ok(request) => {
245                    if let Some(outgoing) = self.state.handle_outgoing_packet(request)? {
246                        network.write(outgoing).await?;
247                    }
248                    match time::timeout(network_timeout, network.flush()).await {
249                        Ok(inner) => inner?,
250                        Err(_)=> return Err(ConnectionError::FlushTimeout),
251                    };
252                    Ok(self.state.events.pop_front().unwrap())
253                }
254                Err(_) => Err(ConnectionError::RequestsDone),
255            },
256            // We generate pings irrespective of network activity. This keeps the ping logic
257            // simple. We can change this behavior in future if necessary (to prevent extra pings)
258            _ = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
259                if self.keepalive_timeout.is_some() && !self.mqtt_options.keep_alive.is_zero() => {
260                let timeout = self.keepalive_timeout.as_mut().unwrap();
261                timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);
262
263                if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? {
264                    network.write(outgoing).await?;
265                }
266                match time::timeout(network_timeout, network.flush()).await {
267                    Ok(inner) => inner?,
268                    Err(_)=> return Err(ConnectionError::FlushTimeout),
269                };
270                Ok(self.state.events.pop_front().unwrap())
271            }
272        }
273    }
274
275    pub fn network_options(&self) -> NetworkOptions {
276        self.network_options.clone()
277    }
278
279    pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self {
280        self.network_options = network_options;
281        self
282    }
283
284    async fn next_request(
285        pending: &mut VecDeque<Request>,
286        rx: &Receiver<Request>,
287        pending_throttle: Duration,
288    ) -> Result<Request, ConnectionError> {
289        if !pending.is_empty() {
290            time::sleep(pending_throttle).await;
291            // We must call .pop_front() AFTER sleep() otherwise we would have
292            // advanced the iterator but the future might be canceled before return
293            Ok(pending.pop_front().unwrap())
294        } else {
295            match rx.recv_async().await {
296                Ok(r) => Ok(r),
297                Err(_) => Err(ConnectionError::RequestsDone),
298            }
299        }
300    }
301}
302
303/// This stream internally processes requests from the request stream provided to the eventloop
304/// while also consuming byte stream from the network and yielding mqtt packets as the output of
305/// the stream.
306/// This function (for convenience) includes internal delays for users to perform internal sleeps
307/// between re-connections so that cancel semantics can be used during this sleep
308async fn connect(
309    mqtt_options: &MqttOptions,
310    network_options: NetworkOptions,
311) -> Result<(Network, ConnAck), ConnectionError> {
312    // connect to the broker
313    let mut network = network_connect(mqtt_options, network_options).await?;
314
315    // make MQTT connection request (which internally awaits for ack)
316    let connack = mqtt_connect(mqtt_options, &mut network).await?;
317
318    Ok((network, connack))
319}
320
321pub(crate) async fn socket_connect(
322    host: String,
323    network_options: NetworkOptions,
324) -> io::Result<TcpStream> {
325    let addrs = lookup_host(host).await?;
326    let mut last_err = None;
327
328    for addr in addrs {
329        let socket = match addr {
330            SocketAddr::V4(_) => TcpSocket::new_v4()?,
331            SocketAddr::V6(_) => TcpSocket::new_v6()?,
332        };
333
334        socket.set_nodelay(network_options.tcp_nodelay)?;
335
336        if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
337            socket.set_send_buffer_size(send_buff_size).unwrap();
338        }
339        if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
340            socket.set_recv_buffer_size(recv_buffer_size).unwrap();
341        }
342
343        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
344        {
345            if let Some(bind_device) = &network_options.bind_device {
346                // call the bind_device function only if the bind_device network option is defined
347                // If binding device is None or an empty string it removes the binding,
348                // which is causing PermissionDenied errors in AWS environment (lambda function).
349                socket.bind_device(Some(bind_device.as_bytes()))?;
350            }
351        }
352
353        match socket.connect(addr).await {
354            Ok(s) => return Ok(s),
355            Err(e) => {
356                last_err = Some(e);
357            }
358        };
359    }
360
361    Err(last_err.unwrap_or_else(|| {
362        io::Error::new(
363            io::ErrorKind::InvalidInput,
364            "could not resolve to any address",
365        )
366    }))
367}
368
369async fn network_connect(
370    options: &MqttOptions,
371    network_options: NetworkOptions,
372) -> Result<Network, ConnectionError> {
373    // Process Unix files early, as proxy is not supported for them.
374    #[cfg(unix)]
375    if matches!(options.transport(), Transport::Unix) {
376        let file = options.broker_addr.as_str();
377        let socket = UnixStream::connect(Path::new(file)).await?;
378        let network = Network::new(
379            socket,
380            options.max_incoming_packet_size,
381            options.max_outgoing_packet_size,
382        );
383        return Ok(network);
384    }
385
386    // For websockets domain and port are taken directly from `broker_addr` (which is a url).
387    let (domain, port) = match options.transport() {
388        #[cfg(feature = "websocket")]
389        Transport::Ws => split_url(&options.broker_addr)?,
390        #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
391        Transport::Wss(_) => split_url(&options.broker_addr)?,
392        _ => options.broker_address(),
393    };
394
395    let tcp_stream: Box<dyn AsyncReadWrite> = {
396        #[cfg(feature = "proxy")]
397        match options.proxy() {
398            Some(proxy) => proxy.connect(&domain, port, network_options).await?,
399            None => {
400                let addr = format!("{domain}:{port}");
401                let tcp = socket_connect(addr, network_options).await?;
402                Box::new(tcp)
403            }
404        }
405        #[cfg(not(feature = "proxy"))]
406        {
407            let addr = format!("{domain}:{port}");
408            let tcp = socket_connect(addr, network_options).await?;
409            Box::new(tcp)
410        }
411    };
412
413    let network = match options.transport() {
414        Transport::Tcp => Network::new(
415            tcp_stream,
416            options.max_incoming_packet_size,
417            options.max_outgoing_packet_size,
418        ),
419        #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
420        Transport::Tls(tls_config) => {
421            let socket =
422                tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
423                    .await?;
424            Network::new(
425                socket,
426                options.max_incoming_packet_size,
427                options.max_outgoing_packet_size,
428            )
429        }
430        #[cfg(unix)]
431        Transport::Unix => unreachable!(),
432        #[cfg(feature = "websocket")]
433        Transport::Ws => {
434            let mut request = options.broker_addr.as_str().into_client_request()?;
435            request
436                .headers_mut()
437                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
438
439            if let Some(request_modifier) = options.request_modifier() {
440                request = request_modifier(request).await;
441            }
442
443            let (socket, response) =
444                async_tungstenite::tokio::client_async(request, tcp_stream).await?;
445            validate_response_headers(response)?;
446
447            Network::new(
448                WsStream::new(socket),
449                options.max_incoming_packet_size,
450                options.max_outgoing_packet_size,
451            )
452        }
453        #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
454        Transport::Wss(tls_config) => {
455            let mut request = options.broker_addr.as_str().into_client_request()?;
456            request
457                .headers_mut()
458                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
459
460            if let Some(request_modifier) = options.request_modifier() {
461                request = request_modifier(request).await;
462            }
463
464            let connector = tls::rustls_connector(&tls_config).await?;
465
466            let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
467                request,
468                tcp_stream,
469                Some(connector),
470            )
471            .await?;
472            validate_response_headers(response)?;
473
474            Network::new(
475                WsStream::new(socket),
476                options.max_incoming_packet_size,
477                options.max_outgoing_packet_size,
478            )
479        }
480    };
481
482    Ok(network)
483}
484
485async fn mqtt_connect(
486    options: &MqttOptions,
487    network: &mut Network,
488) -> Result<ConnAck, ConnectionError> {
489    let mut connect = Connect::new(options.client_id());
490    connect.keep_alive = options.keep_alive().as_secs() as u16;
491    connect.clean_session = options.clean_session();
492    connect.last_will = options.last_will();
493    connect.login = options.credentials();
494
495    // send mqtt connect packet
496    network.write(Packet::Connect(connect)).await?;
497    network.flush().await?;
498
499    // validate connack
500    match network.read().await? {
501        Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => Ok(connack),
502        Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
503        packet => Err(ConnectionError::NotConnAck(packet)),
504    }
505}