1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4#[cfg(all(feature = "use-rustls-ring", feature = "use-rustls-aws-lc"))]
5compile_error!(
6 "Features `use-rustls-ring` and `use-rustls-aws-lc` are mutually exclusive. Enable only one rustls provider feature."
7);
8
9#[macro_use]
10extern crate log;
11
12use bytes::Bytes;
13use std::fmt::{self, Debug, Formatter};
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::net::{TcpStream, lookup_host};
20use tokio::task::JoinSet;
21
22#[cfg(all(feature = "url", unix))]
23use percent_encoding::percent_decode_str;
24
25#[cfg(all(feature = "url", unix))]
26use std::{ffi::OsString, os::unix::ffi::OsStringExt};
27
28mod client;
29mod eventloop;
30mod framed;
31pub mod mqttbytes;
32mod notice;
33mod state;
34mod transport;
35
36#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
37mod tls;
38
39#[cfg(feature = "websocket")]
40mod websockets;
41
42#[cfg(feature = "websocket")]
43use std::{
44 future::{Future, IntoFuture},
45 pin::Pin,
46};
47
48#[cfg(feature = "websocket")]
49type RequestModifierError = Box<dyn std::error::Error + Send + Sync>;
50
51#[cfg(feature = "websocket")]
52type RequestModifierFn = Arc<
53 dyn Fn(http::Request<()>) -> Pin<Box<dyn Future<Output = http::Request<()>> + Send>>
54 + Send
55 + Sync,
56>;
57
58#[cfg(feature = "websocket")]
59type FallibleRequestModifierFn = Arc<
60 dyn Fn(
61 http::Request<()>,
62 )
63 -> Pin<Box<dyn Future<Output = Result<http::Request<()>, RequestModifierError>> + Send>>
64 + Send
65 + Sync,
66>;
67
68#[cfg(feature = "proxy")]
69mod proxy;
70
71pub use client::{
72 AsyncClient, AsyncClientBuilder, Client, ClientBuilder, ClientError, Connection, InvalidTopic,
73 Iter, ManualAck, PublishTopic, RecvError, RecvTimeoutError, TryRecvError, ValidatedTopic,
74};
75pub use eventloop::{ConnectionError, Event, EventLoop};
76pub use mqttbytes::v4::*;
77pub use mqttbytes::*;
78pub use notice::{
79 NoticeFailureReason, PublishNotice, PublishNoticeError, PublishResult, SubscribeNotice,
80 SubscribeNoticeError, UnsubscribeNotice, UnsubscribeNoticeError,
81};
82pub use rumqttc_core::NetworkOptions;
83#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
84pub use rumqttc_core::TlsConfiguration;
85pub use rumqttc_core::default_socket_connect;
86pub use state::{MqttState, MqttStateBuilder, StateError};
87#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
88pub use tls::Error as TlsError;
89#[cfg(feature = "use-native-tls")]
90pub use tokio_native_tls;
91#[cfg(feature = "use-rustls-no-provider")]
92pub use tokio_rustls;
93pub use transport::Transport;
94
95#[cfg(feature = "proxy")]
96pub use proxy::{Proxy, ProxyAuth, ProxyType};
97
98pub type Incoming = Packet;
99
100#[derive(Debug, Clone, PartialEq, Eq)]
102pub enum Outgoing {
103 Publish(u16),
105 Subscribe(u16),
107 Unsubscribe(u16),
109 PubAck(u16),
111 PubRec(u16),
113 PubRel(u16),
115 PubComp(u16),
117 PingReq,
119 PingResp,
121 Disconnect,
123 AwaitAck(u16),
125}
126
127#[derive(Clone, Debug, PartialEq, Eq)]
130pub enum Request {
131 Publish(Publish),
132 PubAck(PubAck),
133 PubRec(PubRec),
134 PubComp(PubComp),
135 PubRel(PubRel),
136 PingReq(PingReq),
137 PingResp(PingResp),
138 Subscribe(Subscribe),
139 SubAck(SubAck),
140 Unsubscribe(Unsubscribe),
141 UnsubAck(UnsubAck),
142 Disconnect(Disconnect),
143 DisconnectNow(Disconnect),
144 DisconnectWithTimeout(Disconnect, Duration),
145}
146
147impl From<Publish> for Request {
148 fn from(publish: Publish) -> Self {
149 Self::Publish(publish)
150 }
151}
152
153impl From<Subscribe> for Request {
154 fn from(subscribe: Subscribe) -> Self {
155 Self::Subscribe(subscribe)
156 }
157}
158
159impl From<Unsubscribe> for Request {
160 fn from(unsubscribe: Unsubscribe) -> Self {
161 Self::Unsubscribe(unsubscribe)
162 }
163}
164
165pub(crate) type SocketConnector = rumqttc_core::SocketConnector;
167
168const CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(100);
169
170async fn first_success_with_stagger<T, I, F, Fut>(
171 items: I,
172 attempt_delay: Duration,
173 connect_fn: F,
174) -> io::Result<T>
175where
176 T: Send + 'static,
177 I: IntoIterator,
178 I::Item: Send + 'static,
179 F: Fn(I::Item) -> Fut + Send + Sync + Clone + 'static,
180 Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
181{
182 let mut join_set = JoinSet::new();
183 let mut item_count = 0usize;
184
185 for (index, item) in items.into_iter().enumerate() {
186 item_count += 1;
187 let delay = attempt_delay.saturating_mul(u32::try_from(index).unwrap_or(u32::MAX));
188 let connect_fn = connect_fn.clone();
189 join_set.spawn(async move {
190 tokio::time::sleep(delay).await;
191 connect_fn(item).await
192 });
193 }
194
195 if item_count == 0 {
196 return Err(io::Error::new(
197 io::ErrorKind::InvalidInput,
198 "could not resolve to any address",
199 ));
200 }
201
202 let mut last_err = None;
203
204 while let Some(task_result) = join_set.join_next().await {
205 match task_result {
206 Ok(Ok(stream)) => {
207 join_set.abort_all();
208 return Ok(stream);
209 }
210 Ok(Err(err)) => {
211 last_err = Some(err);
212 }
213 Err(err) => {
214 last_err = Some(io::Error::other(format!(
215 "concurrent connect task failed: {err}"
216 )));
217 }
218 }
219 }
220
221 Err(last_err.unwrap_or_else(|| {
222 io::Error::new(
223 io::ErrorKind::InvalidInput,
224 "could not resolve to any address",
225 )
226 }))
227}
228
229async fn first_success_sequential<T, I, F, Fut>(items: I, connect_fn: F) -> io::Result<T>
230where
231 I: IntoIterator,
232 F: Fn(I::Item) -> Fut,
233 Fut: std::future::Future<Output = io::Result<T>>,
234{
235 let mut item_count = 0usize;
236 let mut last_err = None;
237
238 for item in items {
239 item_count += 1;
240 match connect_fn(item).await {
241 Ok(stream) => return Ok(stream),
242 Err(err) => last_err = Some(err),
243 }
244 }
245
246 if item_count == 0 {
247 return Err(io::Error::new(
248 io::ErrorKind::InvalidInput,
249 "could not resolve to any address",
250 ));
251 }
252
253 Err(last_err.unwrap_or_else(|| {
254 io::Error::new(
255 io::ErrorKind::InvalidInput,
256 "could not resolve to any address",
257 )
258 }))
259}
260
261fn should_stagger_connect_attempts(network_options: &NetworkOptions) -> bool {
262 network_options
263 .bind_addr()
264 .is_none_or(|bind_addr| bind_addr.port() == 0)
265}
266
267async fn connect_with_retry_mode<T, I, F, Fut>(
268 items: I,
269 network_options: NetworkOptions,
270 connect_fn: F,
271) -> io::Result<T>
272where
273 T: Send + 'static,
274 I: IntoIterator,
275 I::Item: Send + 'static,
276 F: Fn(I::Item, NetworkOptions) -> Fut + Send + Sync + Clone + 'static,
277 Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
278{
279 connect_with_retry_mode_and_delay(items, network_options, CONNECTION_ATTEMPT_DELAY, connect_fn)
280 .await
281}
282
283async fn connect_with_retry_mode_and_delay<T, I, F, Fut>(
284 items: I,
285 network_options: NetworkOptions,
286 connection_attempt_delay: Duration,
287 connect_fn: F,
288) -> io::Result<T>
289where
290 T: Send + 'static,
291 I: IntoIterator,
292 I::Item: Send + 'static,
293 F: Fn(I::Item, NetworkOptions) -> Fut + Send + Sync + Clone + 'static,
294 Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
295{
296 if should_stagger_connect_attempts(&network_options) {
297 first_success_with_stagger(items, connection_attempt_delay, move |item| {
298 let network_options = network_options.clone();
299 let connect_fn = connect_fn.clone();
300 async move { connect_fn(item, network_options).await }
301 })
302 .await
303 } else {
304 first_success_sequential(items, move |item| {
305 let network_options = network_options.clone();
306 let connect_fn = connect_fn.clone();
307 async move { connect_fn(item, network_options).await }
308 })
309 .await
310 }
311}
312
313async fn connect_resolved_addrs_staggered(
314 addrs: Vec<SocketAddr>,
315 network_options: NetworkOptions,
316) -> io::Result<TcpStream> {
317 connect_with_retry_mode(
318 addrs,
319 network_options,
320 move |addr, network_options| async move {
321 rumqttc_core::connect_socket_addr(addr, network_options).await
322 },
323 )
324 .await
325}
326
327async fn default_socket_connect_staggered(
328 host: String,
329 network_options: NetworkOptions,
330) -> io::Result<TcpStream> {
331 let addrs = lookup_host(host).await?.collect::<Vec<_>>();
332 connect_resolved_addrs_staggered(addrs, network_options).await
333}
334
335fn default_socket_connector() -> SocketConnector {
336 Arc::new(|host, network_options| {
337 Box::pin(async move {
338 let tcp = default_socket_connect_staggered(host, network_options).await?;
339 Ok(Box::new(tcp) as Box<dyn crate::framed::AsyncReadWrite>)
340 })
341 })
342}
343
344const DEFAULT_BROKER_PORT: u16 = 1883;
345
346#[derive(Clone, Debug, PartialEq, Eq)]
348pub struct Broker {
349 inner: BrokerInner,
350}
351
352#[derive(Clone, Debug, PartialEq, Eq)]
353enum BrokerInner {
354 Tcp {
355 host: String,
356 port: u16,
357 },
358 #[cfg(unix)]
359 Unix {
360 path: PathBuf,
361 },
362 #[cfg(feature = "websocket")]
363 Websocket {
364 url: String,
365 },
366}
367
368impl Broker {
369 #[must_use]
370 pub fn tcp<S: Into<String>>(host: S, port: u16) -> Self {
371 Self {
372 inner: BrokerInner::Tcp {
373 host: host.into(),
374 port,
375 },
376 }
377 }
378
379 #[cfg(unix)]
380 #[must_use]
381 pub fn unix<P: Into<PathBuf>>(path: P) -> Self {
382 Self {
383 inner: BrokerInner::Unix { path: path.into() },
384 }
385 }
386
387 #[cfg(feature = "websocket")]
388 pub fn websocket<S: Into<String>>(url: S) -> Result<Self, OptionError> {
394 let url = url.into();
395 let uri = url
396 .parse::<http::Uri>()
397 .map_err(|_| OptionError::WebsocketUrl)?;
398
399 match uri.scheme_str() {
400 Some("ws") => {
401 rumqttc_core::split_url(&url).map_err(|_| OptionError::WebsocketUrl)?;
402 Ok(Self {
403 inner: BrokerInner::Websocket { url },
404 })
405 }
406 Some("wss") => Err(OptionError::WssRequiresExplicitTransport),
407 _ => Err(OptionError::Scheme),
408 }
409 }
410
411 #[must_use]
412 pub const fn tcp_address(&self) -> Option<(&str, u16)> {
413 match &self.inner {
414 BrokerInner::Tcp { host, port } => Some((host.as_str(), *port)),
415 #[cfg(unix)]
416 BrokerInner::Unix { .. } => None,
417 #[cfg(feature = "websocket")]
418 BrokerInner::Websocket { .. } => None,
419 }
420 }
421
422 #[cfg(unix)]
423 #[must_use]
424 pub fn unix_path(&self) -> Option<&std::path::Path> {
425 match &self.inner {
426 BrokerInner::Unix { path } => Some(path.as_path()),
427 BrokerInner::Tcp { .. } => None,
428 #[cfg(feature = "websocket")]
429 BrokerInner::Websocket { .. } => None,
430 }
431 }
432
433 #[cfg(feature = "websocket")]
434 #[must_use]
435 pub const fn websocket_url(&self) -> Option<&str> {
436 match &self.inner {
437 BrokerInner::Websocket { url } => Some(url.as_str()),
438 BrokerInner::Tcp { .. } => None,
439 #[cfg(unix)]
440 BrokerInner::Unix { .. } => None,
441 }
442 }
443
444 pub(crate) const fn default_transport(&self) -> Transport {
445 match &self.inner {
446 BrokerInner::Tcp { .. } => Transport::tcp(),
447 #[cfg(unix)]
448 BrokerInner::Unix { .. } => Transport::unix(),
449 #[cfg(feature = "websocket")]
450 BrokerInner::Websocket { .. } => Transport::Ws,
451 }
452 }
453}
454
455impl From<&str> for Broker {
456 fn from(host: &str) -> Self {
457 Self::tcp(host, DEFAULT_BROKER_PORT)
458 }
459}
460
461impl From<String> for Broker {
462 fn from(host: String) -> Self {
463 Self::tcp(host, DEFAULT_BROKER_PORT)
464 }
465}
466
467impl<S: Into<String>> From<(S, u16)> for Broker {
468 fn from((host, port): (S, u16)) -> Self {
469 Self::tcp(host, port)
470 }
471}
472
473#[derive(Clone)]
475pub struct MqttOptions {
476 broker: Broker,
478 transport: Transport,
479 keep_alive: Duration,
481 clean_session: bool,
483 client_id: String,
485 auth: ConnectAuth,
487 max_incoming_packet_size: usize,
489 max_outgoing_packet_size: usize,
491 request_channel_capacity: usize,
493 max_request_batch: usize,
495 read_batch_size: usize,
498 pending_throttle: Duration,
501 inflight: u16,
503 last_will: Option<LastWill>,
505 manual_acks: bool,
509 #[cfg(feature = "proxy")]
510 proxy: Option<Proxy>,
512 #[cfg(feature = "websocket")]
513 request_modifier: Option<RequestModifierFn>,
514 #[cfg(feature = "websocket")]
515 fallible_request_modifier: Option<FallibleRequestModifierFn>,
516 socket_connector: Option<SocketConnector>,
517}
518
519impl MqttOptions {
520 pub fn new<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> Self {
529 let broker = broker.into();
530 Self {
531 transport: broker.default_transport(),
532 broker,
533 keep_alive: Duration::from_secs(60),
534 clean_session: true,
535 client_id: id.into(),
536 auth: ConnectAuth::None,
537 max_incoming_packet_size: 10 * 1024,
538 max_outgoing_packet_size: 10 * 1024,
539 request_channel_capacity: 10,
540 max_request_batch: 0,
541 read_batch_size: 0,
542 pending_throttle: Duration::from_micros(0),
543 inflight: 100,
544 last_will: None,
545 manual_acks: false,
546 #[cfg(feature = "proxy")]
547 proxy: None,
548 #[cfg(feature = "websocket")]
549 request_modifier: None,
550 #[cfg(feature = "websocket")]
551 fallible_request_modifier: None,
552 socket_connector: None,
553 }
554 }
555
556 #[must_use]
566 pub fn builder<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> MqttOptionsBuilder {
567 MqttOptionsBuilder::new(id, broker)
568 }
569
570 #[cfg(feature = "url")]
571 pub fn parse_url<S: Into<String>>(url: S) -> Result<Self, OptionError> {
605 let url = url::Url::parse(&url.into())?;
606 let options = Self::try_from(url)?;
607
608 Ok(options)
609 }
610
611 pub const fn broker(&self) -> &Broker {
613 &self.broker
614 }
615
616 pub fn set_last_will(&mut self, will: LastWill) -> &mut Self {
617 self.last_will = Some(will);
618 self
619 }
620
621 pub fn last_will(&self) -> Option<LastWill> {
622 self.last_will.clone()
623 }
624
625 pub fn set_client_id(&mut self, client_id: String) -> &mut Self {
626 self.client_id = client_id;
627 self
628 }
629
630 #[cfg(not(any(feature = "use-rustls-no-provider", feature = "use-native-tls")))]
631 pub const fn set_transport(&mut self, transport: Transport) -> &mut Self {
632 self.transport = transport;
633 self
634 }
635
636 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
637 pub fn set_transport(&mut self, transport: Transport) -> &mut Self {
638 self.transport = transport;
639 self
640 }
641
642 pub fn transport(&self) -> Transport {
644 self.transport.clone()
645 }
646
647 pub fn set_keep_alive(&mut self, seconds: u16) -> &mut Self {
650 self.keep_alive = Duration::from_secs(u64::from(seconds));
651 self
652 }
653
654 pub const fn keep_alive(&self) -> Duration {
656 self.keep_alive
657 }
658
659 pub fn client_id(&self) -> String {
661 self.client_id.clone()
662 }
663
664 pub const fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self {
666 self.max_incoming_packet_size = incoming;
667 self.max_outgoing_packet_size = outgoing;
668 self
669 }
670
671 pub const fn max_packet_size(&self) -> usize {
673 self.max_incoming_packet_size
674 }
675
676 pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self {
693 assert!(
694 !self.client_id.is_empty() || clean_session,
695 "Cannot unset clean session when client id is empty"
696 );
697 self.clean_session = clean_session;
698 self
699 }
700
701 pub const fn clean_session(&self) -> bool {
703 self.clean_session
704 }
705
706 pub fn set_auth(&mut self, auth: ConnectAuth) -> &mut Self {
717 self.auth = auth;
718 self
719 }
720
721 pub fn clear_auth(&mut self) -> &mut Self {
723 self.auth = ConnectAuth::None;
724 self
725 }
726
727 pub fn set_username<U: Into<String>>(&mut self, username: U) -> &mut Self {
743 self.auth = ConnectAuth::Username {
744 username: username.into(),
745 };
746 self
747 }
748
749 pub fn set_credentials<U: Into<String>, P: Into<Bytes>>(
767 &mut self,
768 username: U,
769 password: P,
770 ) -> &mut Self {
771 self.auth = ConnectAuth::UsernamePassword {
772 username: username.into(),
773 password: password.into(),
774 };
775 self
776 }
777
778 pub const fn auth(&self) -> &ConnectAuth {
792 &self.auth
793 }
794
795 pub const fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self {
797 self.request_channel_capacity = capacity;
798 self
799 }
800
801 pub const fn request_channel_capacity(&self) -> usize {
803 self.request_channel_capacity
804 }
805
806 pub const fn set_max_request_batch(&mut self, max: usize) -> &mut Self {
810 self.max_request_batch = max;
811 self
812 }
813
814 pub const fn max_request_batch(&self) -> usize {
816 self.max_request_batch
817 }
818
819 pub const fn set_read_batch_size(&mut self, size: usize) -> &mut Self {
823 self.read_batch_size = size;
824 self
825 }
826
827 pub const fn read_batch_size(&self) -> usize {
831 self.read_batch_size
832 }
833
834 pub const fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self {
836 self.pending_throttle = duration;
837 self
838 }
839
840 pub const fn pending_throttle(&self) -> Duration {
842 self.pending_throttle
843 }
844
845 pub fn set_inflight(&mut self, inflight: u16) -> &mut Self {
851 assert!(inflight != 0, "zero in flight is not allowed");
852
853 self.inflight = inflight;
854 self
855 }
856
857 pub const fn inflight(&self) -> u16 {
859 self.inflight
860 }
861
862 pub const fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self {
864 self.manual_acks = manual_acks;
865 self
866 }
867
868 pub const fn manual_acks(&self) -> bool {
870 self.manual_acks
871 }
872
873 #[cfg(feature = "proxy")]
874 pub fn set_proxy(&mut self, proxy: Proxy) -> &mut Self {
875 self.proxy = Some(proxy);
876 self
877 }
878
879 #[cfg(feature = "proxy")]
880 pub fn proxy(&self) -> Option<Proxy> {
881 self.proxy.clone()
882 }
883
884 #[cfg(feature = "websocket")]
888 pub fn set_request_modifier<F, O>(&mut self, request_modifier: F) -> &mut Self
889 where
890 F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
891 O: IntoFuture<Output = http::Request<()>> + 'static,
892 O::IntoFuture: Send,
893 {
894 self.request_modifier = Some(Arc::new(move |request| {
895 let request_modifier = request_modifier(request).into_future();
896 Box::pin(request_modifier)
897 }));
898 self.fallible_request_modifier = None;
899 self
900 }
901
902 #[cfg(feature = "websocket")]
908 pub fn set_fallible_request_modifier<F, O, E>(&mut self, request_modifier: F) -> &mut Self
909 where
910 F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
911 O: IntoFuture<Output = Result<http::Request<()>, E>> + 'static,
912 O::IntoFuture: Send,
913 E: std::error::Error + Send + Sync + 'static,
914 {
915 self.fallible_request_modifier = Some(Arc::new(move |request| {
916 let request_modifier = request_modifier(request).into_future();
917 Box::pin(async move {
918 request_modifier
919 .await
920 .map_err(|error| Box::new(error) as RequestModifierError)
921 })
922 }));
923 self.request_modifier = None;
924 self
925 }
926
927 #[cfg(feature = "websocket")]
928 pub fn request_modifier(&self) -> Option<RequestModifierFn> {
929 self.request_modifier.clone()
930 }
931
932 #[cfg(feature = "websocket")]
933 pub(crate) fn fallible_request_modifier(&self) -> Option<FallibleRequestModifierFn> {
934 self.fallible_request_modifier.clone()
935 }
936
937 pub fn set_socket_connector<F, Fut, S>(&mut self, f: F) -> &mut Self
959 where
960 F: Fn(String, NetworkOptions) -> Fut + Send + Sync + 'static,
961 Fut: std::future::Future<Output = Result<S, std::io::Error>> + Send + 'static,
962 S: crate::framed::AsyncReadWrite + 'static,
963 {
964 self.socket_connector = Some(Arc::new(move |host, network_options| {
965 let stream_future = f(host, network_options);
966 let future = async move {
967 let stream = stream_future.await?;
968 Ok(Box::new(stream) as Box<dyn crate::framed::AsyncReadWrite>)
969 };
970 Box::pin(future)
971 }));
972 self
973 }
974
975 pub fn has_socket_connector(&self) -> bool {
977 self.socket_connector.is_some()
978 }
979
980 pub(crate) fn effective_socket_connector(&self) -> SocketConnector {
981 self.socket_connector
982 .clone()
983 .unwrap_or_else(default_socket_connector)
984 }
985
986 pub(crate) async fn socket_connect(
987 &self,
988 host: String,
989 network_options: NetworkOptions,
990 ) -> std::io::Result<Box<dyn crate::framed::AsyncReadWrite>> {
991 let connector = self.effective_socket_connector();
992 connector(host, network_options).await
993 }
994}
995
996pub struct MqttOptionsBuilder {
998 options: MqttOptions,
999}
1000
1001impl MqttOptionsBuilder {
1002 #[must_use]
1004 pub fn new<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> Self {
1005 Self {
1006 options: MqttOptions::new(id, broker),
1007 }
1008 }
1009
1010 #[must_use]
1012 pub fn build(self) -> MqttOptions {
1013 self.options
1014 }
1015
1016 #[must_use]
1018 pub fn last_will(mut self, will: LastWill) -> Self {
1019 self.options.set_last_will(will);
1020 self
1021 }
1022
1023 #[must_use]
1025 pub fn client_id(mut self, client_id: String) -> Self {
1026 self.options.set_client_id(client_id);
1027 self
1028 }
1029
1030 #[cfg(not(any(feature = "use-rustls-no-provider", feature = "use-native-tls")))]
1032 #[must_use]
1033 pub const fn transport(mut self, transport: Transport) -> Self {
1034 self.options.set_transport(transport);
1035 self
1036 }
1037
1038 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
1040 #[must_use]
1041 pub fn transport(mut self, transport: Transport) -> Self {
1042 self.options.set_transport(transport);
1043 self
1044 }
1045
1046 #[must_use]
1048 pub fn keep_alive(mut self, seconds: u16) -> Self {
1049 self.options.set_keep_alive(seconds);
1050 self
1051 }
1052
1053 #[must_use]
1055 pub const fn max_packet_size(mut self, incoming: usize, outgoing: usize) -> Self {
1056 self.options.set_max_packet_size(incoming, outgoing);
1057 self
1058 }
1059
1060 #[must_use]
1066 pub fn clean_session(mut self, clean_session: bool) -> Self {
1067 self.options.set_clean_session(clean_session);
1068 self
1069 }
1070
1071 #[must_use]
1073 pub fn auth(mut self, auth: ConnectAuth) -> Self {
1074 self.options.set_auth(auth);
1075 self
1076 }
1077
1078 #[must_use]
1080 pub fn clear_auth(mut self) -> Self {
1081 self.options.clear_auth();
1082 self
1083 }
1084
1085 #[must_use]
1087 pub fn username<U: Into<String>>(mut self, username: U) -> Self {
1088 self.options.set_username(username);
1089 self
1090 }
1091
1092 #[must_use]
1094 pub fn credentials<U: Into<String>, P: Into<Bytes>>(
1095 mut self,
1096 username: U,
1097 password: P,
1098 ) -> Self {
1099 self.options.set_credentials(username, password);
1100 self
1101 }
1102
1103 #[must_use]
1105 pub const fn request_channel_capacity(mut self, capacity: usize) -> Self {
1106 self.options.set_request_channel_capacity(capacity);
1107 self
1108 }
1109
1110 #[must_use]
1112 pub const fn max_request_batch(mut self, max: usize) -> Self {
1113 self.options.set_max_request_batch(max);
1114 self
1115 }
1116
1117 #[must_use]
1119 pub const fn read_batch_size(mut self, size: usize) -> Self {
1120 self.options.set_read_batch_size(size);
1121 self
1122 }
1123
1124 #[must_use]
1126 pub const fn pending_throttle(mut self, duration: Duration) -> Self {
1127 self.options.set_pending_throttle(duration);
1128 self
1129 }
1130
1131 #[must_use]
1137 pub fn inflight(mut self, inflight: u16) -> Self {
1138 self.options.set_inflight(inflight);
1139 self
1140 }
1141
1142 #[must_use]
1144 pub const fn manual_acks(mut self, manual_acks: bool) -> Self {
1145 self.options.set_manual_acks(manual_acks);
1146 self
1147 }
1148
1149 #[cfg(feature = "proxy")]
1151 #[must_use]
1152 pub fn proxy(mut self, proxy: Proxy) -> Self {
1153 self.options.set_proxy(proxy);
1154 self
1155 }
1156
1157 #[cfg(feature = "websocket")]
1159 #[must_use]
1160 pub fn request_modifier<F, O>(mut self, request_modifier: F) -> Self
1161 where
1162 F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
1163 O: IntoFuture<Output = http::Request<()>> + 'static,
1164 O::IntoFuture: Send,
1165 {
1166 self.options.set_request_modifier(request_modifier);
1167 self
1168 }
1169
1170 #[cfg(feature = "websocket")]
1172 #[must_use]
1173 pub fn fallible_request_modifier<F, O, E>(mut self, request_modifier: F) -> Self
1174 where
1175 F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
1176 O: IntoFuture<Output = Result<http::Request<()>, E>> + 'static,
1177 O::IntoFuture: Send,
1178 E: std::error::Error + Send + Sync + 'static,
1179 {
1180 self.options.set_fallible_request_modifier(request_modifier);
1181 self
1182 }
1183
1184 #[must_use]
1186 pub fn socket_connector<F, Fut, S>(mut self, f: F) -> Self
1187 where
1188 F: Fn(String, NetworkOptions) -> Fut + Send + Sync + 'static,
1189 Fut: std::future::Future<Output = Result<S, std::io::Error>> + Send + 'static,
1190 S: crate::framed::AsyncReadWrite + 'static,
1191 {
1192 self.options.set_socket_connector(f);
1193 self
1194 }
1195}
1196
1197#[derive(Debug, PartialEq, Eq, thiserror::Error)]
1198pub enum OptionError {
1199 #[error("Unsupported URL scheme.")]
1200 Scheme,
1201
1202 #[error(
1203 "Secure MQTT URL schemes require explicit TLS transport configuration via MqttOptions::set_transport(...)."
1204 )]
1205 SecureUrlRequiresExplicitTransport,
1206
1207 #[error("Missing client ID.")]
1208 ClientId,
1209
1210 #[error("Invalid Unix socket path.")]
1211 UnixSocketPath,
1212
1213 #[cfg(feature = "websocket")]
1214 #[error("Invalid websocket url.")]
1215 WebsocketUrl,
1216
1217 #[cfg(feature = "websocket")]
1218 #[error(
1219 "Secure websocket URLs require Broker::websocket(\"ws://...\") plus MqttOptions::set_transport(Transport::wss_with_config(...))."
1220 )]
1221 WssRequiresExplicitTransport,
1222
1223 #[error("Invalid keep-alive value.")]
1224 KeepAlive,
1225
1226 #[error("Invalid clean-session value.")]
1227 CleanSession,
1228
1229 #[error("Invalid max-incoming-packet-size value.")]
1230 MaxIncomingPacketSize,
1231
1232 #[error("Invalid max-outgoing-packet-size value.")]
1233 MaxOutgoingPacketSize,
1234
1235 #[error("Invalid request-channel-capacity value.")]
1236 RequestChannelCapacity,
1237
1238 #[error("Invalid max-request-batch value.")]
1239 MaxRequestBatch,
1240
1241 #[error("Invalid read-batch-size value.")]
1242 ReadBatchSize,
1243
1244 #[error("Invalid pending-throttle value.")]
1245 PendingThrottle,
1246
1247 #[error("Invalid inflight value.")]
1248 Inflight,
1249
1250 #[error("Unknown option: {0}")]
1251 Unknown(String),
1252
1253 #[cfg(feature = "url")]
1254 #[error("Couldn't parse option from url: {0}")]
1255 Parse(#[from] url::ParseError),
1256}
1257
1258#[cfg(feature = "url")]
1259impl std::convert::TryFrom<url::Url> for MqttOptions {
1260 type Error = OptionError;
1261
1262 fn try_from(url: url::Url) -> Result<Self, Self::Error> {
1263 use std::collections::HashMap;
1264
1265 let broker = match url.scheme() {
1266 "mqtts" | "ssl" => return Err(OptionError::SecureUrlRequiresExplicitTransport),
1267 "mqtt" | "tcp" => Broker::tcp(
1268 url.host_str().unwrap_or_default(),
1269 url.port().unwrap_or(DEFAULT_BROKER_PORT),
1270 ),
1271 #[cfg(unix)]
1272 "unix" => Broker::unix(parse_unix_socket_path(&url)?),
1273 #[cfg(feature = "websocket")]
1274 "ws" => Broker::websocket(url.as_str().to_owned())?,
1275 #[cfg(feature = "websocket")]
1276 "wss" => return Err(OptionError::WssRequiresExplicitTransport),
1277 _ => return Err(OptionError::Scheme),
1278 };
1279
1280 let mut queries = url.query_pairs().collect::<HashMap<_, _>>();
1281
1282 let id = queries
1283 .remove("client_id")
1284 .ok_or(OptionError::ClientId)?
1285 .into_owned();
1286
1287 let mut options = Self::new(id, broker);
1288
1289 if let Some(keep_alive) = queries
1290 .remove("keep_alive_secs")
1291 .map(|v| v.parse::<u16>().map_err(|_| OptionError::KeepAlive))
1292 .transpose()?
1293 {
1294 options.set_keep_alive(keep_alive);
1295 }
1296
1297 if let Some(clean_session) = queries
1298 .remove("clean_session")
1299 .map(|v| v.parse::<bool>().map_err(|_| OptionError::CleanSession))
1300 .transpose()?
1301 {
1302 options.set_clean_session(clean_session);
1303 }
1304
1305 set_url_credentials(&mut options, &url);
1306
1307 if let (Some(incoming), Some(outgoing)) = (
1308 queries
1309 .remove("max_incoming_packet_size_bytes")
1310 .map(|v| {
1311 v.parse::<usize>()
1312 .map_err(|_| OptionError::MaxIncomingPacketSize)
1313 })
1314 .transpose()?,
1315 queries
1316 .remove("max_outgoing_packet_size_bytes")
1317 .map(|v| {
1318 v.parse::<usize>()
1319 .map_err(|_| OptionError::MaxOutgoingPacketSize)
1320 })
1321 .transpose()?,
1322 ) {
1323 options.set_max_packet_size(incoming, outgoing);
1324 }
1325
1326 if let Some(request_channel_capacity) = queries
1327 .remove("request_channel_capacity_num")
1328 .map(|v| {
1329 v.parse::<usize>()
1330 .map_err(|_| OptionError::RequestChannelCapacity)
1331 })
1332 .transpose()?
1333 {
1334 options.request_channel_capacity = request_channel_capacity;
1335 }
1336
1337 if let Some(max_request_batch) = queries
1338 .remove("max_request_batch_num")
1339 .map(|v| v.parse::<usize>().map_err(|_| OptionError::MaxRequestBatch))
1340 .transpose()?
1341 {
1342 options.max_request_batch = max_request_batch;
1343 }
1344
1345 if let Some(read_batch_size) = queries
1346 .remove("read_batch_size_num")
1347 .map(|v| v.parse::<usize>().map_err(|_| OptionError::ReadBatchSize))
1348 .transpose()?
1349 {
1350 options.read_batch_size = read_batch_size;
1351 }
1352
1353 if let Some(pending_throttle) = queries
1354 .remove("pending_throttle_usecs")
1355 .map(|v| v.parse::<u64>().map_err(|_| OptionError::PendingThrottle))
1356 .transpose()?
1357 {
1358 options.set_pending_throttle(Duration::from_micros(pending_throttle));
1359 }
1360
1361 if let Some(inflight) = queries
1362 .remove("inflight_num")
1363 .map(|v| v.parse::<u16>().map_err(|_| OptionError::Inflight))
1364 .transpose()?
1365 {
1366 options.set_inflight(inflight);
1367 }
1368
1369 if let Some((opt, _)) = queries.into_iter().next() {
1370 return Err(OptionError::Unknown(opt.into_owned()));
1371 }
1372
1373 Ok(options)
1374 }
1375}
1376
1377#[cfg(feature = "url")]
1378fn set_url_credentials(options: &mut MqttOptions, url: &url::Url) {
1379 let username = url.username();
1380 if let Some(password) = url.password() {
1381 options.set_credentials(username, password.to_owned());
1382 } else if !username.is_empty() {
1383 options.set_username(username);
1384 }
1385}
1386
1387#[cfg(all(feature = "url", unix))]
1388fn parse_unix_socket_path(url: &url::Url) -> Result<PathBuf, OptionError> {
1389 if url.host_str().is_some() {
1390 return Err(OptionError::UnixSocketPath);
1391 }
1392
1393 let path = percent_decode_str(url.path()).collect::<Vec<u8>>();
1394 if path.is_empty() || path == b"/" {
1395 return Err(OptionError::UnixSocketPath);
1396 }
1397
1398 Ok(PathBuf::from(OsString::from_vec(path)))
1399}
1400
1401impl Debug for MqttOptions {
1404 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
1405 f.debug_struct("MqttOptions")
1406 .field("broker", &self.broker)
1407 .field("keep_alive", &self.keep_alive)
1408 .field("clean_session", &self.clean_session)
1409 .field("client_id", &self.client_id)
1410 .field("auth", &self.auth)
1411 .field("max_packet_size", &self.max_incoming_packet_size)
1412 .field("request_channel_capacity", &self.request_channel_capacity)
1413 .field("max_request_batch", &self.max_request_batch)
1414 .field("read_batch_size", &self.read_batch_size)
1415 .field("pending_throttle", &self.pending_throttle)
1416 .field("inflight", &self.inflight)
1417 .field("last_will", &self.last_will)
1418 .field("manual_acks", &self.manual_acks)
1419 .finish_non_exhaustive()
1420 }
1421}
1422
1423#[cfg(test)]
1424mod test {
1425 use super::*;
1426 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1427 use std::sync::Arc;
1428 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
1429 use tokio::net::{TcpListener, TcpSocket};
1430 use tokio::runtime::Builder;
1431 use tokio::sync::Notify;
1432
1433 fn runtime() -> tokio::runtime::Runtime {
1434 Builder::new_current_thread().enable_all().build().unwrap()
1435 }
1436
1437 #[test]
1438 fn staggered_attempts_allow_later_success_to_win() {
1439 runtime().block_on(async {
1440 let started = Arc::new(AtomicUsize::new(0));
1441 let started_for_connect = Arc::clone(&started);
1442 let begin = std::time::Instant::now();
1443
1444 let result = first_success_with_stagger(
1445 [0_u8, 1_u8],
1446 std::time::Duration::from_millis(10),
1447 move |attempt| {
1448 let started = Arc::clone(&started_for_connect);
1449 async move {
1450 started.fetch_add(1, Ordering::SeqCst);
1451 if attempt == 0 {
1452 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1453 Err(std::io::Error::other("slow failure"))
1454 } else {
1455 Ok(42_u8)
1456 }
1457 }
1458 },
1459 )
1460 .await
1461 .unwrap();
1462
1463 assert_eq!(result, 42);
1464 assert_eq!(started.load(Ordering::SeqCst), 2);
1465 assert!(begin.elapsed() < std::time::Duration::from_millis(150));
1466 });
1467 }
1468
1469 #[test]
1470 fn staggered_connect_returns_invalid_input_for_empty_candidates() {
1471 runtime().block_on(async {
1472 let err = connect_resolved_addrs_staggered(Vec::new(), NetworkOptions::new())
1473 .await
1474 .unwrap_err();
1475
1476 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
1477 assert_eq!(err.to_string(), "could not resolve to any address");
1478 });
1479 }
1480
1481 #[test]
1482 fn staggered_connect_tries_later_candidates() {
1483 runtime().block_on(async {
1484 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1485 let good_addr = listener.local_addr().unwrap();
1486
1487 let unused_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1488 let bad_addr = unused_listener.local_addr().unwrap();
1489 drop(unused_listener);
1490
1491 let accept_task = tokio::spawn(async move {
1492 let (_stream, _) = listener.accept().await.unwrap();
1493 });
1494
1495 let stream =
1496 connect_resolved_addrs_staggered(vec![bad_addr, good_addr], NetworkOptions::new())
1497 .await
1498 .unwrap();
1499 assert_eq!(stream.peer_addr().unwrap(), good_addr);
1500
1501 accept_task.await.unwrap();
1502 });
1503 }
1504
1505 #[test]
1506 fn fixed_bind_port_retry_mode_keeps_slow_first_candidate_alive() {
1507 runtime().block_on(async {
1508 let reserved = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1509 let bind_port = reserved.local_addr().unwrap().port();
1510 drop(reserved);
1511
1512 let mut network_options = NetworkOptions::new();
1513 network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(
1514 Ipv4Addr::LOCALHOST,
1515 bind_port,
1516 )));
1517
1518 let first_attempt_started = Arc::new(Notify::new());
1519 let second_attempt_started = Arc::new(AtomicBool::new(false));
1520
1521 let mut connect_task = tokio::spawn({
1522 let first_attempt_started = Arc::clone(&first_attempt_started);
1523 let second_attempt_started = Arc::clone(&second_attempt_started);
1524 let network_options = network_options.clone();
1525 async move {
1526 connect_with_retry_mode_and_delay(
1527 [0_u8, 1_u8],
1528 network_options,
1529 Duration::from_millis(10),
1530 move |attempt, network_options| {
1531 let first_attempt_started = Arc::clone(&first_attempt_started);
1532 let second_attempt_started = Arc::clone(&second_attempt_started);
1533 async move {
1534 if attempt == 0 {
1535 let bind_addr = network_options.bind_addr().unwrap();
1536 let socket = match bind_addr {
1537 SocketAddr::V4(_) => TcpSocket::new_v4()?,
1538 SocketAddr::V6(_) => TcpSocket::new_v6()?,
1539 };
1540 socket.bind(bind_addr)?;
1541 first_attempt_started.notify_one();
1542 std::future::pending::<io::Result<()>>().await
1543 } else {
1544 second_attempt_started.store(true, Ordering::SeqCst);
1545 let _ = network_options;
1546 Ok(())
1547 }
1548 }
1549 },
1550 )
1551 .await
1552 }
1553 });
1554
1555 first_attempt_started.notified().await;
1556
1557 assert!(
1558 tokio::time::timeout(Duration::from_millis(50), &mut connect_task)
1559 .await
1560 .is_err(),
1561 "fixed-port dialing should keep the first slow candidate alive instead of capping it to the stagger delay"
1562 );
1563 assert!(
1564 !second_attempt_started.load(Ordering::SeqCst),
1565 "fixed-port dialing should not start later same-family candidates while the first is still pending"
1566 );
1567 connect_task.abort();
1568 });
1569 }
1570
1571 #[test]
1572 fn fixed_bind_port_resolved_addrs_try_later_candidates() {
1573 runtime().block_on(async {
1574 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1575 let good_addr = listener.local_addr().unwrap();
1576
1577 let unused_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1578 let bad_addr = unused_listener.local_addr().unwrap();
1579 drop(unused_listener);
1580
1581 let reserved = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1582 let bind_port = reserved.local_addr().unwrap().port();
1583 drop(reserved);
1584
1585 let mut network_options = NetworkOptions::new();
1586 network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(
1587 Ipv4Addr::LOCALHOST,
1588 bind_port,
1589 )));
1590
1591 let accept_task = tokio::spawn(async move {
1592 let (stream, peer_addr) = listener.accept().await.unwrap();
1593 drop(stream);
1594 peer_addr
1595 });
1596
1597 let stream =
1598 connect_resolved_addrs_staggered(vec![bad_addr, good_addr], network_options)
1599 .await
1600 .unwrap();
1601 assert_eq!(stream.peer_addr().unwrap(), good_addr);
1602 drop(stream);
1603
1604 let peer_addr = accept_task.await.unwrap();
1605 assert_eq!(peer_addr.port(), bind_port);
1606 assert!(peer_addr.ip().is_loopback());
1607 });
1608 }
1609
1610 #[test]
1611 fn socket_connect_uses_custom_connector_over_default() {
1612 runtime().block_on(async {
1613 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1614 let good_addr = listener.local_addr().unwrap();
1615 let used_custom = Arc::new(AtomicUsize::new(0));
1616 let used_custom_for_connector = Arc::clone(&used_custom);
1617
1618 let accept_task = tokio::spawn(async move {
1619 let (_stream, _) = listener.accept().await.unwrap();
1620 });
1621
1622 let mut options = MqttOptions::new("test-client", "localhost");
1623 options.set_socket_connector(move |_host, _network_options| {
1624 let used_custom = Arc::clone(&used_custom_for_connector);
1625 async move {
1626 used_custom.fetch_add(1, Ordering::SeqCst);
1627 TcpStream::connect(good_addr).await
1628 }
1629 });
1630
1631 assert!(options.has_socket_connector());
1632 options
1633 .socket_connect("invalid.invalid:1883".to_owned(), NetworkOptions::new())
1634 .await
1635 .unwrap();
1636
1637 assert_eq!(used_custom.load(Ordering::SeqCst), 1);
1638 accept_task.await.unwrap();
1639 });
1640 }
1641
1642 #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
1643 mod request_modifier_tests {
1644 use super::{Broker, MqttOptions};
1645
1646 #[derive(Debug)]
1647 struct TestError;
1648
1649 impl std::fmt::Display for TestError {
1650 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1651 write!(f, "test error")
1652 }
1653 }
1654
1655 impl std::error::Error for TestError {}
1656
1657 #[test]
1658 fn infallible_modifier_is_set() {
1659 let mut options = MqttOptions::new(
1660 "test",
1661 Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
1662 );
1663 options.set_request_modifier(|req| async move { req });
1664 assert!(options.request_modifier().is_some());
1665 assert!(options.fallible_request_modifier().is_none());
1666 }
1667
1668 #[test]
1669 fn fallible_modifier_is_set() {
1670 let mut options = MqttOptions::new(
1671 "test",
1672 Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
1673 );
1674 options.set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) });
1675 assert!(options.request_modifier().is_none());
1676 assert!(options.fallible_request_modifier().is_some());
1677 }
1678
1679 #[test]
1680 fn last_setter_call_wins() {
1681 let mut options = MqttOptions::new(
1682 "test",
1683 Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
1684 );
1685
1686 options
1687 .set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) })
1688 .set_request_modifier(|req| async move { req });
1689 assert!(options.request_modifier().is_some());
1690 assert!(options.fallible_request_modifier().is_none());
1691
1692 options
1693 .set_request_modifier(|req| async move { req })
1694 .set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) });
1695 assert!(options.request_modifier().is_none());
1696 assert!(options.fallible_request_modifier().is_some());
1697 }
1698 }
1699
1700 #[test]
1701 #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
1702 fn websocket_transport_can_be_explicitly_upgraded_to_wss() {
1703 let broker = Broker::websocket(
1704 "ws://a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host",
1705 )
1706 .expect("valid websocket broker");
1707 let mut mqttoptions = MqttOptions::new("client_a", broker);
1708
1709 assert!(matches!(mqttoptions.transport(), crate::Transport::Ws));
1710 mqttoptions.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None));
1711
1712 if let crate::Transport::Wss(TlsConfiguration::Simple {
1713 ca,
1714 client_auth,
1715 alpn,
1716 }) = mqttoptions.transport()
1717 {
1718 assert_eq!(ca.as_slice(), b"Test CA");
1719 assert_eq!(client_auth, None);
1720 assert_eq!(alpn, None);
1721 } else {
1722 panic!("Unexpected transport!");
1723 }
1724
1725 assert_eq!(
1726 mqttoptions.broker().websocket_url(),
1727 Some(
1728 "ws://a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"
1729 )
1730 );
1731 }
1732
1733 #[test]
1734 #[cfg(feature = "websocket")]
1735 fn wss_websocket_urls_require_explicit_transport() {
1736 assert_eq!(
1737 Broker::websocket("wss://example.com/mqtt"),
1738 Err(OptionError::WssRequiresExplicitTransport)
1739 );
1740 }
1741
1742 #[test]
1743 #[cfg(all(
1744 feature = "url",
1745 feature = "use-rustls-no-provider",
1746 feature = "websocket"
1747 ))]
1748 fn parse_url_ws_transport_can_be_explicitly_upgraded_to_wss() {
1749 let mut mqttoptions =
1750 MqttOptions::parse_url("ws://example.com:443/mqtt?client_id=client_a")
1751 .expect("valid websocket options");
1752
1753 assert!(matches!(mqttoptions.transport(), crate::Transport::Ws));
1754 mqttoptions.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None));
1755
1756 if let crate::Transport::Wss(TlsConfiguration::Simple {
1757 ca,
1758 client_auth,
1759 alpn,
1760 }) = mqttoptions.transport()
1761 {
1762 assert_eq!(ca.as_slice(), b"Test CA");
1763 assert_eq!(client_auth, None);
1764 assert_eq!(alpn, None);
1765 } else {
1766 panic!("Unexpected transport!");
1767 }
1768 }
1769
1770 #[test]
1771 #[cfg(all(feature = "url", feature = "use-rustls-no-provider"))]
1772 fn parse_url_mqtt_transport_can_be_explicitly_upgraded_to_tls() {
1773 let mut mqttoptions = MqttOptions::parse_url("mqtt://example.com:8883?client_id=client_a")
1774 .expect("valid tls options");
1775
1776 assert!(matches!(mqttoptions.transport(), crate::Transport::Tcp));
1777 mqttoptions.set_transport(crate::Transport::tls(Vec::from("Test CA"), None, None));
1778
1779 if let crate::Transport::Tls(TlsConfiguration::Simple {
1780 ca,
1781 client_auth,
1782 alpn,
1783 }) = mqttoptions.transport()
1784 {
1785 assert_eq!(ca.as_slice(), b"Test CA");
1786 assert_eq!(client_auth, None);
1787 assert_eq!(alpn, None);
1788 } else {
1789 panic!("Unexpected transport!");
1790 }
1791 }
1792
1793 #[test]
1794 #[cfg(feature = "url")]
1795 fn parse_url_rejects_secure_url_schemes() {
1796 assert!(matches!(
1797 MqttOptions::parse_url("mqtts://example.com:8883?client_id=client_a"),
1798 Err(OptionError::SecureUrlRequiresExplicitTransport)
1799 ));
1800 assert!(matches!(
1801 MqttOptions::parse_url("ssl://example.com:8883?client_id=client_a"),
1802 Err(OptionError::SecureUrlRequiresExplicitTransport)
1803 ));
1804
1805 #[cfg(feature = "websocket")]
1806 assert!(matches!(
1807 MqttOptions::parse_url("wss://example.com:443/mqtt?client_id=client_a"),
1808 Err(OptionError::WssRequiresExplicitTransport)
1809 ));
1810 }
1811
1812 #[test]
1813 #[cfg(feature = "url")]
1814 fn from_url() {
1815 fn opt(s: &str) -> Result<MqttOptions, OptionError> {
1816 MqttOptions::parse_url(s)
1817 }
1818 fn ok(s: &str) -> MqttOptions {
1819 opt(s).expect("valid options")
1820 }
1821 fn err(s: &str) -> OptionError {
1822 opt(s).expect_err("invalid options")
1823 }
1824
1825 let v = ok("mqtt://host:42?client_id=foo");
1826 assert_eq!(v.broker().tcp_address(), Some(("host", 42)));
1827 assert_eq!(v.client_id(), "foo".to_owned());
1828
1829 let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5");
1830 assert_eq!(v.keep_alive, Duration::from_secs(5));
1831 let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=0");
1832 assert_eq!(v.keep_alive, Duration::from_secs(0));
1833 let v = ok("mqtt://host:42?client_id=foo&read_batch_size_num=32");
1834 assert_eq!(v.read_batch_size(), 32);
1835 let v = ok("mqtt://user@host:42?client_id=foo");
1836 assert_eq!(
1837 v.auth(),
1838 &ConnectAuth::Username {
1839 username: "user".to_owned(),
1840 }
1841 );
1842 let v = ok("mqtt://user:pw@host:42?client_id=foo");
1843 assert_eq!(
1844 v.auth(),
1845 &ConnectAuth::UsernamePassword {
1846 username: "user".to_owned(),
1847 password: Bytes::from_static(b"pw"),
1848 }
1849 );
1850 let v = ok("mqtt://:pw@host:42?client_id=foo");
1851 assert_eq!(
1852 v.auth(),
1853 &ConnectAuth::UsernamePassword {
1854 username: String::new(),
1855 password: Bytes::from_static(b"pw"),
1856 }
1857 );
1858
1859 assert_eq!(err("mqtt://host:42"), OptionError::ClientId);
1860 assert_eq!(
1861 err("mqtt://host:42?client_id=foo&foo=bar"),
1862 OptionError::Unknown("foo".to_owned())
1863 );
1864 assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme);
1865 assert_eq!(
1866 err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"),
1867 OptionError::KeepAlive
1868 );
1869 assert_eq!(
1870 err("mqtt://host:42?client_id=foo&keep_alive_secs=65536"),
1871 OptionError::KeepAlive
1872 );
1873 assert_eq!(
1874 err("mqtt://host:42?client_id=foo&clean_session=foo"),
1875 OptionError::CleanSession
1876 );
1877 assert_eq!(
1878 err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"),
1879 OptionError::MaxIncomingPacketSize
1880 );
1881 assert_eq!(
1882 err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"),
1883 OptionError::MaxOutgoingPacketSize
1884 );
1885 assert_eq!(
1886 err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"),
1887 OptionError::RequestChannelCapacity
1888 );
1889 assert_eq!(
1890 err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"),
1891 OptionError::MaxRequestBatch
1892 );
1893 assert_eq!(
1894 err("mqtt://host:42?client_id=foo&read_batch_size_num=foo"),
1895 OptionError::ReadBatchSize
1896 );
1897 assert_eq!(
1898 err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"),
1899 OptionError::PendingThrottle
1900 );
1901 assert_eq!(
1902 err("mqtt://host:42?client_id=foo&inflight_num=foo"),
1903 OptionError::Inflight
1904 );
1905 }
1906
1907 #[test]
1908 #[cfg(unix)]
1909 fn unix_broker_sets_unix_transport_and_preserves_defaults() {
1910 let options = MqttOptions::new("client_id", Broker::unix("/tmp/mqtt.sock"));
1911 let baseline = MqttOptions::new("client_id", "127.0.0.1");
1912
1913 assert!(matches!(options.transport(), Transport::Unix));
1914 assert_eq!(
1915 options.broker().unix_path(),
1916 Some(std::path::Path::new("/tmp/mqtt.sock"))
1917 );
1918 assert_eq!(options.keep_alive, baseline.keep_alive);
1919 assert_eq!(options.clean_session, baseline.clean_session);
1920 assert_eq!(options.client_id, baseline.client_id);
1921 assert_eq!(
1922 options.max_incoming_packet_size,
1923 baseline.max_incoming_packet_size
1924 );
1925 assert_eq!(
1926 options.max_outgoing_packet_size,
1927 baseline.max_outgoing_packet_size
1928 );
1929 assert_eq!(
1930 options.request_channel_capacity,
1931 baseline.request_channel_capacity
1932 );
1933 assert_eq!(options.max_request_batch, baseline.max_request_batch);
1934 assert_eq!(options.read_batch_size, baseline.read_batch_size);
1935 assert_eq!(options.pending_throttle, baseline.pending_throttle);
1936 assert_eq!(options.inflight, baseline.inflight);
1937 assert_eq!(options.manual_acks, baseline.manual_acks);
1938 }
1939
1940 #[test]
1941 #[cfg(all(feature = "url", unix))]
1942 fn from_url_supports_unix_socket_paths() {
1943 let options = MqttOptions::parse_url(
1944 "unix:///tmp/mqtt.sock?client_id=foo&keep_alive_secs=5&read_batch_size_num=32",
1945 )
1946 .expect("valid unix socket options");
1947
1948 assert!(matches!(options.transport(), Transport::Unix));
1949 assert_eq!(
1950 options.broker().unix_path(),
1951 Some(std::path::Path::new("/tmp/mqtt.sock"))
1952 );
1953 assert_eq!(options.client_id(), "foo");
1954 assert_eq!(options.keep_alive, Duration::from_secs(5));
1955 assert_eq!(options.read_batch_size(), 32);
1956 }
1957
1958 #[test]
1959 #[cfg(all(feature = "url", unix))]
1960 fn from_url_decodes_percent_escaped_unix_socket_paths() {
1961 let options =
1962 MqttOptions::parse_url("unix:///tmp/mqtt%20broker.sock?client_id=foo").unwrap();
1963
1964 assert_eq!(
1965 options.broker().unix_path(),
1966 Some(std::path::Path::new("/tmp/mqtt broker.sock"))
1967 );
1968 }
1969
1970 #[test]
1971 #[cfg(all(feature = "url", unix))]
1972 fn from_url_preserves_percent_decoded_unix_socket_bytes() {
1973 use std::os::unix::ffi::OsStrExt;
1974
1975 let options = MqttOptions::parse_url("unix:///tmp/mqtt%FF.sock?client_id=foo").unwrap();
1976
1977 assert_eq!(
1978 options.broker().unix_path().unwrap().as_os_str().as_bytes(),
1979 b"/tmp/mqtt\xff.sock"
1980 );
1981 }
1982
1983 #[test]
1984 #[cfg(all(feature = "url", unix))]
1985 fn from_url_rejects_invalid_unix_socket_paths() {
1986 fn err(s: &str) -> OptionError {
1987 MqttOptions::parse_url(s).expect_err("invalid unix socket url")
1988 }
1989
1990 assert_eq!(err("unix:///tmp/mqtt.sock"), OptionError::ClientId);
1991 assert_eq!(
1992 err("unix://localhost/tmp/mqtt.sock?client_id=foo"),
1993 OptionError::UnixSocketPath
1994 );
1995 assert_eq!(err("unix:///?client_id=foo"), OptionError::UnixSocketPath);
1996 }
1997
1998 #[test]
1999 fn accept_empty_client_id() {
2000 let _mqtt_opts = MqttOptions::new("", "127.0.0.1").set_clean_session(true);
2001 }
2002
2003 #[test]
2004 fn mqtt_options_builder_matches_setter_configuration() {
2005 let will = LastWill::new("hello/world", "good bye", QoS::AtLeastOnce, false);
2006 let mut expected = MqttOptions::new("client", ("localhost", 1884));
2007 expected
2008 .set_keep_alive(5)
2009 .set_last_will(will.clone())
2010 .set_clean_session(false)
2011 .set_credentials("user", Bytes::from_static(b"password"))
2012 .set_request_channel_capacity(16)
2013 .set_max_request_batch(8)
2014 .set_read_batch_size(32)
2015 .set_pending_throttle(Duration::from_micros(250))
2016 .set_inflight(4)
2017 .set_manual_acks(true)
2018 .set_max_packet_size(4096, 2048);
2019
2020 let actual = MqttOptions::builder("client", ("localhost", 1884))
2021 .keep_alive(5)
2022 .last_will(will)
2023 .clean_session(false)
2024 .credentials("user", Bytes::from_static(b"password"))
2025 .request_channel_capacity(16)
2026 .max_request_batch(8)
2027 .read_batch_size(32)
2028 .pending_throttle(Duration::from_micros(250))
2029 .inflight(4)
2030 .manual_acks(true)
2031 .max_packet_size(4096, 2048)
2032 .build();
2033
2034 assert_eq!(
2035 actual.broker().tcp_address(),
2036 expected.broker().tcp_address()
2037 );
2038 assert_eq!(actual.keep_alive(), expected.keep_alive());
2039 assert_eq!(actual.last_will(), expected.last_will());
2040 assert_eq!(actual.clean_session(), expected.clean_session());
2041 assert_eq!(actual.auth(), expected.auth());
2042 assert_eq!(
2043 actual.request_channel_capacity(),
2044 expected.request_channel_capacity()
2045 );
2046 assert_eq!(actual.max_request_batch(), expected.max_request_batch());
2047 assert_eq!(actual.read_batch_size(), expected.read_batch_size());
2048 assert_eq!(actual.pending_throttle(), expected.pending_throttle());
2049 assert_eq!(actual.inflight(), expected.inflight());
2050 assert_eq!(actual.manual_acks(), expected.manual_acks());
2051 assert_eq!(
2052 actual.max_incoming_packet_size,
2053 expected.max_incoming_packet_size
2054 );
2055 assert_eq!(
2056 actual.max_outgoing_packet_size,
2057 expected.max_outgoing_packet_size
2058 );
2059 }
2060
2061 #[test]
2062 fn mqtt_options_builder_can_replace_and_clear_auth() {
2063 let actual = MqttOptions::builder("client", "localhost")
2064 .username("user")
2065 .clear_auth()
2066 .auth(ConnectAuth::Username {
2067 username: "next".to_owned(),
2068 })
2069 .build();
2070
2071 assert_eq!(
2072 actual.auth(),
2073 &ConnectAuth::Username {
2074 username: "next".to_owned(),
2075 }
2076 );
2077 }
2078
2079 #[test]
2080 fn mqtt_options_builder_request_capacity_feeds_client_builder_default() {
2081 let mqttoptions = MqttOptions::builder("test-1", "localhost")
2082 .request_channel_capacity(1)
2083 .build();
2084 let (client, _eventloop) = AsyncClient::builder(mqttoptions).build();
2085
2086 client
2087 .try_publish("hello/world", QoS::AtMostOnce, false, "one")
2088 .expect("first request should fit configured capacity");
2089 assert!(matches!(
2090 client.try_publish("hello/world", QoS::AtMostOnce, false, "two"),
2091 Err(ClientError::TryRequest(Request::Publish(_)))
2092 ));
2093 }
2094
2095 #[test]
2096 fn set_clean_session_when_client_id_present() {
2097 let mut options = MqttOptions::new("client_id", "127.0.0.1");
2098 options.set_clean_session(false);
2099 options.set_clean_session(true);
2100 }
2101
2102 #[test]
2103 fn read_batch_size_defaults_to_adaptive() {
2104 let options = MqttOptions::new("client_id", "127.0.0.1");
2105 assert_eq!(options.read_batch_size(), 0);
2106 }
2107
2108 #[test]
2109 fn set_read_batch_size() {
2110 let mut options = MqttOptions::new("client_id", "127.0.0.1");
2111 options.set_read_batch_size(48);
2112 assert_eq!(options.read_batch_size(), 48);
2113 }
2114}