mqtt_service/
lib.rs

1#![doc = include_str!("../README.MD")]
2#![warn(missing_docs)]
3#![deny(warnings)]
4
5#[cfg(feature = "json")]
6use bytes::Buf;
7use std::sync::Arc;
8
9type Result<T> = std::result::Result<T, Error>;
10
11// Structures
12
13#[derive(thiserror::Error, Debug)]
14#[allow(missing_docs)]
15pub enum Error
16{
17  #[error("Future was cancelled.")]
18  CancelledFuture(String),
19  #[error("An error occured in the MQTT channel {0}.")]
20  MqttChannelError(#[from] mqtt_channel::Error),
21  #[error("Channel was closed for topic {0}.")]
22  ChannelClosed(String),
23  #[error("Send error when broadcasting message {0}")]
24  AsyncBroadcastSend(#[from] async_broadcast::SendError<mqtt_channel::RawMessage>),
25  #[error("Recv error when broadcasting message {0}")]
26  AsyncBroadcastRecv(#[from] async_broadcast::RecvError),
27  #[cfg(feature = "json")]
28  #[error("Error during json serialization {0}")]
29  JsonSerialisationError(#[from] serde_json::Error),
30}
31
32/// Handle to the connection
33#[derive(Clone)]
34pub struct Client
35{
36  client: mqtt_channel::Client,
37}
38
39impl Client
40{
41  /// Create a new MQTT Client, with the given name to connect at the given and MQTT broker
42  pub fn new(client: mqtt_channel::Client) -> Client
43  {
44    Client { client }
45  }
46  /// Implement call service using the MQTT Request-Response pattern
47  pub async fn call_raw_service(
48    self,
49    topic: impl Into<String>,
50    message: impl Into<bytes::Bytes>,
51    content_type: Option<String>,
52  ) -> Result<bytes::Bytes>
53  {
54    let qos = rumqttc::v5::mqttbytes::QoS::ExactlyOnce;
55    let topic_string = topic.into();
56    let correlation_data = bytes::Bytes::copy_from_slice(uuid::Uuid::new_v4().as_bytes());
57    let response_topic = format!(
58      "{}/response/{}",
59      topic_string,
60      uuid::Uuid::new_v4().to_string()
61    );
62    let client = self.client.clone();
63    let mut receiver_channel = client
64      .clone()
65      .get_or_create_raw_subscription(&response_topic, qos, 5)
66      .await?;
67    client
68      .publish_raw(
69        topic_string,
70        qos,
71        true,
72        message,
73        Some(rumqttc::v5::mqttbytes::v5::PublishProperties {
74          content_type,
75          response_topic: Some(response_topic),
76          correlation_data: Some(correlation_data.clone()),
77          ..Default::default()
78        }),
79      )
80      .await?;
81    loop
82    {
83      let msg = receiver_channel.recv().await?;
84      if let Some(properties) = msg.properties
85      {
86        if let Some(received_correlation_data) = properties.correlation_data
87        {
88          if received_correlation_data == correlation_data
89          {
90            return Ok(msg.payload);
91          }
92        }
93      }
94    }
95  }
96
97  /// Call a service using json enconding
98  #[cfg(feature = "json")]
99  pub fn call_json_service<TRequest, TResponse>(
100    self,
101    topic: impl Into<String>,
102    request: &TRequest,
103  ) -> Result<impl std::future::Future<Output = Result<TResponse>>>
104  where
105    TRequest: serde::Serialize,
106    for<'de> TResponse: serde::Deserialize<'de>,
107  {
108    let serialized = serde_json::to_string(&request)?;
109    Ok(async {
110      let data = self
111        .call_raw_service(topic, serialized, Some("application/json".to_string()))
112        .await?;
113      Ok(serde_json::from_slice::<TResponse>(data.chunk())?)
114    })
115  }
116
117  /// Create a new service on the topic, the handler will be called
118  /// `cap` is the capacity of the input channel
119  pub async fn create_raw_service(
120    self,
121    topic: impl Into<String>,
122    handler: Box<dyn Send + Sync + Fn(bytes::Bytes) -> (bytes::Bytes, Option<String>)>,
123    cap: usize,
124  ) -> Result<()>
125  {
126    let handler = Arc::new(handler);
127    let qos = rumqttc::v5::mqttbytes::QoS::ExactlyOnce;
128    let topic = topic.into();
129    let client = self.client.clone();
130
131    let mut receiver = self
132      .client
133      .clone()
134      .create_raw_subscription(&topic, qos, cap)
135      .await?;
136    loop
137    {
138      let msg = receiver.recv().await?;
139      if let Some(properties) = msg.properties
140      {
141        if let Some(correlation_data) = properties.correlation_data
142        {
143          if let Some(response_topic) = properties.response_topic
144          {
145            let (payload, content_type) = handler(msg.payload);
146            if let Err(e) = client
147              .clone()
148              .publish_raw(
149                response_topic,
150                qos,
151                true,
152                payload,
153                Some(rumqttc::v5::mqttbytes::v5::PublishProperties {
154                  content_type,
155                  correlation_data: Some(correlation_data.clone()),
156                  ..Default::default()
157                }),
158              )
159              .await
160            {
161              log::error!(
162                "Error occured when publishing answer to service call: {:?}",
163                e
164              );
165            }
166          }
167          else
168          {
169            log::error!("Received message without response topic for service {topic:?}.");
170          }
171        }
172        else
173        {
174          log::error!("Received message without correlation data for service {topic:?}.");
175        }
176      }
177      else
178      {
179        log::error!("Received message without properties for service {topic:?}.");
180      }
181    }
182  }
183
184  /// Create a new service on the topic, the handler will be called and the data is serialized with json
185  #[cfg(feature = "json")]
186  pub fn create_json_service<TRequest, TResponse>(
187    self,
188    topic: impl Into<String>,
189    handler: Box<dyn Send + Sync + Fn(&TRequest) -> TResponse>,
190    cap: usize,
191  ) -> impl std::future::Future<Output = Result<()>>
192  where
193    TResponse: serde::Serialize + 'static,
194    for<'de> TRequest: serde::Deserialize<'de> + 'static,
195  {
196    self.create_raw_service(
197      topic,
198      Box::new(move |data: bytes::Bytes| {
199        let request = serde_json::from_slice::<TRequest>(data.chunk());
200        match request
201        {
202          Ok(request) =>
203          {
204            let resp = handler(&request);
205
206            let json_data = serde_json::to_string(&resp);
207            match json_data
208            {
209              Ok(json_data) => (json_data.into(), Some("application/json".to_string())),
210              Err(err) =>
211              {
212                log::error!("Error during serialization {err:?}");
213                (bytes::Bytes::new(), None)
214              }
215            }
216          }
217          Err(err) =>
218          {
219            log::error!("Error during deserialization {err:?}.");
220            (bytes::Bytes::new(), None)
221          }
222        }
223      }),
224      cap,
225    )
226  }
227}
228
229#[cfg(test)]
230mod tests
231{
232  use super::Client;
233  use bytes::Buf;
234  use futures::FutureExt;
235  use std::env;
236  #[derive(serde::Serialize, serde::Deserialize)]
237  struct Request
238  {
239    a: f32,
240    b: f32,
241  }
242  #[derive(serde::Serialize, serde::Deserialize)]
243  struct Response
244  {
245    r: f32,
246  }
247  #[test]
248  fn test_raw_services()
249  {
250    let rt = tokio::runtime::Runtime::new().unwrap();
251    let (connection, task) = mqtt_channel::Client::build(
252      "name-of-the-client-raw",
253      env::var("MQTT_SERVICE_MQTT_SERVER_HOSTNAME").unwrap_or("localhost".to_string()),
254      1883,
255    )
256    .start();
257    let connection = Client::new(connection);
258    rt.spawn(task);
259    rt.spawn(
260      connection
261        .clone()
262        .create_raw_service(
263          "mqtt-service/test_raw/addition",
264          Box::new(|data| {
265            let request = serde_json::from_slice::<Request>(data.chunk()).unwrap();
266            let json = serde_json::to_string(&Response {
267              r: request.a + request.b,
268            })
269            .unwrap()
270            .into();
271            (json, Some("application/json".to_string()))
272          }),
273          10,
274        )
275        .map(|r| r.unwrap()),
276    );
277    let fut = connection.call_raw_service(
278      "mqtt-service/test_raw/addition",
279      serde_json::to_string(&Request { a: 1.0, b: 2.0 }).unwrap(),
280      Some("application/json".to_string()),
281    );
282
283    let data = rt.block_on(fut).unwrap();
284    let res = serde_json::from_slice::<Response>(data.chunk()).unwrap();
285    assert_eq!(res.r, 3.0);
286  }
287  #[test]
288  #[cfg(feature = "json")]
289  fn test_json_services()
290  {
291    let rt = tokio::runtime::Runtime::new().unwrap();
292    let (connection, task) = mqtt_channel::Client::build(
293      "name-of-the-client-json",
294      env::var("MQTT_SERVICE_MQTT_SERVER_HOSTNAME").unwrap_or("localhost".to_string()),
295      1883,
296    )
297    .start();
298    let connection = Client::new(connection);
299    rt.spawn(task);
300    rt.spawn(
301      connection
302        .clone()
303        .create_json_service(
304          "mqtt-service/test_json/addition",
305          Box::new(|request: &Request| Response {
306            r: request.a + request.b,
307          }),
308          10,
309        )
310        .map(|r| r.unwrap()),
311    );
312    let fut = connection
313      .call_json_service::<Request, Response>(
314        "mqtt-service/test_json/addition",
315        &Request { a: 1.0, b: 2.0 },
316      )
317      .unwrap();
318
319    let res = rt.block_on(fut).unwrap();
320    assert_eq!(res.r, 3.0);
321  }
322}