1use std::{
2 collections::HashMap,
3 error::Error as StdError,
4 str::FromStr,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use async_trait::async_trait;
10use regex::Regex;
11use rumqttc::{
12 AsyncClient as RumqttConnection, ClientError, Event as RumqttEvent,
13 MqttOptions as RumqttOption, NetworkOptions, Packet, Publish, TlsConfiguration, Transport,
14};
15use tokio::{
16 task::{self, JoinHandle},
17 time,
18};
19
20use super::uri::{MQTTScheme, MQTTUri};
21use crate::{
22 ID_SIZE,
23 connection::{EventHandler, GmqConnection, Status},
24 randomstring,
25};
26
27#[derive(Clone)]
29pub struct MqttConnection {
30 opts: InnerOptions,
32 status: Arc<Mutex<Status>>,
34 conn: Arc<Mutex<Option<RumqttConnection>>>,
36 handlers: Arc<Mutex<HashMap<String, Arc<dyn EventHandler>>>>,
38 packet_handlers: Arc<Mutex<HashMap<String, Arc<dyn PacketHandler>>>>,
43 ev_loop: Arc<Mutex<Option<JoinHandle<()>>>>,
45}
46
47pub struct MqttConnectionOptions {
49 pub uri: String,
53 pub connect_timeout_millis: u64,
57 pub reconnect_millis: u64,
61 pub client_id: Option<String>,
63 pub clean_session: bool,
67}
68
69pub(super) trait PacketHandler: Send + Sync {
71 fn on_publish(&self, packet: Publish);
73}
74
75#[derive(Clone)]
77struct InnerOptions {
78 uri: MQTTUri,
80 connect_timeout_millis: u64,
82 reconnect_millis: u64,
84 client_id: String,
86 clean_session: bool,
88}
89
90const DEF_CONN_TIMEOUT_MS: u64 = 3000;
92const DEF_RECONN_TIME_MS: u64 = 1000;
94const CLIENT_ID_PATTERN: &'static str = "^[0-9A-Za-z-]{1,23}$";
96
97impl MqttConnection {
98 pub fn new(opts: MqttConnectionOptions) -> Result<MqttConnection, String> {
100 let uri = MQTTUri::from_str(opts.uri.as_str())?;
101
102 Ok(MqttConnection {
103 opts: InnerOptions {
104 uri,
105 connect_timeout_millis: match opts.connect_timeout_millis {
106 0 => DEF_CONN_TIMEOUT_MS,
107 _ => opts.connect_timeout_millis,
108 },
109 reconnect_millis: match opts.reconnect_millis {
110 0 => DEF_RECONN_TIME_MS,
111 _ => opts.reconnect_millis,
112 },
113 client_id: match opts.client_id {
114 None => format!("general-mq-{}", randomstring(12)),
115 Some(client_id) => {
116 let re = Regex::new(CLIENT_ID_PATTERN).unwrap();
117 if !re.is_match(client_id.as_str()) {
118 return Err(format!("client_id is not match {}", CLIENT_ID_PATTERN));
119 }
120 client_id
121 }
122 },
123 clean_session: opts.clean_session,
124 },
125 status: Arc::new(Mutex::new(Status::Closed)),
126 conn: Arc::new(Mutex::new(None)),
127 handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn EventHandler>>::new())),
128 packet_handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn PacketHandler>>::new())),
129 ev_loop: Arc::new(Mutex::new(None)),
130 })
131 }
132
133 pub(super) fn add_packet_handler(&mut self, name: &str, handler: Arc<dyn PacketHandler>) {
135 self.packet_handlers
136 .lock()
137 .unwrap()
138 .insert(name.to_string(), handler);
139 }
140
141 pub(super) fn remove_packet_handler(&mut self, name: &str) {
143 self.packet_handlers.lock().unwrap().remove(name);
144 }
145
146 pub(super) fn get_raw_connection(&self) -> Option<RumqttConnection> {
148 match self.conn.lock().unwrap().as_ref() {
149 None => None,
150 Some(conn) => Some(conn.clone()),
151 }
152 }
153}
154
155#[async_trait]
156impl GmqConnection for MqttConnection {
157 fn status(&self) -> Status {
158 *self.status.lock().unwrap()
159 }
160
161 fn add_handler(&mut self, handler: Arc<dyn EventHandler>) -> String {
162 let id = randomstring(ID_SIZE);
163 self.handlers.lock().unwrap().insert(id.clone(), handler);
164 id
165 }
166
167 fn remove_handler(&mut self, id: &str) {
168 self.handlers.lock().unwrap().remove(id);
169 }
170
171 fn connect(&mut self) -> Result<(), Box<dyn StdError>> {
172 {
173 let mut task_handle_mutex = self.ev_loop.lock().unwrap();
174 if (*task_handle_mutex).is_some() {
175 return Ok(());
176 }
177 *self.status.lock().unwrap() = Status::Connecting;
178 *task_handle_mutex = Some(create_event_loop(self));
179 }
180 Ok(())
181 }
182
183 async fn close(&mut self) -> Result<(), Box<dyn StdError + Send + Sync>> {
184 match { self.ev_loop.lock().unwrap().take() } {
185 None => return Ok(()),
186 Some(handle) => handle.abort(),
187 }
188 {
189 *self.status.lock().unwrap() = Status::Closing;
190 }
191
192 let conn = { self.conn.lock().unwrap().take() };
193 let mut result: Result<(), ClientError> = Ok(());
194 if let Some(conn) = conn {
195 result = conn.disconnect().await;
196 }
197
198 {
199 *self.status.lock().unwrap() = Status::Closed;
200 }
201 let handlers = { (*self.handlers.lock().unwrap()).clone() };
202 for (id, handler) in handlers {
203 let conn = Arc::new(self.clone());
204 task::spawn(async move {
205 handler.on_status(id.clone(), conn, Status::Closed).await;
206 });
207 }
208
209 result?;
210 Ok(())
211 }
212}
213
214impl Default for MqttConnectionOptions {
215 fn default() -> Self {
216 MqttConnectionOptions {
217 uri: "mqtt://localhost".to_string(),
218 connect_timeout_millis: DEF_CONN_TIMEOUT_MS,
219 reconnect_millis: DEF_RECONN_TIME_MS,
220 client_id: None,
221 clean_session: true,
222 }
223 }
224}
225
226fn create_event_loop(conn: &MqttConnection) -> JoinHandle<()> {
228 let this = Arc::new(conn.clone());
229 task::spawn(async move {
230 loop {
231 match this.status() {
232 Status::Closing | Status::Closed => task::yield_now().await,
233 Status::Connecting | Status::Connected => {
234 let mut opts = RumqttOption::new(
235 this.opts.client_id.as_str(),
236 this.opts.uri.host.as_str(),
237 this.opts.uri.port,
238 );
239 opts.set_clean_session(this.opts.clean_session)
240 .set_credentials(
241 this.opts.uri.username.as_str(),
242 this.opts.uri.password.as_str(),
243 );
244 if this.opts.uri.scheme == MQTTScheme::MQTTS {
245 opts.set_transport(Transport::Tls(TlsConfiguration::default()));
246 }
247
248 let mut to_disconnected = false;
249 let (client, mut event_loop) = RumqttConnection::new(opts, 10);
250 let mut net_opts = NetworkOptions::new();
251 net_opts.set_connection_timeout(this.opts.connect_timeout_millis);
252 event_loop.set_network_options(net_opts);
253 loop {
254 match event_loop.poll().await {
255 Err(_) => {
256 if this.status() == Status::Connected {
257 to_disconnected = true;
258 }
259 break;
260 }
261 Ok(event) => {
262 let packet = match event {
263 RumqttEvent::Incoming(packet) => packet,
264 _ => continue,
265 };
266 match packet {
267 Packet::Publish(packet) => {
268 if this.status() != Status::Connected {
269 continue;
270 }
271 let handler = {
272 let topic = packet.topic.as_str();
273 match this.packet_handlers.lock().unwrap().get(topic) {
274 None => continue,
275 Some(handler) => handler.clone(),
276 }
277 };
278 handler.on_publish(packet);
279 }
280 Packet::ConnAck(_) => {
281 let mut to_connected = false;
282 {
283 let mut status_mutex = this.status.lock().unwrap();
284 let status = *status_mutex;
285 if status == Status::Closing || status == Status::Closed
286 {
287 break;
288 } else if status != Status::Connected {
289 *this.conn.lock().unwrap() = Some(client.clone());
290 *status_mutex = Status::Connected;
291 to_connected = true;
292 }
293 }
294
295 if to_connected {
296 let handlers =
297 { (*this.handlers.lock().unwrap()).clone() };
298 for (id, handler) in handlers {
299 let conn = this.clone();
300 task::spawn(async move {
301 handler
302 .on_status(
303 id.clone(),
304 conn,
305 Status::Connected,
306 )
307 .await;
308 });
309 }
310 }
311 }
312 _ => {}
313 }
314 }
315 }
316 }
317
318 {
319 let mut status_mutex = this.status.lock().unwrap();
320 if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
321 continue;
322 }
323 let _ = this.conn.lock().unwrap().take();
324 *status_mutex = Status::Disconnected;
325 }
326
327 if to_disconnected {
328 let handlers = { (*this.handlers.lock().unwrap()).clone() };
329 for (id, handler) in handlers {
330 let conn = this.clone();
331 task::spawn(async move {
332 handler
333 .on_status(id.clone(), conn, Status::Disconnected)
334 .await;
335 });
336 }
337 }
338 time::sleep(Duration::from_millis(this.opts.reconnect_millis)).await;
339 {
340 let mut status_mutex = this.status.lock().unwrap();
341 if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
342 continue;
343 }
344 *status_mutex = Status::Connecting;
345 }
346 if to_disconnected {
347 let handlers = { (*this.handlers.lock().unwrap()).clone() };
348 for (id, handler) in handlers {
349 let conn = this.clone();
350 task::spawn(async move {
351 handler
352 .on_status(id.clone(), conn, Status::Connecting)
353 .await;
354 });
355 }
356 }
357 }
358 Status::Disconnected => {
359 *this.status.lock().unwrap() = Status::Connecting;
360 }
361 }
362 }
363 })
364}