mqtt_async_client/client/
client.rs

1use bytes::BytesMut;
2#[cfg(feature = "websocket")]
3use tokio_tungstenite::tungstenite::http::Uri;
4use crate::{
5    client::{
6        builder::ClientBuilder,
7        value_types::{
8            KeepAlive,
9            Publish,
10            ReadResult,
11            Subscribe,
12            SubscribeResult,
13            Unsubscribe,
14        },
15    },
16    Error,
17    Result,
18    util::{
19        AsyncStream,
20        FreePidList,
21        TokioRuntime,
22    }
23};
24use futures_util::{
25    future::{
26        FutureExt,
27        pending,
28    },
29    select,
30};
31#[cfg(feature = "websocket")]
32use http::request::Request;
33use log::{debug, error, info, trace};
34use mqttrs::{
35    ConnectReturnCode,
36    Packet,
37    Pid,
38    QoS,
39    QosPid,
40    self,
41    SubscribeTopic,
42};
43#[cfg(feature = "tls")]
44use rustls;
45use std::{
46    cmp::min,
47    collections::BTreeMap,
48    fmt,
49    sync::{
50        Arc,
51        atomic::{AtomicBool, Ordering},
52        Mutex,
53    },
54};
55use tokio::{
56    io::{
57        AsyncReadExt,
58        AsyncWriteExt,
59    },
60    net::TcpStream,
61    sync::{
62        mpsc,
63        oneshot,
64    },
65    time::{
66        sleep,
67        sleep_until,
68        Duration,
69        error::Elapsed,
70        Instant,
71        timeout,
72    },
73};
74#[cfg(feature = "tls")]
75use tokio_rustls::{self, webpki::DNSNameRef, TlsConnector};
76use url::Url;
77
78/// An MQTT client.
79///
80/// Start building an instance by calling Client::builder() to get a
81/// ClientBuilder, using the fluent builder pattern on ClientBuilder,
82/// then calling ClientBuilder::build(). For example:
83///
84/// ```
85/// # use mqtt_async_client::client::Client;
86/// let client =
87///     Client::builder()
88///        .set_url_string("mqtt://example.com").unwrap()
89///        .build();
90/// ```
91///
92/// `Client` is expected to be `Send` (passable between threads), but not
93/// `Sync` (usable by multiple threads at the same time).
94pub struct Client {
95    /// Options configured for the client
96    options: ClientOptions,
97
98    /// Handle values to communicate with the IO task
99    io_task_handle: Option<IoTaskHandle>,
100
101    /// Tracks which Pids (MQTT packet IDs) are in use.
102    ///
103    /// This field uses a Mutex for interior mutability so that
104    /// `Client` is `Send`. It's not expected to be `Sync`.
105    free_write_pids: Mutex<FreePidList>,
106}
107
108impl fmt::Debug for Client {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        f.debug_struct("Client")
111         .field("options", &self.options)
112         .finish()
113    }
114}
115
116#[derive(Clone)]
117pub(crate) struct ClientOptions {
118    // See ClientBuilder methods for per-field documentation.
119    pub(crate) url: Url,
120    pub(crate) username: Option<String>,
121    pub(crate) password: Option<Vec<u8>>,
122    pub(crate) keep_alive: KeepAlive,
123    pub(crate) runtime: TokioRuntime,
124    pub(crate) client_id: Option<String>,
125    pub(crate) packet_buffer_len: usize,
126    pub(crate) max_packet_len: usize,
127    pub(crate) operation_timeout: Duration,
128    pub(crate) connection_mode: ConnectionMode,
129    pub(crate) automatic_connect: bool,
130    pub(crate) connect_retry_delay: Duration,
131}
132
133impl fmt::Debug for ClientOptions {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("ClientOptions")
136         .field("url", &self.url)
137         .field("username", &self.username)
138         // Deliberately skipping password field here to
139         // avoid accidentially leaking it
140         .field("keep_alive", &self.keep_alive)
141         .field("client_id", &self.client_id)
142         .field("packet_buffer_len", &self.packet_buffer_len)
143         .field("max_packet_len", &self.max_packet_len)
144         .field("operation_timeout", &self.operation_timeout)
145         .field("automatic_connect", &self.automatic_connect)
146         .field("connect_retry_delay", &self.connect_retry_delay)
147         .finish()
148    }
149}
150
151/// The client side of the communication channels to an IO task.
152struct IoTaskHandle {
153    /// Sender to send IO requests to the IO task.
154    tx_io_requests: mpsc::Sender<IoRequest>,
155
156    /// Receiver to receive Publish packets from the IO task.
157    rx_recv_published: mpsc::Receiver<Result<Packet>>,
158
159    /// Signal to the IO task to shutdown. Shared with IoTask.
160    halt: Arc<AtomicBool>,
161}
162
163/// The state held by the IO task, a long-running tokio future. The IO
164/// task manages the underlying TCP connection, sends periodic
165/// keep-alive ping packets, and sends response packets to tasks that
166/// are waiting.
167struct IoTask {
168    /// Options configured for the client.
169    options: ClientOptions,
170
171    /// Receiver to receive IO requests for the IO task.
172    rx_io_requests: mpsc::Receiver<IoRequest>,
173
174    /// Sender to send Publish packets from the IO task.
175    tx_recv_published: mpsc::Sender<Result<Packet>>,
176
177    /// enum value describing the current state as disconnected or connected.
178    state: IoTaskState,
179
180    /// Keeps track of active subscriptions in case they need to be
181    /// replayed after reconnecting.
182    subscriptions: BTreeMap<String, QoS>,
183
184    /// Signal to the IO task to shutdown. Shared with IoTaskHandle.
185    halt: Arc<AtomicBool>,
186}
187
188enum IoTaskState {
189    Halted,
190    Disconnected,
191    Connected(IoTaskConnected),
192}
193
194/// The state associated with a network connection to an MQTT broker
195struct IoTaskConnected {
196    /// The stream connected to an MQTT broker.
197    stream: AsyncStream,
198
199    /// A buffer with data read from `stream`.
200    read_buf: BytesMut,
201
202    /// The number of bytes at the start of `read_buf` that have been
203    /// read from `stream`.
204    read_bufn: usize,
205
206    /// The time the last packet was written to `stream`.
207    /// Used to calculate when to send a Pingreq
208    last_write_time: Instant,
209
210    /// The time the last Pingreq packet was written to `stream`.
211    last_pingreq_time: Instant,
212
213    /// The time the last Pingresp packet was read from `stream`.
214    last_pingresp_time: Instant,
215
216    /// A map from response Pid to the IoRequest that initiated the
217    /// request that will be responded to.
218    pid_response_map: BTreeMap<Pid, IoRequest>,
219}
220
221/// An IO request from `Client` to the IO task.
222#[derive(Debug)]
223struct IoRequest {
224    /// A one-shot channel Sender to send the result of the IO request.
225    tx_result: Option<oneshot::Sender<IoResult>>,
226
227    /// Represents the data needed to carry out the IO request.
228    io_type: IoType,
229}
230
231/// The data the IO task needs to carry out an IO request.
232#[derive(Debug)]
233enum IoType {
234    /// A packet to write that expects no response.
235    WriteOnly { packet: Packet },
236
237    /// A packet to write that expects a response with a certain `Pid`.
238    WriteAndResponse { packet: Packet, response_pid: Pid },
239
240    /// A request to shut down the TCP connection gracefully.
241    ShutdownConnection,
242}
243
244/// The result of an IO request sent by the IO task, which may contain a packet.
245#[derive(Debug)]
246struct IoResult {
247    result: Result<Option<Packet>>,
248}
249
250impl Client {
251    /// Start a fluent builder interface to construct a `Client`.
252    pub fn builder() -> ClientBuilder {
253        ClientBuilder::default()
254    }
255
256    pub(crate) fn new(opts: ClientOptions) -> Result<Client> {
257        Ok(Client {
258            options: opts,
259            io_task_handle: None,
260            free_write_pids: Mutex::new(FreePidList::new()),
261        })
262    }
263
264    /// Open a connection to the configured MQTT broker.
265    pub async fn connect(&mut self) -> Result<()> {
266        self.spawn_io_task()?;
267        Ok(())
268    }
269
270    fn spawn_io_task(&mut self) -> Result<()> {
271        self.check_no_io_task()?;
272        let (tx_io_requests, rx_io_requests) =
273            mpsc::channel::<IoRequest>(self.options.packet_buffer_len);
274        // TODO: Change this to allow control messages, e.g. disconnected?
275        let (tx_recv_published, rx_recv_published) =
276            mpsc::channel::<Result<Packet>>(self.options.packet_buffer_len);
277        let halt = Arc::new(AtomicBool::new(false));
278        self.io_task_handle = Some(IoTaskHandle {
279            tx_io_requests,
280            rx_recv_published,
281            halt: halt.clone(),
282        });
283        let io = IoTask {
284            options: self.options.clone(),
285            rx_io_requests,
286            tx_recv_published,
287            state: IoTaskState::Disconnected,
288            subscriptions: BTreeMap::new(),
289            halt,
290        };
291        self.options.runtime.spawn(io.run());
292        Ok(())
293    }
294
295    /// Publish some data on a topic.
296    ///
297    /// Note that this method takes `&self`. This means a caller can
298    /// create several publish futures to publish several payloads of
299    /// data simultaneously without waiting for responses.
300    pub async fn publish(&self, p: &Publish) -> Result<()> {
301        let qos = p.qos();
302        if qos == QoS::ExactlyOnce {
303            return Err("QoS::ExactlyOnce is not supported".into());
304        }
305        let p2 = Packet::Publish(mqttrs::Publish {
306            dup: false, // TODO.
307            qospid: match qos {
308                QoS::AtMostOnce => QosPid::AtMostOnce,
309                QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?),
310                QoS::ExactlyOnce => panic!("Not reached"),
311            },
312            retain: p.retain(),
313            topic_name: p.topic().to_owned(),
314            payload: p.payload().to_owned(),
315        });
316        match qos {
317            QoS::AtMostOnce => {
318                let res = timeout(self.options.operation_timeout,
319                                  self.write_only_packet(&p2)).await;
320                if let Err(Elapsed { .. }) = res {
321                    return Err(format!("Timeout writing publish after {}ms",
322                                       self.options.operation_timeout.as_millis()).into());
323                }
324                res.expect("No timeout")?;
325            }
326            QoS::AtLeastOnce => {
327                let res = timeout(self.options.operation_timeout,
328                                  self.write_response_packet(&p2)).await;
329                if let Err(Elapsed { .. }) = res {
330                    // We report this but can't really deal with it properly.
331                    // The protocol says we can't re-use the packet ID so we have to leak it
332                    // and potentially run out of packet IDs.
333                    return Err(format!("Timeout waiting for Puback after {}ms",
334                                       self.options.operation_timeout.as_millis()).into());
335                }
336                let res = res.expect("No timeout")?;
337                match res {
338                    Packet::Puback(pid) => self.free_write_pid(pid)?,
339                    _ => error!("Bad packet response for publish: {:#?}", res),
340                }
341            },
342            QoS::ExactlyOnce => panic!("Not reached"),
343        };
344        Ok(())
345    }
346
347    /// Subscribe to some topics.`read_subscriptions` will return
348    /// data for them.
349    pub async fn subscribe(&mut self, s: Subscribe) -> Result<SubscribeResult> {
350        let pid = self.alloc_write_pid()?;
351        // TODO: Support subscribe to qos == ExactlyOnce.
352        if s.topics().iter().any(|t| t.qos == QoS::ExactlyOnce) {
353            return Err("Qos::ExactlyOnce is not supported right now".into())
354        }
355        let p = Packet::Subscribe(mqttrs::Subscribe {
356            pid,
357            topics: s.topics().to_owned(),
358        });
359        let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await;
360        if let Err(Elapsed { .. }) = res {
361            // We report this but can't really deal with it properly.
362            // The protocol says we can't re-use the packet ID so we have to leak it
363            // and potentially run out of packet IDs.
364            return Err(format!("Timeout waiting for Suback after {}ms",
365                               self.options.operation_timeout.as_millis()).into());
366        }
367        let res = res.expect("No timeout")?;
368        match res {
369            Packet::Suback(mqttrs::Suback {
370                pid: suback_pid,
371                return_codes: rcs,
372            }) if suback_pid == pid => {
373                self.free_write_pid(pid)?;
374                Ok(SubscribeResult {
375                    return_codes: rcs
376                })
377            },
378            _ => {
379                return Err(format!("Unexpected packet waiting for Suback(Pid={:?}): {:#?}",
380                                   pid, res)
381                           .into());
382            }
383        }
384    }
385
386    /// Unsubscribe from some topics. `read_subscriptions` will no
387    /// longer return data for them.
388    pub async fn unsubscribe(&mut self, u: Unsubscribe) -> Result<()> {
389        let pid = self.alloc_write_pid()?;
390        let p = Packet::Unsubscribe(mqttrs::Unsubscribe {
391            pid,
392            topics: u.topics().iter().map(|ut| ut.topic_name().to_owned())
393                     .collect::<Vec<String>>(),
394        });
395        let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await;
396        if let Err(Elapsed { .. }) = res {
397            // We report this but can't really deal with it properly.
398            // The protocol says we can't re-use the packet ID so we have to leak it
399            // and potentially run out of packet IDs.
400            return Err(format!("Timeout waiting for Unsuback after {}ms",
401                               self.options.operation_timeout.as_millis()).into());
402        }
403        let res = res.expect("No timeout")?;
404        match res {
405            Packet::Unsuback(ack_pid)
406            if ack_pid == pid => {
407                self.free_write_pid(pid)?;
408                Ok(())
409            },
410            _ => {
411                return Err(format!("Unexpected packet waiting for Unsuback(Pid={:?}): {:#?}",
412                                   pid, res)
413                           .into());
414            }
415        }
416    }
417
418    /// Wait for the next Publish packet for one of this Client's subscriptions.
419    pub async fn read_subscriptions(&mut self) -> Result<ReadResult> {
420        let h = self.check_io_task_mut()?;
421        let r = match h.rx_recv_published.recv().await {
422            Some(r) => r?,
423            None => {
424                // Sender closed.
425                self.io_task_handle = None;
426                return Err(Error::Disconnected);
427            }
428        };
429        match r {
430            Packet::Publish(p) => {
431                match p.qospid {
432                    QosPid::AtMostOnce => (),
433                    QosPid::AtLeastOnce(pid) => {
434                        self.write_only_packet(&Packet::Puback(pid)).await?;
435                    },
436                    QosPid::ExactlyOnce(_) => {
437                        error!("Received publish with unimplemented QoS: ExactlyOnce");
438                    }
439                }
440                let rr = ReadResult {
441                    topic: p.topic_name,
442                    payload: p.payload,
443                };
444                Ok(rr)
445            },
446            _ => {
447                return Err(format!("Unexpected packet waiting for read: {:#?}", r).into());
448            }
449        }
450    }
451
452    /// Gracefully close the connection to the server.
453    pub async fn disconnect(&mut self) -> Result<()> {
454        self.check_io_task()?;
455        debug!("Disconnecting");
456        let p = Packet::Disconnect;
457        let res = timeout(self.options.operation_timeout,
458                          self.write_only_packet(&p)).await;
459        if let Err(Elapsed { .. }) = res {
460            return Err(format!("Timeout waiting for Disconnect to send after {}ms",
461                               self.options.operation_timeout.as_millis()).into());
462        }
463        res.expect("No timeout")?;
464        self.shutdown().await?;
465        Ok(())
466    }
467
468    fn alloc_write_pid(&self) -> Result<Pid> {
469        match self.free_write_pids.lock().expect("not poisoned").alloc() {
470            Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")),
471            None => Err(Error::from("No free Pids")),
472        }
473    }
474
475    fn free_write_pid(&self, p: Pid) -> Result<()> {
476        match self.free_write_pids.lock().expect("not poisoned").free(p.get()) {
477            true => Err(Error::from("Pid was already free")),
478            false => Ok(())
479        }
480    }
481
482    async fn shutdown(&mut self) -> Result <()> {
483        let c = self.check_io_task()?;
484        c.halt.store(true, Ordering::SeqCst);
485        self.write_request(IoType::ShutdownConnection, None).await?;
486        self.io_task_handle = None;
487        Ok(())
488    }
489
490    async fn write_only_packet(&self, p: &Packet) -> Result<()> {
491        self.write_request(IoType::WriteOnly { packet: p.clone(), }, None)
492            .await.map(|_v| ())
493
494    }
495
496    async fn write_response_packet(&self, p: &Packet) -> Result<Packet> {
497        let io_type = IoType::WriteAndResponse {
498            packet: p.clone(),
499            response_pid: packet_pid(p).expect("packet_pid"),
500        };
501        let (tx, rx) = oneshot::channel::<IoResult>();
502        self.write_request(io_type, Some(tx))
503            .await?;
504        // TODO: Add a timeout?
505        let res = rx.await.map_err(Error::from_std_err)?;
506        res.result.map(|v| v.expect("return packet"))
507    }
508
509    async fn write_request(&self, io_type: IoType, tx_result: Option<oneshot::Sender<IoResult>>) -> Result<()> {
510        // NB: Some duplication in IoTask::replay_subscriptions.
511
512        let c = self.check_io_task()?;
513        let req = IoRequest { tx_result, io_type };
514        c.tx_io_requests.clone().send(req).await
515            .map_err(|e| Error::from_std_err(e))?;
516        Ok(())
517    }
518
519    fn check_io_task_mut(&mut self) -> Result<&mut IoTaskHandle> {
520        match self.io_task_handle {
521            Some(ref mut h) => Ok(h),
522            None => Err("No IO task, did you call connect?".into()),
523        }
524    }
525
526    fn check_io_task(&self) -> Result<&IoTaskHandle> {
527        match self.io_task_handle {
528            Some(ref h) => Ok(h),
529            None => Err("No IO task, did you call connect?".into()),
530        }
531    }
532
533    fn check_no_io_task(&self) -> Result<()> {
534        match self.io_task_handle {
535            Some(_) => Err("Already spawned IO task".into()),
536            None => Ok(()),
537        }
538    }
539}
540
541/// Start network connection to the server.
542async fn connect_stream(opts: &ClientOptions) -> Result<AsyncStream> {
543    debug!("Connecting to {}", opts.url);
544    let host = opts
545        .url
546        .host_str()
547        .ok_or(Error::String("Missing host".to_owned()))?;
548    match opts.connection_mode {
549        #[cfg(feature = "tls")]
550        ConnectionMode::Tls(ref c) => {
551            let port = opts.url.port().unwrap_or(8883);
552            let connector = TlsConnector::from(c.clone());
553            let domain = DNSNameRef::try_from_ascii_str(host)
554                .map_err(|e| Error::from_std_err(e))?;
555            let tcp = TcpStream::connect((host, port)).await?;
556            let conn = connector.connect(domain, tcp).await?;
557            Ok(AsyncStream::TlsStream(conn))
558        },
559        ConnectionMode::Tcp => {
560            let port = opts.url.port().unwrap_or(1883);
561            let tcp = TcpStream::connect((host, port)).await?;
562            Ok(AsyncStream::TcpStream(tcp))
563        }
564        #[cfg(feature = "websocket")]
565        ConnectionMode::Websocket => {
566            let port = opts.url.port().unwrap_or(80);
567            let path_and_query = format!(
568                "{}{}",
569                opts.url.path(),
570                opts.url
571                    .query()
572                    .map_or("".to_owned(), |q| format!("?{}", q))
573            );
574            let tcp_connection =
575                tokio_tungstenite::MaybeTlsStream::Plain(TcpStream::connect((host, port)).await?);
576            let websocket = tokio_tungstenite::client_async(
577                Request::get(
578                    Uri::builder()
579                        .scheme("ws")
580                        .authority(format!("{}:{}", host, port).as_str())
581                        .path_and_query(path_and_query)
582                        .build()
583                        .map_err(Error::from_std_err)?,
584                )
585                .header("Sec-WebSocket-Protocol", "mqtt")
586                .body(())
587                .unwrap(),
588                tcp_connection,
589            )
590            .await
591            .map_err(crate::util::tungstenite_error_to_std_io_error)?
592            .0;
593            Ok(AsyncStream::WebSocket(websocket))
594        }
595        #[cfg(feature = "websocket")]
596        ConnectionMode::WebsocketSecure(ref c) => {
597            let port = opts.url.port().unwrap_or(443);
598            let tls_stream = TlsConnector::from(c.clone())
599                .connect(
600                    DNSNameRef::try_from_ascii_str(host).map_err(|e| Error::from_std_err(e))?,
601                    TcpStream::connect((host, port)).await?,
602                )
603                .await?;
604            let path_and_query = format!(
605                "{}{}",
606                opts.url.path(),
607                opts.url
608                    .query()
609                    .map_or("".to_owned(), |q| format!("?{}", q))
610            );
611            let websocket = tokio_tungstenite::client_async(
612                Request::get(
613                    Uri::builder()
614                        .scheme("wss")
615                        .authority(format!("{}:{}", host, port).as_str())
616                        .path_and_query(path_and_query)
617                        .build()
618                        .map_err(Error::from_std_err)?,
619                )
620                .header("Sec-WebSocket-Protocol", "mqtt")
621                .body(())
622                .unwrap(),
623                tokio_tungstenite::MaybeTlsStream::Rustls(tls_stream),
624            )
625            .await
626            .map_err(crate::util::tungstenite_error_to_std_io_error)?
627            .0;
628            Ok(AsyncStream::WebSocket(websocket))
629        }
630    }
631}
632
633/// Build a connect packet from ClientOptions.
634fn connect_packet(opts: &ClientOptions) -> Result<Packet> {
635    Ok(Packet::Connect(mqttrs::Connect {
636        protocol: mqttrs::Protocol::MQTT311,
637        keep_alive: match opts.keep_alive {
638            KeepAlive::Disabled => 0,
639            KeepAlive::Enabled { secs } => secs,
640        },
641        client_id: match &opts.client_id {
642            None => "".to_owned(),
643            Some(cid) => cid.to_owned(),
644        },
645        clean_session: true, // TODO
646        last_will: None, // TODO
647        username: opts.username.clone(),
648        password: opts.password.clone(),
649    }))
650}
651
652fn packet_pid(p: &Packet) -> Option<Pid> {
653    match p {
654        Packet::Connect(_) => None,
655        Packet::Connack(_) => None,
656        Packet::Publish(publish) => publish.qospid.pid(),
657        Packet::Puback(pid) => Some(pid.to_owned()),
658        Packet::Pubrec(pid) => Some(pid.to_owned()),
659        Packet::Pubrel(pid) => Some(pid.to_owned()),
660        Packet::Pubcomp(pid) => Some(pid.to_owned()),
661        Packet::Subscribe(sub) => Some(sub.pid),
662        Packet::Suback(suback) => Some(suback.pid),
663        Packet::Unsubscribe(unsub) => Some(unsub.pid),
664        Packet::Unsuback(pid) => Some(pid.to_owned()),
665        Packet::Pingreq => None,
666        Packet::Pingresp => None,
667        Packet::Disconnect => None,
668    }
669}
670
671/// Represents what happened "next" that we should handle.
672enum SelectResult {
673    /// An IO request from the Client
674    IoReq(Option<IoRequest>),
675
676    /// Read a packet from the network
677    Read(Result<Packet>),
678
679    /// Time to send a keep-alive ping request packet.
680    Ping,
681
682    /// Timeout waiting for a Pingresp.
683    PingrespExpected,
684}
685
686impl IoTask {
687    async fn run(mut self) {
688        loop {
689            if self.halt.load(Ordering::SeqCst) {
690                self.shutdown_conn().await;
691                debug!("IoTask: halting by request.");
692                self.state = IoTaskState::Halted;
693                return;
694            }
695
696            match self.state {
697                IoTaskState::Halted => return,
698                IoTaskState::Disconnected =>
699                    match Self::try_connect(&mut self).await {
700                        Err(e) => {
701                            error!("IoTask: Error connecting: {}", e);
702                            if self.options.automatic_connect {
703                                sleep(self.options.connect_retry_delay).await;
704                            } else {
705                                info!("IoTask: halting due to connection failure, auto connect is off.");
706                                self.state = IoTaskState::Halted;
707                                return;
708                            }
709                        },
710                        Ok(()) => {
711                            if let Err(e) = Self::replay_subscriptions(&mut self).await {
712                                error!("IoTask: Error replaying subscriptions on reconnect: {}",
713                                       e);
714                            }
715                        },
716                    },
717                IoTaskState::Connected(_) =>
718                    match Self::run_once_connected(&mut self).await {
719                        Err(Error::Disconnected) => {
720                            info!("IoTask: Disconnected, resetting state");
721                            self.state = IoTaskState::Disconnected;
722                        },
723                        Err(e) => {
724                            error!("IoTask: Quitting run loop due to error: {}", e);
725                            return;
726                        },
727                        _ => {},
728                    },
729            }
730        }
731    }
732
733    async fn try_connect(&mut self) -> Result<()> {
734        let stream = connect_stream(&self.options).await?;
735        self.state =  IoTaskState::Connected(IoTaskConnected {
736            stream,
737            read_buf: BytesMut::with_capacity(self.options.max_packet_len),
738            read_bufn: 0,
739            last_write_time: Instant::now(),
740            last_pingreq_time: Instant::now(),
741            last_pingresp_time: Instant::now(),
742            pid_response_map: BTreeMap::new(),
743        });
744        let c = match self.state {
745            IoTaskState::Connected(ref mut c) => c,
746            _ => panic!("Not reached"),
747        };
748        let conn = connect_packet(&self.options)?;
749        debug!("IoTask: Sending connect packet");
750        Self::write_packet(&self.options, c, &conn).await?;
751        let read = Self::read_packet(&mut c.stream,
752                                     &mut c.read_buf,
753                                     &mut c.read_bufn,
754                                     self.options.max_packet_len);
755        let res = match timeout(self.options.operation_timeout,
756                                read).await {
757            // Timeout
758            Err(Elapsed { .. }) =>
759                Err(format!("Timeout waiting for Connack after {}ms",
760                            self.options.operation_timeout.as_millis()).into()),
761
762            // Non-timeout error
763            Ok(Err(e)) => Err(e),
764
765            Ok(Ok(Packet::Connack(ca))) => {
766                match ca.code {
767                    ConnectReturnCode::Accepted => {
768                        debug!("IoTask: connack with code=Accepted.");
769                        Ok(())
770                    },
771                    _ => Err(format!("Bad connect return code: {:?}", ca.code).into()),
772                }
773            },
774
775            // Other unexpected packets.
776            Ok(Ok(p)) =>
777                Err(format!("Received packet not CONNACK after connect: {:?}", p).into()),
778        };
779        match res {
780            Ok(()) => Ok(()),
781            Err(e) => {
782                self.shutdown_conn().await;
783                Err(e)
784            },
785        }
786    }
787
788    /// Shutdown the network connection to the MQTT broker.
789    ///
790    /// Logs and swallows errors.
791    async fn shutdown_conn(&mut self) {
792        debug!("IoTask: shutdown_conn");
793        let c = match self.state {
794            // Already disconnected / halted, nothing more to do.
795            IoTaskState::Disconnected |
796            IoTaskState::Halted => return,
797
798            IoTaskState::Connected(ref mut c) => c,
799        };
800
801        if let Err(e) = c.stream.shutdown().await {
802            if e.kind() != std::io::ErrorKind::NotConnected {
803                error!("IoTask: Error on stream shutdown in shutdown_conn: {:?}", e);
804            }
805        }
806        self.state = IoTaskState::Disconnected;
807    }
808
809    async fn replay_subscriptions(&mut self) -> Result<()> {
810        // NB: Some duplication in Client::subscribe and Client::write_request.
811        let subs = self.subscriptions.clone();
812        for (t, qos) in subs.iter() {
813            trace!("Replaying subscription topic='{}' qos={:?}", t, qos);
814            // Pick a high pid to probably avoid collisions with one allocated
815            // by the Client.
816            let pid = Pid::try_from(65535).expect("non-zero pid");
817            let p = Packet::Subscribe(mqttrs::Subscribe {
818                pid,
819                topics: vec![SubscribeTopic { topic_path: t.to_owned(), qos: qos.to_owned() }]
820            });
821            let req = IoRequest {
822                io_type: IoType::WriteAndResponse { packet: p, response_pid: pid },
823                // TODO: I'm not sure how to receive the result; ignore it for now.
824                tx_result: None,
825            };
826            self.handle_io_req(req).await?;
827        }
828        Ok(())
829    }
830
831    /// Unhandled errors are returned and terminate the run loop.
832    async fn run_once_connected(&mut self) -> Result<()> {
833        let c = match self.state {
834            IoTaskState::Connected(ref mut c) => c,
835            _ => panic!("Not reached"),
836        };
837        let pingreq_next = self.options.keep_alive.as_duration()
838            .map(|dur| c.last_write_time + dur);
839
840        let pingresp_expected_by =
841            if self.options.keep_alive.is_enabled() &&
842                c.last_pingreq_time > c.last_pingresp_time
843            {
844                // Expect a ping response before the operation timeout and the keepalive interval.
845                // If the keepalive interval expired first then the "next operation" as
846                // returned by SelectResult below would be Ping even when Pingresp is expected,
847                // and we would never time out the connection.
848                let ka = self.options.keep_alive.as_duration().expect("enabled");
849                Some(c.last_pingreq_time + min(self.options.operation_timeout, ka))
850            } else {
851                None
852            };
853
854        // Select over futures to determine what to do next:
855        // * Handle a write request from the Client
856        // * Handle an incoming packet from the network
857        // * Handle a keep-alive period elapsing and send a ping request
858        // * Handle a PingrespExpected timeout and disconnect
859        //
860        // From these futures we compute an enum value in sel_res
861        // that encapsulates what to do next, then match over
862        // sel_res to actually do the work. The reason for this
863        // structure is just to keep the borrow checker happy.
864        // The futures calculation uses a mutable borrow on `stream`
865        // for the `read_packet` call, but the mutable borrow ends there.
866        // Then when we want to do the work we can take a new, separate mutable
867        // borrow to write packets based on IO requests.
868        // These two mutable borrows don't overlap.
869        let sel_res: SelectResult = {
870            let mut req_fut = Box::pin(self.rx_io_requests.recv().fuse());
871            let mut read_fut = Box::pin(
872                Self::read_packet(&mut c.stream, &mut c.read_buf, &mut c.read_bufn,
873                                  self.options.max_packet_len).fuse());
874            let mut ping_fut = match pingreq_next {
875                Some(t) => Box::pin(sleep_until(t).boxed().fuse()),
876                None => Box::pin(pending().boxed().fuse()),
877            };
878            let mut pingresp_expected_fut = match pingresp_expected_by {
879                Some(t) => Box::pin(sleep_until(t).boxed().fuse()),
880                None => Box::pin(pending().boxed().fuse()),
881            };
882            select! {
883                req = req_fut => SelectResult::IoReq(req),
884                read = read_fut => SelectResult::Read(read),
885                _ = ping_fut => SelectResult::Ping,
886                _ = pingresp_expected_fut => SelectResult::PingrespExpected,
887            }
888        };
889        match sel_res {
890            SelectResult::Read(read) => return self.handle_read(read).await,
891            SelectResult::IoReq(req) => match req {
892                None => {
893                    // Sender closed.
894                    debug!("IoTask: Req stream closed, shutting down.");
895                    self.shutdown_conn().await;
896                    return Err(Error::Disconnected);
897                },
898                Some(req) => return self.handle_io_req(req).await,
899            },
900            SelectResult::Ping => return self.send_ping().await,
901            SelectResult::PingrespExpected => {
902                // We timed out waiting for a ping response from
903                // the server, shutdown the stream.
904                debug!("IoTask: Timed out waiting for Pingresp, shutting down.");
905                self.shutdown_conn().await;
906                return Err(Error::Disconnected);
907            }
908        }
909    }
910
911    async fn handle_read(&mut self, read: Result<Packet>) -> Result<()> {
912        let c = match self.state {
913            IoTaskState::Connected(ref mut c) => c,
914            _ => panic!("Not reached"),
915        };
916
917        match read {
918            Err(Error::Disconnected) => {
919                self.tx_recv_published.send(Err(Error::Disconnected)).await
920                    .map_err(Error::from_std_err)?;
921            }
922            Err(e) => {
923                self.tx_recv_published.send(
924                        Err(format!("IoTask: Failed to read packet: {:?}", e).into())
925                    )
926                    .await
927                    .map_err(Error::from_std_err)?;
928            },
929            Ok(p) => {
930                match p {
931                    Packet::Pingresp => {
932                        debug!("IoTask: Received Pingresp");
933                        c.last_pingresp_time = Instant::now();
934                    },
935                    Packet::Publish(_) => {
936                        if let Err(e) = self.tx_recv_published.send(Ok(p)).await {
937                            error!("IoTask: Failed to send Packet: {:?}", e);
938                        }
939                    },
940                    Packet::Connack(_) => {
941                        error!("IoTask: Unexpected CONNACK in handle_read(): {:?}", p);
942                        self.shutdown_conn().await;
943                        return Err(Error::Disconnected);
944                    }
945                    _ => {
946                        let pid = packet_pid(&p);
947                        if let Some(pid) = pid {
948                            let pid_response = c.pid_response_map.remove(&pid);
949                            match pid_response {
950                                None => error!("Unknown PID: {:?}", pid),
951                                Some(req) => {
952                                    trace!("Sending response PID={:?} p={:?}",
953                                           pid, p);
954                                    let res = IoResult { result: Ok(Some(p)) };
955                                    Self::send_io_result(req, res)?;
956                                },
957                            }
958                        }
959                    },
960                }
961            },
962        }
963        Ok(())
964    }
965
966    async fn handle_io_req(&mut self, req: IoRequest) -> Result<()> {
967        let c = match self.state {
968            IoTaskState::Connected(ref mut c) => c,
969            _ => panic!("Not reached"),
970        };
971        let packet = req.io_type.packet();
972        if let Some(p) = packet {
973            c.last_write_time = Instant::now();
974            let res = Self::write_packet(&self.options, c, &p).await;
975            if let Err(e) = res {
976                error!("IoTask: Error writing packet: {:?}", e);
977                let res = IoResult { result: Err(e) };
978                Self::send_io_result(req, res)?;
979                return Ok(())
980            }
981            match p {
982                Packet::Subscribe(s) => {
983                    for st in s.topics.iter() {
984                        trace!("Tracking subscription topic='{}', qos={:?}",
985                               st.topic_path, st.qos);
986                        let _ = self.subscriptions.insert(st.topic_path.clone(), st.qos);
987                    }
988                },
989                Packet::Unsubscribe(u) => {
990                    for t in u.topics.iter() {
991                        trace!("Tracking unsubscription topic='{}'", t);
992                        let _ = self.subscriptions.remove(t);
993                    }
994                },
995                _ => {},
996            }
997            match req.io_type {
998                IoType::WriteOnly { .. } => {
999                    let res = IoResult { result: res.map(|_| None) };
1000                    Self::send_io_result(req, res)?;
1001                },
1002                IoType::WriteAndResponse { response_pid, .. } => {
1003                    c.pid_response_map.insert(response_pid, req);
1004                },
1005                IoType::ShutdownConnection => {
1006                    panic!("Not reached because ShutdownConnection has no packet")
1007                },
1008            }
1009        } else {
1010            match req.io_type {
1011                IoType::ShutdownConnection => {
1012                    debug!("IoTask: IoType::ShutdownConnection.");
1013                    self.shutdown_conn().await;
1014                    let res = IoResult { result: Ok(None) };
1015                    Self::send_io_result(req, res)?;
1016                    return Err(Error::Disconnected);
1017                }
1018                _ => (),
1019            }
1020        }
1021        Ok(())
1022    }
1023
1024    fn send_io_result(req: IoRequest, res: IoResult) -> Result<()> {
1025        match req.tx_result {
1026            Some(tx) => {
1027                if let Err(e) = tx.send(res) {
1028                    error!("IoTask: Failed to send IoResult={:?}", e);
1029                }
1030            },
1031            None => {
1032                debug!("IoTask: Ignored IoResult: {:?}", res);
1033            },
1034        }
1035        Ok(())
1036    }
1037
1038    async fn send_ping(&mut self) -> Result<()> {
1039        let c = match self.state {
1040            IoTaskState::Connected(ref mut c) => c,
1041            _ => panic!("Not reached"),
1042        };
1043        debug!("IoTask: Writing Pingreq");
1044        c.last_write_time = Instant::now();
1045        c.last_pingreq_time = Instant::now();
1046        let p = Packet::Pingreq;
1047        if let Err(e) = Self::write_packet(&self.options, c, &p).await {
1048            error!("IoTask: Failed to write ping: {:?}", e);
1049        }
1050        Ok(())
1051    }
1052
1053    async fn write_packet(
1054        opts: &ClientOptions,
1055        c: &mut IoTaskConnected,
1056        p: &Packet,
1057    ) -> Result<()> {
1058        if cfg!(feature = "unsafe-logging") {
1059            trace!("write_packet p={:#?}", p);
1060        }
1061        // TODO: Test long packets.
1062        let mut bytes = BytesMut::with_capacity(opts.max_packet_len);
1063        mqttrs::encode(&p, &mut bytes)?;
1064        if cfg!(feature = "unsafe-logging") {
1065            trace!("write_packet bytes p={:?}", &*bytes);
1066        }
1067        c.stream.write_all(&*bytes).await?;
1068        Ok(())
1069    }
1070
1071    async fn read_packet(
1072        stream: &mut AsyncStream,
1073        read_buf: &mut BytesMut,
1074        read_bufn: &mut usize,
1075        max_packet_len: usize
1076    ) -> Result<Packet> {
1077        // TODO: Test long packets.
1078        loop {
1079            if cfg!(feature = "unsafe-logging") {
1080                trace!("read_packet Decoding buf={:?}", &read_buf[0..*read_bufn]);
1081            }
1082            if *read_bufn > 0 {
1083                // We already have some bytes in the buffer. Try to decode a packet
1084                read_buf.split_off(*read_bufn);
1085                let old_len = read_buf.len();
1086                let decoded = mqttrs::decode(read_buf)?;
1087                if cfg!(feature = "unsafe-logging") {
1088                    trace!("read_packet decoded={:#?}", decoded);
1089                }
1090                if let Some(p) = decoded {
1091                    let new_len = read_buf.len();
1092                    trace!("read_packet old_len={} new_len={} read_bufn={}",
1093                           old_len, new_len, *read_bufn);
1094                    *read_bufn -= old_len - new_len;
1095                    if cfg!(feature = "unsafe-logging") {
1096                        trace!("read_packet Remaining buf={:?}", &read_buf[0..*read_bufn]);
1097                    }
1098                    return Ok(p);
1099                }
1100            }
1101            read_buf.resize(max_packet_len, 0u8);
1102            let readlen = read_buf.len();
1103            trace!("read_packet read read_bufn={} readlen={}", *read_bufn, readlen);
1104            let nread = stream.read(&mut read_buf[*read_bufn..readlen]).await?;
1105            *read_bufn += nread;
1106            if nread == 0 {
1107                // Socket disconnected
1108                error!("IoTask: Socket disconnected");
1109                return Err(Error::Disconnected);
1110            }
1111        }
1112    }
1113}
1114
1115impl IoType {
1116    fn packet(&self) -> Option<&Packet> {
1117        match self {
1118            IoType::ShutdownConnection => None,
1119            IoType::WriteOnly { packet } => Some(&packet),
1120            IoType::WriteAndResponse { packet, .. } => Some(&packet),
1121        }
1122    }
1123}
1124
1125/// An enum for specifying which mode we will use to connect to the broker
1126#[derive(Clone)]
1127pub enum ConnectionMode {
1128    Tcp,
1129    #[cfg(feature = "websocket")]
1130    Websocket,
1131    #[cfg(feature = "websocket")]
1132    WebsocketSecure(Arc<rustls::ClientConfig>),
1133    #[cfg(feature = "tls")]
1134    Tls(Arc<rustls::ClientConfig>),
1135}
1136impl Default for ConnectionMode {
1137    fn default() -> Self {
1138        Self::Tcp
1139    }
1140}
1141
1142
1143#[cfg(test)]
1144mod test {
1145    use super::Client;
1146
1147    #[test]
1148    fn client_is_send() {
1149        let c = Client::builder()
1150            .set_url_string("mqtt://localhost")
1151            .unwrap()
1152            .build()
1153            .unwrap();
1154        let _s: &dyn Send = &c;
1155    }
1156}