#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(all(feature = "use-rustls-ring", feature = "use-rustls-aws-lc"))]
compile_error!(
"Features `use-rustls-ring` and `use-rustls-aws-lc` are mutually exclusive. Enable only one rustls provider feature."
);
#[macro_use]
extern crate log;
use bytes::Bytes;
use std::fmt::{self, Debug, Formatter};
use std::io;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpStream, lookup_host};
use tokio::task::JoinSet;
#[cfg(all(feature = "url", unix))]
use percent_encoding::percent_decode_str;
#[cfg(all(feature = "url", unix))]
use std::{ffi::OsString, os::unix::ffi::OsStringExt};
mod client;
mod eventloop;
mod framed;
pub mod mqttbytes;
mod notice;
mod state;
mod transport;
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
mod tls;
#[cfg(feature = "websocket")]
mod websockets;
#[cfg(feature = "websocket")]
use std::{
future::{Future, IntoFuture},
pin::Pin,
};
#[cfg(feature = "websocket")]
type RequestModifierError = Box<dyn std::error::Error + Send + Sync>;
#[cfg(feature = "websocket")]
type RequestModifierFn = Arc<
dyn Fn(http::Request<()>) -> Pin<Box<dyn Future<Output = http::Request<()>> + Send>>
+ Send
+ Sync,
>;
#[cfg(feature = "websocket")]
type FallibleRequestModifierFn = Arc<
dyn Fn(
http::Request<()>,
)
-> Pin<Box<dyn Future<Output = Result<http::Request<()>, RequestModifierError>> + Send>>
+ Send
+ Sync,
>;
#[cfg(feature = "proxy")]
mod proxy;
pub use client::{
AsyncClient, AsyncClientBuilder, Client, ClientBuilder, ClientError, Connection, InvalidTopic,
Iter, ManualAck, PublishTopic, RecvError, RecvTimeoutError, TryRecvError, ValidatedTopic,
};
pub use eventloop::{ConnectionError, Event, EventLoop};
pub use mqttbytes::v4::*;
pub use mqttbytes::*;
pub use notice::{
NoticeFailureReason, PublishNotice, PublishNoticeError, PublishResult, SubscribeNotice,
SubscribeNoticeError, UnsubscribeNotice, UnsubscribeNoticeError,
};
pub use rumqttc_core::NetworkOptions;
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
pub use rumqttc_core::TlsConfiguration;
pub use rumqttc_core::default_socket_connect;
pub use state::{MqttState, MqttStateBuilder, StateError};
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
pub use tls::Error as TlsError;
#[cfg(feature = "use-native-tls")]
pub use tokio_native_tls;
#[cfg(feature = "use-rustls-no-provider")]
pub use tokio_rustls;
pub use transport::Transport;
#[cfg(feature = "proxy")]
pub use proxy::{Proxy, ProxyAuth, ProxyType};
pub type Incoming = Packet;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Outgoing {
Publish(u16),
Subscribe(u16),
Unsubscribe(u16),
PubAck(u16),
PubRec(u16),
PubRel(u16),
PubComp(u16),
PingReq,
PingResp,
Disconnect,
AwaitAck(u16),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Request {
Publish(Publish),
PubAck(PubAck),
PubRec(PubRec),
PubComp(PubComp),
PubRel(PubRel),
PingReq(PingReq),
PingResp(PingResp),
Subscribe(Subscribe),
SubAck(SubAck),
Unsubscribe(Unsubscribe),
UnsubAck(UnsubAck),
Disconnect(Disconnect),
DisconnectNow(Disconnect),
DisconnectWithTimeout(Disconnect, Duration),
}
impl From<Publish> for Request {
fn from(publish: Publish) -> Self {
Self::Publish(publish)
}
}
impl From<Subscribe> for Request {
fn from(subscribe: Subscribe) -> Self {
Self::Subscribe(subscribe)
}
}
impl From<Unsubscribe> for Request {
fn from(unsubscribe: Unsubscribe) -> Self {
Self::Unsubscribe(unsubscribe)
}
}
pub(crate) type SocketConnector = rumqttc_core::SocketConnector;
const CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(100);
async fn first_success_with_stagger<T, I, F, Fut>(
items: I,
attempt_delay: Duration,
connect_fn: F,
) -> io::Result<T>
where
T: Send + 'static,
I: IntoIterator,
I::Item: Send + 'static,
F: Fn(I::Item) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
{
let mut join_set = JoinSet::new();
let mut item_count = 0usize;
for (index, item) in items.into_iter().enumerate() {
item_count += 1;
let delay = attempt_delay.saturating_mul(u32::try_from(index).unwrap_or(u32::MAX));
let connect_fn = connect_fn.clone();
join_set.spawn(async move {
tokio::time::sleep(delay).await;
connect_fn(item).await
});
}
if item_count == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
));
}
let mut last_err = None;
while let Some(task_result) = join_set.join_next().await {
match task_result {
Ok(Ok(stream)) => {
join_set.abort_all();
return Ok(stream);
}
Ok(Err(err)) => {
last_err = Some(err);
}
Err(err) => {
last_err = Some(io::Error::other(format!(
"concurrent connect task failed: {err}"
)));
}
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
}
async fn first_success_sequential<T, I, F, Fut>(items: I, connect_fn: F) -> io::Result<T>
where
I: IntoIterator,
F: Fn(I::Item) -> Fut,
Fut: std::future::Future<Output = io::Result<T>>,
{
let mut item_count = 0usize;
let mut last_err = None;
for item in items {
item_count += 1;
match connect_fn(item).await {
Ok(stream) => return Ok(stream),
Err(err) => last_err = Some(err),
}
}
if item_count == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
));
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
}
fn should_stagger_connect_attempts(network_options: &NetworkOptions) -> bool {
network_options
.bind_addr()
.is_none_or(|bind_addr| bind_addr.port() == 0)
}
async fn connect_with_retry_mode<T, I, F, Fut>(
items: I,
network_options: NetworkOptions,
connect_fn: F,
) -> io::Result<T>
where
T: Send + 'static,
I: IntoIterator,
I::Item: Send + 'static,
F: Fn(I::Item, NetworkOptions) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
{
connect_with_retry_mode_and_delay(items, network_options, CONNECTION_ATTEMPT_DELAY, connect_fn)
.await
}
async fn connect_with_retry_mode_and_delay<T, I, F, Fut>(
items: I,
network_options: NetworkOptions,
connection_attempt_delay: Duration,
connect_fn: F,
) -> io::Result<T>
where
T: Send + 'static,
I: IntoIterator,
I::Item: Send + 'static,
F: Fn(I::Item, NetworkOptions) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = io::Result<T>> + Send + 'static,
{
if should_stagger_connect_attempts(&network_options) {
first_success_with_stagger(items, connection_attempt_delay, move |item| {
let network_options = network_options.clone();
let connect_fn = connect_fn.clone();
async move { connect_fn(item, network_options).await }
})
.await
} else {
first_success_sequential(items, move |item| {
let network_options = network_options.clone();
let connect_fn = connect_fn.clone();
async move { connect_fn(item, network_options).await }
})
.await
}
}
async fn connect_resolved_addrs_staggered(
addrs: Vec<SocketAddr>,
network_options: NetworkOptions,
) -> io::Result<TcpStream> {
connect_with_retry_mode(
addrs,
network_options,
move |addr, network_options| async move {
rumqttc_core::connect_socket_addr(addr, network_options).await
},
)
.await
}
async fn default_socket_connect_staggered(
host: String,
network_options: NetworkOptions,
) -> io::Result<TcpStream> {
let addrs = lookup_host(host).await?.collect::<Vec<_>>();
connect_resolved_addrs_staggered(addrs, network_options).await
}
fn default_socket_connector() -> SocketConnector {
Arc::new(|host, network_options| {
Box::pin(async move {
let tcp = default_socket_connect_staggered(host, network_options).await?;
Ok(Box::new(tcp) as Box<dyn crate::framed::AsyncReadWrite>)
})
})
}
const DEFAULT_BROKER_PORT: u16 = 1883;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Broker {
inner: BrokerInner,
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum BrokerInner {
Tcp {
host: String,
port: u16,
},
#[cfg(unix)]
Unix {
path: PathBuf,
},
#[cfg(feature = "websocket")]
Websocket {
url: String,
},
}
impl Broker {
#[must_use]
pub fn tcp<S: Into<String>>(host: S, port: u16) -> Self {
Self {
inner: BrokerInner::Tcp {
host: host.into(),
port,
},
}
}
#[cfg(unix)]
#[must_use]
pub fn unix<P: Into<PathBuf>>(path: P) -> Self {
Self {
inner: BrokerInner::Unix { path: path.into() },
}
}
#[cfg(feature = "websocket")]
pub fn websocket<S: Into<String>>(url: S) -> Result<Self, OptionError> {
let url = url.into();
let uri = url
.parse::<http::Uri>()
.map_err(|_| OptionError::WebsocketUrl)?;
match uri.scheme_str() {
Some("ws") => {
rumqttc_core::split_url(&url).map_err(|_| OptionError::WebsocketUrl)?;
Ok(Self {
inner: BrokerInner::Websocket { url },
})
}
Some("wss") => Err(OptionError::WssRequiresExplicitTransport),
_ => Err(OptionError::Scheme),
}
}
#[must_use]
pub const fn tcp_address(&self) -> Option<(&str, u16)> {
match &self.inner {
BrokerInner::Tcp { host, port } => Some((host.as_str(), *port)),
#[cfg(unix)]
BrokerInner::Unix { .. } => None,
#[cfg(feature = "websocket")]
BrokerInner::Websocket { .. } => None,
}
}
#[cfg(unix)]
#[must_use]
pub fn unix_path(&self) -> Option<&std::path::Path> {
match &self.inner {
BrokerInner::Unix { path } => Some(path.as_path()),
BrokerInner::Tcp { .. } => None,
#[cfg(feature = "websocket")]
BrokerInner::Websocket { .. } => None,
}
}
#[cfg(feature = "websocket")]
#[must_use]
pub const fn websocket_url(&self) -> Option<&str> {
match &self.inner {
BrokerInner::Websocket { url } => Some(url.as_str()),
BrokerInner::Tcp { .. } => None,
#[cfg(unix)]
BrokerInner::Unix { .. } => None,
}
}
pub(crate) const fn default_transport(&self) -> Transport {
match &self.inner {
BrokerInner::Tcp { .. } => Transport::tcp(),
#[cfg(unix)]
BrokerInner::Unix { .. } => Transport::unix(),
#[cfg(feature = "websocket")]
BrokerInner::Websocket { .. } => Transport::Ws,
}
}
}
impl From<&str> for Broker {
fn from(host: &str) -> Self {
Self::tcp(host, DEFAULT_BROKER_PORT)
}
}
impl From<String> for Broker {
fn from(host: String) -> Self {
Self::tcp(host, DEFAULT_BROKER_PORT)
}
}
impl<S: Into<String>> From<(S, u16)> for Broker {
fn from((host, port): (S, u16)) -> Self {
Self::tcp(host, port)
}
}
#[derive(Clone)]
pub struct MqttOptions {
broker: Broker,
transport: Transport,
keep_alive: Duration,
clean_session: bool,
client_id: String,
auth: ConnectAuth,
max_incoming_packet_size: usize,
max_outgoing_packet_size: usize,
request_channel_capacity: usize,
max_request_batch: usize,
read_batch_size: usize,
pending_throttle: Duration,
inflight: u16,
last_will: Option<LastWill>,
manual_acks: bool,
#[cfg(feature = "proxy")]
proxy: Option<Proxy>,
#[cfg(feature = "websocket")]
request_modifier: Option<RequestModifierFn>,
#[cfg(feature = "websocket")]
fallible_request_modifier: Option<FallibleRequestModifierFn>,
socket_connector: Option<SocketConnector>,
}
impl MqttOptions {
pub fn new<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> Self {
let broker = broker.into();
Self {
transport: broker.default_transport(),
broker,
keep_alive: Duration::from_secs(60),
clean_session: true,
client_id: id.into(),
auth: ConnectAuth::None,
max_incoming_packet_size: 10 * 1024,
max_outgoing_packet_size: 10 * 1024,
request_channel_capacity: 10,
max_request_batch: 0,
read_batch_size: 0,
pending_throttle: Duration::from_micros(0),
inflight: 100,
last_will: None,
manual_acks: false,
#[cfg(feature = "proxy")]
proxy: None,
#[cfg(feature = "websocket")]
request_modifier: None,
#[cfg(feature = "websocket")]
fallible_request_modifier: None,
socket_connector: None,
}
}
#[must_use]
pub fn builder<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> MqttOptionsBuilder {
MqttOptionsBuilder::new(id, broker)
}
#[cfg(feature = "url")]
pub fn parse_url<S: Into<String>>(url: S) -> Result<Self, OptionError> {
let url = url::Url::parse(&url.into())?;
let options = Self::try_from(url)?;
Ok(options)
}
pub const fn broker(&self) -> &Broker {
&self.broker
}
pub fn set_last_will(&mut self, will: LastWill) -> &mut Self {
self.last_will = Some(will);
self
}
pub fn last_will(&self) -> Option<LastWill> {
self.last_will.clone()
}
pub fn set_client_id(&mut self, client_id: String) -> &mut Self {
self.client_id = client_id;
self
}
#[cfg(not(any(feature = "use-rustls-no-provider", feature = "use-native-tls")))]
pub const fn set_transport(&mut self, transport: Transport) -> &mut Self {
self.transport = transport;
self
}
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
pub fn set_transport(&mut self, transport: Transport) -> &mut Self {
self.transport = transport;
self
}
pub fn transport(&self) -> Transport {
self.transport.clone()
}
pub fn set_keep_alive(&mut self, seconds: u16) -> &mut Self {
self.keep_alive = Duration::from_secs(u64::from(seconds));
self
}
pub const fn keep_alive(&self) -> Duration {
self.keep_alive
}
pub fn client_id(&self) -> String {
self.client_id.clone()
}
pub const fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self {
self.max_incoming_packet_size = incoming;
self.max_outgoing_packet_size = outgoing;
self
}
pub const fn max_packet_size(&self) -> usize {
self.max_incoming_packet_size
}
pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self {
assert!(
!self.client_id.is_empty() || clean_session,
"Cannot unset clean session when client id is empty"
);
self.clean_session = clean_session;
self
}
pub const fn clean_session(&self) -> bool {
self.clean_session
}
pub fn set_auth(&mut self, auth: ConnectAuth) -> &mut Self {
self.auth = auth;
self
}
pub fn clear_auth(&mut self) -> &mut Self {
self.auth = ConnectAuth::None;
self
}
pub fn set_username<U: Into<String>>(&mut self, username: U) -> &mut Self {
self.auth = ConnectAuth::Username {
username: username.into(),
};
self
}
pub fn set_credentials<U: Into<String>, P: Into<Bytes>>(
&mut self,
username: U,
password: P,
) -> &mut Self {
self.auth = ConnectAuth::UsernamePassword {
username: username.into(),
password: password.into(),
};
self
}
pub const fn auth(&self) -> &ConnectAuth {
&self.auth
}
pub const fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self {
self.request_channel_capacity = capacity;
self
}
pub const fn request_channel_capacity(&self) -> usize {
self.request_channel_capacity
}
pub const fn set_max_request_batch(&mut self, max: usize) -> &mut Self {
self.max_request_batch = max;
self
}
pub const fn max_request_batch(&self) -> usize {
self.max_request_batch
}
pub const fn set_read_batch_size(&mut self, size: usize) -> &mut Self {
self.read_batch_size = size;
self
}
pub const fn read_batch_size(&self) -> usize {
self.read_batch_size
}
pub const fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self {
self.pending_throttle = duration;
self
}
pub const fn pending_throttle(&self) -> Duration {
self.pending_throttle
}
pub fn set_inflight(&mut self, inflight: u16) -> &mut Self {
assert!(inflight != 0, "zero in flight is not allowed");
self.inflight = inflight;
self
}
pub const fn inflight(&self) -> u16 {
self.inflight
}
pub const fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self {
self.manual_acks = manual_acks;
self
}
pub const fn manual_acks(&self) -> bool {
self.manual_acks
}
#[cfg(feature = "proxy")]
pub fn set_proxy(&mut self, proxy: Proxy) -> &mut Self {
self.proxy = Some(proxy);
self
}
#[cfg(feature = "proxy")]
pub fn proxy(&self) -> Option<Proxy> {
self.proxy.clone()
}
#[cfg(feature = "websocket")]
pub fn set_request_modifier<F, O>(&mut self, request_modifier: F) -> &mut Self
where
F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
O: IntoFuture<Output = http::Request<()>> + 'static,
O::IntoFuture: Send,
{
self.request_modifier = Some(Arc::new(move |request| {
let request_modifier = request_modifier(request).into_future();
Box::pin(request_modifier)
}));
self.fallible_request_modifier = None;
self
}
#[cfg(feature = "websocket")]
pub fn set_fallible_request_modifier<F, O, E>(&mut self, request_modifier: F) -> &mut Self
where
F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
O: IntoFuture<Output = Result<http::Request<()>, E>> + 'static,
O::IntoFuture: Send,
E: std::error::Error + Send + Sync + 'static,
{
self.fallible_request_modifier = Some(Arc::new(move |request| {
let request_modifier = request_modifier(request).into_future();
Box::pin(async move {
request_modifier
.await
.map_err(|error| Box::new(error) as RequestModifierError)
})
}));
self.request_modifier = None;
self
}
#[cfg(feature = "websocket")]
pub fn request_modifier(&self) -> Option<RequestModifierFn> {
self.request_modifier.clone()
}
#[cfg(feature = "websocket")]
pub(crate) fn fallible_request_modifier(&self) -> Option<FallibleRequestModifierFn> {
self.fallible_request_modifier.clone()
}
pub fn set_socket_connector<F, Fut, S>(&mut self, f: F) -> &mut Self
where
F: Fn(String, NetworkOptions) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<S, std::io::Error>> + Send + 'static,
S: crate::framed::AsyncReadWrite + 'static,
{
self.socket_connector = Some(Arc::new(move |host, network_options| {
let stream_future = f(host, network_options);
let future = async move {
let stream = stream_future.await?;
Ok(Box::new(stream) as Box<dyn crate::framed::AsyncReadWrite>)
};
Box::pin(future)
}));
self
}
pub fn has_socket_connector(&self) -> bool {
self.socket_connector.is_some()
}
pub(crate) fn effective_socket_connector(&self) -> SocketConnector {
self.socket_connector
.clone()
.unwrap_or_else(default_socket_connector)
}
pub(crate) async fn socket_connect(
&self,
host: String,
network_options: NetworkOptions,
) -> std::io::Result<Box<dyn crate::framed::AsyncReadWrite>> {
let connector = self.effective_socket_connector();
connector(host, network_options).await
}
}
pub struct MqttOptionsBuilder {
options: MqttOptions,
}
impl MqttOptionsBuilder {
#[must_use]
pub fn new<S: Into<String>, B: Into<Broker>>(id: S, broker: B) -> Self {
Self {
options: MqttOptions::new(id, broker),
}
}
#[must_use]
pub fn build(self) -> MqttOptions {
self.options
}
#[must_use]
pub fn last_will(mut self, will: LastWill) -> Self {
self.options.set_last_will(will);
self
}
#[must_use]
pub fn client_id(mut self, client_id: String) -> Self {
self.options.set_client_id(client_id);
self
}
#[cfg(not(any(feature = "use-rustls-no-provider", feature = "use-native-tls")))]
#[must_use]
pub const fn transport(mut self, transport: Transport) -> Self {
self.options.set_transport(transport);
self
}
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
#[must_use]
pub fn transport(mut self, transport: Transport) -> Self {
self.options.set_transport(transport);
self
}
#[must_use]
pub fn keep_alive(mut self, seconds: u16) -> Self {
self.options.set_keep_alive(seconds);
self
}
#[must_use]
pub const fn max_packet_size(mut self, incoming: usize, outgoing: usize) -> Self {
self.options.set_max_packet_size(incoming, outgoing);
self
}
#[must_use]
pub fn clean_session(mut self, clean_session: bool) -> Self {
self.options.set_clean_session(clean_session);
self
}
#[must_use]
pub fn auth(mut self, auth: ConnectAuth) -> Self {
self.options.set_auth(auth);
self
}
#[must_use]
pub fn clear_auth(mut self) -> Self {
self.options.clear_auth();
self
}
#[must_use]
pub fn username<U: Into<String>>(mut self, username: U) -> Self {
self.options.set_username(username);
self
}
#[must_use]
pub fn credentials<U: Into<String>, P: Into<Bytes>>(
mut self,
username: U,
password: P,
) -> Self {
self.options.set_credentials(username, password);
self
}
#[must_use]
pub const fn request_channel_capacity(mut self, capacity: usize) -> Self {
self.options.set_request_channel_capacity(capacity);
self
}
#[must_use]
pub const fn max_request_batch(mut self, max: usize) -> Self {
self.options.set_max_request_batch(max);
self
}
#[must_use]
pub const fn read_batch_size(mut self, size: usize) -> Self {
self.options.set_read_batch_size(size);
self
}
#[must_use]
pub const fn pending_throttle(mut self, duration: Duration) -> Self {
self.options.set_pending_throttle(duration);
self
}
#[must_use]
pub fn inflight(mut self, inflight: u16) -> Self {
self.options.set_inflight(inflight);
self
}
#[must_use]
pub const fn manual_acks(mut self, manual_acks: bool) -> Self {
self.options.set_manual_acks(manual_acks);
self
}
#[cfg(feature = "proxy")]
#[must_use]
pub fn proxy(mut self, proxy: Proxy) -> Self {
self.options.set_proxy(proxy);
self
}
#[cfg(feature = "websocket")]
#[must_use]
pub fn request_modifier<F, O>(mut self, request_modifier: F) -> Self
where
F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
O: IntoFuture<Output = http::Request<()>> + 'static,
O::IntoFuture: Send,
{
self.options.set_request_modifier(request_modifier);
self
}
#[cfg(feature = "websocket")]
#[must_use]
pub fn fallible_request_modifier<F, O, E>(mut self, request_modifier: F) -> Self
where
F: Fn(http::Request<()>) -> O + Send + Sync + 'static,
O: IntoFuture<Output = Result<http::Request<()>, E>> + 'static,
O::IntoFuture: Send,
E: std::error::Error + Send + Sync + 'static,
{
self.options.set_fallible_request_modifier(request_modifier);
self
}
#[must_use]
pub fn socket_connector<F, Fut, S>(mut self, f: F) -> Self
where
F: Fn(String, NetworkOptions) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<S, std::io::Error>> + Send + 'static,
S: crate::framed::AsyncReadWrite + 'static,
{
self.options.set_socket_connector(f);
self
}
}
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum OptionError {
#[error("Unsupported URL scheme.")]
Scheme,
#[error(
"Secure MQTT URL schemes require explicit TLS transport configuration via MqttOptions::set_transport(...)."
)]
SecureUrlRequiresExplicitTransport,
#[error("Missing client ID.")]
ClientId,
#[error("Invalid Unix socket path.")]
UnixSocketPath,
#[cfg(feature = "websocket")]
#[error("Invalid websocket url.")]
WebsocketUrl,
#[cfg(feature = "websocket")]
#[error(
"Secure websocket URLs require Broker::websocket(\"ws://...\") plus MqttOptions::set_transport(Transport::wss_with_config(...))."
)]
WssRequiresExplicitTransport,
#[error("Invalid keep-alive value.")]
KeepAlive,
#[error("Invalid clean-session value.")]
CleanSession,
#[error("Invalid max-incoming-packet-size value.")]
MaxIncomingPacketSize,
#[error("Invalid max-outgoing-packet-size value.")]
MaxOutgoingPacketSize,
#[error("Invalid request-channel-capacity value.")]
RequestChannelCapacity,
#[error("Invalid max-request-batch value.")]
MaxRequestBatch,
#[error("Invalid read-batch-size value.")]
ReadBatchSize,
#[error("Invalid pending-throttle value.")]
PendingThrottle,
#[error("Invalid inflight value.")]
Inflight,
#[error("Unknown option: {0}")]
Unknown(String),
#[cfg(feature = "url")]
#[error("Couldn't parse option from url: {0}")]
Parse(#[from] url::ParseError),
}
#[cfg(feature = "url")]
impl std::convert::TryFrom<url::Url> for MqttOptions {
type Error = OptionError;
fn try_from(url: url::Url) -> Result<Self, Self::Error> {
use std::collections::HashMap;
let broker = match url.scheme() {
"mqtts" | "ssl" => return Err(OptionError::SecureUrlRequiresExplicitTransport),
"mqtt" | "tcp" => Broker::tcp(
url.host_str().unwrap_or_default(),
url.port().unwrap_or(DEFAULT_BROKER_PORT),
),
#[cfg(unix)]
"unix" => Broker::unix(parse_unix_socket_path(&url)?),
#[cfg(feature = "websocket")]
"ws" => Broker::websocket(url.as_str().to_owned())?,
#[cfg(feature = "websocket")]
"wss" => return Err(OptionError::WssRequiresExplicitTransport),
_ => return Err(OptionError::Scheme),
};
let mut queries = url.query_pairs().collect::<HashMap<_, _>>();
let id = queries
.remove("client_id")
.ok_or(OptionError::ClientId)?
.into_owned();
let mut options = Self::new(id, broker);
if let Some(keep_alive) = queries
.remove("keep_alive_secs")
.map(|v| v.parse::<u16>().map_err(|_| OptionError::KeepAlive))
.transpose()?
{
options.set_keep_alive(keep_alive);
}
if let Some(clean_session) = queries
.remove("clean_session")
.map(|v| v.parse::<bool>().map_err(|_| OptionError::CleanSession))
.transpose()?
{
options.set_clean_session(clean_session);
}
set_url_credentials(&mut options, &url);
if let (Some(incoming), Some(outgoing)) = (
queries
.remove("max_incoming_packet_size_bytes")
.map(|v| {
v.parse::<usize>()
.map_err(|_| OptionError::MaxIncomingPacketSize)
})
.transpose()?,
queries
.remove("max_outgoing_packet_size_bytes")
.map(|v| {
v.parse::<usize>()
.map_err(|_| OptionError::MaxOutgoingPacketSize)
})
.transpose()?,
) {
options.set_max_packet_size(incoming, outgoing);
}
if let Some(request_channel_capacity) = queries
.remove("request_channel_capacity_num")
.map(|v| {
v.parse::<usize>()
.map_err(|_| OptionError::RequestChannelCapacity)
})
.transpose()?
{
options.request_channel_capacity = request_channel_capacity;
}
if let Some(max_request_batch) = queries
.remove("max_request_batch_num")
.map(|v| v.parse::<usize>().map_err(|_| OptionError::MaxRequestBatch))
.transpose()?
{
options.max_request_batch = max_request_batch;
}
if let Some(read_batch_size) = queries
.remove("read_batch_size_num")
.map(|v| v.parse::<usize>().map_err(|_| OptionError::ReadBatchSize))
.transpose()?
{
options.read_batch_size = read_batch_size;
}
if let Some(pending_throttle) = queries
.remove("pending_throttle_usecs")
.map(|v| v.parse::<u64>().map_err(|_| OptionError::PendingThrottle))
.transpose()?
{
options.set_pending_throttle(Duration::from_micros(pending_throttle));
}
if let Some(inflight) = queries
.remove("inflight_num")
.map(|v| v.parse::<u16>().map_err(|_| OptionError::Inflight))
.transpose()?
{
options.set_inflight(inflight);
}
if let Some((opt, _)) = queries.into_iter().next() {
return Err(OptionError::Unknown(opt.into_owned()));
}
Ok(options)
}
}
#[cfg(feature = "url")]
fn set_url_credentials(options: &mut MqttOptions, url: &url::Url) {
let username = url.username();
if let Some(password) = url.password() {
options.set_credentials(username, password.to_owned());
} else if !username.is_empty() {
options.set_username(username);
}
}
#[cfg(all(feature = "url", unix))]
fn parse_unix_socket_path(url: &url::Url) -> Result<PathBuf, OptionError> {
if url.host_str().is_some() {
return Err(OptionError::UnixSocketPath);
}
let path = percent_decode_str(url.path()).collect::<Vec<u8>>();
if path.is_empty() || path == b"/" {
return Err(OptionError::UnixSocketPath);
}
Ok(PathBuf::from(OsString::from_vec(path)))
}
impl Debug for MqttOptions {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("MqttOptions")
.field("broker", &self.broker)
.field("keep_alive", &self.keep_alive)
.field("clean_session", &self.clean_session)
.field("client_id", &self.client_id)
.field("auth", &self.auth)
.field("max_packet_size", &self.max_incoming_packet_size)
.field("request_channel_capacity", &self.request_channel_capacity)
.field("max_request_batch", &self.max_request_batch)
.field("read_batch_size", &self.read_batch_size)
.field("pending_throttle", &self.pending_throttle)
.field("inflight", &self.inflight)
.field("last_will", &self.last_will)
.field("manual_acks", &self.manual_acks)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod test {
use super::*;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::net::{TcpListener, TcpSocket};
use tokio::runtime::Builder;
use tokio::sync::Notify;
fn runtime() -> tokio::runtime::Runtime {
Builder::new_current_thread().enable_all().build().unwrap()
}
#[test]
fn staggered_attempts_allow_later_success_to_win() {
runtime().block_on(async {
let started = Arc::new(AtomicUsize::new(0));
let started_for_connect = Arc::clone(&started);
let begin = std::time::Instant::now();
let result = first_success_with_stagger(
[0_u8, 1_u8],
std::time::Duration::from_millis(10),
move |attempt| {
let started = Arc::clone(&started_for_connect);
async move {
started.fetch_add(1, Ordering::SeqCst);
if attempt == 0 {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
Err(std::io::Error::other("slow failure"))
} else {
Ok(42_u8)
}
}
},
)
.await
.unwrap();
assert_eq!(result, 42);
assert_eq!(started.load(Ordering::SeqCst), 2);
assert!(begin.elapsed() < std::time::Duration::from_millis(150));
});
}
#[test]
fn staggered_connect_returns_invalid_input_for_empty_candidates() {
runtime().block_on(async {
let err = connect_resolved_addrs_staggered(Vec::new(), NetworkOptions::new())
.await
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
assert_eq!(err.to_string(), "could not resolve to any address");
});
}
#[test]
fn staggered_connect_tries_later_candidates() {
runtime().block_on(async {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let good_addr = listener.local_addr().unwrap();
let unused_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let bad_addr = unused_listener.local_addr().unwrap();
drop(unused_listener);
let accept_task = tokio::spawn(async move {
let (_stream, _) = listener.accept().await.unwrap();
});
let stream =
connect_resolved_addrs_staggered(vec![bad_addr, good_addr], NetworkOptions::new())
.await
.unwrap();
assert_eq!(stream.peer_addr().unwrap(), good_addr);
accept_task.await.unwrap();
});
}
#[test]
fn fixed_bind_port_retry_mode_keeps_slow_first_candidate_alive() {
runtime().block_on(async {
let reserved = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let bind_port = reserved.local_addr().unwrap().port();
drop(reserved);
let mut network_options = NetworkOptions::new();
network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::LOCALHOST,
bind_port,
)));
let first_attempt_started = Arc::new(Notify::new());
let second_attempt_started = Arc::new(AtomicBool::new(false));
let mut connect_task = tokio::spawn({
let first_attempt_started = Arc::clone(&first_attempt_started);
let second_attempt_started = Arc::clone(&second_attempt_started);
let network_options = network_options.clone();
async move {
connect_with_retry_mode_and_delay(
[0_u8, 1_u8],
network_options,
Duration::from_millis(10),
move |attempt, network_options| {
let first_attempt_started = Arc::clone(&first_attempt_started);
let second_attempt_started = Arc::clone(&second_attempt_started);
async move {
if attempt == 0 {
let bind_addr = network_options.bind_addr().unwrap();
let socket = match bind_addr {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
socket.bind(bind_addr)?;
first_attempt_started.notify_one();
std::future::pending::<io::Result<()>>().await
} else {
second_attempt_started.store(true, Ordering::SeqCst);
let _ = network_options;
Ok(())
}
}
},
)
.await
}
});
first_attempt_started.notified().await;
assert!(
tokio::time::timeout(Duration::from_millis(50), &mut connect_task)
.await
.is_err(),
"fixed-port dialing should keep the first slow candidate alive instead of capping it to the stagger delay"
);
assert!(
!second_attempt_started.load(Ordering::SeqCst),
"fixed-port dialing should not start later same-family candidates while the first is still pending"
);
connect_task.abort();
});
}
#[test]
fn fixed_bind_port_resolved_addrs_try_later_candidates() {
runtime().block_on(async {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let good_addr = listener.local_addr().unwrap();
let unused_listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let bad_addr = unused_listener.local_addr().unwrap();
drop(unused_listener);
let reserved = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let bind_port = reserved.local_addr().unwrap().port();
drop(reserved);
let mut network_options = NetworkOptions::new();
network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::LOCALHOST,
bind_port,
)));
let accept_task = tokio::spawn(async move {
let (stream, peer_addr) = listener.accept().await.unwrap();
drop(stream);
peer_addr
});
let stream =
connect_resolved_addrs_staggered(vec![bad_addr, good_addr], network_options)
.await
.unwrap();
assert_eq!(stream.peer_addr().unwrap(), good_addr);
drop(stream);
let peer_addr = accept_task.await.unwrap();
assert_eq!(peer_addr.port(), bind_port);
assert!(peer_addr.ip().is_loopback());
});
}
#[test]
fn socket_connect_uses_custom_connector_over_default() {
runtime().block_on(async {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
let good_addr = listener.local_addr().unwrap();
let used_custom = Arc::new(AtomicUsize::new(0));
let used_custom_for_connector = Arc::clone(&used_custom);
let accept_task = tokio::spawn(async move {
let (_stream, _) = listener.accept().await.unwrap();
});
let mut options = MqttOptions::new("test-client", "localhost");
options.set_socket_connector(move |_host, _network_options| {
let used_custom = Arc::clone(&used_custom_for_connector);
async move {
used_custom.fetch_add(1, Ordering::SeqCst);
TcpStream::connect(good_addr).await
}
});
assert!(options.has_socket_connector());
options
.socket_connect("invalid.invalid:1883".to_owned(), NetworkOptions::new())
.await
.unwrap();
assert_eq!(used_custom.load(Ordering::SeqCst), 1);
accept_task.await.unwrap();
});
}
#[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
mod request_modifier_tests {
use super::{Broker, MqttOptions};
#[derive(Debug)]
struct TestError;
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "test error")
}
}
impl std::error::Error for TestError {}
#[test]
fn infallible_modifier_is_set() {
let mut options = MqttOptions::new(
"test",
Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
);
options.set_request_modifier(|req| async move { req });
assert!(options.request_modifier().is_some());
assert!(options.fallible_request_modifier().is_none());
}
#[test]
fn fallible_modifier_is_set() {
let mut options = MqttOptions::new(
"test",
Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
);
options.set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) });
assert!(options.request_modifier().is_none());
assert!(options.fallible_request_modifier().is_some());
}
#[test]
fn last_setter_call_wins() {
let mut options = MqttOptions::new(
"test",
Broker::websocket("ws://localhost:8080").expect("valid websocket broker"),
);
options
.set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) })
.set_request_modifier(|req| async move { req });
assert!(options.request_modifier().is_some());
assert!(options.fallible_request_modifier().is_none());
options
.set_request_modifier(|req| async move { req })
.set_fallible_request_modifier(|req| async move { Ok::<_, TestError>(req) });
assert!(options.request_modifier().is_none());
assert!(options.fallible_request_modifier().is_some());
}
}
#[test]
#[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
fn websocket_transport_can_be_explicitly_upgraded_to_wss() {
let broker = Broker::websocket(
"ws://a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host",
)
.expect("valid websocket broker");
let mut mqttoptions = MqttOptions::new("client_a", broker);
assert!(matches!(mqttoptions.transport(), crate::Transport::Ws));
mqttoptions.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None));
if let crate::Transport::Wss(TlsConfiguration::Simple {
ca,
client_auth,
alpn,
}) = mqttoptions.transport()
{
assert_eq!(ca.as_slice(), b"Test CA");
assert_eq!(client_auth, None);
assert_eq!(alpn, None);
} else {
panic!("Unexpected transport!");
}
assert_eq!(
mqttoptions.broker().websocket_url(),
Some(
"ws://a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"
)
);
}
#[test]
#[cfg(feature = "websocket")]
fn wss_websocket_urls_require_explicit_transport() {
assert_eq!(
Broker::websocket("wss://example.com/mqtt"),
Err(OptionError::WssRequiresExplicitTransport)
);
}
#[test]
#[cfg(all(
feature = "url",
feature = "use-rustls-no-provider",
feature = "websocket"
))]
fn parse_url_ws_transport_can_be_explicitly_upgraded_to_wss() {
let mut mqttoptions =
MqttOptions::parse_url("ws://example.com:443/mqtt?client_id=client_a")
.expect("valid websocket options");
assert!(matches!(mqttoptions.transport(), crate::Transport::Ws));
mqttoptions.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None));
if let crate::Transport::Wss(TlsConfiguration::Simple {
ca,
client_auth,
alpn,
}) = mqttoptions.transport()
{
assert_eq!(ca.as_slice(), b"Test CA");
assert_eq!(client_auth, None);
assert_eq!(alpn, None);
} else {
panic!("Unexpected transport!");
}
}
#[test]
#[cfg(all(feature = "url", feature = "use-rustls-no-provider"))]
fn parse_url_mqtt_transport_can_be_explicitly_upgraded_to_tls() {
let mut mqttoptions = MqttOptions::parse_url("mqtt://example.com:8883?client_id=client_a")
.expect("valid tls options");
assert!(matches!(mqttoptions.transport(), crate::Transport::Tcp));
mqttoptions.set_transport(crate::Transport::tls(Vec::from("Test CA"), None, None));
if let crate::Transport::Tls(TlsConfiguration::Simple {
ca,
client_auth,
alpn,
}) = mqttoptions.transport()
{
assert_eq!(ca.as_slice(), b"Test CA");
assert_eq!(client_auth, None);
assert_eq!(alpn, None);
} else {
panic!("Unexpected transport!");
}
}
#[test]
#[cfg(feature = "url")]
fn parse_url_rejects_secure_url_schemes() {
assert!(matches!(
MqttOptions::parse_url("mqtts://example.com:8883?client_id=client_a"),
Err(OptionError::SecureUrlRequiresExplicitTransport)
));
assert!(matches!(
MqttOptions::parse_url("ssl://example.com:8883?client_id=client_a"),
Err(OptionError::SecureUrlRequiresExplicitTransport)
));
#[cfg(feature = "websocket")]
assert!(matches!(
MqttOptions::parse_url("wss://example.com:443/mqtt?client_id=client_a"),
Err(OptionError::WssRequiresExplicitTransport)
));
}
#[test]
#[cfg(feature = "url")]
fn from_url() {
fn opt(s: &str) -> Result<MqttOptions, OptionError> {
MqttOptions::parse_url(s)
}
fn ok(s: &str) -> MqttOptions {
opt(s).expect("valid options")
}
fn err(s: &str) -> OptionError {
opt(s).expect_err("invalid options")
}
let v = ok("mqtt://host:42?client_id=foo");
assert_eq!(v.broker().tcp_address(), Some(("host", 42)));
assert_eq!(v.client_id(), "foo".to_owned());
let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5");
assert_eq!(v.keep_alive, Duration::from_secs(5));
let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=0");
assert_eq!(v.keep_alive, Duration::from_secs(0));
let v = ok("mqtt://host:42?client_id=foo&read_batch_size_num=32");
assert_eq!(v.read_batch_size(), 32);
let v = ok("mqtt://user@host:42?client_id=foo");
assert_eq!(
v.auth(),
&ConnectAuth::Username {
username: "user".to_owned(),
}
);
let v = ok("mqtt://user:pw@host:42?client_id=foo");
assert_eq!(
v.auth(),
&ConnectAuth::UsernamePassword {
username: "user".to_owned(),
password: Bytes::from_static(b"pw"),
}
);
let v = ok("mqtt://:pw@host:42?client_id=foo");
assert_eq!(
v.auth(),
&ConnectAuth::UsernamePassword {
username: String::new(),
password: Bytes::from_static(b"pw"),
}
);
assert_eq!(err("mqtt://host:42"), OptionError::ClientId);
assert_eq!(
err("mqtt://host:42?client_id=foo&foo=bar"),
OptionError::Unknown("foo".to_owned())
);
assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme);
assert_eq!(
err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"),
OptionError::KeepAlive
);
assert_eq!(
err("mqtt://host:42?client_id=foo&keep_alive_secs=65536"),
OptionError::KeepAlive
);
assert_eq!(
err("mqtt://host:42?client_id=foo&clean_session=foo"),
OptionError::CleanSession
);
assert_eq!(
err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"),
OptionError::MaxIncomingPacketSize
);
assert_eq!(
err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"),
OptionError::MaxOutgoingPacketSize
);
assert_eq!(
err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"),
OptionError::RequestChannelCapacity
);
assert_eq!(
err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"),
OptionError::MaxRequestBatch
);
assert_eq!(
err("mqtt://host:42?client_id=foo&read_batch_size_num=foo"),
OptionError::ReadBatchSize
);
assert_eq!(
err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"),
OptionError::PendingThrottle
);
assert_eq!(
err("mqtt://host:42?client_id=foo&inflight_num=foo"),
OptionError::Inflight
);
}
#[test]
#[cfg(unix)]
fn unix_broker_sets_unix_transport_and_preserves_defaults() {
let options = MqttOptions::new("client_id", Broker::unix("/tmp/mqtt.sock"));
let baseline = MqttOptions::new("client_id", "127.0.0.1");
assert!(matches!(options.transport(), Transport::Unix));
assert_eq!(
options.broker().unix_path(),
Some(std::path::Path::new("/tmp/mqtt.sock"))
);
assert_eq!(options.keep_alive, baseline.keep_alive);
assert_eq!(options.clean_session, baseline.clean_session);
assert_eq!(options.client_id, baseline.client_id);
assert_eq!(
options.max_incoming_packet_size,
baseline.max_incoming_packet_size
);
assert_eq!(
options.max_outgoing_packet_size,
baseline.max_outgoing_packet_size
);
assert_eq!(
options.request_channel_capacity,
baseline.request_channel_capacity
);
assert_eq!(options.max_request_batch, baseline.max_request_batch);
assert_eq!(options.read_batch_size, baseline.read_batch_size);
assert_eq!(options.pending_throttle, baseline.pending_throttle);
assert_eq!(options.inflight, baseline.inflight);
assert_eq!(options.manual_acks, baseline.manual_acks);
}
#[test]
#[cfg(all(feature = "url", unix))]
fn from_url_supports_unix_socket_paths() {
let options = MqttOptions::parse_url(
"unix:///tmp/mqtt.sock?client_id=foo&keep_alive_secs=5&read_batch_size_num=32",
)
.expect("valid unix socket options");
assert!(matches!(options.transport(), Transport::Unix));
assert_eq!(
options.broker().unix_path(),
Some(std::path::Path::new("/tmp/mqtt.sock"))
);
assert_eq!(options.client_id(), "foo");
assert_eq!(options.keep_alive, Duration::from_secs(5));
assert_eq!(options.read_batch_size(), 32);
}
#[test]
#[cfg(all(feature = "url", unix))]
fn from_url_decodes_percent_escaped_unix_socket_paths() {
let options =
MqttOptions::parse_url("unix:///tmp/mqtt%20broker.sock?client_id=foo").unwrap();
assert_eq!(
options.broker().unix_path(),
Some(std::path::Path::new("/tmp/mqtt broker.sock"))
);
}
#[test]
#[cfg(all(feature = "url", unix))]
fn from_url_preserves_percent_decoded_unix_socket_bytes() {
use std::os::unix::ffi::OsStrExt;
let options = MqttOptions::parse_url("unix:///tmp/mqtt%FF.sock?client_id=foo").unwrap();
assert_eq!(
options.broker().unix_path().unwrap().as_os_str().as_bytes(),
b"/tmp/mqtt\xff.sock"
);
}
#[test]
#[cfg(all(feature = "url", unix))]
fn from_url_rejects_invalid_unix_socket_paths() {
fn err(s: &str) -> OptionError {
MqttOptions::parse_url(s).expect_err("invalid unix socket url")
}
assert_eq!(err("unix:///tmp/mqtt.sock"), OptionError::ClientId);
assert_eq!(
err("unix://localhost/tmp/mqtt.sock?client_id=foo"),
OptionError::UnixSocketPath
);
assert_eq!(err("unix:///?client_id=foo"), OptionError::UnixSocketPath);
}
#[test]
fn accept_empty_client_id() {
let _mqtt_opts = MqttOptions::new("", "127.0.0.1").set_clean_session(true);
}
#[test]
fn mqtt_options_builder_matches_setter_configuration() {
let will = LastWill::new("hello/world", "good bye", QoS::AtLeastOnce, false);
let mut expected = MqttOptions::new("client", ("localhost", 1884));
expected
.set_keep_alive(5)
.set_last_will(will.clone())
.set_clean_session(false)
.set_credentials("user", Bytes::from_static(b"password"))
.set_request_channel_capacity(16)
.set_max_request_batch(8)
.set_read_batch_size(32)
.set_pending_throttle(Duration::from_micros(250))
.set_inflight(4)
.set_manual_acks(true)
.set_max_packet_size(4096, 2048);
let actual = MqttOptions::builder("client", ("localhost", 1884))
.keep_alive(5)
.last_will(will)
.clean_session(false)
.credentials("user", Bytes::from_static(b"password"))
.request_channel_capacity(16)
.max_request_batch(8)
.read_batch_size(32)
.pending_throttle(Duration::from_micros(250))
.inflight(4)
.manual_acks(true)
.max_packet_size(4096, 2048)
.build();
assert_eq!(
actual.broker().tcp_address(),
expected.broker().tcp_address()
);
assert_eq!(actual.keep_alive(), expected.keep_alive());
assert_eq!(actual.last_will(), expected.last_will());
assert_eq!(actual.clean_session(), expected.clean_session());
assert_eq!(actual.auth(), expected.auth());
assert_eq!(
actual.request_channel_capacity(),
expected.request_channel_capacity()
);
assert_eq!(actual.max_request_batch(), expected.max_request_batch());
assert_eq!(actual.read_batch_size(), expected.read_batch_size());
assert_eq!(actual.pending_throttle(), expected.pending_throttle());
assert_eq!(actual.inflight(), expected.inflight());
assert_eq!(actual.manual_acks(), expected.manual_acks());
assert_eq!(
actual.max_incoming_packet_size,
expected.max_incoming_packet_size
);
assert_eq!(
actual.max_outgoing_packet_size,
expected.max_outgoing_packet_size
);
}
#[test]
fn mqtt_options_builder_can_replace_and_clear_auth() {
let actual = MqttOptions::builder("client", "localhost")
.username("user")
.clear_auth()
.auth(ConnectAuth::Username {
username: "next".to_owned(),
})
.build();
assert_eq!(
actual.auth(),
&ConnectAuth::Username {
username: "next".to_owned(),
}
);
}
#[test]
fn mqtt_options_builder_request_capacity_feeds_client_builder_default() {
let mqttoptions = MqttOptions::builder("test-1", "localhost")
.request_channel_capacity(1)
.build();
let (client, _eventloop) = AsyncClient::builder(mqttoptions).build();
client
.try_publish("hello/world", QoS::AtMostOnce, false, "one")
.expect("first request should fit configured capacity");
assert!(matches!(
client.try_publish("hello/world", QoS::AtMostOnce, false, "two"),
Err(ClientError::TryRequest(Request::Publish(_)))
));
}
#[test]
fn set_clean_session_when_client_id_present() {
let mut options = MqttOptions::new("client_id", "127.0.0.1");
options.set_clean_session(false);
options.set_clean_session(true);
}
#[test]
fn read_batch_size_defaults_to_adaptive() {
let options = MqttOptions::new("client_id", "127.0.0.1");
assert_eq!(options.read_batch_size(), 0);
}
#[test]
fn set_read_batch_size() {
let mut options = MqttOptions::new("client_id", "127.0.0.1");
options.set_read_batch_size(48);
assert_eq!(options.read_batch_size(), 48);
}
}