1use std::{
5 collections::HashMap,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use backon::{ExponentialBuilder, Retryable};
12use displaydoc::Display;
13use futures::Stream;
14#[cfg(any(feature = "use-native-tls", feature = "use-rustls"))]
15use rumqttc::TlsConfiguration;
16use rumqttc::{
17 Transport,
18 v5::{
19 AsyncClient, ClientError, ConnectionError, Event, EventLoop, MqttOptions,
20 mqttbytes::{
21 QoS,
22 v5::{ConnectReturnCode, LastWill, PublishProperties},
23 },
24 },
25};
26use tokio::{
27 sync::{
28 Mutex,
29 broadcast::{self, Receiver, Sender},
30 watch,
31 },
32 time::Duration,
33};
34use tokio_stream::wrappers::WatchStream;
35use tracing::{error, info};
36
37use crate::{
38 Device, Error, Message, PublishOptions, Publishable, Publisher, Security, error,
39 publish_options::PublishOptionsResolved,
40};
41
42pub struct MQTTMaxPacketSize(u32);
44
45#[derive(Default)]
47pub struct MQTTOptionsOverrides {
48 pub clean_session: Option<bool>,
50 pub session_expiry_interval: Option<Duration>,
52 pub keep_alive: Option<Duration>,
54 pub max_packet_size: Option<MQTTMaxPacketSize>,
56 pub request_channel_capacity: Option<usize>,
58 pub pending_throttle: Option<Duration>,
60 pub inflight: Option<u16>,
62 pub last_will: Option<LastWill>,
64 pub transport: Option<Transport>,
66}
67
68pub struct ClientSettings {
70 security: Security,
72 device: Device,
74 endpoint: String,
76 mqtt_options_overrides: Option<MQTTOptionsOverrides>,
78}
79
80impl ClientSettings {
81 #[must_use]
91 pub const fn new(
92 security: Security,
93 device: Device,
94 endpoint: String,
95 mqtt_options_overrides: Option<MQTTOptionsOverrides>,
96 ) -> Self {
97 Self {
98 security,
99 device,
100 endpoint,
101 mqtt_options_overrides,
102 }
103 }
104
105 pub async fn to_mqtt_options(&self) -> Result<MqttOptions, error::Error> {
119 let mut mqtt_options =
120 MqttOptions::new(self.device.client_id(), self.endpoint.clone(), 8883);
121
122 if let Some(ref overrides) = self.mqtt_options_overrides {
123 if let Some(clean_session) = overrides.clean_session {
124 mqtt_options.set_clean_start(clean_session);
125 }
126 if let Some(session_expiry_interval) = overrides.session_expiry_interval {
127 let mut connect_properties = mqtt_options.connect_properties().unwrap_or_default();
128 connect_properties.session_expiry_interval =
129 Some(session_expiry_interval.as_secs().try_into()?);
130 mqtt_options.set_connect_properties(connect_properties);
131 }
132 if let Some(transport) = overrides.transport.clone() {
133 mqtt_options.set_transport(transport);
134 } else {
135 let transport = get_default_transport();
136 mqtt_options.set_transport(transport);
137 }
138 if let Some(keep_alive) = overrides.keep_alive {
139 mqtt_options.set_keep_alive(keep_alive);
140 }
141 if let Some(ref packet_size) = overrides.max_packet_size {
142 mqtt_options.set_max_packet_size(Some(packet_size.0));
143 }
144 if let Some(request_channel_capacity) = overrides.request_channel_capacity {
145 mqtt_options.set_request_channel_capacity(request_channel_capacity);
146 }
147 if let Some(pending_throttle) = overrides.pending_throttle {
148 mqtt_options.set_pending_throttle(pending_throttle);
149 }
150 if let Some(inflight) = overrides.inflight {
151 mqtt_options.set_outgoing_inflight_upper_limit(inflight);
152 }
153 if let Some(last_will) = overrides.last_will.clone() {
154 mqtt_options.set_last_will(last_will);
155 }
156 }
157
158 let token = self.security.generate_token().await?;
160 mqtt_options.set_credentials("unused", token);
161
162 Ok(mqtt_options)
163 }
164}
165
166const fn get_default_transport() -> Transport {
167 #[cfg(all(feature = "use-native-tls", not(feature = "use-rustls")))]
168 let transport = Transport::Tls(TlsConfiguration::Native);
169 #[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
170 let transport = Transport::Tls(TlsConfiguration::default());
171 #[cfg(all(feature = "use-native-tls", feature = "use-rustls"))]
172 let transport = Transport::Tls(TlsConfiguration::Native);
173 #[cfg(not(any(feature = "use-rustls", feature = "use-native-tls")))]
174 let transport = Transport::Tcp;
175
176 transport
177}
178
179#[derive(Clone)]
180struct EventLoopManager {
181 event_loop: Arc<Mutex<EventLoop>>,
182 client_status: watch::Sender<ClientStatus>,
183}
184
185impl EventLoopManager {
186 fn new(event_loop: EventLoop, client_status: watch::Sender<ClientStatus>) -> Self {
187 Self {
188 event_loop: Arc::new(Mutex::new(event_loop)),
189 client_status,
190 }
191 }
192
193 async fn poll(&self) -> Result<Event, ConnectionError> {
194 let mut in_error = false;
195
196 let polling_result = (|| async { self.event_loop.lock().await.poll().await })
197 .retry(ExponentialBuilder::default().without_max_times())
198 .when(is_backoff_error)
199 .notify(|err, dur: Duration| {
200 in_error = true;
201 self.set_client_status(ClientStatus::InError(ClientStatusError::MqttConnection(
202 err.to_string(),
203 )));
204
205 let dur = dur.as_secs_f32();
206 error!("Error while polling MQTT event loop: {err}\n -> Retrying in {dur:.1}s...");
207 })
208 .await;
209
210 match polling_result.as_ref() {
211 Ok(_) => {
212 self.set_client_status(ClientStatus::Connected);
213 if in_error {
214 info!("MQTT connection restored");
215 }
216 }
217 Err(err) => self.set_client_status(ClientStatus::InError(
218 ClientStatusError::MqttConnection(err.to_string()),
219 )),
220 }
221
222 polling_result
223 }
224
225 fn set_client_status(&self, status: ClientStatus) {
226 self.client_status.send_if_modified(|current_status| {
227 let notify = current_status != &status;
228 *current_status = status;
229 notify
230 });
231 }
232}
233
234const fn is_backoff_error(err: &ConnectionError) -> bool {
235 !matches!(
236 err,
237 &ConnectionError::ConnectionRefused(
238 ConnectReturnCode::ProtocolError
239 | ConnectReturnCode::UnsupportedProtocolVersion
240 | ConnectReturnCode::ClientIdentifierNotValid
241 | ConnectReturnCode::BadUserNamePassword
242 | ConnectReturnCode::NotAuthorized
243 | ConnectReturnCode::Banned
244 | ConnectReturnCode::BadAuthenticationMethod
245 | ConnectReturnCode::UseAnotherServer
246 | ConnectReturnCode::ServerMoved
247 )
248 )
249}
250
251#[derive(Display, Debug, Clone, PartialEq, Eq, thiserror::Error)]
253pub enum ClientStatusError {
254 MqttConnection(String),
256}
257
258#[derive(Clone, PartialEq, Eq)]
260pub enum ClientStatus {
261 Connected,
263 Disconnected,
265 InError(ClientStatusError),
268}
269
270pub struct ClientStatusStream {
275 receiver: watch::Receiver<ClientStatus>,
276 stream: WatchStream<ClientStatus>,
277}
278
279impl ClientStatusStream {
280 #[must_use]
281 fn new(receiver: watch::Receiver<ClientStatus>) -> Self {
282 let stream = WatchStream::new(receiver.clone());
283 Self { receiver, stream }
284 }
285
286 #[must_use]
292 pub fn current(&self) -> ClientStatus {
293 self.receiver.borrow().clone()
294 }
295}
296
297impl Stream for ClientStatusStream {
298 type Item = ClientStatus;
299
300 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301 let this = self.get_mut();
302 Pin::new(&mut this.stream).poll_next(cx)
303 }
304}
305
306impl Unpin for ClientStatusStream {}
307
308#[derive(Clone)]
310pub struct Client {
311 client: AsyncClient,
312 manager: EventLoopManager,
313 device: Device,
314 incoming_event_sender: Sender<Message>,
315 status_rx: watch::Receiver<ClientStatus>,
316}
317
318impl Client {
319 pub async fn new(settings: ClientSettings) -> Result<Self, error::Error> {
327 let mqtt_options = settings.to_mqtt_options().await?;
328
329 let (client, eventloop) = AsyncClient::new(mqtt_options, 10);
330 let (request_tx, _) = broadcast::channel(50);
331 let (status_tx, status_rx) = watch::channel(ClientStatus::Disconnected);
332 let manager = EventLoopManager::new(eventloop, status_tx);
333
334 Ok(Self {
335 client,
336 manager,
337 device: settings.device.clone(),
338 incoming_event_sender: request_tx,
339 status_rx,
340 })
341 }
342
343 pub async fn run(&self) -> Result<(), error::Error> {
369 loop {
370 let event = self.manager.poll().await?;
371 if let Err(err) = self.handle_event(event) {
372 error!("Error while handling MQTT event: {err}");
373 }
374 }
375 }
376
377 fn handle_event(&self, event: Event) -> Result<(), error::Error> {
378 match event {
379 Event::Incoming(packet) => {
380 let message = Message::try_from_packet(packet)?;
381 self.incoming_event_sender.send(message)?;
382 Ok(())
383 }
384 Event::Outgoing(_) => Ok(()), }
386 }
387
388 async fn subscribe<S: Into<String> + Send>(
394 &self,
395 topic: S,
396 qos: QoS,
397 ) -> Result<(), ClientError> {
398 self.client.subscribe(topic, qos).await
399 }
400
401 pub async fn subscribe_config(&self, qos: QoS) -> Result<(), ClientError> {
411 self.subscribe(
412 format!(
413 "registries/{}/devices/{}/config",
414 self.device.registry_id(),
415 self.device.serial()
416 ),
417 qos,
418 )
419 .await
420 }
421
422 pub async fn subscribe_portforward(&self, qos: QoS) -> Result<(), ClientError> {
432 self.subscribe(
433 format!(
434 "registries/{}/devices/{}/pfwd/tx",
435 self.device.registry_id(),
436 self.device.serial()
437 ),
438 qos,
439 )
440 .await
441 }
442
443 pub async fn subscribe_commands(&self, qos: QoS) -> Result<(), ClientError> {
453 self.subscribe(
454 format!(
455 "registries/{}/devices/{}/commands",
456 self.device.registry_id(),
457 self.device.serial()
458 ),
459 qos,
460 )
461 .await
462 }
463
464 pub async fn subscribe_actions(&self, qos: QoS) -> Result<(), ClientError> {
474 self.subscribe(
475 format!(
476 "registries/{}/devices/{}/actions",
477 self.device.registry_id(),
478 self.device.serial(),
479 ),
480 qos,
481 )
482 .await
483 }
484
485 #[must_use]
496 pub fn get_receiver(&self) -> Receiver<Message> {
497 self.incoming_event_sender.subscribe()
498 }
499
500 #[must_use]
503 #[allow(clippy::nursery)]
504 pub fn get_client(self) -> AsyncClient {
505 self.client
506 }
507
508 #[must_use]
518 pub fn status_stream(&self) -> ClientStatusStream {
519 ClientStatusStream::new(self.status_rx.clone())
520 }
521}
522
523impl Publisher for Client {
524 type Error = Error;
525
526 async fn publish_with<P: Publishable + Send>(
527 &self,
528 publishable: P,
529 options: PublishOptions,
530 ) -> Result<(), Self::Error> {
531 let options = PublishOptionsResolved {
532 qos: QoS::ExactlyOnce,
533 retain: false,
534 user_properties: HashMap::new(),
535 content_type: None,
536 }
537 .override_with(&publishable.publish_overrides())
538 .override_with(&options);
539
540 let properties = PublishProperties {
541 user_properties: options.user_properties.into_iter().collect(),
542 content_type: options.content_type,
543 ..Default::default()
544 };
545
546 self.client
547 .publish_with_properties(
548 publishable.topic(&self.device),
549 options.qos,
550 options.retain,
551 publishable.payload().map_err(Into::into)?,
552 properties,
553 )
554 .await?;
555
556 Ok(())
557 }
558}