use bytes::Bytes;
use crate::error::Result;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Qos {
AtMostOnce,
AtLeastOnce,
ExactlyOnce,
}
#[derive(Clone, Debug, Default)]
pub struct TlsConfig {
pub ca_pem: Option<Vec<u8>>,
pub client_cert_pem: Option<Vec<u8>>,
pub client_key_pem: Option<Vec<u8>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct OutboundMessage {
pub topic: String,
pub qos: Qos,
pub retain: bool,
pub payload: Bytes,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct IncomingMessage {
pub topic: String,
pub payload: Bytes,
}
#[derive(Clone, Debug)]
pub struct ConnectOptions {
pub client_id: String,
pub host: String,
pub port: u16,
pub keep_alive_secs: u16,
pub clean_start: bool,
pub will: Option<OutboundMessage>,
pub tls: Option<TlsConfig>,
}
#[allow(async_fn_in_trait)]
pub trait MqttTransport {
async fn connect(&mut self, opts: &ConnectOptions) -> Result<()>;
async fn subscribe(&mut self, topic_filter: &str, qos: Qos) -> Result<()>;
async fn publish(&mut self, message: &OutboundMessage) -> Result<()>;
async fn disconnect(&mut self) -> Result<()>;
async fn recv(&mut self) -> Result<Option<IncomingMessage>>;
}
#[cfg(feature = "transport-rumqttc")]
mod rumqtt_impl {
use std::time::{Duration, Instant};
use rumqttc::v5::mqttbytes::QoS as RumqttQos;
use rumqttc::v5::mqttbytes::v5::{ConnectProperties, LastWill};
use rumqttc::v5::{AsyncClient, ConnectionError, Event, EventLoop, Incoming, MqttOptions};
use super::{ConnectOptions, IncomingMessage, MqttTransport, OutboundMessage, Qos};
use crate::error::{Result, SparkplugError};
const fn to_rumqtt_qos(qos: Qos) -> RumqttQos {
match qos {
Qos::AtMostOnce => RumqttQos::AtMostOnce,
Qos::AtLeastOnce => RumqttQos::AtLeastOnce,
Qos::ExactlyOnce => RumqttQos::ExactlyOnce,
}
}
fn transport_err(e: impl ToString) -> SparkplugError {
SparkplugError::Transport(e.to_string())
}
#[cfg(feature = "tls")]
fn apply_tls(options: &mut MqttOptions, tls: Option<&super::TlsConfig>) -> Result<()> {
use rumqttc::{TlsConfiguration, Transport};
let Some(tls) = tls else {
return Ok(());
};
let Some(ca) = tls.ca_pem.clone() else {
return Err(SparkplugError::Transport(
"TLS requested without a CA certificate (TlsConfig.ca_pem is None)".to_owned(),
));
};
let client_auth = match (&tls.client_cert_pem, &tls.client_key_pem) {
(Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
(None, None) => None,
_ => {
return Err(SparkplugError::Transport(
"mTLS requires BOTH client_cert_pem and client_key_pem".to_owned(),
));
}
};
options.set_transport(Transport::tls_with_config(TlsConfiguration::Simple {
ca,
alpn: None,
client_auth,
}));
Ok(())
}
#[cfg(not(feature = "tls"))]
fn apply_tls(_options: &mut MqttOptions, tls: Option<&super::TlsConfig>) -> Result<()> {
if tls.is_some() {
return Err(SparkplugError::Transport(
"TLS was requested but the `tls` feature is disabled; would connect in plaintext"
.to_owned(),
));
}
Ok(())
}
pub struct RumqttcTransport {
client: Option<AsyncClient>,
eventloop: Option<EventLoop>,
connect_timeout: Duration,
channel_capacity: usize,
}
impl RumqttcTransport {
#[must_use]
pub fn new() -> Self {
Self {
client: None,
eventloop: None,
connect_timeout: Duration::from_secs(10),
channel_capacity: 256,
}
}
#[must_use]
pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
fn client(&self) -> Result<&AsyncClient> {
self.client
.as_ref()
.ok_or_else(|| SparkplugError::Transport("not connected".to_owned()))
}
}
impl Default for RumqttcTransport {
fn default() -> Self {
Self::new()
}
}
impl MqttTransport for RumqttcTransport {
async fn connect(&mut self, opts: &ConnectOptions) -> Result<()> {
let mut options =
MqttOptions::new(opts.client_id.clone(), opts.host.clone(), opts.port);
options.set_keep_alive(Duration::from_secs(u64::from(opts.keep_alive_secs)));
options.set_clean_start(opts.clean_start);
let mut props = ConnectProperties::new();
props.session_expiry_interval = Some(0);
options.set_connect_properties(props);
if let Some(will) = &opts.will {
options.set_last_will(LastWill::new(
will.topic.clone(),
will.payload.to_vec(),
to_rumqtt_qos(will.qos),
will.retain,
None,
));
}
apply_tls(&mut options, opts.tls.as_ref())?;
let (client, mut eventloop) = AsyncClient::new(options, self.channel_capacity);
let deadline = Instant::now() + self.connect_timeout;
loop {
match tokio::time::timeout(Duration::from_secs(1), eventloop.poll()).await {
Ok(Ok(Event::Incoming(Incoming::ConnAck(_)))) => break,
Ok(Ok(_)) => {}
Ok(Err(
e
@ (ConnectionError::ConnectionRefused(_) | ConnectionError::NotConnAck(_)),
)) => return Err(transport_err(e)),
Ok(Err(_)) => tokio::time::sleep(Duration::from_millis(50)).await,
Err(_elapsed) => {}
}
if Instant::now() >= deadline {
return Err(SparkplugError::Transport(
"timed out waiting for CONNACK".to_owned(),
));
}
}
self.client = Some(client);
self.eventloop = Some(eventloop);
Ok(())
}
async fn subscribe(&mut self, topic_filter: &str, qos: Qos) -> Result<()> {
self.client()?
.subscribe(topic_filter, to_rumqtt_qos(qos))
.await
.map_err(transport_err)
}
async fn publish(&mut self, message: &OutboundMessage) -> Result<()> {
self.client()?
.publish(
message.topic.clone(),
to_rumqtt_qos(message.qos),
message.retain,
message.payload.to_vec(),
)
.await
.map_err(transport_err)
}
async fn disconnect(&mut self) -> Result<()> {
if let Some(client) = &self.client {
client.disconnect().await.map_err(transport_err)?;
}
Ok(())
}
async fn recv(&mut self) -> Result<Option<IncomingMessage>> {
let eventloop = self
.eventloop
.as_mut()
.ok_or_else(|| SparkplugError::Transport("not connected".to_owned()))?;
loop {
match eventloop.poll().await {
Ok(Event::Incoming(Incoming::Publish(publish))) => {
let topic = String::from_utf8(publish.topic.to_vec())
.map_err(|_| SparkplugError::InvalidUtf8)?;
return Ok(Some(IncomingMessage {
topic,
payload: bytes::Bytes::from(publish.payload.to_vec()),
}));
}
Ok(_) => {}
Err(e) => return Err(transport_err(e)),
}
}
}
}
}
#[cfg(feature = "transport-rumqttc")]
pub use rumqtt_impl::RumqttcTransport;