Skip to main content

sparkplug_b/
transport.rs

1//! MQTT transport abstraction (Phase 2+).
2//!
3//! The Edge/Host engines drive an [`MqttTransport` ] so the library can sit on
4//! any MQTT client. A `rumqttc`-backed implementation (with TLS/HA) lands in
5//! Phase 4; tests use an in-memory transport. The QoS / retain / will rules are
6//! enforced by the edge/host layers, not the transport.
7
8use bytes::Bytes;
9
10use crate::error::Result;
11
12/// MQTT Quality of Service level.
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum Qos {
15    /// At most once (`0`) — Sparkplug data/cmd/birth/death.
16    AtMostOnce,
17    /// At least once (`1`) — Sparkplug will (NDEATH) and STATE.
18    AtLeastOnce,
19    /// Exactly once (`2`).
20    ExactlyOnce,
21}
22
23/// TLS configuration for a transport (PEM-encoded material).
24#[derive(Clone, Debug, Default)]
25pub struct TlsConfig {
26    /// Trusted CA chain (PEM).
27    pub ca_pem: Option<Vec<u8>>,
28    /// Client certificate for mTLS (PEM).
29    pub client_cert_pem: Option<Vec<u8>>,
30    /// Client private key for mTLS (PEM).
31    pub client_key_pem: Option<Vec<u8>>,
32}
33
34/// A message to publish.
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct OutboundMessage {
37    /// The MQTT topic.
38    pub topic: String,
39    /// The QoS level.
40    pub qos: Qos,
41    /// The retain flag.
42    pub retain: bool,
43    /// The raw payload bytes.
44    pub payload: Bytes,
45}
46
47/// A message received from the broker.
48#[derive(Clone, Debug, PartialEq, Eq)]
49pub struct IncomingMessage {
50    /// The MQTT topic the message arrived on.
51    pub topic: String,
52    /// The raw payload bytes.
53    pub payload: Bytes,
54}
55
56/// Connection options, including the Last-Will-and-Testament.
57#[derive(Clone, Debug)]
58pub struct ConnectOptions {
59    /// MQTT client id.
60    pub client_id: String,
61    /// Broker host. With TLS this is also matched against the server
62    /// certificate's SubjectAltName, so it must be covered by the cert (e.g.
63    /// `127.0.0.1` needs an IP SAN, a DNS name needs a matching DNS SAN).
64    pub host: String,
65    /// Broker port.
66    pub port: u16,
67    /// Keep-alive interval, seconds.
68    pub keep_alive_secs: u16,
69    /// MQTT 3.1.1 Clean Session / MQTT 5.0 Clean Start (Sparkplug requires `true`).
70    pub clean_start: bool,
71    /// The Last-Will-and-Testament (the Edge Node's NDEATH, QoS 1, retain=false).
72    pub will: Option<OutboundMessage>,
73    /// Optional TLS configuration.
74    pub tls: Option<TlsConfig>,
75}
76
77/// The MQTT transport the edge/host engines drive.
78///
79/// Implementations are used via static dispatch (the engines are generic over
80/// `T: MqttTransport`), so the auto-trait-bound caveat of `async fn` in traits
81/// does not apply here.
82#[allow(async_fn_in_trait)]
83pub trait MqttTransport {
84    /// Connect to the broker with the given options (registering the will).
85    ///
86    /// # Errors
87    /// Returns an error if the connection cannot be established.
88    async fn connect(&mut self, opts: &ConnectOptions) -> Result<()>;
89
90    /// Subscribe to a topic filter at the given QoS.
91    ///
92    /// # Errors
93    /// Returns an error if the subscription fails.
94    async fn subscribe(&mut self, topic_filter: &str, qos: Qos) -> Result<()>;
95
96    /// Publish a message.
97    ///
98    /// # Errors
99    /// Returns an error if publishing fails.
100    async fn publish(&mut self, message: &OutboundMessage) -> Result<()>;
101
102    /// Disconnect gracefully (the broker must NOT deliver the will).
103    ///
104    /// # Errors
105    /// Returns an error if the disconnect fails.
106    async fn disconnect(&mut self) -> Result<()>;
107
108    /// Await the next inbound message, or `None` once the stream is closed.
109    ///
110    /// # Errors
111    /// Returns an error if the transport fails while receiving. An error may be
112    /// **transient** (e.g. a reconnecting client): a run-loop should generally
113    /// log/back-off and call `recv` again rather than treat it as terminal.
114    async fn recv(&mut self) -> Result<Option<IncomingMessage>>;
115}
116
117#[cfg(feature = "transport-rumqttc")]
118mod rumqtt_impl {
119    use std::time::{Duration, Instant};
120
121    use rumqttc::v5::mqttbytes::QoS as RumqttQos;
122    use rumqttc::v5::mqttbytes::v5::{ConnectProperties, LastWill};
123    use rumqttc::v5::{AsyncClient, ConnectionError, Event, EventLoop, Incoming, MqttOptions};
124
125    use super::{ConnectOptions, IncomingMessage, MqttTransport, OutboundMessage, Qos};
126    use crate::error::{Result, SparkplugError};
127
128    const fn to_rumqtt_qos(qos: Qos) -> RumqttQos {
129        match qos {
130            Qos::AtMostOnce => RumqttQos::AtMostOnce,
131            Qos::AtLeastOnce => RumqttQos::AtLeastOnce,
132            Qos::ExactlyOnce => RumqttQos::ExactlyOnce,
133        }
134    }
135
136    fn transport_err(e: impl ToString) -> SparkplugError {
137        SparkplugError::Transport(e.to_string())
138    }
139
140    /// Apply a [`super::TlsConfig`] to the MQTT options (server-only TLS when only
141    /// a CA is given; mTLS when a client cert + key are both present).
142    ///
143    /// Server-certificate chain validation against the supplied CA and
144    /// hostname/SAN verification are always on (rumqttc's default rustls
145    /// verifier) — there is no opt-out, and we never load native roots.
146    #[cfg(feature = "tls")]
147    fn apply_tls(options: &mut MqttOptions, tls: Option<&super::TlsConfig>) -> Result<()> {
148        use rumqttc::{TlsConfiguration, Transport};
149
150        let Some(tls) = tls else {
151            return Ok(());
152        };
153        // A server-trust CA is mandatory; we never load native roots (that path
154        // panics on a bad cert).
155        let Some(ca) = tls.ca_pem.clone() else {
156            return Err(SparkplugError::Transport(
157                "TLS requested without a CA certificate (TlsConfig.ca_pem is None)".to_owned(),
158            ));
159        };
160        // mTLS iff BOTH a client cert and key are present; exactly one is a
161        // misconfiguration we reject rather than silently downgrade.
162        let client_auth = match (&tls.client_cert_pem, &tls.client_key_pem) {
163            (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
164            (None, None) => None,
165            _ => {
166                return Err(SparkplugError::Transport(
167                    "mTLS requires BOTH client_cert_pem and client_key_pem".to_owned(),
168                ));
169            }
170        };
171        options.set_transport(Transport::tls_with_config(TlsConfiguration::Simple {
172            ca,
173            alpn: None,
174            client_auth,
175        }));
176        Ok(())
177    }
178
179    /// Without the `tls` feature, a TLS request fails loud rather than silently
180    /// connecting in the clear.
181    #[cfg(not(feature = "tls"))]
182    fn apply_tls(_options: &mut MqttOptions, tls: Option<&super::TlsConfig>) -> Result<()> {
183        if tls.is_some() {
184            return Err(SparkplugError::Transport(
185                "TLS was requested but the `tls` feature is disabled; would connect in plaintext"
186                    .to_owned(),
187            ));
188        }
189        Ok(())
190    }
191
192    /// A [`MqttTransport`] backed by the `rumqttc` MQTT v5 async client.
193    ///
194    /// **The event loop must be pumped.** `subscribe`/`publish` only *enqueue*
195    /// requests onto a bounded channel; nothing reaches the broker until
196    /// [`MqttTransport::recv`] polls the event loop. A caller MUST therefore drive
197    /// `recv` continuously (e.g. the edge/host engines' `recv_and_handle` loop) —
198    /// otherwise queued messages are never sent and, once the channel fills,
199    /// `publish`/`subscribe` will block. `connect` only polls until the CONNACK.
200    ///
201    /// TLS/mTLS: with the `tls` feature, [`super::ConnectOptions::tls`] is honored
202    /// (server-only TLS when only `ca_pem` is set; mTLS when a client cert + key
203    /// are also present). Without the `tls` feature, a TLS request fails loudly
204    /// rather than silently connecting in plaintext.
205    pub struct RumqttcTransport {
206        client: Option<AsyncClient>,
207        eventloop: Option<EventLoop>,
208        connect_timeout: Duration,
209        channel_capacity: usize,
210    }
211
212    impl RumqttcTransport {
213        /// A new, unconnected transport (call [`MqttTransport::connect`]).
214        #[must_use]
215        pub fn new() -> Self {
216            Self {
217                client: None,
218                eventloop: None,
219                connect_timeout: Duration::from_secs(10),
220                channel_capacity: 256,
221            }
222        }
223
224        /// Override how long [`MqttTransport::connect`] waits for the CONNACK.
225        #[must_use]
226        pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
227            self.connect_timeout = timeout;
228            self
229        }
230
231        fn client(&self) -> Result<&AsyncClient> {
232            self.client
233                .as_ref()
234                .ok_or_else(|| SparkplugError::Transport("not connected".to_owned()))
235        }
236    }
237
238    impl Default for RumqttcTransport {
239        fn default() -> Self {
240            Self::new()
241        }
242    }
243
244    impl MqttTransport for RumqttcTransport {
245        async fn connect(&mut self, opts: &ConnectOptions) -> Result<()> {
246            let mut options =
247                MqttOptions::new(opts.client_id.clone(), opts.host.clone(), opts.port);
248            options.set_keep_alive(Duration::from_secs(u64::from(opts.keep_alive_secs)));
249            options.set_clean_start(opts.clean_start);
250            // MQTT 5.0: Sparkplug requires Clean Start = true AND Session Expiry
251            // Interval = 0 (tck-id-principles-persistence-clean-session-50). Set the
252            // property explicitly so it is present on the wire for strict brokers/TCK.
253            let mut props = ConnectProperties::new();
254            props.session_expiry_interval = Some(0);
255            options.set_connect_properties(props);
256            if let Some(will) = &opts.will {
257                options.set_last_will(LastWill::new(
258                    will.topic.clone(),
259                    will.payload.to_vec(),
260                    to_rumqtt_qos(will.qos),
261                    will.retain,
262                    None,
263                ));
264            }
265            apply_tls(&mut options, opts.tls.as_ref())?;
266
267            let (client, mut eventloop) = AsyncClient::new(options, self.channel_capacity);
268
269            // Drive the event loop until the CONNACK; the broker may still be
270            // binding, so a poll error is retried until the deadline.
271            let deadline = Instant::now() + self.connect_timeout;
272            loop {
273                match tokio::time::timeout(Duration::from_secs(1), eventloop.poll()).await {
274                    Ok(Ok(Event::Incoming(Incoming::ConnAck(_)))) => break,
275                    Ok(Ok(_)) => {}
276                    // A refused CONNACK or non-CONNACK first packet is fatal —
277                    // fail fast with the reason instead of spinning to the deadline.
278                    Ok(Err(
279                        e
280                        @ (ConnectionError::ConnectionRefused(_) | ConnectionError::NotConnAck(_)),
281                    )) => return Err(transport_err(e)),
282                    Ok(Err(_)) => tokio::time::sleep(Duration::from_millis(50)).await,
283                    Err(_elapsed) => {}
284                }
285                if Instant::now() >= deadline {
286                    return Err(SparkplugError::Transport(
287                        "timed out waiting for CONNACK".to_owned(),
288                    ));
289                }
290            }
291
292            self.client = Some(client);
293            self.eventloop = Some(eventloop);
294            Ok(())
295        }
296
297        async fn subscribe(&mut self, topic_filter: &str, qos: Qos) -> Result<()> {
298            self.client()?
299                .subscribe(topic_filter, to_rumqtt_qos(qos))
300                .await
301                .map_err(transport_err)
302        }
303
304        async fn publish(&mut self, message: &OutboundMessage) -> Result<()> {
305            self.client()?
306                .publish(
307                    message.topic.clone(),
308                    to_rumqtt_qos(message.qos),
309                    message.retain,
310                    message.payload.to_vec(),
311                )
312                .await
313                .map_err(transport_err)
314        }
315
316        async fn disconnect(&mut self) -> Result<()> {
317            if let Some(client) = &self.client {
318                client.disconnect().await.map_err(transport_err)?;
319            }
320            Ok(())
321        }
322
323        async fn recv(&mut self) -> Result<Option<IncomingMessage>> {
324            let eventloop = self
325                .eventloop
326                .as_mut()
327                .ok_or_else(|| SparkplugError::Transport("not connected".to_owned()))?;
328            loop {
329                match eventloop.poll().await {
330                    Ok(Event::Incoming(Incoming::Publish(publish))) => {
331                        let topic = String::from_utf8(publish.topic.to_vec())
332                            .map_err(|_| SparkplugError::InvalidUtf8)?;
333                        return Ok(Some(IncomingMessage {
334                            topic,
335                            payload: bytes::Bytes::from(publish.payload.to_vec()),
336                        }));
337                    }
338                    Ok(_) => {}
339                    Err(e) => return Err(transport_err(e)),
340                }
341            }
342        }
343    }
344}
345
346#[cfg(feature = "transport-rumqttc")]
347pub use rumqtt_impl::RumqttcTransport;