mqtt_manager/
lib.rs

1//! Convenience class for managing the MQTT connection
2//! 
3//! # Examples
4//! 
5//! ```no_run
6//! use mqtt_manager::*;
7//! use std::time::Duration;
8//! 
9//! async fn handle_msg_a(pubdata: Publish) {
10//!     println!("Received msg A: {:?}", pubdata.payload);
11//! }
12//! 
13//! async fn handle_msg_b(pubdata: Publish) {
14//!     println!("Received msg A: {:?}", pubdata.payload);
15//! }
16//! 
17//! #[tokio::main]
18//! async fn main() {
19//!     let mut mgr = MqttManager::new("mqtt://localhost:1883/override_client_id");
20//!     mgr.subscribe("msg/a", 0, make_callback!(handle_msg_a)).await;
21//!     mgr.subscribe("msg/b", 0, make_callback!(handle_msg_b)).await;
22//!     mgr.publish("msg/a", "test", 0).await;
23//!     loop {
24//!         tokio::select! {
25//!             _ = mgr.process() => (),
26//!             _ = tokio::signal::ctrl_c() => {
27//!                 mgr.disconnect().await;
28//!                 break;
29//!             }
30//!         }
31//!     }
32//! }
33//! ```
34use std::pin::Pin;
35use std::{time::Duration, future::Future};
36use std::collections::HashMap;
37
38pub use rumqttc::Publish;
39use rumqttc::{MqttOptions, AsyncClient, EventLoop, Event, mqttbytes::matches};
40
41/// Type for subscription callbacks. See [`crate::make_callback`]
42pub type CallbackFn = Box<dyn Fn(Publish) -> Pin<Box<dyn Future<Output=()>>>>;
43
44/// A macro to turn an async fn into a callback to be passed to [`MqttManager::subscribe`]
45/// 
46/// # Example
47/// ```
48/// use mqtt_manager::{make_callback, Publish};
49/// async fn callback(pubpkt: Publish) {}
50/// 
51/// let cb_handle = make_callback!(callback);
52/// ```
53#[macro_export]
54macro_rules! make_callback {
55    ($fn:expr) => {
56        Box::new(move |publish| Box::pin($fn(publish)))
57    };
58}
59
60/// The main MQTT manager struct
61pub struct MqttManager {
62    client: AsyncClient,
63    eventloop: EventLoop,
64    subscriptions: HashMap<String, CallbackFn>,
65}
66
67#[allow(dead_code)]
68impl MqttManager {
69    /// Create a new MqttManager from a host URL in the form:
70    /// - mqtt://localhost:1883?client_id=client2
71    pub fn new(host_url: &str) -> MqttManager {
72        let mut opts = if host_url.contains("client_id=") {
73            MqttOptions::parse_url(host_url).expect("Error parsing MQTT URL")
74        } else {
75            MqttOptions::parse_url(format!("{host_url}?client_id=busbridge")).expect("Error parsing MQTT URL")
76        };
77        opts.set_keep_alive(Duration::from_secs(60));
78        let (client, eventloop) = AsyncClient::new(opts, 10);
79        MqttManager {
80            client,
81            eventloop,
82            subscriptions: HashMap::new(),
83        }
84    }
85
86    /// Send a DISCONNECT to clean up the connection.
87    /// ## panic
88    /// Panics when the call to disconnect fails.
89    pub async fn disconnect(&mut self) {
90        self.client.disconnect().await.expect("Unable to disconnect");
91    }
92    
93    /// Publish data to a topic.
94    /// 
95    /// topic is the topic name, payload is the payload bytes, and qos is the MQTT qos in the range 0-2
96    /// ## panic
97    /// This panics if the qos is invalid or if there is an error publishing.
98    pub async fn publish<T, U>(&mut self, topic: T, payload: U, qos: u8)
99    where T: Into<String>,
100          U: Into<Vec<u8>> {
101        self.client.publish(topic, rumqttc::qos(qos).expect("Invalid QoS value"), false, payload).await.expect("Error publishing")
102    }
103
104    /// Subscribe to a topic.
105    /// 
106    /// topic is the subscription topic including optional wildcards per the MQTT spec.
107    /// qos is the subscription qos in the range 0-3. callback is a callback async function
108    /// which should be wrapped by calling [`crate::make_callback`].
109    /// ## panic
110    /// This panics if the qos is invalid or if there is an error subscribing.
111    pub async fn subscribe<T: Into<String>>(&mut self, topic: T, qos: u8, callback: CallbackFn) {
112        let t = topic.into();
113        self.client.subscribe(t.clone(), rumqttc::qos(qos).expect("Invalid QoS value")).await.expect("Failed to subscribe");
114        self.subscriptions.insert(t, callback);
115    }
116
117    /// Wait for a single packet and process reconnects, pings, etc.
118    /// 
119    /// This should be called reguarly, either in an event loop or in a background thread using tokio::spawn.
120    /// 
121    /// If a SUBSCRIBE packet is returned and there's a registered callback it will be await'd.
122    pub async fn process(&mut self) -> std::result::Result<(), rumqttc::ConnectionError> {
123        match self.eventloop.poll().await {
124            Ok(Event::Incoming(rumqttc::Packet::Publish(data))) => {
125                for (filter, callback) in self.subscriptions.iter() {
126                    if matches(&data.topic, filter) {
127                        callback(data).await;
128                        break;
129                    }
130                }
131                Ok(())
132            }
133            Ok(_) => Ok(()),
134            Err(err) => Err(err)
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    /// These tests require an MQTT broker.
142    /// `docker run eclipse-mosquitto:2.0` will run mosquitto on localhost:1883
143    /// Other brokers have not been tested except mqttest which is known to not
144    /// work as it doesn't support QOS 2.
145    use bytes::Bytes;
146
147    use super::*;
148
149    const MQTT_URL: &str = "mqtt://localhost:1883?client_id=";
150
151    #[tokio::test]
152    async fn connect() {
153        let mut mgr = MqttManager::new(format!("{MQTT_URL}rts_connect").as_str());
154        mgr.process().await.expect("Connection refused. Make sure you are running a broker on localhost:1883");
155        mgr.disconnect().await;
156    }
157
158    #[tokio::test]
159    async fn publish() {
160        let mut mgr = MqttManager::new(format!("{MQTT_URL}rts_publish").as_str());
161        mgr.publish("test", "bar", 0).await;
162        mgr.publish("test", "ack", 1).await;
163        mgr.publish("test", "blah", 2).await;
164        for _ in 0..8 {  // We expect exactly 8 packets for the above sequence
165            tokio::time::timeout(Duration::from_secs(5), mgr.process()).await.expect("Error, timed out waiting on packet").unwrap();
166        }
167        mgr.disconnect().await;
168    }
169
170    #[tokio::test]
171    async fn subscribe() {
172        let mut mgr = MqttManager::new(format!("{MQTT_URL}rts_subscribe").as_str());
173        mgr.process().await.unwrap();
174        mgr.subscribe("test2", 0, make_callback!(|pkt: Publish| async move { assert_eq!(pkt.payload, Bytes::from("test")); })).await;
175        mgr.process().await.unwrap();
176        mgr.publish("test2", "test", 0).await;
177        mgr.process().await.unwrap();
178        mgr.process().await.unwrap();
179        mgr.process().await.unwrap();
180        mgr.disconnect().await;
181    }
182}