use std::borrow::Cow;
use std::time::Duration;
use crate::eventloop::{RequestChannelCapacity, RequestEnvelope};
use crate::mqttbytes::{
QoS,
v4::{Disconnect, PubAck, PubRec, Publish, Subscribe, SubscribeFilter, Unsubscribe},
};
use crate::notice::{PublishNoticeTx, SubscribeNoticeTx, UnsubscribeNoticeTx};
use crate::{
ConnectionError, Event, EventLoop, MqttOptions, PublishNotice, Request, SubscribeNotice,
UnsubscribeNotice, valid_filter, valid_topic,
};
use bytes::Bytes;
use flume::{SendError, Sender, TrySendError};
use futures_util::FutureExt;
use tokio::runtime::{self, Runtime};
use tokio::time::timeout;
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
#[error("Invalid MQTT topic: '{0}'")]
pub struct InvalidTopic(String);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ValidatedTopic(String);
impl ValidatedTopic {
pub fn new<S: Into<String>>(topic: S) -> Result<Self, InvalidTopic> {
let topic_string = topic.into();
if valid_publish_topic(&topic_string) {
Ok(Self(topic_string))
} else {
Err(InvalidTopic(topic_string))
}
}
}
impl From<ValidatedTopic> for String {
fn from(topic: ValidatedTopic) -> Self {
topic.0
}
}
pub enum PublishTopic {
Unvalidated(String),
Validated(ValidatedTopic),
}
impl PublishTopic {
fn into_string_and_validation(self) -> (String, bool) {
match self {
Self::Unvalidated(topic) => (topic, true),
Self::Validated(topic) => (topic.0, false),
}
}
}
impl From<ValidatedTopic> for PublishTopic {
fn from(topic: ValidatedTopic) -> Self {
Self::Validated(topic)
}
}
impl From<String> for PublishTopic {
fn from(topic: String) -> Self {
Self::Unvalidated(topic)
}
}
impl From<&str> for PublishTopic {
fn from(topic: &str) -> Self {
Self::Unvalidated(topic.to_owned())
}
}
impl From<&String> for PublishTopic {
fn from(topic: &String) -> Self {
Self::Unvalidated(topic.clone())
}
}
impl From<Cow<'_, str>> for PublishTopic {
fn from(topic: Cow<'_, str>) -> Self {
Self::Unvalidated(topic.into_owned())
}
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum ClientError {
#[error("Failed to send mqtt requests to eventloop")]
Request(Request),
#[error("Failed to send mqtt requests to eventloop")]
TryRequest(Request),
#[error("Tracked request API is unavailable for this client instance")]
TrackingUnavailable,
}
impl From<SendError<Request>> for ClientError {
fn from(e: SendError<Request>) -> Self {
Self::Request(e.into_inner())
}
}
impl From<TrySendError<Request>> for ClientError {
fn from(e: TrySendError<Request>) -> Self {
Self::TryRequest(e.into_inner())
}
}
#[derive(Clone, Debug)]
enum RequestSender {
Plain(Sender<Request>),
WithNotice {
requests: Sender<RequestEnvelope>,
control_requests: Sender<RequestEnvelope>,
immediate_disconnect: Sender<RequestEnvelope>,
},
}
fn into_request(envelope: RequestEnvelope) -> Request {
let (request, _notice) = envelope.into_parts();
request
}
fn map_send_envelope_error(err: SendError<RequestEnvelope>) -> ClientError {
ClientError::Request(into_request(err.into_inner()))
}
fn map_try_send_envelope_error(err: TrySendError<RequestEnvelope>) -> ClientError {
match err {
TrySendError::Full(envelope) | TrySendError::Disconnected(envelope) => {
ClientError::TryRequest(into_request(envelope))
}
}
}
const fn is_publish_request(request: &Request) -> bool {
matches!(request, Request::Publish(_))
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ManualAck {
PubAck(PubAck),
PubRec(PubRec),
}
impl ManualAck {
const fn into_request(self) -> Request {
match self {
Self::PubAck(ack) => Request::PubAck(ack),
Self::PubRec(rec) => Request::PubRec(rec),
}
}
}
#[derive(Clone, Debug)]
pub struct AsyncClient {
request_tx: RequestSender,
}
#[derive(Debug)]
pub struct ClientBuilder {
options: MqttOptions,
capacity: RequestChannelCapacity,
}
#[derive(Debug)]
pub struct AsyncClientBuilder {
options: MqttOptions,
capacity: RequestChannelCapacity,
}
#[must_use]
fn build_async_client(
options: MqttOptions,
capacity: RequestChannelCapacity,
) -> (AsyncClient, EventLoop) {
let (eventloop, request_tx, control_request_tx, immediate_disconnect_tx) =
EventLoop::new_for_async_client_with_capacity(options, capacity);
let client = AsyncClient {
request_tx: RequestSender::WithNotice {
requests: request_tx,
control_requests: control_request_tx,
immediate_disconnect: immediate_disconnect_tx,
},
};
(client, eventloop)
}
impl ClientBuilder {
#[must_use]
pub const fn new(options: MqttOptions) -> Self {
let capacity = RequestChannelCapacity::Bounded(options.request_channel_capacity());
Self { options, capacity }
}
#[must_use]
pub const fn capacity(mut self, cap: usize) -> Self {
self.capacity = RequestChannelCapacity::Bounded(cap);
self
}
#[must_use]
pub const fn unbounded(mut self) -> Self {
self.capacity = RequestChannelCapacity::Unbounded;
self
}
#[must_use]
pub fn build(self) -> (Client, Connection) {
let (client, eventloop) = build_async_client(self.options, self.capacity);
let client = Client { client };
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let connection = Connection::new(eventloop, runtime);
(client, connection)
}
}
impl AsyncClientBuilder {
#[must_use]
pub const fn new(options: MqttOptions) -> Self {
let capacity = RequestChannelCapacity::Bounded(options.request_channel_capacity());
Self { options, capacity }
}
#[must_use]
pub const fn capacity(mut self, cap: usize) -> Self {
self.capacity = RequestChannelCapacity::Bounded(cap);
self
}
#[must_use]
pub const fn unbounded(mut self) -> Self {
self.capacity = RequestChannelCapacity::Unbounded;
self
}
#[must_use]
pub fn build(self) -> (AsyncClient, EventLoop) {
build_async_client(self.options, self.capacity)
}
}
impl AsyncClient {
#[must_use]
pub const fn builder(options: MqttOptions) -> AsyncClientBuilder {
AsyncClientBuilder::new(options)
}
#[must_use]
pub const fn from_senders(request_tx: Sender<Request>) -> Self {
Self {
request_tx: RequestSender::Plain(request_tx),
}
}
async fn send_request_async(&self, request: Request) -> Result<(), ClientError> {
match &self.request_tx {
RequestSender::Plain(tx) => tx.send_async(request).await.map_err(ClientError::from),
RequestSender::WithNotice {
requests,
control_requests,
..
} => {
let tx = if is_publish_request(&request) {
requests
} else {
control_requests
};
tx.send_async(RequestEnvelope::plain(request))
.await
.map_err(map_send_envelope_error)
}
}
}
fn try_send_request(&self, request: Request) -> Result<(), ClientError> {
match &self.request_tx {
RequestSender::Plain(tx) => tx.try_send(request).map_err(ClientError::from),
RequestSender::WithNotice {
requests,
control_requests,
..
} => {
let tx = if is_publish_request(&request) {
requests
} else {
control_requests
};
tx.try_send(RequestEnvelope::plain(request))
.map_err(map_try_send_envelope_error)
}
}
}
fn send_request(&self, request: Request) -> Result<(), ClientError> {
match &self.request_tx {
RequestSender::Plain(tx) => tx.send(request).map_err(ClientError::from),
RequestSender::WithNotice {
requests,
control_requests,
..
} => {
let tx = if is_publish_request(&request) {
requests
} else {
control_requests
};
tx.send(RequestEnvelope::plain(request))
.map_err(map_send_envelope_error)
}
}
}
async fn send_immediate_disconnect_async(&self, request: Request) -> Result<(), ClientError> {
match &self.request_tx {
RequestSender::Plain(tx) => tx.send_async(request).await.map_err(ClientError::from),
RequestSender::WithNotice {
immediate_disconnect,
..
} => immediate_disconnect
.send_async(RequestEnvelope::plain(request))
.await
.map_err(map_send_envelope_error),
}
}
fn try_send_immediate_disconnect(&self, request: Request) -> Result<(), ClientError> {
match &self.request_tx {
RequestSender::Plain(tx) => tx.try_send(request).map_err(ClientError::from),
RequestSender::WithNotice {
immediate_disconnect,
..
} => immediate_disconnect
.try_send(RequestEnvelope::plain(request))
.map_err(map_try_send_envelope_error),
}
}
fn send_immediate_disconnect(&self, request: Request) -> Result<(), ClientError> {
match &self.request_tx {
RequestSender::Plain(tx) => tx.send(request).map_err(ClientError::from),
RequestSender::WithNotice {
immediate_disconnect,
..
} => immediate_disconnect
.send(RequestEnvelope::plain(request))
.map_err(map_send_envelope_error),
}
}
async fn send_tracked_publish_async(
&self,
publish: Publish,
) -> Result<PublishNotice, ClientError> {
let RequestSender::WithNotice {
requests: request_tx,
..
} = &self.request_tx
else {
return Err(ClientError::TrackingUnavailable);
};
let (notice_tx, notice) = PublishNoticeTx::new();
request_tx
.send_async(RequestEnvelope::tracked_publish(publish, notice_tx))
.await
.map_err(map_send_envelope_error)?;
Ok(notice)
}
fn try_send_tracked_publish(&self, publish: Publish) -> Result<PublishNotice, ClientError> {
let RequestSender::WithNotice {
requests: request_tx,
..
} = &self.request_tx
else {
return Err(ClientError::TrackingUnavailable);
};
let (notice_tx, notice) = PublishNoticeTx::new();
request_tx
.try_send(RequestEnvelope::tracked_publish(publish, notice_tx))
.map_err(map_try_send_envelope_error)?;
Ok(notice)
}
async fn send_tracked_subscribe_async(
&self,
subscribe: Subscribe,
) -> Result<SubscribeNotice, ClientError> {
let RequestSender::WithNotice {
control_requests: request_tx,
..
} = &self.request_tx
else {
return Err(ClientError::TrackingUnavailable);
};
let (notice_tx, notice) = SubscribeNoticeTx::new();
request_tx
.send_async(RequestEnvelope::tracked_subscribe(subscribe, notice_tx))
.await
.map_err(map_send_envelope_error)?;
Ok(notice)
}
fn try_send_tracked_subscribe(
&self,
subscribe: Subscribe,
) -> Result<SubscribeNotice, ClientError> {
let RequestSender::WithNotice {
control_requests: request_tx,
..
} = &self.request_tx
else {
return Err(ClientError::TrackingUnavailable);
};
let (notice_tx, notice) = SubscribeNoticeTx::new();
request_tx
.try_send(RequestEnvelope::tracked_subscribe(subscribe, notice_tx))
.map_err(map_try_send_envelope_error)?;
Ok(notice)
}
async fn send_tracked_unsubscribe_async(
&self,
unsubscribe: Unsubscribe,
) -> Result<UnsubscribeNotice, ClientError> {
let RequestSender::WithNotice {
control_requests: request_tx,
..
} = &self.request_tx
else {
return Err(ClientError::TrackingUnavailable);
};
let (notice_tx, notice) = UnsubscribeNoticeTx::new();
request_tx
.send_async(RequestEnvelope::tracked_unsubscribe(unsubscribe, notice_tx))
.await
.map_err(map_send_envelope_error)?;
Ok(notice)
}
fn try_send_tracked_unsubscribe(
&self,
unsubscribe: Unsubscribe,
) -> Result<UnsubscribeNotice, ClientError> {
let RequestSender::WithNotice {
control_requests: request_tx,
..
} = &self.request_tx
else {
return Err(ClientError::TrackingUnavailable);
};
let (notice_tx, notice) = UnsubscribeNoticeTx::new();
request_tx
.try_send(RequestEnvelope::tracked_unsubscribe(unsubscribe, notice_tx))
.map_err(map_try_send_envelope_error)?;
Ok(notice)
}
async fn handle_publish<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::new(topic, qos, payload);
publish.retain = retain;
let publish = Request::Publish(publish);
if invalid_topic {
return Err(ClientError::Request(publish));
}
self.send_request_async(publish).await?;
Ok(())
}
async fn handle_publish_tracked<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<PublishNotice, ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::new(topic, qos, payload);
publish.retain = retain;
let request = Request::Publish(publish.clone());
if invalid_topic {
return Err(ClientError::Request(request));
}
self.send_tracked_publish_async(publish).await
}
pub async fn publish<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
self.handle_publish(topic, qos, retain, payload).await
}
pub async fn publish_tracked<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<PublishNotice, ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
self.handle_publish_tracked(topic, qos, retain, payload)
.await
}
fn handle_try_publish<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::new(topic, qos, payload);
publish.retain = retain;
let publish = Request::Publish(publish);
if invalid_topic {
return Err(ClientError::TryRequest(publish));
}
self.try_send_request(publish)?;
Ok(())
}
fn handle_try_publish_tracked<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<PublishNotice, ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::new(topic, qos, payload);
publish.retain = retain;
let request = Request::Publish(publish.clone());
if invalid_topic {
return Err(ClientError::TryRequest(request));
}
self.try_send_tracked_publish(publish)
}
pub fn try_publish<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
self.handle_try_publish(topic, qos, retain, payload)
}
pub fn try_publish_tracked<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<PublishNotice, ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
self.handle_try_publish_tracked(topic, qos, retain, payload)
}
pub const fn prepare_ack(&self, publish: &Publish) -> Option<ManualAck> {
prepare_ack(publish)
}
pub async fn manual_ack(&self, ack: ManualAck) -> Result<(), ClientError> {
self.send_request_async(ack.into_request()).await?;
Ok(())
}
pub fn try_manual_ack(&self, ack: ManualAck) -> Result<(), ClientError> {
self.try_send_request(ack.into_request())?;
Ok(())
}
pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> {
if let Some(ack) = self.prepare_ack(publish) {
self.manual_ack(ack).await?;
}
Ok(())
}
pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> {
if let Some(ack) = self.prepare_ack(publish) {
self.try_manual_ack(ack)?;
}
Ok(())
}
async fn handle_publish_bytes<T>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: Bytes,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::from_bytes(topic, qos, payload);
publish.retain = retain;
let publish = Request::Publish(publish);
if invalid_topic {
return Err(ClientError::Request(publish));
}
self.send_request_async(publish).await?;
Ok(())
}
async fn handle_publish_bytes_tracked<T>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: Bytes,
) -> Result<PublishNotice, ClientError>
where
T: Into<PublishTopic>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::from_bytes(topic, qos, payload);
publish.retain = retain;
let request = Request::Publish(publish.clone());
if invalid_topic {
return Err(ClientError::Request(request));
}
self.send_tracked_publish_async(publish).await
}
pub async fn publish_bytes<T>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: Bytes,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
{
self.handle_publish_bytes(topic, qos, retain, payload).await
}
pub async fn publish_bytes_tracked<T>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: Bytes,
) -> Result<PublishNotice, ClientError>
where
T: Into<PublishTopic>,
{
self.handle_publish_bytes_tracked(topic, qos, retain, payload)
.await
}
pub async fn subscribe<S: Into<String>>(&self, topic: S, qos: QoS) -> Result<(), ClientError> {
let subscribe = Subscribe::new(topic, qos);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::Request(subscribe.into()));
}
self.send_request_async(subscribe.into()).await?;
Ok(())
}
pub async fn subscribe_tracked<S: Into<String>>(
&self,
topic: S,
qos: QoS,
) -> Result<SubscribeNotice, ClientError> {
let subscribe = Subscribe::new(topic, qos);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::Request(subscribe.into()));
}
self.send_tracked_subscribe_async(subscribe).await
}
pub fn try_subscribe<S: Into<String>>(&self, topic: S, qos: QoS) -> Result<(), ClientError> {
let subscribe = Subscribe::new(topic, qos);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::TryRequest(subscribe.into()));
}
self.try_send_request(subscribe.into())?;
Ok(())
}
pub fn try_subscribe_tracked<S: Into<String>>(
&self,
topic: S,
qos: QoS,
) -> Result<SubscribeNotice, ClientError> {
let subscribe = Subscribe::new(topic, qos);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::TryRequest(subscribe.into()));
}
self.try_send_tracked_subscribe(subscribe)
}
pub async fn subscribe_many<T>(&self, topics: T) -> Result<(), ClientError>
where
T: IntoIterator<Item = SubscribeFilter>,
{
let subscribe = Subscribe::new_many(topics);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::Request(subscribe.into()));
}
self.send_request_async(subscribe.into()).await?;
Ok(())
}
pub async fn subscribe_many_tracked<T>(&self, topics: T) -> Result<SubscribeNotice, ClientError>
where
T: IntoIterator<Item = SubscribeFilter>,
{
let subscribe = Subscribe::new_many(topics);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::Request(subscribe.into()));
}
self.send_tracked_subscribe_async(subscribe).await
}
pub fn try_subscribe_many<T>(&self, topics: T) -> Result<(), ClientError>
where
T: IntoIterator<Item = SubscribeFilter>,
{
let subscribe = Subscribe::new_many(topics);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::TryRequest(subscribe.into()));
}
self.try_send_request(subscribe.into())?;
Ok(())
}
pub fn try_subscribe_many_tracked<T>(&self, topics: T) -> Result<SubscribeNotice, ClientError>
where
T: IntoIterator<Item = SubscribeFilter>,
{
let subscribe = Subscribe::new_many(topics);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::TryRequest(subscribe.into()));
}
self.try_send_tracked_subscribe(subscribe)
}
pub async fn unsubscribe<S: Into<String>>(&self, topic: S) -> Result<(), ClientError> {
let unsubscribe = Unsubscribe::new(topic.into());
let request = Request::Unsubscribe(unsubscribe);
self.send_request_async(request).await?;
Ok(())
}
pub async fn unsubscribe_tracked<S: Into<String>>(
&self,
topic: S,
) -> Result<UnsubscribeNotice, ClientError> {
let unsubscribe = Unsubscribe::new(topic.into());
self.send_tracked_unsubscribe_async(unsubscribe).await
}
pub fn try_unsubscribe<S: Into<String>>(&self, topic: S) -> Result<(), ClientError> {
let unsubscribe = Unsubscribe::new(topic.into());
let request = Request::Unsubscribe(unsubscribe);
self.try_send_request(request)?;
Ok(())
}
pub fn try_unsubscribe_tracked<S: Into<String>>(
&self,
topic: S,
) -> Result<UnsubscribeNotice, ClientError> {
let unsubscribe = Unsubscribe::new(topic.into());
self.try_send_tracked_unsubscribe(unsubscribe)
}
pub async fn disconnect(&self) -> Result<(), ClientError> {
let request = Request::Disconnect(Disconnect);
self.send_request_async(request).await?;
Ok(())
}
pub async fn disconnect_with_timeout(&self, timeout: Duration) -> Result<(), ClientError> {
let request = Request::DisconnectWithTimeout(Disconnect, timeout);
self.send_request_async(request).await?;
Ok(())
}
pub async fn disconnect_now(&self) -> Result<(), ClientError> {
let request = Request::DisconnectNow(Disconnect);
self.send_immediate_disconnect_async(request).await?;
Ok(())
}
pub fn try_disconnect(&self) -> Result<(), ClientError> {
let request = Request::Disconnect(Disconnect);
self.try_send_request(request)?;
Ok(())
}
pub fn try_disconnect_with_timeout(&self, timeout: Duration) -> Result<(), ClientError> {
let request = Request::DisconnectWithTimeout(Disconnect, timeout);
self.try_send_request(request)?;
Ok(())
}
pub fn try_disconnect_now(&self) -> Result<(), ClientError> {
let request = Request::DisconnectNow(Disconnect);
self.try_send_immediate_disconnect(request)?;
Ok(())
}
}
const fn prepare_ack(publish: &Publish) -> Option<ManualAck> {
let ack = match publish.qos {
QoS::AtMostOnce => return None,
QoS::AtLeastOnce => ManualAck::PubAck(PubAck::new(publish.pkid)),
QoS::ExactlyOnce => ManualAck::PubRec(PubRec::new(publish.pkid)),
};
Some(ack)
}
#[derive(Clone)]
pub struct Client {
client: AsyncClient,
}
impl Client {
#[must_use]
pub const fn builder(options: MqttOptions) -> ClientBuilder {
ClientBuilder::new(options)
}
#[must_use]
pub const fn from_sender(request_tx: Sender<Request>) -> Self {
Self {
client: AsyncClient::from_senders(request_tx),
}
}
pub fn publish<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
let (topic, needs_validation) = topic.into().into_string_and_validation();
let invalid_topic = needs_validation && !valid_publish_topic(&topic);
let mut publish = Publish::new(topic, qos, payload);
publish.retain = retain;
let publish = Request::Publish(publish);
if invalid_topic {
return Err(ClientError::Request(publish));
}
self.client.send_request(publish)?;
Ok(())
}
pub fn try_publish<T, V>(
&self,
topic: T,
qos: QoS,
retain: bool,
payload: V,
) -> Result<(), ClientError>
where
T: Into<PublishTopic>,
V: Into<Vec<u8>>,
{
self.client.try_publish(topic, qos, retain, payload)?;
Ok(())
}
pub const fn prepare_ack(&self, publish: &Publish) -> Option<ManualAck> {
self.client.prepare_ack(publish)
}
pub fn manual_ack(&self, ack: ManualAck) -> Result<(), ClientError> {
self.client.send_request(ack.into_request())?;
Ok(())
}
pub fn try_manual_ack(&self, ack: ManualAck) -> Result<(), ClientError> {
self.client.try_manual_ack(ack)?;
Ok(())
}
pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> {
if let Some(ack) = self.prepare_ack(publish) {
self.manual_ack(ack)?;
}
Ok(())
}
pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> {
if let Some(ack) = self.prepare_ack(publish) {
self.try_manual_ack(ack)?;
}
Ok(())
}
pub fn subscribe<S: Into<String>>(&self, topic: S, qos: QoS) -> Result<(), ClientError> {
let subscribe = Subscribe::new(topic, qos);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::Request(subscribe.into()));
}
self.client.send_request(subscribe.into())?;
Ok(())
}
pub fn try_subscribe<S: Into<String>>(&self, topic: S, qos: QoS) -> Result<(), ClientError> {
self.client.try_subscribe(topic, qos)?;
Ok(())
}
pub fn subscribe_many<T>(&self, topics: T) -> Result<(), ClientError>
where
T: IntoIterator<Item = SubscribeFilter>,
{
let subscribe = Subscribe::new_many(topics);
if !subscribe_has_valid_filters(&subscribe) {
return Err(ClientError::Request(subscribe.into()));
}
self.client.send_request(subscribe.into())?;
Ok(())
}
pub fn try_subscribe_many<T>(&self, topics: T) -> Result<(), ClientError>
where
T: IntoIterator<Item = SubscribeFilter>,
{
self.client.try_subscribe_many(topics)
}
pub fn unsubscribe<S: Into<String>>(&self, topic: S) -> Result<(), ClientError> {
let unsubscribe = Unsubscribe::new(topic.into());
let request = Request::Unsubscribe(unsubscribe);
self.client.send_request(request)?;
Ok(())
}
pub fn try_unsubscribe<S: Into<String>>(&self, topic: S) -> Result<(), ClientError> {
self.client.try_unsubscribe(topic)?;
Ok(())
}
pub fn disconnect(&self) -> Result<(), ClientError> {
let request = Request::Disconnect(Disconnect);
self.client.send_request(request)?;
Ok(())
}
pub fn disconnect_with_timeout(&self, timeout: Duration) -> Result<(), ClientError> {
let request = Request::DisconnectWithTimeout(Disconnect, timeout);
self.client.send_request(request)?;
Ok(())
}
pub fn disconnect_now(&self) -> Result<(), ClientError> {
let request = Request::DisconnectNow(Disconnect);
self.client.send_immediate_disconnect(request)?;
Ok(())
}
pub fn try_disconnect(&self) -> Result<(), ClientError> {
self.client.try_disconnect()?;
Ok(())
}
pub fn try_disconnect_with_timeout(&self, timeout: Duration) -> Result<(), ClientError> {
self.client.try_disconnect_with_timeout(timeout)?;
Ok(())
}
pub fn try_disconnect_now(&self) -> Result<(), ClientError> {
self.client.try_disconnect_now()?;
Ok(())
}
}
#[must_use]
fn valid_publish_topic(topic: &str) -> bool {
!topic.is_empty() && valid_topic(topic)
}
#[must_use]
fn subscribe_has_valid_filters(subscribe: &Subscribe) -> bool {
!subscribe.filters.is_empty()
&& subscribe
.filters
.iter()
.all(|filter| valid_filter(&filter.path))
}
#[derive(Debug, Eq, PartialEq)]
pub struct RecvError;
#[derive(Debug, Eq, PartialEq)]
pub enum TryRecvError {
Disconnected,
Empty,
}
#[derive(Debug, Eq, PartialEq)]
pub enum RecvTimeoutError {
Disconnected,
Timeout,
}
pub struct Connection {
pub eventloop: EventLoop,
runtime: Runtime,
}
impl Connection {
const fn new(eventloop: EventLoop, runtime: Runtime) -> Self {
Self { eventloop, runtime }
}
#[must_use = "Connection should be iterated over a loop to make progress"]
pub const fn iter(&mut self) -> Iter<'_> {
Iter { connection: self }
}
pub fn recv(&mut self) -> Result<Result<Event, ConnectionError>, RecvError> {
let f = self.eventloop.poll();
let event = self.runtime.block_on(f);
resolve_event(event).ok_or(RecvError)
}
pub fn try_recv(&mut self) -> Result<Result<Event, ConnectionError>, TryRecvError> {
let f = self.eventloop.poll();
let _guard = self.runtime.enter();
let event = f.now_or_never().ok_or(TryRecvError::Empty)?;
resolve_event(event).ok_or(TryRecvError::Disconnected)
}
pub fn recv_timeout(
&mut self,
duration: Duration,
) -> Result<Result<Event, ConnectionError>, RecvTimeoutError> {
let f = self.eventloop.poll();
let event = self
.runtime
.block_on(async { timeout(duration, f).await })
.map_err(|_| RecvTimeoutError::Timeout)?;
resolve_event(event).ok_or(RecvTimeoutError::Disconnected)
}
}
fn resolve_event(event: Result<Event, ConnectionError>) -> Option<Result<Event, ConnectionError>> {
match event {
Ok(v) => Some(Ok(v)),
Err(ConnectionError::RequestsDone) => {
trace!("Done with requests");
None
}
Err(e) => Some(Err(e)),
}
}
pub struct Iter<'a> {
connection: &'a mut Connection,
}
impl Iterator for Iter<'_> {
type Item = Result<Event, ConnectionError>;
fn next(&mut self) -> Option<Self::Item> {
self.connection.recv().ok()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::LastWill;
#[test]
fn calling_iter_twice_on_connection_shouldnt_panic() {
let mut mqttoptions = MqttOptions::new("test-1", "localhost");
let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false);
mqttoptions.set_keep_alive(5).set_last_will(will);
let (_, mut connection) = Client::builder(mqttoptions).capacity(10).build();
let _ = connection.iter();
let _ = connection.iter();
}
#[test]
fn builder_uses_options_request_channel_capacity_by_default() {
let mut mqttoptions = MqttOptions::new("test-1", "localhost");
mqttoptions.set_request_channel_capacity(1);
let builder: AsyncClientBuilder = AsyncClient::builder(mqttoptions);
let (client, _eventloop) = builder.build();
client
.try_publish("hello/world", QoS::AtMostOnce, false, "one")
.expect("first request should fit configured capacity");
assert!(matches!(
client.try_publish("hello/world", QoS::AtMostOnce, false, "two"),
Err(ClientError::TryRequest(Request::Publish(_)))
));
}
#[test]
fn sync_and_async_entry_points_return_distinct_builder_types() {
let sync_builder = Client::builder(MqttOptions::new("test-sync", "localhost"));
let async_builder = AsyncClient::builder(MqttOptions::new("test-async", "localhost"));
let _: ClientBuilder = sync_builder;
let _: AsyncClientBuilder = async_builder;
}
#[test]
fn builder_capacity_overrides_options_request_channel_capacity() {
let mut mqttoptions = MqttOptions::new("test-1", "localhost");
mqttoptions.set_request_channel_capacity(1);
let (client, _eventloop) = Client::builder(mqttoptions).capacity(2).build();
client
.try_publish("hello/world", QoS::AtMostOnce, false, "one")
.expect("first request should fit overridden capacity");
client
.try_publish("hello/world", QoS::AtMostOnce, false, "two")
.expect("second request should fit overridden capacity");
assert!(matches!(
client.try_publish("hello/world", QoS::AtMostOnce, false, "three"),
Err(ClientError::TryRequest(Request::Publish(_)))
));
}
#[test]
fn builder_capacity_zero_is_bounded_rendezvous() {
let mqttoptions = MqttOptions::new("test-1", "localhost");
let (client, _eventloop) = AsyncClient::builder(mqttoptions).capacity(0).build();
assert!(matches!(
client.try_publish("hello/world", QoS::AtMostOnce, false, "one"),
Err(ClientError::TryRequest(Request::Publish(_)))
));
}
#[test]
fn unbounded_builder_allows_try_publish_without_polling() {
let mqttoptions = MqttOptions::new("test-1", "localhost");
let (client, _eventloop) = AsyncClient::builder(mqttoptions).unbounded().build();
for i in 0..128 {
client
.try_publish("hello/world", QoS::AtMostOnce, false, vec![i])
.expect("unbounded channel should accept requests without polling");
}
}
#[tokio::test]
async fn bounded_publish_blocks_when_channel_is_full_without_polling() {
let mqttoptions = MqttOptions::new("test-1", "localhost");
let (client, _eventloop) = AsyncClient::builder(mqttoptions).capacity(1).build();
client
.publish("hello/world", QoS::AtMostOnce, false, "one")
.await
.expect("first request should fit bounded channel");
let result = tokio::time::timeout(
std::time::Duration::from_millis(25),
client.publish("hello/world", QoS::AtMostOnce, false, "two"),
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn unbounded_publish_completes_without_polling() {
let mqttoptions = MqttOptions::new("test-1", "localhost");
let (client, _eventloop) = AsyncClient::builder(mqttoptions).unbounded().build();
for i in 0..128 {
client
.publish("hello/world", QoS::AtMostOnce, false, vec![i])
.await
.expect("unbounded channel should accept requests without polling");
}
}
#[test]
fn should_be_able_to_build_test_client_from_channel() {
let (tx, rx) = flume::bounded(1);
let client = Client::from_sender(tx);
client
.publish("hello/world", QoS::ExactlyOnce, false, "good bye")
.expect("Should be able to publish");
let _ = rx.try_recv().expect("Should have message");
}
#[test]
fn prepare_ack_maps_qos_to_manual_ack_packets_v4() {
let (tx, _) = flume::bounded(1);
let client = Client::from_sender(tx);
let qos0 = Publish::new("hello/world", QoS::AtMostOnce, vec![1]);
assert!(client.prepare_ack(&qos0).is_none());
let mut qos1 = Publish::new("hello/world", QoS::AtLeastOnce, vec![1]);
qos1.pkid = 7;
match client.prepare_ack(&qos1) {
Some(ManualAck::PubAck(ack)) => assert_eq!(ack.pkid, 7),
ack => panic!("expected QoS1 PubAck, got {ack:?}"),
}
let mut qos2 = Publish::new("hello/world", QoS::ExactlyOnce, vec![1]);
qos2.pkid = 9;
match client.prepare_ack(&qos2) {
Some(ManualAck::PubRec(ack)) => assert_eq!(ack.pkid, 9),
ack => panic!("expected QoS2 PubRec, got {ack:?}"),
}
}
#[test]
fn manual_ack_sends_puback_request_v4() {
let (tx, rx) = flume::bounded(1);
let client = Client::from_sender(tx);
client
.manual_ack(ManualAck::PubAck(PubAck::new(42)))
.expect("manual_ack should send request");
let request = rx.try_recv().expect("Should have ack request");
match request {
Request::PubAck(ack) => assert_eq!(ack.pkid, 42),
request => panic!("Expected PubAck request, got {request:?}"),
}
}
#[test]
fn try_manual_ack_sends_pubrec_request_v4() {
let (tx, rx) = flume::bounded(1);
let client = Client::from_sender(tx);
client
.try_manual_ack(ManualAck::PubRec(PubRec::new(51)))
.expect("try_manual_ack should send request");
let request = rx.try_recv().expect("Should have ack request");
match request {
Request::PubRec(ack) => assert_eq!(ack.pkid, 51),
request => panic!("Expected PubRec request, got {request:?}"),
}
}
#[test]
fn ack_and_try_ack_use_manual_ack_flow_v4() {
let (tx, rx) = flume::bounded(2);
let client = Client::from_sender(tx);
let mut qos1 = Publish::new("hello/world", QoS::AtLeastOnce, vec![1]);
qos1.pkid = 11;
client.ack(&qos1).expect("ack should send PubAck");
let mut qos2 = Publish::new("hello/world", QoS::ExactlyOnce, vec![1]);
qos2.pkid = 13;
client
.try_ack(&qos2)
.expect("try_ack should send PubRec request");
let first = rx.try_recv().expect("Should receive first ack request");
match first {
Request::PubAck(ack) => assert_eq!(ack.pkid, 11),
request => panic!("Expected PubAck request, got {request:?}"),
}
let second = rx.try_recv().expect("Should receive second ack request");
match second {
Request::PubRec(ack) => assert_eq!(ack.pkid, 13),
request => panic!("Expected PubRec request, got {request:?}"),
}
}
#[test]
fn can_publish_with_validated_topic() {
let (tx, rx) = flume::bounded(1);
let client = Client::from_sender(tx);
let valid_topic = ValidatedTopic::new("hello/world").unwrap();
client
.publish(valid_topic, QoS::ExactlyOnce, false, "good bye")
.expect("Should be able to publish");
let _ = rx.try_recv().expect("Should have message");
}
#[test]
fn publish_accepts_borrowed_string_topic() {
let (tx, rx) = flume::bounded(2);
let client = Client::from_sender(tx);
let topic = "hello/world".to_string();
client
.publish(&topic, QoS::ExactlyOnce, false, "good bye")
.expect("Should be able to publish");
client
.try_publish(&topic, QoS::ExactlyOnce, false, "good bye")
.expect("Should be able to publish");
let _ = rx.try_recv().expect("Should have message");
let _ = rx.try_recv().expect("Should have message");
}
#[test]
fn publish_accepts_cow_topic_variants() {
let (tx, rx) = flume::bounded(2);
let client = Client::from_sender(tx);
client
.publish(
std::borrow::Cow::Borrowed("hello/world"),
QoS::ExactlyOnce,
false,
"good bye",
)
.expect("Should be able to publish");
client
.try_publish(
std::borrow::Cow::Owned("hello/world".to_owned()),
QoS::ExactlyOnce,
false,
"good bye",
)
.expect("Should be able to publish");
let _ = rx.try_recv().expect("Should have message");
let _ = rx.try_recv().expect("Should have message");
}
#[test]
fn publishing_invalid_cow_topic_fails() {
let (tx, _) = flume::bounded(1);
let client = Client::from_sender(tx);
let err = client
.publish(
std::borrow::Cow::Borrowed("a/+/b"),
QoS::ExactlyOnce,
false,
"good bye",
)
.expect_err("Invalid publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
}
#[test]
fn validated_topic_ergonomics() {
let valid_topic = ValidatedTopic::new("hello/world").unwrap();
let valid_topic_can_be_cloned = valid_topic.clone();
assert_eq!(valid_topic, valid_topic_can_be_cloned);
}
#[test]
fn creating_invalid_validated_topic_fails() {
assert_eq!(
ValidatedTopic::new("a/+/b"),
Err(InvalidTopic("a/+/b".to_string()))
);
assert_eq!(ValidatedTopic::new(""), Err(InvalidTopic(String::new())));
}
#[test]
fn publishing_invalid_raw_topic_fails() {
let (tx, _) = flume::bounded(1);
let client = Client::from_sender(tx);
let err = client
.publish("a/+/b", QoS::ExactlyOnce, false, "good bye")
.expect_err("Invalid publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
let err = client
.publish("", QoS::ExactlyOnce, false, "good bye")
.expect_err("Empty publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
}
#[test]
fn async_publish_paths_accept_validated_topic() {
let (tx, rx) = flume::bounded(2);
let client = AsyncClient::from_senders(tx);
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
client
.publish(
ValidatedTopic::new("hello/world").unwrap(),
QoS::ExactlyOnce,
false,
"good bye",
)
.await
.expect("Should be able to publish");
client
.publish_bytes(
ValidatedTopic::new("hello/world").unwrap(),
QoS::ExactlyOnce,
false,
Bytes::from_static(b"good bye"),
)
.await
.expect("Should be able to publish");
});
let _ = rx.try_recv().expect("Should have message");
let _ = rx.try_recv().expect("Should have message");
}
#[test]
fn async_publishing_invalid_raw_topic_fails() {
let (tx, _) = flume::bounded(2);
let client = AsyncClient::from_senders(tx);
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let err = client
.publish("a/+/b", QoS::ExactlyOnce, false, "good bye")
.await
.expect_err("Invalid publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
let err = client
.publish_bytes(
"a/+/b",
QoS::ExactlyOnce,
false,
Bytes::from_static(b"good bye"),
)
.await
.expect_err("Invalid publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
let err = client
.publish("", QoS::ExactlyOnce, false, "good bye")
.await
.expect_err("Empty publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
let err = client
.publish_bytes("", QoS::ExactlyOnce, false, Bytes::from_static(b"good bye"))
.await
.expect_err("Empty publish topic should fail");
assert!(matches!(err, ClientError::Request(req) if matches!(req, Request::Publish(_))));
});
}
#[test]
fn tracked_publish_requires_tracking_channel() {
let (tx, _) = flume::bounded(2);
let client = AsyncClient::from_senders(tx);
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async {
let err = client
.publish_tracked("hello/world", QoS::AtLeastOnce, false, "good bye")
.await
.expect_err("tracked publish should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
let err = client
.publish_bytes_tracked(
"hello/world",
QoS::AtLeastOnce,
false,
Bytes::from_static(b"good bye"),
)
.await
.expect_err("tracked publish bytes should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
let err = client
.subscribe_tracked("hello/world", QoS::AtLeastOnce)
.await
.expect_err("tracked subscribe should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
let err = client
.subscribe_many_tracked(vec![SubscribeFilter::new(
"hello/world".to_string(),
QoS::AtLeastOnce,
)])
.await
.expect_err("tracked subscribe many should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
let err = client
.unsubscribe_tracked("hello/world")
.await
.expect_err("tracked unsubscribe should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
});
let err = client
.try_subscribe_tracked("hello/world", QoS::AtLeastOnce)
.expect_err("tracked try_subscribe should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
let err = client
.try_subscribe_many_tracked(vec![SubscribeFilter::new(
"hello/world".to_string(),
QoS::AtLeastOnce,
)])
.expect_err("tracked try_subscribe_many should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
let err = client
.try_unsubscribe_tracked("hello/world")
.expect_err("tracked try_unsubscribe should fail without tracked channel");
assert!(matches!(err, ClientError::TrackingUnavailable));
}
#[test]
fn tracked_unsubscribe_uses_control_request_channel() {
let (requests, requests_rx) = flume::bounded(1);
let (control_requests, control_requests_rx) = flume::bounded(1);
let (immediate_disconnect, _immediate_disconnect_rx) = flume::unbounded();
let client = AsyncClient {
request_tx: RequestSender::WithNotice {
requests,
control_requests,
immediate_disconnect,
},
};
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime
.block_on(client.unsubscribe_tracked("hello/world"))
.expect("tracked unsubscribe should enqueue");
assert!(requests_rx.is_empty());
let envelope = control_requests_rx
.try_recv()
.expect("tracked unsubscribe should use control channel");
assert!(matches!(envelope.into_parts().0, Request::Unsubscribe(_)));
}
}