fundamentum_sdk_mqtt/
client.rs

1//! Fundamentum Client
2//!
3
4use std::{
5    collections::HashMap,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use backon::{ExponentialBuilder, Retryable};
12use displaydoc::Display;
13use futures::Stream;
14#[cfg(any(feature = "use-native-tls", feature = "use-rustls"))]
15use rumqttc::TlsConfiguration;
16use rumqttc::{
17    Transport,
18    v5::{
19        AsyncClient, ClientError, ConnectionError, Event, EventLoop, MqttOptions,
20        mqttbytes::{
21            QoS,
22            v5::{ConnectReturnCode, LastWill, PublishProperties},
23        },
24    },
25};
26use tokio::{
27    sync::{
28        Mutex,
29        broadcast::{self, Receiver, Sender},
30        watch,
31    },
32    time::Duration,
33};
34use tokio_stream::wrappers::WatchStream;
35use tracing::{error, info};
36
37use crate::{
38    Device, Error, Message, PublishOptions, Publishable, Publisher, Security, error,
39    publish_options::PublishOptionsResolved,
40};
41
42/// Max Packet Size Override
43pub struct MQTTMaxPacketSize(u32);
44
45/// Option Size Override
46#[derive(Default)]
47pub struct MQTTOptionsOverrides {
48    /// Override clean session
49    pub clean_session: Option<bool>,
50    /// Override session expiry interval
51    pub session_expiry_interval: Option<Duration>,
52    /// Override keep alive
53    pub keep_alive: Option<Duration>,
54    /// Override max packet size
55    pub max_packet_size: Option<MQTTMaxPacketSize>,
56    /// Override request channel capacity
57    pub request_channel_capacity: Option<usize>,
58    /// Override pending throttle
59    pub pending_throttle: Option<Duration>,
60    /// Override inflight
61    pub inflight: Option<u16>,
62    /// Override last will
63    pub last_will: Option<LastWill>,
64    /// Override transport
65    pub transport: Option<Transport>,
66}
67
68/// Fundamentum `IoT` Settings
69pub struct ClientSettings {
70    /// Security strategy
71    security: Security,
72    /// Device's representation
73    device: Device,
74    /// Fundamentum endpoint
75    endpoint: String,
76    /// MQTT options overrides
77    mqtt_options_overrides: Option<MQTTOptionsOverrides>,
78}
79
80impl ClientSettings {
81    /// Create a new `FundamentumIoTSettings`
82    ///
83    /// # Params
84    ///
85    /// * `security`: Security generator
86    /// * `device`: Device's definition
87    /// * `iot_endpoint`: Uri endpoint MQTT
88    /// * `mqtt_options_overrides`: MQTT options overrides
89    ///
90    #[must_use]
91    pub const fn new(
92        security: Security,
93        device: Device,
94        endpoint: String,
95        mqtt_options_overrides: Option<MQTTOptionsOverrides>,
96    ) -> Self {
97        Self {
98            security,
99            device,
100            endpoint,
101            mqtt_options_overrides,
102        }
103    }
104
105    /// Converts the `ClientSettings` struct into `MqttOptions`.
106    ///
107    /// This method transforms the current `ClientSettings` instance into a set of
108    /// options that can be used to configure rumqtt client.
109    ///
110    /// # Returns
111    ///
112    /// A struct containing the MQTT options derived from the current `ClientSettings`.
113    ///
114    /// # Errors
115    ///
116    /// Returns an `error::Error` if an error occurs while generating a new signed token.
117    ///
118    pub async fn to_mqtt_options(&self) -> Result<MqttOptions, error::Error> {
119        let mut mqtt_options =
120            MqttOptions::new(self.device.client_id(), self.endpoint.clone(), 8883);
121
122        if let Some(ref overrides) = self.mqtt_options_overrides {
123            if let Some(clean_session) = overrides.clean_session {
124                mqtt_options.set_clean_start(clean_session);
125            }
126            if let Some(session_expiry_interval) = overrides.session_expiry_interval {
127                let mut connect_properties = mqtt_options.connect_properties().unwrap_or_default();
128                connect_properties.session_expiry_interval =
129                    Some(session_expiry_interval.as_secs().try_into()?);
130                mqtt_options.set_connect_properties(connect_properties);
131            }
132            if let Some(transport) = overrides.transport.clone() {
133                mqtt_options.set_transport(transport);
134            } else {
135                let transport = get_default_transport();
136                mqtt_options.set_transport(transport);
137            }
138            if let Some(keep_alive) = overrides.keep_alive {
139                mqtt_options.set_keep_alive(keep_alive);
140            }
141            if let Some(ref packet_size) = overrides.max_packet_size {
142                mqtt_options.set_max_packet_size(Some(packet_size.0));
143            }
144            if let Some(request_channel_capacity) = overrides.request_channel_capacity {
145                mqtt_options.set_request_channel_capacity(request_channel_capacity);
146            }
147            if let Some(pending_throttle) = overrides.pending_throttle {
148                mqtt_options.set_pending_throttle(pending_throttle);
149            }
150            if let Some(inflight) = overrides.inflight {
151                mqtt_options.set_outgoing_inflight_upper_limit(inflight);
152            }
153            if let Some(last_will) = overrides.last_will.clone() {
154                mqtt_options.set_last_will(last_will);
155            }
156        }
157
158        // Generate a new signed token
159        let token = self.security.generate_token().await?;
160        mqtt_options.set_credentials("unused", token);
161
162        Ok(mqtt_options)
163    }
164}
165
166const fn get_default_transport() -> Transport {
167    #[cfg(all(feature = "use-native-tls", not(feature = "use-rustls")))]
168    let transport = Transport::Tls(TlsConfiguration::Native);
169    #[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
170    let transport = Transport::Tls(TlsConfiguration::default());
171    #[cfg(all(feature = "use-native-tls", feature = "use-rustls"))]
172    let transport = Transport::Tls(TlsConfiguration::Native);
173    #[cfg(not(any(feature = "use-rustls", feature = "use-native-tls")))]
174    let transport = Transport::Tcp;
175
176    transport
177}
178
179#[derive(Clone)]
180struct EventLoopManager {
181    event_loop: Arc<Mutex<EventLoop>>,
182    client_status: watch::Sender<ClientStatus>,
183}
184
185impl EventLoopManager {
186    fn new(event_loop: EventLoop, client_status: watch::Sender<ClientStatus>) -> Self {
187        Self {
188            event_loop: Arc::new(Mutex::new(event_loop)),
189            client_status,
190        }
191    }
192
193    async fn poll(&self) -> Result<Event, ConnectionError> {
194        let mut in_error = false;
195
196        let polling_result = (|| async { self.event_loop.lock().await.poll().await })
197            .retry(ExponentialBuilder::default().without_max_times())
198            .when(is_backoff_error)
199            .notify(|err, dur: Duration| {
200                in_error = true;
201                self.set_client_status(ClientStatus::InError(ClientStatusError::MqttConnection(
202                    err.to_string(),
203                )));
204
205                let dur = dur.as_secs_f32();
206                error!("Error while polling MQTT event loop: {err}\n -> Retrying in {dur:.1}s...");
207            })
208            .await;
209
210        match polling_result.as_ref() {
211            Ok(_) => {
212                self.set_client_status(ClientStatus::Connected);
213                if in_error {
214                    info!("MQTT connection restored");
215                }
216            }
217            Err(err) => self.set_client_status(ClientStatus::InError(
218                ClientStatusError::MqttConnection(err.to_string()),
219            )),
220        }
221
222        polling_result
223    }
224
225    fn set_client_status(&self, status: ClientStatus) {
226        self.client_status.send_if_modified(|current_status| {
227            let notify = current_status != &status;
228            *current_status = status;
229            notify
230        });
231    }
232}
233
234const fn is_backoff_error(err: &ConnectionError) -> bool {
235    !matches!(
236        err,
237        &ConnectionError::ConnectionRefused(
238            ConnectReturnCode::ProtocolError
239                | ConnectReturnCode::UnsupportedProtocolVersion
240                | ConnectReturnCode::ClientIdentifierNotValid
241                | ConnectReturnCode::BadUserNamePassword
242                | ConnectReturnCode::NotAuthorized
243                | ConnectReturnCode::Banned
244                | ConnectReturnCode::BadAuthenticationMethod
245                | ConnectReturnCode::UseAnotherServer
246                | ConnectReturnCode::ServerMoved
247        )
248    )
249}
250
251/// Represents the possible errors that can occur with the client connection status.
252#[derive(Display, Debug, Clone, PartialEq, Eq, thiserror::Error)]
253pub enum ClientStatusError {
254    /// An error occurred with the MQTT connection: {0}
255    MqttConnection(String),
256}
257
258/// Represents the status of the client connection.
259#[derive(Clone, PartialEq, Eq)]
260pub enum ClientStatus {
261    /// The client is currently connected.
262    Connected,
263    /// The client is currently disconnected.
264    Disconnected,
265    /// Represents an error state in the connection status.
266    /// Contains a string describing the error.
267    InError(ClientStatusError),
268}
269
270/// A stream for tracking the status of a client.
271///
272/// `ClientStatusStream` provides functionality to observe the current status
273/// of a client and subscribe to status updates via a stream.
274pub struct ClientStatusStream {
275    receiver: watch::Receiver<ClientStatus>,
276    stream: WatchStream<ClientStatus>,
277}
278
279impl ClientStatusStream {
280    #[must_use]
281    fn new(receiver: watch::Receiver<ClientStatus>) -> Self {
282        let stream = WatchStream::new(receiver.clone());
283        Self { receiver, stream }
284    }
285
286    /// Retrieves the current status of the client.
287    ///
288    /// # Returns
289    ///
290    /// The current `ClientStatus` value.
291    #[must_use]
292    pub fn current(&self) -> ClientStatus {
293        self.receiver.borrow().clone()
294    }
295}
296
297impl Stream for ClientStatusStream {
298    type Item = ClientStatus;
299
300    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301        let this = self.get_mut();
302        Pin::new(&mut this.stream).poll_next(cx)
303    }
304}
305
306impl Unpin for ClientStatusStream {}
307
308/// Fundamentum `IoT` Async Client
309#[derive(Clone)]
310pub struct Client {
311    client: AsyncClient,
312    manager: EventLoopManager,
313    device: Device,
314    incoming_event_sender: Sender<Message>,
315    status_rx: watch::Receiver<ClientStatus>,
316}
317
318impl Client {
319    /// Create new `FundamentumIoTAsyncClient`. Input argument should be the `FundamentumIoTSettings`. Returns a tuple where the first element is the
320    /// `FundamentumIoTAsyncClient`, and the second element is a new tuple with the eventloop and incoming
321    /// event sender. This tuple should be sent as an argument to the `async_event_loop_listener`.
322    ///
323    /// # Errors
324    ///
325    /// If the creation failed then the result returns an error `FundamentumIoTError`
326    pub async fn new(settings: ClientSettings) -> Result<Self, error::Error> {
327        let mqtt_options = settings.to_mqtt_options().await?;
328
329        let (client, eventloop) = AsyncClient::new(mqtt_options, 10);
330        let (request_tx, _) = broadcast::channel(50);
331        let (status_tx, status_rx) = watch::channel(ClientStatus::Disconnected);
332        let manager = EventLoopManager::new(eventloop, status_tx);
333
334        Ok(Self {
335            client,
336            manager,
337            device: settings.device.clone(),
338            incoming_event_sender: request_tx,
339            status_rx,
340        })
341    }
342
343    /// Run the client’s background task.
344    ///
345    /// This task continuously polls the underlying even loop which
346    /// ensures messages are exchanged between this client and the broker.
347    ///
348    /// Nothing will be exchanged until this task is run.
349    ///
350    /// # Errors
351    ///
352    /// Returns an `error::Error` if a fatal MQTT connection error occurs that cannot be retried or recovered from.
353    /// The underlying `ConnectionError` is considered fatal if it matches one of the following MQTT connection refusal codes:
354    /// - `ProtocolError`
355    /// - `UnsupportedProtocolVersion`
356    /// - `ClientIdentifierNotValid`
357    /// - `BadUserNamePassword`
358    /// - `NotAuthorized`
359    /// - `Banned`
360    /// - `BadAuthenticationMethod`
361    /// - `UseAnotherServer`
362    /// - `ServerMoved`
363    ///
364    /// In these cases, the client will terminate with an error. For all other connection errors (such as temporary network issues),
365    /// the client will attempt to reconnect automatically (backoff).
366    ///
367    /// Any error returned by this function indicates that the client can no longer function and requires user intervention.
368    pub async fn run(&self) -> Result<(), error::Error> {
369        loop {
370            let event = self.manager.poll().await?;
371            if let Err(err) = self.handle_event(event) {
372                error!("Error while handling MQTT event: {err}");
373            }
374        }
375    }
376
377    fn handle_event(&self, event: Event) -> Result<(), error::Error> {
378        match event {
379            Event::Incoming(packet) => {
380                let message = Message::try_from_packet(packet)?;
381                self.incoming_event_sender.send(message)?;
382                Ok(())
383            }
384            Event::Outgoing(_) => Ok(()), // Silently discarded.
385        }
386    }
387
388    /// Subscribe to a topic
389    ///
390    /// # Errors
391    ///
392    /// Returns `ClientError` if the request failed
393    async fn subscribe<S: Into<String> + Send>(
394        &self,
395        topic: S,
396        qos: QoS,
397    ) -> Result<(), ClientError> {
398        self.client.subscribe(topic, qos).await
399    }
400
401    /// Subscribe to device's config channel
402    ///
403    /// # Reference
404    ///
405    /// * `registries/{REGISTRY_ID}/devices/{DEVICE_SN}/config`
406    ///
407    /// # Errors
408    ///
409    /// Returns `ClientError` if the request failed
410    pub async fn subscribe_config(&self, qos: QoS) -> Result<(), ClientError> {
411        self.subscribe(
412            format!(
413                "registries/{}/devices/{}/config",
414                self.device.registry_id(),
415                self.device.serial()
416            ),
417            qos,
418        )
419        .await
420    }
421
422    /// Subscribe to device's port forward channel
423    ///
424    /// # Reference
425    ///
426    /// * `registries/{REGISTRY_ID}/devices/{DEVICE_SN}/portforward/tx`
427    ///
428    /// # Errors
429    ///
430    /// Returns `ClientError` if the request failed
431    pub async fn subscribe_portforward(&self, qos: QoS) -> Result<(), ClientError> {
432        self.subscribe(
433            format!(
434                "registries/{}/devices/{}/pfwd/tx",
435                self.device.registry_id(),
436                self.device.serial()
437            ),
438            qos,
439        )
440        .await
441    }
442
443    /// Subscribe to device's commands channel
444    ///
445    /// # Reference
446    ///
447    /// * `registries/{REGISTRY_ID}/devices/{DEVICE_SN}/commands`
448    ///
449    /// # Errors
450    ///
451    /// Returns `ClientError` if the request failed
452    pub async fn subscribe_commands(&self, qos: QoS) -> Result<(), ClientError> {
453        self.subscribe(
454            format!(
455                "registries/{}/devices/{}/commands",
456                self.device.registry_id(),
457                self.device.serial()
458            ),
459            qos,
460        )
461        .await
462    }
463
464    /// Subscribe to device's actions channel
465    ///
466    /// # Reference
467    ///
468    /// * `registries/{REGISTRY_ID}/devices/{DEVICE_SN}/actions`
469    ///
470    /// # Errors
471    ///
472    /// Returns `ClientError` if the request failed
473    pub async fn subscribe_actions(&self, qos: QoS) -> Result<(), ClientError> {
474        self.subscribe(
475            format!(
476                "registries/{}/devices/{}/actions",
477                self.device.registry_id(),
478                self.device.serial(),
479            ),
480            qos,
481        )
482        .await
483    }
484
485    /// Get a receiver of the incoming messages. Send this to any function that
486    /// wants to read the incoming messages from `IoT` Core.
487    ///
488    /// Note that it is very important that your retrieve your *receiver* prior
489    /// to any `subscribe_` of interest. Otherwise, you most likely won't
490    /// receive messages sent to you in between this call and prior
491    /// subscriptions.
492    ///
493    /// This is of particular importance for messages returned as a result to
494    /// your subscription.
495    #[must_use]
496    pub fn get_receiver(&self) -> Receiver<Message> {
497        self.incoming_event_sender.subscribe()
498    }
499
500    /// If you want to use the Rumqttc `AsyncClient` and `EventLoop` manually, this method can be used
501    /// to get the `AsyncClient`.
502    #[must_use]
503    #[allow(clippy::nursery)]
504    pub fn get_client(self) -> AsyncClient {
505        self.client
506    }
507
508    /// Monitor our client's statuses over time.
509    ///
510    /// The returned [`ClientStatusStream`] is a [`futures::Stream`] of
511    /// [`ClientStatus`] that can also provide the most recent status through
512    /// its [`current`](ClientStatusStream::current) method.
513    ///
514    /// # Returns
515    ///
516    /// The streamable monitor instance.
517    #[must_use]
518    pub fn status_stream(&self) -> ClientStatusStream {
519        ClientStatusStream::new(self.status_rx.clone())
520    }
521}
522
523impl Publisher for Client {
524    type Error = Error;
525
526    async fn publish_with<P: Publishable + Send>(
527        &self,
528        publishable: P,
529        options: PublishOptions,
530    ) -> Result<(), Self::Error> {
531        let options = PublishOptionsResolved {
532            qos: QoS::ExactlyOnce,
533            retain: false,
534            user_properties: HashMap::new(),
535            content_type: None,
536        }
537        .override_with(&publishable.publish_overrides())
538        .override_with(&options);
539
540        let properties = PublishProperties {
541            user_properties: options.user_properties.into_iter().collect(),
542            content_type: options.content_type,
543            ..Default::default()
544        };
545
546        self.client
547            .publish_with_properties(
548                publishable.topic(&self.device),
549                options.qos,
550                options.retain,
551                publishable.payload().map_err(Into::into)?,
552                properties,
553            )
554            .await?;
555
556        Ok(())
557    }
558}