1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4#[cfg(all(feature = "use-rustls-ring", feature = "use-rustls-aws-lc"))]
5compile_error!(
6 "Features `use-rustls-ring` and `use-rustls-aws-lc` are mutually exclusive. Enable only one rustls provider feature."
7);
8
9#[macro_use]
10extern crate log;
11
12use bytes::Bytes;
13use std::fmt::{self, Debug, Formatter};
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19use tokio::net::{TcpStream, lookup_host};
20use tokio::task::JoinSet;
21
22#[cfg(all(feature = "url", unix))]
23use percent_encoding::percent_decode_str;
24
25#[cfg(all(feature = "url", unix))]
26use std::{ffi::OsString, os::unix::ffi::OsStringExt};
27
28#[cfg(feature = "websocket")]
29use std::{
30 future::{Future, IntoFuture},
31 pin::Pin,
32};
33
34mod client;
35mod eventloop;
36mod framed;
37pub mod mqttbytes;
38mod notice;
39mod state;
40mod transport;
41
42#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
43mod tls;
44
45#[cfg(feature = "websocket")]
46mod websockets;
47
48#[cfg(feature = "proxy")]
49mod proxy;
50
51pub use client::{
52 AsyncClient, Client, ClientError, Connection, InvalidTopic, Iter, ManualAck, PublishTopic,
53 RecvError, RecvTimeoutError, TryRecvError, ValidatedTopic,
54};
55pub use eventloop::{ConnectionError, Event, EventLoop};
56pub use mqttbytes::v5::*;
57pub use mqttbytes::*;
58pub use notice::{
59 NoticeFailureReason, PublishNotice, PublishNoticeError, RequestNotice, RequestNoticeError,
60};
61pub use rumqttc_core::NetworkOptions;
62#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
63pub use rumqttc_core::TlsConfiguration;
64pub use rumqttc_core::default_socket_connect;
65pub use state::{MqttState, StateError};
66pub use transport::Transport;
67
68#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
69pub use crate::tls::Error as TlsError;
70
71#[cfg(feature = "proxy")]
72pub use crate::proxy::{Proxy, ProxyAuth, ProxyType};
73
74#[cfg(feature = "use-native-tls")]
75pub use tokio_native_tls;
76#[cfg(feature = "use-rustls-no-provider")]
77pub use tokio_rustls;
78
79pub type Incoming = Packet;
80
81#[derive(Debug, Clone, PartialEq, Eq)]
83pub enum Outgoing {
84 Publish(u16),
86 Subscribe(u16),
88 Unsubscribe(u16),
90 PubAck(u16),
92 PubRec(u16),
94 PubRel(u16),
96 PubComp(u16),
98 PingReq,
100 PingResp,
102 Disconnect,
104 AwaitAck(u16),
106 Auth,
108}
109
110pub(crate) type SocketConnector = rumqttc_core::SocketConnector;
112
113const CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(100);
114
115async fn first_success_with_stagger<T, I, F, Fut>(
116 items: I,
117 attempt_delay: Duration,
118 connect_fn: F,
119) -> io::Result<T>
120where
121 T: Send + 'static,
122 I: IntoIterator,
123 I::Item: Send + 'static,
124 F: Fn(I::Item) -> Fut + Send + Sync + Clone + 'static,
125 Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
126{
127 let mut join_set = JoinSet::new();
128 let mut item_count = 0usize;
129
130 for (index, item) in items.into_iter().enumerate() {
131 item_count += 1;
132 let delay = attempt_delay.saturating_mul(u32::try_from(index).unwrap_or(u32::MAX));
133 let connect_fn = connect_fn.clone();
134 join_set.spawn(async move {
135 tokio::time::sleep(delay).await;
136 connect_fn(item).await
137 });
138 }
139
140 if item_count == 0 {
141 return Err(io::Error::new(
142 io::ErrorKind::InvalidInput,
143 "could not resolve to any address",
144 ));
145 }
146
147 let mut last_err = None;
148
149 while let Some(task_result) = join_set.join_next().await {
150 match task_result {
151 Ok(Ok(stream)) => {
152 join_set.abort_all();
153 return Ok(stream);
154 }
155 Ok(Err(err)) => {
156 last_err = Some(err);
157 }
158 Err(err) => {
159 last_err = Some(io::Error::other(format!(
160 "concurrent connect task failed: {err}"
161 )));
162 }
163 }
164 }
165
166 Err(last_err.unwrap_or_else(|| {
167 io::Error::new(
168 io::ErrorKind::InvalidInput,
169 "could not resolve to any address",
170 )
171 }))
172}
173
174async fn first_success_sequential<T, I, F, Fut>(items: I, connect_fn: F) -> io::Result<T>
175where
176 I: IntoIterator,
177 F: Fn(I::Item) -> Fut,
178 Fut: std::future::Future<Output = io::Result<T>>,
179{
180 let mut item_count = 0usize;
181 let mut last_err = None;
182
183 for item in items {
184 item_count += 1;
185 match connect_fn(item).await {
186 Ok(stream) => return Ok(stream),
187 Err(err) => last_err = Some(err),
188 }
189 }
190
191 if item_count == 0 {
192 return Err(io::Error::new(
193 io::ErrorKind::InvalidInput,
194 "could not resolve to any address",
195 ));
196 }
197
198 Err(last_err.unwrap_or_else(|| {
199 io::Error::new(
200 io::ErrorKind::InvalidInput,
201 "could not resolve to any address",
202 )
203 }))
204}
205
206fn should_stagger_connect_attempts(network_options: &NetworkOptions) -> bool {
207 network_options
208 .bind_addr()
209 .is_none_or(|bind_addr| bind_addr.port() == 0)
210}
211
212async fn connect_with_retry_mode<T, I, F, Fut>(
213 items: I,
214 network_options: NetworkOptions,
215 connect_fn: F,
216) -> io::Result<T>
217where
218 T: Send + 'static,
219 I: IntoIterator,
220 I::Item: Send + 'static,
221 F: Fn(I::Item, NetworkOptions) -> Fut + Send + Sync + Clone + 'static,
222 Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
223{
224 connect_with_retry_mode_and_delay(items, network_options, CONNECTION_ATTEMPT_DELAY, connect_fn)
225 .await
226}
227
228async fn connect_with_retry_mode_and_delay<T, I, F, Fut>(
229 items: I,
230 network_options: NetworkOptions,
231 connection_attempt_delay: Duration,
232 connect_fn: F,
233) -> io::Result<T>
234where
235 T: Send + 'static,
236 I: IntoIterator,
237 I::Item: Send + 'static,
238 F: Fn(I::Item, NetworkOptions) -> Fut + Send + Sync + Clone + 'static,
239 Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
240{
241 if should_stagger_connect_attempts(&network_options) {
242 first_success_with_stagger(items, connection_attempt_delay, move |item| {
243 let network_options = network_options.clone();
244 let connect_fn = connect_fn.clone();
245 async move { connect_fn(item, network_options).await }
246 })
247 .await
248 } else {
249 first_success_sequential(items, move |item| {
250 let network_options = network_options.clone();
251 let connect_fn = connect_fn.clone();
252 async move { connect_fn(item, network_options).await }
253 })
254 .await
255 }
256}
257
258async fn connect_resolved_addrs_staggered(
259 addrs: Vec<SocketAddr>,
260 network_options: NetworkOptions,
261) -> io::Result<TcpStream> {
262 connect_with_retry_mode(
263 addrs,
264 network_options,
265 move |addr, network_options| async move {
266 rumqttc_core::connect_socket_addr(addr, network_options).await
267 },
268 )
269 .await
270}
271
272async fn default_socket_connect_staggered(
273 host: String,
274 network_options: NetworkOptions,
275) -> io::Result<TcpStream> {
276 let addrs = lookup_host(host).await?.collect::<Vec<_>>();
277 connect_resolved_addrs_staggered(addrs, network_options).await
278}
279
280fn default_socket_connector() -> SocketConnector {
281 Arc::new(|host, network_options| {
282 Box::pin(async move {
283 let tcp = default_socket_connect_staggered(host, network_options).await?;
284 Ok(Box::new(tcp) as Box<dyn crate::framed::AsyncReadWrite>)
285 })
286 })
287}
288
289const DEFAULT_BROKER_PORT: u16 = 1883;
290
291#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct Broker {
294 inner: BrokerInner,
295}
296
297#[derive(Clone, Debug, PartialEq, Eq)]
298enum BrokerInner {
299 Tcp {
300 host: String,
301 port: u16,
302 },
303 #[cfg(unix)]
304 Unix {
305 path: PathBuf,
306 },
307 #[cfg(feature = "websocket")]
308 Websocket {
309 url: String,
310 },
311}
312
313impl Broker {
314 #[must_use]
315 pub fn tcp<S: Into<String>>(host: S, port: u16) -> Self {
316 Self {
317 inner: BrokerInner::Tcp {
318 host: host.into(),
319 port,
320 },
321 }
322 }
323
324 #[cfg(unix)]
325 #[must_use]
326 pub fn unix<P: Into<PathBuf>>(path: P) -> Self {
327 Self {
328 inner: BrokerInner::Unix { path: path.into() },
329 }
330 }
331
332 #[cfg(feature = "websocket")]
333 pub fn websocket<S: Into<String>>(url: S) -> Result<Self, OptionError> {
339 let url = url.into();
340 let uri = url
341 .parse::<http::Uri>()
342 .map_err(|_| OptionError::WebsocketUrl)?;
343
344 match uri.scheme_str() {
345 Some("ws") => {
346 rumqttc_core::split_url(&url).map_err(|_| OptionError::WebsocketUrl)?;
347 Ok(Self {
348 inner: BrokerInner::Websocket { url },
349 })
350 }
351 Some("wss") => Err(OptionError::WssRequiresExplicitTransport),
352 _ => Err(OptionError::Scheme),
353 }
354 }
355
356 #[must_use]
357 pub const fn tcp_address(&self) -> Option<(&str, u16)> {
358 match &self.inner {
359 BrokerInner::Tcp { host, port } => Some((host.as_str(), *port)),
360 #[cfg(unix)]
361 BrokerInner::Unix { .. } => None,
362 #[cfg(feature = "websocket")]
363 BrokerInner::Websocket { .. } => None,
364 }
365 }
366
367 #[cfg(unix)]
368 #[must_use]
369 pub fn unix_path(&self) -> Option<&std::path::Path> {
370 match &self.inner {
371 BrokerInner::Unix { path } => Some(path.as_path()),
372 BrokerInner::Tcp { .. } => None,
373 #[cfg(feature = "websocket")]
374 BrokerInner::Websocket { .. } => None,
375 }
376 }
377
378 #[cfg(feature = "websocket")]
379 #[must_use]
380 pub const fn websocket_url(&self) -> Option<&str> {
381 match &self.inner {
382 BrokerInner::Websocket { url } => Some(url.as_str()),
383 BrokerInner::Tcp { .. } => None,
384 #[cfg(unix)]
385 BrokerInner::Unix { .. } => None,
386 }
387 }
388
389 pub(crate) const fn default_transport(&self) -> Transport {
390 match &self.inner {
391 BrokerInner::Tcp { .. } => Transport::tcp(),
392 #[cfg(unix)]
393 BrokerInner::Unix { .. } => Transport::unix(),
394 #[cfg(feature = "websocket")]
395 BrokerInner::Websocket { .. } => Transport::Ws,
396 }
397 }
398}
399
400impl From<&str> for Broker {
401 fn from(host: &str) -> Self {
402 Self::tcp(host, DEFAULT_BROKER_PORT)
403 }
404}
405
406impl From<String> for Broker {
407 fn from(host: String) -> Self {
408 Self::tcp(host, DEFAULT_BROKER_PORT)
409 }
410}
411
412impl<S: Into<String>> From<(S, u16)> for Broker {
413 fn from((host, port): (S, u16)) -> Self {
414 Self::tcp(host, port)
415 }
416}
417
418#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
420pub enum IncomingPacketSizeLimit {
421 #[default]
423 Default,
424 Unlimited,
426 Bytes(u32),
428}
429
430pub trait AuthManager: std::fmt::Debug + Send {
431 fn auth_continue(
447 &mut self,
448 auth_prop: Option<AuthProperties>,
449 ) -> Result<Option<AuthProperties>, String>;
450}
451
452#[derive(Clone, Debug, PartialEq, Eq)]
455pub enum Request {
456 Publish(Publish),
457 PubAck(PubAck),
458 PubRec(PubRec),
459 PubComp(PubComp),
460 PubRel(PubRel),
461 PingReq,
462 PingResp,
463 Subscribe(Subscribe),
464 SubAck(SubAck),
465 Unsubscribe(Unsubscribe),
466 UnsubAck(UnsubAck),
467 Auth(Auth),
468 Disconnect(Disconnect),
469}
470
471impl From<Subscribe> for Request {
472 fn from(subscribe: Subscribe) -> Self {
473 Self::Subscribe(subscribe)
474 }
475}
476
477#[cfg(feature = "websocket")]
478type RequestModifierFn = Arc<
479 dyn Fn(http::Request<()>) -> Pin<Box<dyn Future<Output = http::Request<()>> + Send>>
480 + Send
481 + Sync,
482>;
483
484#[cfg(feature = "websocket")]
485type RequestModifierError = Box<dyn std::error::Error + Send + Sync>;
486
487#[cfg(feature = "websocket")]
488type FallibleRequestModifierFn = Arc<
489 dyn Fn(
490 http::Request<()>,
491 )
492 -> Pin<Box<dyn Future<Output = Result<http::Request<()>, RequestModifierError>> + Send>>
493 + Send
494 + Sync,
495>;
496
497#[derive(Clone)]
499pub struct MqttOptions {
500 broker: Broker,
502 transport: Transport,
503 keep_alive: Duration,
505 clean_start: bool,
507 client_id: String,
509 auth: ConnectAuth,
511 request_channel_capacity: usize,
513 max_request_batch: usize,
515 read_batch_size: usize,
518 pending_throttle: Duration,
521 last_will: Option<LastWill>,
523 connect_timeout: Duration,
525 default_max_incoming_size: u32,
528 incoming_packet_size_limit: IncomingPacketSizeLimit,
530 connect_properties: Option<ConnectProperties>,
532 manual_acks: bool,
536 network_options: NetworkOptions,
537 #[cfg(feature = "proxy")]
538 proxy: Option<Proxy>,
540 outgoing_inflight_upper_limit: Option<u16>,
543 #[cfg(feature = "websocket")]
544 request_modifier: Option<RequestModifierFn>,
545 #[cfg(feature = "websocket")]
546 fallible_request_modifier: Option<FallibleRequestModifierFn>,
547 socket_connector: Option<SocketConnector>,
548
549 auth_manager: Option<Arc<Mutex<dyn AuthManager>>>,
550}
551
552impl MqttOptions {
553 pub fn new<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> Self {
562 let broker = broker.into();
563 Self {
564 transport: broker.default_transport(),
565 broker,
566 keep_alive: Duration::from_secs(60),
567 clean_start: true,
568 client_id: id.into(),
569 auth: ConnectAuth::None,
570 request_channel_capacity: 10,
571 max_request_batch: 0,
572 read_batch_size: 0,
573 pending_throttle: Duration::from_micros(0),
574 last_will: None,
575 connect_timeout: Duration::from_secs(5),
576 default_max_incoming_size: 10 * 1024,
577 incoming_packet_size_limit: IncomingPacketSizeLimit::Default,
578 connect_properties: None,
579 manual_acks: false,
580 network_options: NetworkOptions::new(),
581 #[cfg(feature = "proxy")]
582 proxy: None,
583 outgoing_inflight_upper_limit: None,
584 #[cfg(feature = "websocket")]
585 request_modifier: None,
586 #[cfg(feature = "websocket")]
587 fallible_request_modifier: None,
588 socket_connector: None,
589 auth_manager: None,
590 }
591 }
592
593 #[cfg(feature = "url")]
594 pub fn parse_url<S: Into<String>>(url: S) -> Result<Self, OptionError> {
628 let url = url::Url::parse(&url.into())?;
629 let options = Self::try_from(url)?;
630
631 Ok(options)
632 }
633
634 pub const fn broker(&self) -> &Broker {
636 &self.broker
637 }
638
639 pub fn set_last_will(&mut self, will: LastWill) -> &mut Self {
640 self.last_will = Some(will);
641 self
642 }
643
644 pub fn last_will(&self) -> Option<LastWill> {
645 self.last_will.clone()
646 }
647
648 #[cfg(feature = "websocket")]
652 pub fn set_request_modifier<F, O>(&mut self, request_modifier: F) -> &mut Self
653 where
654 F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
655 O: IntoFuture<Output = http::Request<()>> + 'static,
656 O::IntoFuture: Send,
657 {
658 self.request_modifier = Some(Arc::new(move |request| {
659 let request_modifier = request_modifier(request).into_future();
660 Box::pin(request_modifier)
661 }));
662 self.fallible_request_modifier = None;
663 self
664 }
665
666 #[cfg(feature = "websocket")]
672 pub fn set_fallible_request_modifier<F, O, E>(&mut self, request_modifier: F) -> &mut Self
673 where
674 F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
675 O: IntoFuture<Output = Result<http::Request<()>, E>> + 'static,
676 O::IntoFuture: Send,
677 E: std::error::Error + Send + Sync + 'static,
678 {
679 self.fallible_request_modifier = Some(Arc::new(move |request| {
680 let request_modifier = request_modifier(request).into_future();
681 Box::pin(async move {
682 request_modifier
683 .await
684 .map_err(|error| Box::new(error) as RequestModifierError)
685 })
686 }));
687 self.request_modifier = None;
688 self
689 }
690
691 #[cfg(feature = "websocket")]
692 pub fn request_modifier(&self) -> Option<RequestModifierFn> {
693 self.request_modifier.clone()
694 }
695
696 #[cfg(feature = "websocket")]
697 pub(crate) fn fallible_request_modifier(&self) -> Option<FallibleRequestModifierFn> {
698 self.fallible_request_modifier.clone()
699 }
700
701 #[cfg(not(feature = "websocket"))]
718 pub fn set_socket_connector<F, Fut, S>(&mut self, f: F) -> &mut Self
719 where
720 F: Fn(String, NetworkOptions) -> Fut + Send + Sync + 'static,
721 Fut: std::future::Future<Output = Result<S, std::io::Error>> + Send + 'static,
722 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin + 'static,
723 {
724 self.socket_connector = Some(Arc::new(move |host, network_options| {
725 let stream_future = f(host, network_options);
726 let future = async move {
727 let stream = stream_future.await?;
728 Ok(Box::new(stream) as Box<dyn crate::framed::AsyncReadWrite>)
729 };
730 Box::pin(future)
731 }));
732 self
733 }
734
735 #[cfg(feature = "websocket")]
752 pub fn set_socket_connector<F, Fut, S>(&mut self, f: F) -> &mut Self
753 where
754 F: Fn(String, NetworkOptions) -> Fut + Send + Sync + 'static,
755 Fut: std::future::Future<Output = Result<S, std::io::Error>> + Send + 'static,
756 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static,
757 {
758 self.socket_connector = Some(Arc::new(move |host, network_options| {
759 let stream_future = f(host, network_options);
760 let future = async move {
761 let stream = stream_future.await?;
762 Ok(Box::new(stream) as Box<dyn crate::framed::AsyncReadWrite>)
763 };
764 Box::pin(future)
765 }));
766 self
767 }
768
769 pub fn has_socket_connector(&self) -> bool {
771 self.socket_connector.is_some()
772 }
773
774 pub fn set_client_id(&mut self, client_id: String) -> &mut Self {
775 self.client_id = client_id;
776 self
777 }
778
779 #[cfg(not(any(feature = "use-rustls-no-provider", feature = "use-native-tls")))]
780 pub const fn set_transport(&mut self, transport: Transport) -> &mut Self {
781 self.transport = transport;
782 self
783 }
784
785 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
786 pub fn set_transport(&mut self, transport: Transport) -> &mut Self {
787 self.transport = transport;
788 self
789 }
790
791 pub fn transport(&self) -> Transport {
793 self.transport.clone()
794 }
795
796 pub fn set_keep_alive(&mut self, seconds: u16) -> &mut Self {
800 self.keep_alive = Duration::from_secs(u64::from(seconds));
801 self
802 }
803
804 pub const fn keep_alive(&self) -> Duration {
806 self.keep_alive
807 }
808
809 pub fn client_id(&self) -> String {
811 self.client_id.clone()
812 }
813
814 pub const fn set_clean_start(&mut self, clean_start: bool) -> &mut Self {
821 self.clean_start = clean_start;
822 self
823 }
824
825 pub const fn clean_start(&self) -> bool {
827 self.clean_start
828 }
829
830 pub fn set_auth(&mut self, auth: ConnectAuth) -> &mut Self {
843 self.auth = auth;
844 self
845 }
846
847 pub fn clear_auth(&mut self) -> &mut Self {
849 self.auth = ConnectAuth::None;
850 self
851 }
852
853 pub fn set_username<U: Into<String>>(&mut self, username: U) -> &mut Self {
869 self.auth = ConnectAuth::Username {
870 username: username.into(),
871 };
872 self
873 }
874
875 pub fn set_password<P: Into<Bytes>>(&mut self, password: P) -> &mut Self {
892 self.auth = ConnectAuth::Password {
893 password: password.into(),
894 };
895 self
896 }
897
898 pub fn set_credentials<U: Into<String>, P: Into<Bytes>>(
916 &mut self,
917 username: U,
918 password: P,
919 ) -> &mut Self {
920 self.auth = ConnectAuth::UsernamePassword {
921 username: username.into(),
922 password: password.into(),
923 };
924 self
925 }
926
927 pub const fn auth(&self) -> &ConnectAuth {
941 &self.auth
942 }
943
944 pub const fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self {
946 self.request_channel_capacity = capacity;
947 self
948 }
949
950 pub const fn request_channel_capacity(&self) -> usize {
952 self.request_channel_capacity
953 }
954
955 pub const fn set_max_request_batch(&mut self, max: usize) -> &mut Self {
959 self.max_request_batch = max;
960 self
961 }
962
963 pub const fn max_request_batch(&self) -> usize {
965 self.max_request_batch
966 }
967
968 pub const fn set_read_batch_size(&mut self, size: usize) -> &mut Self {
972 self.read_batch_size = size;
973 self
974 }
975
976 pub const fn read_batch_size(&self) -> usize {
980 self.read_batch_size
981 }
982
983 pub const fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self {
985 self.pending_throttle = duration;
986 self
987 }
988
989 pub const fn pending_throttle(&self) -> Duration {
991 self.pending_throttle
992 }
993
994 pub const fn set_connect_timeout(&mut self, timeout: Duration) -> &mut Self {
996 self.connect_timeout = timeout;
997 self
998 }
999
1000 pub const fn connect_timeout(&self) -> Duration {
1002 self.connect_timeout
1003 }
1004
1005 pub fn set_connect_properties(&mut self, properties: ConnectProperties) -> &mut Self {
1007 self.incoming_packet_size_limit = properties.max_packet_size.map_or(
1008 IncomingPacketSizeLimit::Default,
1009 IncomingPacketSizeLimit::Bytes,
1010 );
1011 self.connect_properties = Some(properties);
1012 self
1013 }
1014
1015 pub fn connect_properties(&self) -> Option<ConnectProperties> {
1017 self.connect_properties.clone()
1018 }
1019
1020 pub fn set_session_expiry_interval(&mut self, interval: Option<u32>) -> &mut Self {
1022 if let Some(conn_props) = &mut self.connect_properties {
1023 conn_props.session_expiry_interval = interval;
1024 self
1025 } else {
1026 let mut conn_props = ConnectProperties::new();
1027 conn_props.session_expiry_interval = interval;
1028 self.set_connect_properties(conn_props)
1029 }
1030 }
1031
1032 pub const fn session_expiry_interval(&self) -> Option<u32> {
1034 if let Some(conn_props) = &self.connect_properties {
1035 conn_props.session_expiry_interval
1036 } else {
1037 None
1038 }
1039 }
1040
1041 pub fn set_receive_maximum(&mut self, recv_max: Option<u16>) -> &mut Self {
1043 if let Some(conn_props) = &mut self.connect_properties {
1044 conn_props.receive_maximum = recv_max;
1045 self
1046 } else {
1047 let mut conn_props = ConnectProperties::new();
1048 conn_props.receive_maximum = recv_max;
1049 self.set_connect_properties(conn_props)
1050 }
1051 }
1052
1053 pub const fn receive_maximum(&self) -> Option<u16> {
1055 if let Some(conn_props) = &self.connect_properties {
1056 conn_props.receive_maximum
1057 } else {
1058 None
1059 }
1060 }
1061
1062 pub fn set_max_packet_size(&mut self, max_size: Option<u32>) -> &mut Self {
1064 self.incoming_packet_size_limit = max_size.map_or(
1065 IncomingPacketSizeLimit::Default,
1066 IncomingPacketSizeLimit::Bytes,
1067 );
1068
1069 if let Some(conn_props) = &mut self.connect_properties {
1070 conn_props.max_packet_size = max_size;
1071 self
1072 } else {
1073 let mut conn_props = ConnectProperties::new();
1074 conn_props.max_packet_size = max_size;
1075 self.set_connect_properties(conn_props)
1076 }
1077 }
1078
1079 pub const fn max_packet_size(&self) -> Option<u32> {
1081 if let Some(conn_props) = &self.connect_properties {
1082 conn_props.max_packet_size
1083 } else {
1084 None
1085 }
1086 }
1087
1088 pub fn set_incoming_packet_size_limit(&mut self, limit: IncomingPacketSizeLimit) -> &mut Self {
1095 self.incoming_packet_size_limit = limit;
1096
1097 if let Some(conn_props) = &mut self.connect_properties {
1098 conn_props.max_packet_size = match limit {
1099 IncomingPacketSizeLimit::Bytes(max_size) => Some(max_size),
1100 IncomingPacketSizeLimit::Default | IncomingPacketSizeLimit::Unlimited => None,
1101 };
1102 return self;
1103 }
1104
1105 if let IncomingPacketSizeLimit::Bytes(max_size) = limit {
1106 let mut conn_props = ConnectProperties::new();
1107 conn_props.max_packet_size = Some(max_size);
1108 self.set_connect_properties(conn_props)
1109 } else {
1110 self
1111 }
1112 }
1113
1114 pub fn set_unlimited_incoming_packet_size(&mut self) -> &mut Self {
1116 self.set_incoming_packet_size_limit(IncomingPacketSizeLimit::Unlimited)
1117 }
1118
1119 pub const fn incoming_packet_size_limit(&self) -> IncomingPacketSizeLimit {
1121 self.incoming_packet_size_limit
1122 }
1123
1124 pub(crate) const fn max_incoming_packet_size(&self) -> Option<u32> {
1125 match self.incoming_packet_size_limit {
1126 IncomingPacketSizeLimit::Default => Some(self.default_max_incoming_size),
1127 IncomingPacketSizeLimit::Unlimited => None,
1128 IncomingPacketSizeLimit::Bytes(max_size) => Some(max_size),
1129 }
1130 }
1131
1132 pub fn set_topic_alias_max(&mut self, topic_alias_max: Option<u16>) -> &mut Self {
1134 if let Some(conn_props) = &mut self.connect_properties {
1135 conn_props.topic_alias_max = topic_alias_max;
1136 self
1137 } else {
1138 let mut conn_props = ConnectProperties::new();
1139 conn_props.topic_alias_max = topic_alias_max;
1140 self.set_connect_properties(conn_props)
1141 }
1142 }
1143
1144 pub const fn topic_alias_max(&self) -> Option<u16> {
1146 if let Some(conn_props) = &self.connect_properties {
1147 conn_props.topic_alias_max
1148 } else {
1149 None
1150 }
1151 }
1152
1153 pub fn set_request_response_info(&mut self, request_response_info: Option<u8>) -> &mut Self {
1155 if let Some(conn_props) = &mut self.connect_properties {
1156 conn_props.request_response_info = request_response_info;
1157 self
1158 } else {
1159 let mut conn_props = ConnectProperties::new();
1160 conn_props.request_response_info = request_response_info;
1161 self.set_connect_properties(conn_props)
1162 }
1163 }
1164
1165 pub const fn request_response_info(&self) -> Option<u8> {
1167 if let Some(conn_props) = &self.connect_properties {
1168 conn_props.request_response_info
1169 } else {
1170 None
1171 }
1172 }
1173
1174 pub fn set_request_problem_info(&mut self, request_problem_info: Option<u8>) -> &mut Self {
1176 if let Some(conn_props) = &mut self.connect_properties {
1177 conn_props.request_problem_info = request_problem_info;
1178 self
1179 } else {
1180 let mut conn_props = ConnectProperties::new();
1181 conn_props.request_problem_info = request_problem_info;
1182 self.set_connect_properties(conn_props)
1183 }
1184 }
1185
1186 pub const fn request_problem_info(&self) -> Option<u8> {
1188 if let Some(conn_props) = &self.connect_properties {
1189 conn_props.request_problem_info
1190 } else {
1191 None
1192 }
1193 }
1194
1195 pub fn set_user_properties(&mut self, user_properties: Vec<(String, String)>) -> &mut Self {
1197 if let Some(conn_props) = &mut self.connect_properties {
1198 conn_props.user_properties = user_properties;
1199 self
1200 } else {
1201 let mut conn_props = ConnectProperties::new();
1202 conn_props.user_properties = user_properties;
1203 self.set_connect_properties(conn_props)
1204 }
1205 }
1206
1207 pub fn user_properties(&self) -> Vec<(String, String)> {
1209 self.connect_properties
1210 .as_ref()
1211 .map_or_else(Vec::new, |conn_props| conn_props.user_properties.clone())
1212 }
1213
1214 pub fn set_authentication_method(
1216 &mut self,
1217 authentication_method: Option<String>,
1218 ) -> &mut Self {
1219 if let Some(conn_props) = &mut self.connect_properties {
1220 conn_props.authentication_method = authentication_method;
1221 self
1222 } else {
1223 let mut conn_props = ConnectProperties::new();
1224 conn_props.authentication_method = authentication_method;
1225 self.set_connect_properties(conn_props)
1226 }
1227 }
1228
1229 pub fn authentication_method(&self) -> Option<String> {
1231 self.connect_properties
1232 .as_ref()
1233 .and_then(|conn_props| conn_props.authentication_method.clone())
1234 }
1235
1236 pub fn set_authentication_data(&mut self, authentication_data: Option<Bytes>) -> &mut Self {
1238 if let Some(conn_props) = &mut self.connect_properties {
1239 conn_props.authentication_data = authentication_data;
1240 self
1241 } else {
1242 let mut conn_props = ConnectProperties::new();
1243 conn_props.authentication_data = authentication_data;
1244 self.set_connect_properties(conn_props)
1245 }
1246 }
1247
1248 pub fn authentication_data(&self) -> Option<Bytes> {
1250 self.connect_properties
1251 .as_ref()
1252 .map_or_else(|| None, |conn_props| conn_props.authentication_data.clone())
1253 }
1254
1255 pub const fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self {
1257 self.manual_acks = manual_acks;
1258 self
1259 }
1260
1261 pub const fn manual_acks(&self) -> bool {
1263 self.manual_acks
1264 }
1265
1266 pub fn network_options(&self) -> NetworkOptions {
1267 self.network_options.clone()
1268 }
1269
1270 pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self {
1271 self.network_options = network_options;
1272 self
1273 }
1274
1275 #[cfg(feature = "proxy")]
1276 pub fn set_proxy(&mut self, proxy: Proxy) -> &mut Self {
1277 self.proxy = Some(proxy);
1278 self
1279 }
1280
1281 #[cfg(feature = "proxy")]
1282 pub fn proxy(&self) -> Option<Proxy> {
1283 self.proxy.clone()
1284 }
1285
1286 pub(crate) fn effective_socket_connector(&self) -> SocketConnector {
1287 self.socket_connector
1288 .clone()
1289 .unwrap_or_else(default_socket_connector)
1290 }
1291
1292 pub(crate) async fn socket_connect(
1293 &self,
1294 host: String,
1295 network_options: NetworkOptions,
1296 ) -> std::io::Result<Box<dyn crate::framed::AsyncReadWrite>> {
1297 let connector = self.effective_socket_connector();
1298 connector(host, network_options).await
1299 }
1300
1301 pub const fn set_outgoing_inflight_upper_limit(&mut self, limit: u16) -> &mut Self {
1304 self.outgoing_inflight_upper_limit = Some(limit);
1305 self
1306 }
1307
1308 pub const fn get_outgoing_inflight_upper_limit(&self) -> Option<u16> {
1311 self.outgoing_inflight_upper_limit
1312 }
1313
1314 pub fn set_auth_manager(&mut self, auth_manager: Arc<Mutex<dyn AuthManager>>) -> &mut Self {
1315 self.auth_manager = Some(auth_manager);
1316 self
1317 }
1318
1319 pub fn auth_manager(&self) -> Option<Arc<Mutex<dyn AuthManager>>> {
1320 self.auth_manager.as_ref()?;
1321
1322 self.auth_manager.clone()
1323 }
1324}
1325
1326#[derive(Debug, PartialEq, Eq, thiserror::Error)]
1327pub enum OptionError {
1328 #[error("Unsupported URL scheme.")]
1329 Scheme,
1330
1331 #[error(
1332 "Secure MQTT URL schemes require explicit TLS transport configuration via MqttOptions::set_transport(...)."
1333 )]
1334 SecureUrlRequiresExplicitTransport,
1335
1336 #[error("Missing client ID.")]
1337 ClientId,
1338
1339 #[error("Invalid Unix socket path.")]
1340 UnixSocketPath,
1341
1342 #[cfg(feature = "websocket")]
1343 #[error("Invalid websocket url.")]
1344 WebsocketUrl,
1345
1346 #[cfg(feature = "websocket")]
1347 #[error(
1348 "Secure websocket URLs require Broker::websocket(\"ws://...\") plus MqttOptions::set_transport(Transport::wss_with_config(...))."
1349 )]
1350 WssRequiresExplicitTransport,
1351
1352 #[error("Invalid keep-alive value.")]
1353 KeepAlive,
1354
1355 #[error("Invalid clean-start value.")]
1356 CleanStart,
1357
1358 #[error("Invalid max-incoming-packet-size value.")]
1359 MaxIncomingPacketSize,
1360
1361 #[error("Invalid max-outgoing-packet-size value.")]
1362 MaxOutgoingPacketSize,
1363
1364 #[error("Invalid request-channel-capacity value.")]
1365 RequestChannelCapacity,
1366
1367 #[error("Invalid max-request-batch value.")]
1368 MaxRequestBatch,
1369
1370 #[error("Invalid read-batch-size value.")]
1371 ReadBatchSize,
1372
1373 #[error("Invalid pending-throttle value.")]
1374 PendingThrottle,
1375
1376 #[error("Invalid inflight value.")]
1377 Inflight,
1378
1379 #[error("Invalid conn-timeout value.")]
1380 ConnTimeout,
1381
1382 #[error("Unknown option: {0}")]
1383 Unknown(String),
1384
1385 #[cfg(feature = "url")]
1386 #[error("Couldn't parse option from url: {0}")]
1387 Parse(#[from] url::ParseError),
1388}
1389
1390#[cfg(feature = "url")]
1391impl std::convert::TryFrom<url::Url> for MqttOptions {
1392 type Error = OptionError;
1393
1394 fn try_from(url: url::Url) -> Result<Self, Self::Error> {
1395 use std::collections::HashMap;
1396
1397 let broker = match url.scheme() {
1398 "mqtts" | "ssl" => return Err(OptionError::SecureUrlRequiresExplicitTransport),
1399 "mqtt" | "tcp" => Broker::tcp(
1400 url.host_str().unwrap_or_default(),
1401 url.port().unwrap_or(DEFAULT_BROKER_PORT),
1402 ),
1403 #[cfg(unix)]
1404 "unix" => Broker::unix(parse_unix_socket_path(&url)?),
1405 #[cfg(feature = "websocket")]
1406 "ws" => Broker::websocket(url.as_str().to_owned())?,
1407 #[cfg(feature = "websocket")]
1408 "wss" => return Err(OptionError::WssRequiresExplicitTransport),
1409 _ => return Err(OptionError::Scheme),
1410 };
1411
1412 let mut queries = url.query_pairs().collect::<HashMap<_, _>>();
1413
1414 let id = queries
1415 .remove("client_id")
1416 .ok_or(OptionError::ClientId)?
1417 .into_owned();
1418
1419 let mut options = Self::new(id, broker);
1420 let mut connect_props = ConnectProperties::new();
1421
1422 if let Some(keep_alive) = queries
1423 .remove("keep_alive_secs")
1424 .map(|v| v.parse::<u16>().map_err(|_| OptionError::KeepAlive))
1425 .transpose()?
1426 {
1427 options.set_keep_alive(keep_alive);
1428 }
1429
1430 if let Some(clean_start) = queries
1431 .remove("clean_start")
1432 .map(|v| v.parse::<bool>().map_err(|_| OptionError::CleanStart))
1433 .transpose()?
1434 {
1435 options.set_clean_start(clean_start);
1436 }
1437
1438 let username = url.username();
1439 if let Some(password) = url.password() {
1440 options.set_credentials(username, password.to_owned());
1441 } else if !username.is_empty() {
1442 options.set_username(username);
1443 }
1444
1445 connect_props.max_packet_size = queries
1446 .remove("max_incoming_packet_size_bytes")
1447 .map(|v| {
1448 v.parse::<u32>()
1449 .map_err(|_| OptionError::MaxIncomingPacketSize)
1450 })
1451 .transpose()?;
1452
1453 if let Some(request_channel_capacity) = queries
1454 .remove("request_channel_capacity_num")
1455 .map(|v| {
1456 v.parse::<usize>()
1457 .map_err(|_| OptionError::RequestChannelCapacity)
1458 })
1459 .transpose()?
1460 {
1461 options.request_channel_capacity = request_channel_capacity;
1462 }
1463
1464 if let Some(max_request_batch) = queries
1465 .remove("max_request_batch_num")
1466 .map(|v| v.parse::<usize>().map_err(|_| OptionError::MaxRequestBatch))
1467 .transpose()?
1468 {
1469 options.max_request_batch = max_request_batch;
1470 }
1471
1472 if let Some(read_batch_size) = queries
1473 .remove("read_batch_size_num")
1474 .map(|v| v.parse::<usize>().map_err(|_| OptionError::ReadBatchSize))
1475 .transpose()?
1476 {
1477 options.read_batch_size = read_batch_size;
1478 }
1479
1480 if let Some(pending_throttle) = queries
1481 .remove("pending_throttle_usecs")
1482 .map(|v| v.parse::<u64>().map_err(|_| OptionError::PendingThrottle))
1483 .transpose()?
1484 {
1485 options.set_pending_throttle(Duration::from_micros(pending_throttle));
1486 }
1487
1488 connect_props.receive_maximum = queries
1489 .remove("inflight_num")
1490 .map(|v| v.parse::<u16>().map_err(|_| OptionError::Inflight))
1491 .transpose()?;
1492
1493 if let Some(conn_timeout) = queries
1494 .remove("conn_timeout_secs")
1495 .map(|v| v.parse::<u64>().map_err(|_| OptionError::ConnTimeout))
1496 .transpose()?
1497 {
1498 options.set_connect_timeout(Duration::from_secs(conn_timeout));
1499 }
1500
1501 if let Some((opt, _)) = queries.into_iter().next() {
1502 return Err(OptionError::Unknown(opt.into_owned()));
1503 }
1504
1505 options.set_connect_properties(connect_props);
1506 Ok(options)
1507 }
1508}
1509
1510#[cfg(all(feature = "url", unix))]
1511fn parse_unix_socket_path(url: &url::Url) -> Result<PathBuf, OptionError> {
1512 if url.host_str().is_some() {
1513 return Err(OptionError::UnixSocketPath);
1514 }
1515
1516 let path = percent_decode_str(url.path()).collect::<Vec<u8>>();
1517 if path.is_empty() || path == b"/" {
1518 return Err(OptionError::UnixSocketPath);
1519 }
1520
1521 Ok(PathBuf::from(OsString::from_vec(path)))
1522}
1523
1524impl Debug for MqttOptions {
1527 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
1528 f.debug_struct("MqttOptions")
1529 .field("broker", &self.broker)
1530 .field("keep_alive", &self.keep_alive)
1531 .field("clean_start", &self.clean_start)
1532 .field("client_id", &self.client_id)
1533 .field("auth", &self.auth)
1534 .field("request_channel_capacity", &self.request_channel_capacity)
1535 .field("max_request_batch", &self.max_request_batch)
1536 .field("read_batch_size", &self.read_batch_size)
1537 .field("pending_throttle", &self.pending_throttle)
1538 .field("last_will", &self.last_will)
1539 .field("connect_timeout", &self.connect_timeout)
1540 .field("manual_acks", &self.manual_acks)
1541 .field("connect properties", &self.connect_properties)
1542 .finish_non_exhaustive()
1543 }
1544}
1545
1546#[cfg(test)]
1547mod test {
1548 use super::*;
1549 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
1550 use std::sync::Arc;
1551 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
1552 use tokio::net::{TcpListener, TcpSocket};
1553 use tokio::runtime::Builder;
1554 use tokio::sync::Notify;
1555
1556 fn runtime() -> tokio::runtime::Runtime {
1557 Builder::new_current_thread().enable_all().build().unwrap()
1558 }
1559
1560 #[test]
1561 fn staggered_attempts_allow_later_success_to_win() {
1562 runtime().block_on(async {
1563 let started = Arc::new(AtomicUsize::new(0));
1564 let started_for_connect = Arc::clone(&started);
1565 let begin = std::time::Instant::now();
1566
1567 let result = first_success_with_stagger(
1568 [0_u8, 1_u8],
1569 std::time::Duration::from_millis(10),
1570 move |attempt| {
1571 let started = Arc::clone(&started_for_connect);
1572 async move {
1573 started.fetch_add(1, Ordering::SeqCst);
1574 if attempt == 0 {
1575 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1576 Err(std::io::Error::other("slow failure"))
1577 } else {
1578 Ok(42_u8)
1579 }
1580 }
1581 },
1582 )
1583 .await
1584 .unwrap();
1585
1586 assert_eq!(result, 42);
1587 assert_eq!(started.load(Ordering::SeqCst), 2);
1588 assert!(begin.elapsed() < std::time::Duration::from_millis(150));
1589 });
1590 }
1591
1592 #[test]
1593 fn staggered_connect_returns_invalid_input_for_empty_candidates() {
1594 runtime().block_on(async {
1595 let err = connect_resolved_addrs_staggered(Vec::new(), NetworkOptions::new())
1596 .await
1597 .unwrap_err();
1598
1599 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
1600 assert_eq!(err.to_string(), "could not resolve to any address");
1601 });
1602 }
1603
1604 #[test]
1605 fn staggered_connect_tries_later_candidates() {
1606 runtime().block_on(async {
1607 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1608 let good_addr = listener.local_addr().unwrap();
1609
1610 let unused_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1611 let bad_addr = unused_listener.local_addr().unwrap();
1612 drop(unused_listener);
1613
1614 let accept_task = tokio::spawn(async move {
1615 let (_stream, _) = listener.accept().await.unwrap();
1616 });
1617
1618 let stream =
1619 connect_resolved_addrs_staggered(vec![bad_addr, good_addr], NetworkOptions::new())
1620 .await
1621 .unwrap();
1622 assert_eq!(stream.peer_addr().unwrap(), good_addr);
1623
1624 accept_task.await.unwrap();
1625 });
1626 }
1627
1628 #[test]
1629 fn fixed_bind_port_retry_mode_keeps_slow_first_candidate_alive() {
1630 runtime().block_on(async {
1631 let reserved = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1632 let bind_port = reserved.local_addr().unwrap().port();
1633 drop(reserved);
1634
1635 let mut network_options = NetworkOptions::new();
1636 network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(
1637 Ipv4Addr::LOCALHOST,
1638 bind_port,
1639 )));
1640
1641 let first_attempt_started = Arc::new(Notify::new());
1642 let second_attempt_started = Arc::new(AtomicBool::new(false));
1643
1644 let mut connect_task = tokio::spawn({
1645 let first_attempt_started = Arc::clone(&first_attempt_started);
1646 let second_attempt_started = Arc::clone(&second_attempt_started);
1647 let network_options = network_options.clone();
1648 async move {
1649 connect_with_retry_mode_and_delay(
1650 [0_u8, 1_u8],
1651 network_options,
1652 Duration::from_millis(10),
1653 move |attempt, network_options| {
1654 let first_attempt_started = Arc::clone(&first_attempt_started);
1655 let second_attempt_started = Arc::clone(&second_attempt_started);
1656 async move {
1657 if attempt == 0 {
1658 let bind_addr = network_options.bind_addr().unwrap();
1659 let socket = match bind_addr {
1660 SocketAddr::V4(_) => TcpSocket::new_v4()?,
1661 SocketAddr::V6(_) => TcpSocket::new_v6()?,
1662 };
1663 socket.bind(bind_addr)?;
1664 first_attempt_started.notify_one();
1665 std::future::pending::<io::Result<()>>().await
1666 } else {
1667 second_attempt_started.store(true, Ordering::SeqCst);
1668 let _ = network_options;
1669 Ok(())
1670 }
1671 }
1672 },
1673 )
1674 .await
1675 }
1676 });
1677
1678 first_attempt_started.notified().await;
1679
1680 assert!(
1681 tokio::time::timeout(Duration::from_millis(50), &mut connect_task)
1682 .await
1683 .is_err(),
1684 "fixed-port dialing should keep the first slow candidate alive instead of capping it to the stagger delay"
1685 );
1686 assert!(
1687 !second_attempt_started.load(Ordering::SeqCst),
1688 "fixed-port dialing should not start later same-family candidates while the first is still pending"
1689 );
1690 connect_task.abort();
1691 });
1692 }
1693
1694 #[test]
1695 fn fixed_bind_port_resolved_addrs_try_later_candidates() {
1696 runtime().block_on(async {
1697 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1698 let good_addr = listener.local_addr().unwrap();
1699
1700 let unused_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1701 let bad_addr = unused_listener.local_addr().unwrap();
1702 drop(unused_listener);
1703
1704 let reserved = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1705 let bind_port = reserved.local_addr().unwrap().port();
1706 drop(reserved);
1707
1708 let mut network_options = NetworkOptions::new();
1709 network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(
1710 Ipv4Addr::LOCALHOST,
1711 bind_port,
1712 )));
1713
1714 let accept_task = tokio::spawn(async move {
1715 let (stream, peer_addr) = listener.accept().await.unwrap();
1716 drop(stream);
1717 peer_addr
1718 });
1719
1720 let stream =
1721 connect_resolved_addrs_staggered(vec![bad_addr, good_addr], network_options)
1722 .await
1723 .unwrap();
1724 assert_eq!(stream.peer_addr().unwrap(), good_addr);
1725 drop(stream);
1726
1727 let peer_addr = accept_task.await.unwrap();
1728 assert_eq!(peer_addr.port(), bind_port);
1729 assert!(peer_addr.ip().is_loopback());
1730 });
1731 }
1732
1733 #[test]
1734 fn socket_connect_uses_custom_connector_over_default() {
1735 runtime().block_on(async {
1736 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
1737 let good_addr = listener.local_addr().unwrap();
1738 let used_custom = Arc::new(AtomicUsize::new(0));
1739 let used_custom_for_connector = Arc::clone(&used_custom);
1740
1741 let accept_task = tokio::spawn(async move {
1742 let (_stream, _) = listener.accept().await.unwrap();
1743 });
1744
1745 let mut options = MqttOptions::new("test-client", "localhost");
1746 options.set_socket_connector(move |_host, _network_options| {
1747 let used_custom = Arc::clone(&used_custom_for_connector);
1748 async move {
1749 used_custom.fetch_add(1, Ordering::SeqCst);
1750 TcpStream::connect(good_addr).await
1751 }
1752 });
1753
1754 assert!(options.has_socket_connector());
1755 options
1756 .socket_connect("invalid.invalid:1883".to_owned(), NetworkOptions::new())
1757 .await
1758 .unwrap();
1759
1760 assert_eq!(used_custom.load(Ordering::SeqCst), 1);
1761 accept_task.await.unwrap();
1762 });
1763 }
1764
1765 #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
1766 mod request_modifier_tests {
1767 use super::{Broker, MqttOptions};
1768
1769 #[derive(Debug)]
1770 struct TestError;
1771
1772 impl std::fmt::Display for TestError {
1773 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1774 write!(f, "test error")
1775 }
1776 }
1777
1778 impl std::error::Error for TestError {}
1779
1780 #[test]
1781 fn infallible_modifier_is_set() {
1782 let mut options = MqttOptions::new(
1783 "test",
1784 Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
1785 );
1786 options.set_request_modifier(|req| async move { req });
1787 assert!(options.request_modifier().is_some());
1788 assert!(options.fallible_request_modifier().is_none());
1789 }
1790
1791 #[test]
1792 fn fallible_modifier_is_set() {
1793 let mut options = MqttOptions::new(
1794 "test",
1795 Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
1796 );
1797 options.set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) });
1798 assert!(options.request_modifier().is_none());
1799 assert!(options.fallible_request_modifier().is_some());
1800 }
1801
1802 #[test]
1803 fn last_setter_call_wins() {
1804 let mut options = MqttOptions::new(
1805 "test",
1806 Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
1807 );
1808
1809 options
1810 .set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) })
1811 .set_request_modifier(|req| async move { req });
1812 assert!(options.request_modifier().is_some());
1813 assert!(options.fallible_request_modifier().is_none());
1814
1815 options
1816 .set_request_modifier(|req| async move { req })
1817 .set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) });
1818 assert!(options.request_modifier().is_none());
1819 assert!(options.fallible_request_modifier().is_some());
1820 }
1821 }
1822
1823 #[test]
1824 fn incoming_packet_size_limit_defaults_to_default_policy() {
1825 let mqtt_opts = MqttOptions::new("client", "127.0.0.1");
1826 assert_eq!(
1827 mqtt_opts.incoming_packet_size_limit(),
1828 IncomingPacketSizeLimit::Default
1829 );
1830 assert_eq!(
1831 mqtt_opts.max_incoming_packet_size(),
1832 Some(mqtt_opts.default_max_incoming_size)
1833 );
1834 }
1835
1836 #[test]
1837 fn set_max_packet_size_remains_backward_compatible() {
1838 let mut mqtt_opts = MqttOptions::new("client", "127.0.0.1");
1839
1840 mqtt_opts.set_max_packet_size(Some(2048));
1841 assert_eq!(
1842 mqtt_opts.incoming_packet_size_limit(),
1843 IncomingPacketSizeLimit::Bytes(2048)
1844 );
1845 assert_eq!(mqtt_opts.max_packet_size(), Some(2048));
1846 assert_eq!(mqtt_opts.max_incoming_packet_size(), Some(2048));
1847
1848 mqtt_opts.set_max_packet_size(None);
1849 assert_eq!(
1850 mqtt_opts.incoming_packet_size_limit(),
1851 IncomingPacketSizeLimit::Default
1852 );
1853 assert_eq!(mqtt_opts.max_packet_size(), None);
1854 assert_eq!(
1855 mqtt_opts.max_incoming_packet_size(),
1856 Some(mqtt_opts.default_max_incoming_size)
1857 );
1858 }
1859
1860 #[test]
1861 fn incoming_packet_size_limit_unlimited_disables_local_check() {
1862 let mut mqtt_opts = MqttOptions::new("client", "127.0.0.1");
1863 mqtt_opts.set_unlimited_incoming_packet_size();
1864
1865 assert_eq!(
1866 mqtt_opts.incoming_packet_size_limit(),
1867 IncomingPacketSizeLimit::Unlimited
1868 );
1869 assert_eq!(mqtt_opts.max_incoming_packet_size(), None);
1870 assert_eq!(mqtt_opts.max_packet_size(), None);
1871 assert!(mqtt_opts.connect_properties.is_none());
1872 }
1873
1874 #[test]
1875 #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
1876 fn websocket_transport_can_be_explicitly_upgraded_to_wss() {
1877 use crate::{TlsConfiguration, Transport};
1878 let broker = Broker::websocket(
1879 "ws://a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host",
1880 )
1881 .expect("valid websocket broker");
1882 let mut mqttoptions = MqttOptions::new("client_a", broker);
1883
1884 assert!(matches!(mqttoptions.transport(), Transport::Ws));
1885 mqttoptions.set_transport(Transport::wss(Vec::from("Test CA"), None, None));
1886
1887 if let Transport::Wss(TlsConfiguration::Simple {
1888 ca,
1889 client_auth,
1890 alpn,
1891 }) = mqttoptions.transport()
1892 {
1893 assert_eq!(ca.as_slice(), b"Test CA");
1894 assert_eq!(client_auth, None);
1895 assert_eq!(alpn, None);
1896 } else {
1897 panic!("Unexpected transport!");
1898 }
1899
1900 assert_eq!(
1901 mqttoptions.broker().websocket_url(),
1902 Some(
1903 "ws://a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"
1904 )
1905 );
1906 }
1907
1908 #[test]
1909 #[cfg(feature = "websocket")]
1910 fn wss_websocket_urls_require_explicit_transport() {
1911 assert_eq!(
1912 Broker::websocket("wss://example.com/mqtt"),
1913 Err(OptionError::WssRequiresExplicitTransport)
1914 );
1915 }
1916
1917 #[test]
1918 #[cfg(all(
1919 feature = "url",
1920 feature = "use-rustls-no-provider",
1921 feature = "websocket"
1922 ))]
1923 fn parse_url_ws_transport_can_be_explicitly_upgraded_to_wss() {
1924 use crate::{TlsConfiguration, Transport};
1925 let mut mqttoptions =
1926 MqttOptions::parse_url("ws://example.com:443/mqtt?client_id=client_a")
1927 .expect("valid websocket options");
1928
1929 assert!(matches!(mqttoptions.transport(), Transport::Ws));
1930 mqttoptions.set_transport(Transport::wss(Vec::from("Test CA"), None, None));
1931
1932 if let Transport::Wss(TlsConfiguration::Simple {
1933 ca,
1934 client_auth,
1935 alpn,
1936 }) = mqttoptions.transport()
1937 {
1938 assert_eq!(ca.as_slice(), b"Test CA");
1939 assert_eq!(client_auth, None);
1940 assert_eq!(alpn, None);
1941 } else {
1942 panic!("Unexpected transport!");
1943 }
1944 }
1945
1946 #[test]
1947 #[cfg(all(feature = "url", feature = "use-rustls-no-provider"))]
1948 fn parse_url_mqtt_transport_can_be_explicitly_upgraded_to_tls() {
1949 use crate::{TlsConfiguration, Transport};
1950 let mut mqttoptions = MqttOptions::parse_url("mqtt://example.com:8883?client_id=client_a")
1951 .expect("valid tls options");
1952
1953 assert!(matches!(mqttoptions.transport(), Transport::Tcp));
1954 mqttoptions.set_transport(Transport::tls(Vec::from("Test CA"), None, None));
1955
1956 if let Transport::Tls(TlsConfiguration::Simple {
1957 ca,
1958 client_auth,
1959 alpn,
1960 }) = mqttoptions.transport()
1961 {
1962 assert_eq!(ca.as_slice(), b"Test CA");
1963 assert_eq!(client_auth, None);
1964 assert_eq!(alpn, None);
1965 } else {
1966 panic!("Unexpected transport!");
1967 }
1968 }
1969
1970 #[test]
1971 #[cfg(feature = "url")]
1972 fn parse_url_rejects_secure_url_schemes() {
1973 assert!(matches!(
1974 MqttOptions::parse_url("mqtts://example.com:8883?client_id=client_a"),
1975 Err(OptionError::SecureUrlRequiresExplicitTransport)
1976 ));
1977 assert!(matches!(
1978 MqttOptions::parse_url("ssl://example.com:8883?client_id=client_a"),
1979 Err(OptionError::SecureUrlRequiresExplicitTransport)
1980 ));
1981
1982 #[cfg(feature = "websocket")]
1983 assert!(matches!(
1984 MqttOptions::parse_url("wss://example.com:443/mqtt?client_id=client_a"),
1985 Err(OptionError::WssRequiresExplicitTransport)
1986 ));
1987 }
1988
1989 #[test]
1990 #[cfg(feature = "url")]
1991 fn from_url() {
1992 fn opt(s: &str) -> Result<MqttOptions, OptionError> {
1993 MqttOptions::parse_url(s)
1994 }
1995 fn ok(s: &str) -> MqttOptions {
1996 opt(s).expect("valid options")
1997 }
1998 fn err(s: &str) -> OptionError {
1999 opt(s).expect_err("invalid options")
2000 }
2001
2002 let v = ok("mqtt://host:42?client_id=foo");
2003 assert_eq!(v.broker().tcp_address(), Some(("host", 42)));
2004 assert_eq!(v.client_id(), "foo".to_owned());
2005
2006 let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5");
2007 assert_eq!(v.keep_alive, Duration::from_secs(5));
2008 let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=0");
2009 assert_eq!(v.keep_alive, Duration::from_secs(0));
2010 let v = ok("mqtt://host:42?client_id=foo&read_batch_size_num=32");
2011 assert_eq!(v.read_batch_size(), 32);
2012 let v = ok("mqtt://host:42?client_id=foo&conn_timeout_secs=7");
2013 assert_eq!(v.connect_timeout(), Duration::from_secs(7));
2014 let v = ok("mqtt://user@host:42?client_id=foo");
2015 assert_eq!(
2016 v.auth(),
2017 &ConnectAuth::Username {
2018 username: "user".to_owned(),
2019 }
2020 );
2021 let v = ok("mqtt://user:pw@host:42?client_id=foo");
2022 assert_eq!(
2023 v.auth(),
2024 &ConnectAuth::UsernamePassword {
2025 username: "user".to_owned(),
2026 password: Bytes::from_static(b"pw"),
2027 }
2028 );
2029 let v = ok("mqtt://:pw@host:42?client_id=foo");
2030 assert_eq!(
2031 v.auth(),
2032 &ConnectAuth::UsernamePassword {
2033 username: String::new(),
2034 password: Bytes::from_static(b"pw"),
2035 }
2036 );
2037
2038 assert_eq!(err("mqtt://host:42"), OptionError::ClientId);
2039 assert_eq!(
2040 err("mqtt://host:42?client_id=foo&foo=bar"),
2041 OptionError::Unknown("foo".to_owned())
2042 );
2043 assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme);
2044 assert_eq!(
2045 err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"),
2046 OptionError::KeepAlive
2047 );
2048 assert_eq!(
2049 err("mqtt://host:42?client_id=foo&keep_alive_secs=65536"),
2050 OptionError::KeepAlive
2051 );
2052 assert_eq!(
2053 err("mqtt://host:42?client_id=foo&clean_start=foo"),
2054 OptionError::CleanStart
2055 );
2056 assert_eq!(
2057 err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"),
2058 OptionError::MaxIncomingPacketSize
2059 );
2060 assert_eq!(
2061 err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"),
2062 OptionError::RequestChannelCapacity
2063 );
2064 assert_eq!(
2065 err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"),
2066 OptionError::MaxRequestBatch
2067 );
2068 assert_eq!(
2069 err("mqtt://host:42?client_id=foo&read_batch_size_num=foo"),
2070 OptionError::ReadBatchSize
2071 );
2072 assert_eq!(
2073 err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"),
2074 OptionError::PendingThrottle
2075 );
2076 assert_eq!(
2077 err("mqtt://host:42?client_id=foo&inflight_num=foo"),
2078 OptionError::Inflight
2079 );
2080 assert_eq!(
2081 err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"),
2082 OptionError::ConnTimeout
2083 );
2084 }
2085
2086 #[test]
2087 #[cfg(unix)]
2088 fn unix_broker_sets_unix_transport_and_preserves_defaults() {
2089 let options = MqttOptions::new("client", Broker::unix("/tmp/mqtt.sock"));
2090 let baseline = MqttOptions::new("client", "127.0.0.1");
2091
2092 assert!(matches!(options.transport(), Transport::Unix));
2093 assert_eq!(
2094 options.broker().unix_path(),
2095 Some(std::path::Path::new("/tmp/mqtt.sock"))
2096 );
2097 assert_eq!(options.keep_alive, baseline.keep_alive);
2098 assert_eq!(options.clean_start, baseline.clean_start);
2099 assert_eq!(options.client_id, baseline.client_id);
2100 assert_eq!(
2101 options.request_channel_capacity,
2102 baseline.request_channel_capacity
2103 );
2104 assert_eq!(options.max_request_batch, baseline.max_request_batch);
2105 assert_eq!(options.read_batch_size, baseline.read_batch_size);
2106 assert_eq!(options.pending_throttle, baseline.pending_throttle);
2107 assert_eq!(options.connect_timeout, baseline.connect_timeout);
2108 assert_eq!(
2109 options.default_max_incoming_size,
2110 baseline.default_max_incoming_size
2111 );
2112 assert_eq!(
2113 options.incoming_packet_size_limit,
2114 baseline.incoming_packet_size_limit
2115 );
2116 assert_eq!(options.manual_acks, baseline.manual_acks);
2117 assert_eq!(
2118 options.outgoing_inflight_upper_limit,
2119 baseline.outgoing_inflight_upper_limit
2120 );
2121 assert!(options.auth_manager.is_none());
2122 }
2123
2124 #[test]
2125 #[cfg(all(feature = "url", unix))]
2126 fn from_url_supports_unix_socket_paths() {
2127 let options = MqttOptions::parse_url(
2128 "unix:///tmp/mqtt.sock?client_id=foo&keep_alive_secs=5&read_batch_size_num=32",
2129 )
2130 .expect("valid unix socket options");
2131
2132 assert!(matches!(options.transport(), Transport::Unix));
2133 assert_eq!(
2134 options.broker().unix_path(),
2135 Some(std::path::Path::new("/tmp/mqtt.sock"))
2136 );
2137 assert_eq!(options.client_id(), "foo");
2138 assert_eq!(options.keep_alive, Duration::from_secs(5));
2139 assert_eq!(options.read_batch_size(), 32);
2140 }
2141
2142 #[test]
2143 #[cfg(all(feature = "url", unix))]
2144 fn from_url_decodes_percent_escaped_unix_socket_paths() {
2145 let options =
2146 MqttOptions::parse_url("unix:///tmp/mqtt%20broker.sock?client_id=foo").unwrap();
2147
2148 assert_eq!(
2149 options.broker().unix_path(),
2150 Some(std::path::Path::new("/tmp/mqtt broker.sock"))
2151 );
2152 }
2153
2154 #[test]
2155 #[cfg(all(feature = "url", unix))]
2156 fn from_url_preserves_percent_decoded_unix_socket_bytes() {
2157 use std::os::unix::ffi::OsStrExt;
2158
2159 let options = MqttOptions::parse_url("unix:///tmp/mqtt%FF.sock?client_id=foo").unwrap();
2160
2161 assert_eq!(
2162 options.broker().unix_path().unwrap().as_os_str().as_bytes(),
2163 b"/tmp/mqtt\xff.sock"
2164 );
2165 }
2166
2167 #[test]
2168 #[cfg(all(feature = "url", unix))]
2169 fn from_url_rejects_invalid_unix_socket_paths() {
2170 fn err(s: &str) -> OptionError {
2171 MqttOptions::parse_url(s).expect_err("invalid unix socket url")
2172 }
2173
2174 assert_eq!(err("unix:///tmp/mqtt.sock"), OptionError::ClientId);
2175 assert_eq!(
2176 err("unix://localhost/tmp/mqtt.sock?client_id=foo"),
2177 OptionError::UnixSocketPath
2178 );
2179 assert_eq!(err("unix:///?client_id=foo"), OptionError::UnixSocketPath);
2180 }
2181
2182 #[test]
2183 fn allow_empty_client_id() {
2184 let _mqtt_opts = MqttOptions::new("", "127.0.0.1").set_clean_start(true);
2185 }
2186
2187 #[test]
2188 fn read_batch_size_defaults_to_adaptive() {
2189 let options = MqttOptions::new("client", "127.0.0.1");
2190 assert_eq!(options.read_batch_size(), 0);
2191 }
2192
2193 #[test]
2194 fn set_read_batch_size() {
2195 let mut options = MqttOptions::new("client", "127.0.0.1");
2196 options.set_read_batch_size(48);
2197 assert_eq!(options.read_batch_size(), 48);
2198 }
2199
2200 #[test]
2201 #[cfg(feature = "url")]
2202 fn from_url_uses_default_incoming_limit_when_unspecified() {
2203 let mqtt_opts = MqttOptions::parse_url("mqtt://host:42?client_id=foo").unwrap();
2204 assert_eq!(
2205 mqtt_opts.incoming_packet_size_limit(),
2206 IncomingPacketSizeLimit::Default
2207 );
2208 assert_eq!(
2209 mqtt_opts.max_incoming_packet_size(),
2210 Some(mqtt_opts.default_max_incoming_size)
2211 );
2212 }
2213}