rust-mqtt 0.5.1

MQTT client for embedded and non-embedded environments
Documentation
use std::{net::SocketAddr, sync::Mutex};

use embedded_io_adapters::tokio_1::FromTokio;
use log::{info, warn};
use rust_mqtt::{
    Bytes,
    buffer::AllocBuffer,
    client::{
        Client, MqttError,
        event::{Event, Publish, Suback},
        options::{ConnectOptions, DisconnectOptions, PublicationOptions, SubscriptionOptions},
    },
    types::{
        IdentifiedQoS, MqttBinary, MqttString, PacketIdentifier, QoS, ReasonCode, TopicFilter,
        TopicName,
    },
};
use tokio::net::TcpStream;

use crate::common::{
    FailingClient, TestClient,
    failing::{ByteLimit, FailingTcp},
    fmt::warn_inspect,
};

static UNIQUE_TOPIC_COUNTER: Mutex<u64> = Mutex::new(0);
pub static ALLOC: StaticAlloc = StaticAlloc(AllocBuffer);

fn unique_number() -> u64 {
    let mut counter = UNIQUE_TOPIC_COUNTER.lock().unwrap();
    let number = *counter;
    *counter += 1;

    number
}
pub fn unique_topic() -> (TopicName<'static>, TopicFilter<'static>) {
    let s = format!("rust/mqtt/is-amazing/{}", unique_number());
    let b = s.into_bytes().into_boxed_slice();
    let s = MqttBinary::from_bytes(Bytes::Owned(b))
        .unwrap()
        .try_into()
        .unwrap();

    let n = TopicName::new(s).unwrap();
    let f = n.clone().into();

    (n, f)
}

pub struct StaticAlloc(AllocBuffer);
impl StaticAlloc {
    pub fn get(&self) -> &'static mut AllocBuffer {
        let inner = &raw const self.0 as *mut AllocBuffer;
        unsafe { &mut *inner }
    }
}

pub async fn tcp_connection_raw(broker: SocketAddr) -> Result<TcpStream, MqttError<'static>> {
    warn_inspect!(
        TcpStream::connect(broker).await,
        "Error while connecting TCP session"
    )
    .map_err(|_| MqttError::RecoveryRequired)
}

pub async fn tcp_connection(
    broker: SocketAddr,
) -> Result<FromTokio<TcpStream>, MqttError<'static>> {
    tcp_connection_raw(broker).await.map(FromTokio::new)
}

pub async fn connected_client(
    broker: SocketAddr,
    options: &ConnectOptions<'static>,
    client_identifier: Option<MqttString<'_>>,
) -> Result<TestClient<'static>, MqttError<'static>> {
    let mut client = Client::new(ALLOC.get());

    let tcp = tcp_connection(broker).await?;

    warn_inspect!(
        client.connect(tcp, options, client_identifier).await,
        "Client::connect() failed"
    )
    .map(|_| client)
}

pub async fn connected_failing_client(
    broker: SocketAddr,
    options: &ConnectOptions<'static>,
    client_identifier: Option<MqttString<'_>>,
    read_limit: ByteLimit,
    write_limit: ByteLimit,
) -> Result<FailingClient<'static>, MqttError<'static>> {
    let mut client = Client::new(ALLOC.get());

    let tcp = tcp_connection_raw(broker).await?;
    let tcp = FailingTcp::new(tcp, read_limit, write_limit);
    let tcp = FromTokio::new(tcp);

    warn_inspect!(
        client.connect(tcp, options, client_identifier).await,
        "Client::connect() failed"
    )
    .map(|_| client)
}

pub async fn disconnect(client: &mut TestClient<'_>, options: &DisconnectOptions) {
    let _ = warn_inspect!(
        client.disconnect(options).await,
        "Client::disconnect() failed"
    );
}

pub async fn subscribe<'c>(
    client: &mut TestClient<'c>,
    options: SubscriptionOptions,
    topic: TopicFilter<'_>,
) -> Result<QoS, MqttError<'c>> {
    let pid = warn_inspect!(
        client.subscribe(topic, options).await,
        "Client::subscribe() failed"
    )?;

    loop {
        match warn_inspect!(client.poll().await, "Client::poll() failed")? {
            Event::Suback(Suback {
                packet_identifier,
                reason_code,
            }) if packet_identifier == pid => {
                info!("Subscribed with reason code {:?}", reason_code);
                break match reason_code {
                    ReasonCode::Success => Ok(QoS::AtMostOnce),
                    ReasonCode::GrantedQoS1 => Ok(QoS::AtLeastOnce),
                    ReasonCode::GrantedQoS2 => Ok(QoS::ExactlyOnce),
                    _ => unreachable!(),
                };
            }
            Event::Suback(Suback {
                packet_identifier,
                reason_code: _,
            }) => warn!(
                "Expected SUBACK for packet identifier {:?}, but received SUBACK for packet identifier {:?}",
                pid, packet_identifier
            ),
            e => warn!("Expected Event::Suback, but received {:?}", e),
        }
    }
}

