modbus_mqtt/
mqtt.rs

1use std::collections::HashMap;
2
3use bytes::Bytes;
4use rumqttc::{
5    mqttbytes::matches as matches_topic, AsyncClient, Event, EventLoop, MqttOptions, Publish,
6    Subscribe,
7};
8use tokio::{
9    select,
10    sync::mpsc::{self, channel, Receiver, Sender},
11};
12use tracing::{debug, info, warn};
13
14#[derive(Debug)]
15pub struct Payload {
16    pub bytes: Bytes,
17    pub topic: String,
18}
19
20#[derive(Debug, Clone)]
21pub enum Message {
22    Subscribe(Subscribe, Sender<Payload>),
23    Publish(Publish),
24    Shutdown,
25}
26
27pub(crate) async fn new(options: MqttOptions) -> Connection {
28    let (client, event_loop) = AsyncClient::new(options, 32);
29
30    let (tx, rx) = channel(32);
31    Connection {
32        client,
33        event_loop,
34        subscriptions: HashMap::new(),
35        tx,
36        rx,
37    }
38}
39
40// Maintain internal subscriptions as well as MQTT subscriptions. Relay all received messages on MQTT subscribed topics
41// to internal components who have a matching topic. Unsubscribe topics when no one is listening anymore.
42pub(crate) struct Connection {
43    subscriptions: HashMap<String, Vec<Sender<Payload>>>,
44    tx: Sender<Message>,
45    rx: Receiver<Message>,
46    client: AsyncClient,
47    event_loop: EventLoop,
48}
49
50impl Connection {
51    pub async fn run(&mut self) -> crate::Result<()> {
52        loop {
53            select! {
54                event = self.event_loop.poll() => {
55                    self.handle_event(event?).await?
56                }
57                request = self.rx.recv() => {
58                    match request {
59                        None => return Ok(()),
60                        Some(Message::Shutdown) => {
61                            info!("MQTT connection shutting down");
62                            break;
63                        }
64                        Some(req) => self.handle_request(req).await?,
65                    }
66                }
67            }
68        }
69
70        Ok(())
71    }
72
73    pub fn handle(&self, prefix: String) -> Handle {
74        Handle {
75            prefix,
76            tx: self.tx.clone(),
77        }
78    }
79
80    async fn handle_event(&mut self, event: Event) -> crate::Result<()> {
81        use rumqttc::Incoming;
82
83        #[allow(clippy::single_match)]
84        match event {
85            Event::Incoming(Incoming::Publish(Publish { topic, payload, .. })) => {
86                debug!(%topic, ?payload, "publish");
87                self.handle_data(topic, payload).await?;
88            }
89            // event => debug!(?event),
90            _ => {}
91        }
92
93        Ok(())
94    }
95
96    #[tracing::instrument(level = "debug", skip(self), fields(subscriptions = ?self.subscriptions.keys()))]
97    async fn handle_data(&mut self, topic: String, bytes: Bytes) -> crate::Result<()> {
98        let mut targets = vec![];
99
100        // Remove subscriptions whose channels are closed, adding matching channels to the `targets` vec.
101        self.subscriptions.retain(|filter, channels| {
102            if matches_topic(&topic, filter) {
103                channels.retain(|channel| {
104                    if channel.is_closed() {
105                        warn!(?channel, "closed");
106                        false
107                    } else {
108                        targets.push(channel.clone());
109                        true
110                    }
111                });
112                !channels.is_empty()
113            } else {
114                true
115            }
116        });
117
118        for target in targets {
119            if target
120                .send(Payload {
121                    topic: topic.clone(),
122                    bytes: bytes.clone(),
123                })
124                .await
125                .is_err()
126            {
127                // These will be removed above next time a matching payload is removed
128            }
129        }
130        Ok(())
131    }
132
133    async fn handle_request(&mut self, request: Message) -> crate::Result<()> {
134        debug!(?request);
135        match request {
136            Message::Publish(Publish {
137                topic,
138                payload,
139                qos,
140                retain,
141                ..
142            }) => {
143                self.client
144                    .publish_bytes(topic, qos, retain, payload)
145                    .await?
146            }
147            Message::Subscribe(Subscribe { filters, .. }, channel) => {
148                for filter in &filters {
149                    let channel = channel.clone();
150
151                    // NOTE: Curently allows multiple components to watch the same topic filter, but if there is no need
152                    // for this, it might make more sense to have it _replace_ the channel, so that old (stale)
153                    // components automatically finish running.
154                    match self.subscriptions.get_mut(&filter.path) {
155                        Some(channels) => channels.push(channel),
156                        None => {
157                            self.subscriptions
158                                .insert(filter.path.clone(), vec![channel]);
159                        }
160                    }
161                }
162
163                self.client.subscribe_many(filters).await?
164            }
165            Message::Shutdown => panic!("Handled by the caller"),
166        }
167        Ok(())
168    }
169}
170
171#[derive(Debug, Clone)]
172pub struct Handle {
173    prefix: String,
174    tx: Sender<Message>,
175}
176
177// IDEA: make subscribe+publish _generic_ over the payload type, as long as it implements a Payload trait we define,
178// which allows them to perform the serialization/deserialization to Bytes. For most domain types, the trait would be
179// implemented to use serde_json but for Bytes and Vec<u8> it would just return itself.
180// The return values may need to be crate::Result<Receiver<Option<T>> or crate::Result<Receiver<crate::Result<T>>>.
181impl Handle {
182    pub async fn subscribe(&self) -> crate::Result<Receiver<Payload>> {
183        let (tx_bytes, rx) = mpsc::channel(8);
184
185        let msg = Message::Subscribe(
186            Subscribe::new(&self.prefix, rumqttc::QoS::AtLeastOnce),
187            tx_bytes,
188        );
189        self.tx
190            .send(msg)
191            .await
192            .map_err(|_| crate::Error::SendError)?;
193        Ok(rx)
194    }
195
196    /// subscribe_under is a convenience method for subscribing to a topic underneath our topic prefix
197    pub async fn subscribe_under<S: Into<String>>(
198        &self,
199        topic: S,
200    ) -> crate::Result<Receiver<Payload>> {
201        self.scoped(topic).subscribe().await
202    }
203
204    pub async fn publish<B: Into<Bytes>>(&self, payload: B) -> crate::Result<()> {
205        let msg = Message::Publish(Publish::new(
206            &self.prefix,
207            rumqttc::QoS::AtLeastOnce,
208            payload.into(),
209        ));
210        self.tx
211            .send(msg)
212            .await
213            .map_err(|_| crate::Error::SendError)?;
214        Ok(())
215    }
216
217    /// publish_under is a convenience method for publishing to a topic underneath our topic prefix
218    pub async fn publish_under<S: Into<String>, B: Into<Bytes>>(
219        &self,
220        topic: S,
221        payload: B,
222    ) -> crate::Result<()> {
223        self.scoped(topic).publish(payload).await
224    }
225
226    pub async fn shutdown(self) -> crate::Result<()> {
227        self.tx
228            .send(Message::Shutdown)
229            .await
230            .map_err(|_| crate::Error::SendError)
231    }
232}
233
234pub(crate) trait Scopable {
235    fn scoped<S: Into<String>>(&self, prefix: S) -> Self;
236}
237
238impl Scopable for Handle {
239    fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
240        Self {
241            prefix: format!("{}/{}", self.prefix, prefix.into()),
242            ..self.clone()
243        }
244    }
245}
246
247impl From<Payload> for Bytes {
248    fn from(payload: Payload) -> Self {
249        payload.bytes
250    }
251}
252
253impl std::ops::Deref for Payload {
254    type Target = Bytes;
255
256    fn deref(&self) -> &Self::Target {
257        &self.bytes
258    }
259}