1use std::collections::HashMap;
2
3use bytes::Bytes;
4use rumqttc::{
5 mqttbytes::matches as matches_topic, AsyncClient, Event, EventLoop, MqttOptions, Publish,
6 Subscribe,
7};
8use tokio::{
9 select,
10 sync::mpsc::{self, channel, Receiver, Sender},
11};
12use tracing::{debug, info, warn};
13
14#[derive(Debug)]
15pub struct Payload {
16 pub bytes: Bytes,
17 pub topic: String,
18}
19
20#[derive(Debug, Clone)]
21pub enum Message {
22 Subscribe(Subscribe, Sender<Payload>),
23 Publish(Publish),
24 Shutdown,
25}
26
27pub(crate) async fn new(options: MqttOptions) -> Connection {
28 let (client, event_loop) = AsyncClient::new(options, 32);
29
30 let (tx, rx) = channel(32);
31 Connection {
32 client,
33 event_loop,
34 subscriptions: HashMap::new(),
35 tx,
36 rx,
37 }
38}
39
40pub(crate) struct Connection {
43 subscriptions: HashMap<String, Vec<Sender<Payload>>>,
44 tx: Sender<Message>,
45 rx: Receiver<Message>,
46 client: AsyncClient,
47 event_loop: EventLoop,
48}
49
50impl Connection {
51 pub async fn run(&mut self) -> crate::Result<()> {
52 loop {
53 select! {
54 event = self.event_loop.poll() => {
55 self.handle_event(event?).await?
56 }
57 request = self.rx.recv() => {
58 match request {
59 None => return Ok(()),
60 Some(Message::Shutdown) => {
61 info!("MQTT connection shutting down");
62 break;
63 }
64 Some(req) => self.handle_request(req).await?,
65 }
66 }
67 }
68 }
69
70 Ok(())
71 }
72
73 pub fn handle(&self, prefix: String) -> Handle {
74 Handle {
75 prefix,
76 tx: self.tx.clone(),
77 }
78 }
79
80 async fn handle_event(&mut self, event: Event) -> crate::Result<()> {
81 use rumqttc::Incoming;
82
83 #[allow(clippy::single_match)]
84 match event {
85 Event::Incoming(Incoming::Publish(Publish { topic, payload, .. })) => {
86 debug!(%topic, ?payload, "publish");
87 self.handle_data(topic, payload).await?;
88 }
89 _ => {}
91 }
92
93 Ok(())
94 }
95
96 #[tracing::instrument(level = "debug", skip(self), fields(subscriptions = ?self.subscriptions.keys()))]
97 async fn handle_data(&mut self, topic: String, bytes: Bytes) -> crate::Result<()> {
98 let mut targets = vec![];
99
100 self.subscriptions.retain(|filter, channels| {
102 if matches_topic(&topic, filter) {
103 channels.retain(|channel| {
104 if channel.is_closed() {
105 warn!(?channel, "closed");
106 false
107 } else {
108 targets.push(channel.clone());
109 true
110 }
111 });
112 !channels.is_empty()
113 } else {
114 true
115 }
116 });
117
118 for target in targets {
119 if target
120 .send(Payload {
121 topic: topic.clone(),
122 bytes: bytes.clone(),
123 })
124 .await
125 .is_err()
126 {
127 }
129 }
130 Ok(())
131 }
132
133 async fn handle_request(&mut self, request: Message) -> crate::Result<()> {
134 debug!(?request);
135 match request {
136 Message::Publish(Publish {
137 topic,
138 payload,
139 qos,
140 retain,
141 ..
142 }) => {
143 self.client
144 .publish_bytes(topic, qos, retain, payload)
145 .await?
146 }
147 Message::Subscribe(Subscribe { filters, .. }, channel) => {
148 for filter in &filters {
149 let channel = channel.clone();
150
151 match self.subscriptions.get_mut(&filter.path) {
155 Some(channels) => channels.push(channel),
156 None => {
157 self.subscriptions
158 .insert(filter.path.clone(), vec![channel]);
159 }
160 }
161 }
162
163 self.client.subscribe_many(filters).await?
164 }
165 Message::Shutdown => panic!("Handled by the caller"),
166 }
167 Ok(())
168 }
169}
170
171#[derive(Debug, Clone)]
172pub struct Handle {
173 prefix: String,
174 tx: Sender<Message>,
175}
176
177impl Handle {
182 pub async fn subscribe(&self) -> crate::Result<Receiver<Payload>> {
183 let (tx_bytes, rx) = mpsc::channel(8);
184
185 let msg = Message::Subscribe(
186 Subscribe::new(&self.prefix, rumqttc::QoS::AtLeastOnce),
187 tx_bytes,
188 );
189 self.tx
190 .send(msg)
191 .await
192 .map_err(|_| crate::Error::SendError)?;
193 Ok(rx)
194 }
195
196 pub async fn subscribe_under<S: Into<String>>(
198 &self,
199 topic: S,
200 ) -> crate::Result<Receiver<Payload>> {
201 self.scoped(topic).subscribe().await
202 }
203
204 pub async fn publish<B: Into<Bytes>>(&self, payload: B) -> crate::Result<()> {
205 let msg = Message::Publish(Publish::new(
206 &self.prefix,
207 rumqttc::QoS::AtLeastOnce,
208 payload.into(),
209 ));
210 self.tx
211 .send(msg)
212 .await
213 .map_err(|_| crate::Error::SendError)?;
214 Ok(())
215 }
216
217 pub async fn publish_under<S: Into<String>, B: Into<Bytes>>(
219 &self,
220 topic: S,
221 payload: B,
222 ) -> crate::Result<()> {
223 self.scoped(topic).publish(payload).await
224 }
225
226 pub async fn shutdown(self) -> crate::Result<()> {
227 self.tx
228 .send(Message::Shutdown)
229 .await
230 .map_err(|_| crate::Error::SendError)
231 }
232}
233
234pub(crate) trait Scopable {
235 fn scoped<S: Into<String>>(&self, prefix: S) -> Self;
236}
237
238impl Scopable for Handle {
239 fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
240 Self {
241 prefix: format!("{}/{}", self.prefix, prefix.into()),
242 ..self.clone()
243 }
244 }
245}
246
247impl From<Payload> for Bytes {
248 fn from(payload: Payload) -> Self {
249 payload.bytes
250 }
251}
252
253impl std::ops::Deref for Payload {
254 type Target = Bytes;
255
256 fn deref(&self) -> &Self::Target {
257 &self.bytes
258 }
259}