mqtt_manager/lib.rs
1//! Convenience class for managing the MQTT connection
2//!
3//! # Examples
4//!
5//! ```no_run
6//! use mqtt_manager::*;
7//! use std::time::Duration;
8//!
9//! async fn handle_msg_a(pubdata: Publish) {
10//! println!("Received msg A: {:?}", pubdata.payload);
11//! }
12//!
13//! async fn handle_msg_b(pubdata: Publish) {
14//! println!("Received msg A: {:?}", pubdata.payload);
15//! }
16//!
17//! #[tokio::main]
18//! async fn main() {
19//! let mut mgr = MqttManager::new("mqtt://localhost:1883/override_client_id");
20//! mgr.subscribe("msg/a", 0, make_callback!(handle_msg_a)).await;
21//! mgr.subscribe("msg/b", 0, make_callback!(handle_msg_b)).await;
22//! mgr.publish("msg/a", "test", 0).await;
23//! loop {
24//! tokio::select! {
25//! _ = mgr.process() => (),
26//! _ = tokio::signal::ctrl_c() => {
27//! mgr.disconnect().await;
28//! break;
29//! }
30//! }
31//! }
32//! }
33//! ```
34use std::pin::Pin;
35use std::{time::Duration, future::Future};
36use std::collections::HashMap;
37
38pub use rumqttc::Publish;
39use rumqttc::{MqttOptions, AsyncClient, EventLoop, Event, mqttbytes::matches};
40
41/// Type for subscription callbacks. See [`crate::make_callback`]
42pub type CallbackFn = Box<dyn Fn(Publish) -> Pin<Box<dyn Future<Output=()>>>>;
43
44/// A macro to turn an async fn into a callback to be passed to [`MqttManager::subscribe`]
45///
46/// # Example
47/// ```
48/// use mqtt_manager::{make_callback, Publish};
49/// async fn callback(pubpkt: Publish) {}
50///
51/// let cb_handle = make_callback!(callback);
52/// ```
53#[macro_export]
54macro_rules! make_callback {
55 ($fn:expr) => {
56 Box::new(move |publish| Box::pin($fn(publish)))
57 };
58}
59
60/// The main MQTT manager struct
61pub struct MqttManager {
62 client: AsyncClient,
63 eventloop: EventLoop,
64 subscriptions: HashMap<String, CallbackFn>,
65}
66
67#[allow(dead_code)]
68impl MqttManager {
69 /// Create a new MqttManager from a host URL in the form:
70 /// - mqtt://localhost:1883?client_id=client2
71 pub fn new(host_url: &str) -> MqttManager {
72 let mut opts = if host_url.contains("client_id=") {
73 MqttOptions::parse_url(host_url).expect("Error parsing MQTT URL")
74 } else {
75 MqttOptions::parse_url(format!("{host_url}?client_id=busbridge")).expect("Error parsing MQTT URL")
76 };
77 opts.set_keep_alive(Duration::from_secs(60));
78 let (client, eventloop) = AsyncClient::new(opts, 10);
79 MqttManager {
80 client,
81 eventloop,
82 subscriptions: HashMap::new(),
83 }
84 }
85
86 /// Send a DISCONNECT to clean up the connection.
87 /// ## panic
88 /// Panics when the call to disconnect fails.
89 pub async fn disconnect(&mut self) {
90 self.client.disconnect().await.expect("Unable to disconnect");
91 }
92
93 /// Publish data to a topic.
94 ///
95 /// topic is the topic name, payload is the payload bytes, and qos is the MQTT qos in the range 0-2
96 /// ## panic
97 /// This panics if the qos is invalid or if there is an error publishing.
98 pub async fn publish<T, U>(&mut self, topic: T, payload: U, qos: u8)
99 where T: Into<String>,
100 U: Into<Vec<u8>> {
101 self.client.publish(topic, rumqttc::qos(qos).expect("Invalid QoS value"), false, payload).await.expect("Error publishing")
102 }
103
104 /// Subscribe to a topic.
105 ///
106 /// topic is the subscription topic including optional wildcards per the MQTT spec.
107 /// qos is the subscription qos in the range 0-3. callback is a callback async function
108 /// which should be wrapped by calling [`crate::make_callback`].
109 /// ## panic
110 /// This panics if the qos is invalid or if there is an error subscribing.
111 pub async fn subscribe<T: Into<String>>(&mut self, topic: T, qos: u8, callback: CallbackFn) {
112 let t = topic.into();
113 self.client.subscribe(t.clone(), rumqttc::qos(qos).expect("Invalid QoS value")).await.expect("Failed to subscribe");
114 self.subscriptions.insert(t, callback);
115 }
116
117 /// Wait for a single packet and process reconnects, pings, etc.
118 ///
119 /// This should be called reguarly, either in an event loop or in a background thread using tokio::spawn.
120 ///
121 /// If a SUBSCRIBE packet is returned and there's a registered callback it will be await'd.
122 pub async fn process(&mut self) -> std::result::Result<(), rumqttc::ConnectionError> {
123 match self.eventloop.poll().await {
124 Ok(Event::Incoming(rumqttc::Packet::Publish(data))) => {
125 for (filter, callback) in self.subscriptions.iter() {
126 if matches(&data.topic, filter) {
127 callback(data).await;
128 break;
129 }
130 }
131 Ok(())
132 }
133 Ok(_) => Ok(()),
134 Err(err) => Err(err)
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 /// These tests require an MQTT broker.
142 /// `docker run eclipse-mosquitto:2.0` will run mosquitto on localhost:1883
143 /// Other brokers have not been tested except mqttest which is known to not
144 /// work as it doesn't support QOS 2.
145 use bytes::Bytes;
146
147 use super::*;
148
149 const MQTT_URL: &str = "mqtt://localhost:1883?client_id=";
150
151 #[tokio::test]
152 async fn connect() {
153 let mut mgr = MqttManager::new(format!("{MQTT_URL}rts_connect").as_str());
154 mgr.process().await.expect("Connection refused. Make sure you are running a broker on localhost:1883");
155 mgr.disconnect().await;
156 }
157
158 #[tokio::test]
159 async fn publish() {
160 let mut mgr = MqttManager::new(format!("{MQTT_URL}rts_publish").as_str());
161 mgr.publish("test", "bar", 0).await;
162 mgr.publish("test", "ack", 1).await;
163 mgr.publish("test", "blah", 2).await;
164 for _ in 0..8 { // We expect exactly 8 packets for the above sequence
165 tokio::time::timeout(Duration::from_secs(5), mgr.process()).await.expect("Error, timed out waiting on packet").unwrap();
166 }
167 mgr.disconnect().await;
168 }
169
170 #[tokio::test]
171 async fn subscribe() {
172 let mut mgr = MqttManager::new(format!("{MQTT_URL}rts_subscribe").as_str());
173 mgr.process().await.unwrap();
174 mgr.subscribe("test2", 0, make_callback!(|pkt: Publish| async move { assert_eq!(pkt.payload, Bytes::from("test")); })).await;
175 mgr.process().await.unwrap();
176 mgr.publish("test2", "test", 0).await;
177 mgr.process().await.unwrap();
178 mgr.process().await.unwrap();
179 mgr.process().await.unwrap();
180 mgr.disconnect().await;
181 }
182}