mqtt_service/
lib.rs

1#![doc = include_str!("../README.MD")]
2#![warn(missing_docs)]
3#![deny(warnings)]
4#![allow(clippy::result_large_err)]
5
6use std::sync::Arc;
7
8#[cfg(any(feature = "json", feature = "async_service"))]
9use std::future::Future;
10
11#[cfg(feature = "json")]
12use bytes::Buf;
13
14#[cfg(feature = "async_service")]
15use yaaral::prelude::*;
16
17type Result<T> = std::result::Result<T, Error>;
18
19// Structures
20
21#[derive(thiserror::Error, Debug)]
22#[allow(missing_docs)]
23pub enum Error
24{
25  #[error("Future was cancelled.")]
26  CancelledFuture(String),
27  #[error("An error occured in the MQTT channel {0}.")]
28  MqttChannelError(#[from] mqtt_channel::Error),
29  #[error("Channel was closed for topic {0}.")]
30  ChannelClosed(String),
31  #[error("Send error when broadcasting message {0}")]
32  AsyncBroadcastSend(#[from] async_broadcast::SendError<mqtt_channel::RawMessage>),
33  #[error("Recv error when broadcasting message {0}")]
34  AsyncBroadcastRecv(#[from] async_broadcast::RecvError),
35  #[cfg(feature = "json")]
36  #[error("Error during json serialization {0}")]
37  JsonSerialisationError(#[from] serde_json::Error),
38}
39
40/// Handle to the connection
41#[derive(Clone)]
42pub struct Client
43{
44  client: mqtt_channel::Client,
45}
46
47impl Client
48{
49  /// Create a new MQTT Client, with the given name to connect at the given and MQTT broker
50  pub fn new(client: mqtt_channel::Client) -> Client
51  {
52    Client { client }
53  }
54  /// Implement call service using the MQTT Request-Response pattern
55  pub async fn call_raw_service(
56    self,
57    topic: impl Into<String>,
58    message: impl Into<bytes::Bytes>,
59    content_type: Option<String>,
60  ) -> Result<bytes::Bytes>
61  {
62    let qos = rumqttc::v5::mqttbytes::QoS::ExactlyOnce;
63    let topic_string = topic.into();
64    let correlation_data = bytes::Bytes::copy_from_slice(uuid::Uuid::new_v4().as_bytes());
65    let response_topic = format!("{}/response/{}", topic_string, uuid::Uuid::new_v4());
66    let client = self.client.clone();
67    let mut receiver_channel = client
68      .clone()
69      .get_or_create_raw_subscription(&response_topic, qos, 5)
70      .await?;
71    client
72      .publish_raw(
73        topic_string,
74        qos,
75        true,
76        message,
77        Some(rumqttc::v5::mqttbytes::v5::PublishProperties {
78          content_type,
79          response_topic: Some(response_topic),
80          correlation_data: Some(correlation_data.clone()),
81          ..Default::default()
82        }),
83      )
84      .await?;
85    loop
86    {
87      let msg = receiver_channel.recv().await?;
88      if let Some(properties) = msg.properties
89      {
90        if let Some(received_correlation_data) = properties.correlation_data
91        {
92          if received_correlation_data == correlation_data
93          {
94            return Ok(msg.payload);
95          }
96        }
97      }
98    }
99  }
100
101  /// Call a service using json enconding
102  #[cfg(feature = "json")]
103  pub fn call_json_service<TRequest, TResponse>(
104    self,
105    topic: impl Into<String>,
106    request: &TRequest,
107  ) -> Result<impl Future<Output = Result<TResponse>>>
108  where
109    TRequest: serde::Serialize,
110    for<'de> TResponse: serde::Deserialize<'de>,
111  {
112    let serialized = serde_json::to_string(&request)?;
113    Ok(async {
114      let data = self
115        .call_raw_service(topic, serialized, Some("application/json".to_string()))
116        .await?;
117      Ok(serde_json::from_slice::<TResponse>(data.chunk())?)
118    })
119  }
120
121  /// Create a new service on the topic, the handler will be called
122  /// `cap` is the capacity of the input channel
123  pub async fn create_raw_service<THandler>(
124    self,
125    topic: impl Into<String>,
126    handler: THandler,
127    capacity: usize,
128  ) -> Result<()>
129  where
130    THandler: Send + Sync + Fn(bytes::Bytes) -> (bytes::Bytes, Option<String>) + 'static,
131  {
132    let handler = Arc::new(handler);
133    let qos = rumqttc::v5::mqttbytes::QoS::ExactlyOnce;
134    let topic = topic.into();
135    let client = self.client.clone();
136
137    let mut receiver = self
138      .client
139      .clone()
140      .create_raw_subscription(&topic, qos, capacity)
141      .await?;
142    while let Ok(msg) = receiver.recv().await
143    {
144      match msg.properties
145      {
146        Some(props) =>
147        {
148          match (props.response_topic, props.correlation_data)
149          {
150            (Some(response_topic), Some(correlation_data)) =>
151            {
152              let (payload, content_type) = handler(msg.payload);
153              let publish_props = rumqttc::v5::mqttbytes::v5::PublishProperties {
154                content_type,
155                correlation_data: Some(correlation_data),
156                ..Default::default()
157              };
158              if let Err(e) = client
159                .clone()
160                .publish_raw(response_topic, qos, true, payload, Some(publish_props))
161                .await
162              {
163                log::error!(
164                  "Error occured when publishing answer to service call: {:?}",
165                  e
166                );
167              }
168            }
169            _ =>
170            {
171              log::error!("Received message without response topic or correlation data for service {topic:?}.");
172            }
173          }
174        }
175        None =>
176        {
177          log::error!("Received message without properties for service {topic:?}.");
178        }
179      }
180    }
181    Ok(())
182  }
183
184  /// Create a new service on the topic, the handler will be called and the data is serialized with json
185  #[cfg(feature = "json")]
186  pub fn create_json_service<TRequest, TResponse, THandler>(
187    self,
188    topic: impl Into<String>,
189    handler: THandler,
190    cap: usize,
191  ) -> impl Future<Output = Result<()>>
192  where
193    TResponse: serde::Serialize,
194    THandler: Send + Sync + Fn(TRequest) -> TResponse + 'static,
195    for<'de> TRequest: serde::Deserialize<'de>,
196  {
197    self.create_raw_service(
198      topic,
199      move |data: bytes::Bytes| {
200        let request = serde_json::from_slice::<TRequest>(data.chunk());
201        match request
202        {
203          Ok(request) =>
204          {
205            let resp = handler(request);
206
207            let json_data = serde_json::to_string(&resp);
208            match json_data
209            {
210              Ok(json_data) => (json_data.into(), Some("application/json".to_string())),
211              Err(err) =>
212              {
213                log::error!("Error during serialization {err:?}");
214                (bytes::Bytes::new(), None)
215              }
216            }
217          }
218          Err(err) =>
219          {
220            log::error!("Error during deserialization {err:?}.");
221            (bytes::Bytes::new(), None)
222          }
223        }
224      },
225      cap,
226    )
227  }
228
229  /// Create a new service on the topic, the handler will be called
230  /// `capacity` is the capacity of the input channel
231  #[cfg(feature = "async_service")]
232  pub async fn create_raw_async_service<THandler, THandlerFuture>(
233    self,
234    runtime: impl yaaral::compat::CompatInterface,
235    topic: impl Into<String>,
236    handler: THandler,
237    capacity: usize,
238  ) -> Result<()>
239  where
240    THandler: Fn(bytes::Bytes) -> THandlerFuture + Send + Sync + 'static,
241    THandlerFuture: Future<Output = (bytes::Bytes, Option<String>)> + Send,
242  {
243    let handler = Arc::new(handler);
244    let qos = rumqttc::v5::mqttbytes::QoS::ExactlyOnce;
245    let topic = topic.into();
246    let client = self.client.clone();
247
248    let mut receiver = self
249      .client
250      .clone()
251      .create_raw_subscription(&topic, qos, capacity)
252      .await?;
253    while let Ok(msg) = receiver.recv().await
254    {
255      match msg.properties
256      {
257        Some(props) =>
258        {
259          match (props.response_topic, props.correlation_data)
260          {
261            (Some(response_topic), Some(correlation_data)) =>
262            {
263              let client = client.clone();
264              let handler = handler.clone();
265              let e = runtime
266                .spawn_tokio_task(async move {
267                  let (payload, content_type) = handler(msg.payload).await;
268                  let publish_props = rumqttc::v5::mqttbytes::v5::PublishProperties {
269                    content_type,
270                    correlation_data: Some(correlation_data),
271                    ..Default::default()
272                  };
273                  if let Err(e) = client
274                    .clone()
275                    .publish_raw(response_topic, qos, true, payload, Some(publish_props))
276                    .await
277                  {
278                    log::error!(
279                      "Error occured when publishing answer to service call: {:?}",
280                      e
281                    );
282                  }
283                })
284                .detach();
285              if let Err(e) = e
286              {
287                log::error!("An error occured while spawning service task: {:?}", e);
288              }
289            }
290            _ =>
291            {
292              log::error!("Received message without response topic or correlation data for service {topic:?}.");
293            }
294          }
295        }
296        None =>
297        {
298          log::error!("Received message without properties for service {topic:?}.");
299        }
300      }
301    }
302    Ok(())
303  }
304
305  /// Create a new service on the topic, the handler will be called and the data is serialized with json
306  #[cfg(all(feature = "json", feature = "async_service"))]
307  pub fn create_json_async_service<TRequest, TResponse, THandler, THandlerFuture>(
308    self,
309    runtime: impl yaaral::compat::CompatInterface,
310    topic: impl Into<String>,
311    handler: THandler,
312    capacity: usize,
313  ) -> impl Future<Output = Result<()>> + Send
314  where
315    TResponse: serde::Serialize,
316    TRequest: Send,
317    THandler: Fn(TRequest) -> THandlerFuture + Send + Sync + 'static,
318    THandlerFuture: Future<Output = TResponse> + Send,
319    for<'de> TRequest: serde::Deserialize<'de>,
320  {
321    let topic = topic.into();
322    let handler = Arc::new(handler);
323    // explicitly box the closure with lifetime 'a
324    let closure = move |data: bytes::Bytes| {
325      let handler = handler.clone();
326      async move {
327        match serde_json::from_slice::<TRequest>(data.chunk())
328        {
329          Ok(request) =>
330          {
331            let resp = handler(request).await;
332            match serde_json::to_string(&resp)
333            {
334              Ok(json_data) => (json_data.into(), Some("application/json".to_string())),
335              Err(err) =>
336              {
337                log::error!("Error during serialization {err:?}");
338                (bytes::Bytes::new(), None)
339              }
340            }
341          }
342          Err(err) =>
343          {
344            log::error!("Error during deserialization {err:?}.");
345            (bytes::Bytes::new(), None)
346          }
347        }
348      }
349    };
350
351    self.create_raw_async_service(runtime, topic, closure, capacity)
352  }
353}
354
355#[cfg(test)]
356mod tests
357{
358  use super::Client;
359  use bytes::Buf;
360  use futures::FutureExt;
361  use std::env;
362  #[derive(serde::Serialize, serde::Deserialize)]
363  struct Request
364  {
365    a: f32,
366    b: f32,
367  }
368  #[derive(serde::Serialize, serde::Deserialize)]
369  struct Response
370  {
371    r: f32,
372  }
373  #[test]
374  fn test_raw_services()
375  {
376    let rt = tokio::runtime::Runtime::new().unwrap();
377    let (connection, task) = mqtt_channel::Client::build(
378      "name-of-the-client-raw",
379      env::var("MQTT_SERVICE_MQTT_SERVER_HOSTNAME").unwrap_or("localhost".to_string()),
380      1883,
381    )
382    .start();
383    let connection = Client::new(connection);
384    rt.spawn(task);
385    rt.spawn(
386      connection
387        .clone()
388        .create_raw_service(
389          "mqtt-service/test_raw/addition",
390          |data| {
391            let request = serde_json::from_slice::<Request>(data.chunk()).unwrap();
392            let json = serde_json::to_string(&Response {
393              r: request.a + request.b,
394            })
395            .unwrap()
396            .into();
397            (json, Some("application/json".to_string()))
398          },
399          10,
400        )
401        .map(|r| r.unwrap()),
402    );
403    let fut = connection.call_raw_service(
404      "mqtt-service/test_raw/addition",
405      serde_json::to_string(&Request { a: 1.0, b: 2.0 }).unwrap(),
406      Some("application/json".to_string()),
407    );
408
409    let data = rt.block_on(fut).unwrap();
410    let res = serde_json::from_slice::<Response>(data.chunk()).unwrap();
411    assert_eq!(res.r, 3.0);
412  }
413  #[test]
414  #[cfg(feature = "async_service")]
415  fn test_raw_services_async()
416  {
417    let rt = tokio::runtime::Runtime::new().unwrap();
418    let (connection, task) = mqtt_channel::Client::build(
419      "name-of-the-client-raw-async",
420      env::var("MQTT_SERVICE_MQTT_SERVER_HOSTNAME").unwrap_or("localhost".to_string()),
421      1883,
422    )
423    .start();
424    let connection = Client::new(connection);
425    let yrt: yaaral::tokio::Runtime = rt.handle().into();
426    rt.spawn(task);
427    rt.spawn(
428      connection
429        .clone()
430        .create_raw_async_service(
431          yrt,
432          "mqtt-service/test_raw_async/addition",
433          |data| async move {
434            let request = serde_json::from_slice::<Request>(data.chunk()).unwrap();
435            let json = serde_json::to_string(&Response {
436              r: request.a + request.b,
437            })
438            .unwrap()
439            .into();
440            (json, Some("application/json".to_string()))
441          },
442          10,
443        )
444        .map(|r| r.unwrap()),
445    );
446    let fut = connection.call_raw_service(
447      "mqtt-service/test_raw_async/addition",
448      serde_json::to_string(&Request { a: 1.0, b: 2.0 }).unwrap(),
449      Some("application/json".to_string()),
450    );
451
452    let data = rt.block_on(fut).unwrap();
453    let res = serde_json::from_slice::<Response>(data.chunk()).unwrap();
454    assert_eq!(res.r, 3.0);
455  }
456  #[test]
457  #[cfg(feature = "json")]
458  fn test_json_services()
459  {
460    let rt = tokio::runtime::Runtime::new().unwrap();
461    let (connection, task) = mqtt_channel::Client::build(
462      "name-of-the-client-json",
463      env::var("MQTT_SERVICE_MQTT_SERVER_HOSTNAME").unwrap_or("localhost".to_string()),
464      1883,
465    )
466    .start();
467    let connection = Client::new(connection);
468    rt.spawn(task);
469    rt.spawn(
470      connection
471        .clone()
472        .create_json_service(
473          "mqtt-service/test_json/addition",
474          Box::new(|request: Request| Response {
475            r: request.a + request.b,
476          }),
477          10,
478        )
479        .map(|r| r.unwrap()),
480    );
481    let fut = connection
482      .call_json_service::<Request, Response>(
483        "mqtt-service/test_json/addition",
484        &Request { a: 1.0, b: 2.0 },
485      )
486      .unwrap();
487
488    let res = rt.block_on(fut).unwrap();
489    assert_eq!(res.r, 3.0);
490  }
491  #[test]
492  #[cfg(all(feature = "json", feature = "async_service"))]
493  fn test_async_json_services()
494  {
495    let rt = tokio::runtime::Runtime::new().unwrap();
496    let (connection, task) = mqtt_channel::Client::build(
497      "name-of-the-client-json-async",
498      env::var("MQTT_SERVICE_MQTT_SERVER_HOSTNAME").unwrap_or("localhost".to_string()),
499      1883,
500    )
501    .start();
502    let connection = Client::new(connection);
503    rt.spawn(task);
504    let yrt: yaaral::tokio::Runtime = rt.handle().into();
505    rt.spawn(
506      connection
507        .clone()
508        .create_json_async_service(
509          yrt,
510          "mqtt-service/test_json_async/addition",
511          |request: Request| async move {
512            Response {
513              r: request.a + request.b,
514            }
515          },
516          10,
517        )
518        .map(|r| r.unwrap()),
519    );
520    let fut = connection
521      .call_json_service::<Request, Response>(
522        "mqtt-service/test_json_async/addition",
523        &Request { a: 1.0, b: 2.0 },
524      )
525      .unwrap();
526
527    let res = rt.block_on(fut).unwrap();
528    assert_eq!(res.r, 3.0);
529  }
530}