pub async fn unsubscribe<'c>(
    client: &mut TestClient<'c>,
    topic: TopicFilter<'_>,
) -> Result<(), MqttError<'c>> {
    let pid = warn_inspect!(
        client.unsubscribe(topic).await,
        "Client::unsubscribe() failed"
    )?;

    loop {
        match warn_inspect!(client.poll().await, "Client::poll() failed")? {
            Event::Unsuback(Suback {
                packet_identifier,
                reason_code,
            }) if packet_identifier == pid => {
                info!("Unsubscribed with reason code {:?}", reason_code);
                break Ok(());
            }
            Event::Unsuback(Suback {
                packet_identifier,
                reason_code: _,
            }) => warn!(
                "Expected UNSUBACK for packet identifier {:?}, but received UNSUBACK for packet identifier {:?}",
                pid, packet_identifier
            ),
            e => warn!("Expected Event::Unsuback, but received {:?}", e),
        }
    }
}

pub async fn receive_publish<'c>(
    client: &mut TestClient<'c>,
) -> Result<Publish<'c, 1>, MqttError<'c>> {
    loop {
        match warn_inspect!(client.poll().await, "Client::poll() failed")? {
            Event::Publish(p) => break Ok(p),
            e => warn!("Expected PUBLISH, but received {:?}", e),
        }
    }
}

pub async fn receive_and_complete<'c>(
    client: &mut TestClient<'c>,
) -> Result<Publish<'c, 1>, MqttError<'c>> {
    let publish = receive_publish(client).await?;

    match publish.identified_qos {
        IdentifiedQoS::AtMostOnce => {}
        IdentifiedQoS::AtLeastOnce(_) => {}
        IdentifiedQoS::ExactlyOnce(pid) => loop {
            match warn_inspect!(client.poll().await, "Client::poll() failed")? {
                Event::PublishReleased(p) if p.packet_identifier == pid => {
                    break;
                }
                Event::PublishReleased(p) => warn!(
                    "Expected PUBREL with packet identifier {:?}, but received PUBREL with packet identifier {:?}",
                    pid, p.packet_identifier
                ),
                e => warn!("Expected PUBLISH, but received {:?}", e),
            }
        },
    }

    Ok(publish)
}

pub async fn publish_and_complete<'c>(
    client: &mut TestClient<'c>,
    options: &PublicationOptions<'_>,
    message: Bytes<'_>,
) -> Result<Option<PacketIdentifier>, MqttError<'c>> {
    let pid = warn_inspect!(
        client.publish(options, message).await,
        "Client::publish() failed"
    )?;

    match options.qos {
        QoS::AtMostOnce => {}
        QoS::AtLeastOnce => loop {
            match warn_inspect!(client.poll().await, "Client::poll() failed")? {
                Event::PublishAcknowledged(p) if p.packet_identifier == pid.unwrap() => break,
                Event::PublishAcknowledged(p) => warn!(
                    "Expected PUBACK with packet identifier {:?}, but received PUBACK with packet identifier {:?}",
                    pid, p.packet_identifier
                ),
                e => warn!("Expected PUBACK, but received {:?}", e),
            }
        },
        QoS::ExactlyOnce => {
            loop {
                match warn_inspect!(client.poll().await, "Client::poll() failed")? {
                    Event::PublishReceived(p) if p.packet_identifier == pid.unwrap() => break,
                    Event::PublishReceived(p) => warn!(
                        "Expected PUBREC with packet identifier {:?}, but received PUBREC with packet identifier {:?}",
                        pid, p.packet_identifier
                    ),
                    e => warn!("Expected PUBREC, but received {:?}", e),
                }
            }
            loop {
                match warn_inspect!(client.poll().await, "Client::poll() failed")? {
                    Event::PublishComplete(p) if p.packet_identifier == pid.unwrap() => break,
                    Event::PublishComplete(p) => warn!(
                        "Expected PUBCOMP with packet identifier {:?}, but received PUBCOMP with packet identifier {:?}",
                        pid, p.packet_identifier
                    ),
                    e => warn!("Expected PUBCOMP, but received {:?}", e),
                }
            }
        }
    }

    Ok(pid)
}