general_mq/mqtt/
connection.rs

1use std::{
2    collections::HashMap,
3    error::Error as StdError,
4    str::FromStr,
5    sync::{Arc, Mutex},
6    time::Duration,
7};
8
9use async_trait::async_trait;
10use regex::Regex;
11use rumqttc::{
12    AsyncClient as RumqttConnection, ClientError, Event as RumqttEvent,
13    MqttOptions as RumqttOption, NetworkOptions, Packet, Publish, TlsConfiguration, Transport,
14};
15use tokio::{
16    task::{self, JoinHandle},
17    time,
18};
19
20use super::uri::{MQTTScheme, MQTTUri};
21use crate::{
22    ID_SIZE,
23    connection::{EventHandler, GmqConnection, Status},
24    randomstring,
25};
26
27/// Manages a MQTT connection.
28#[derive(Clone)]
29pub struct MqttConnection {
30    /// Options of the connection.
31    opts: InnerOptions,
32    /// Connection status.
33    status: Arc<Mutex<Status>>,
34    /// Hold the connection instance.
35    conn: Arc<Mutex<Option<RumqttConnection>>>,
36    /// Event handlers.
37    handlers: Arc<Mutex<HashMap<String, Arc<dyn EventHandler>>>>,
38    /// Publish packet handlers. The key is **the queue name**.
39    ///
40    /// Because MQTT is connection-driven, the receiver [`crate::MqttQueue`] queues must register a
41    /// handler to receive Publish packets.
42    packet_handlers: Arc<Mutex<HashMap<String, Arc<dyn PacketHandler>>>>,
43    /// The event loop to manage and monitor the connection instance.
44    ev_loop: Arc<Mutex<Option<JoinHandle<()>>>>,
45}
46
47/// The connection options.
48pub struct MqttConnectionOptions {
49    /// Connection URI. Use `mqtt|mqtts://username:password@host:port` format.
50    ///
51    /// Default is `mqtt://localhost`.
52    pub uri: String,
53    /// Connection timeout in milliseconds.
54    ///
55    /// Default or zero value is `3000`.
56    pub connect_timeout_millis: u64,
57    /// Time in milliseconds from disconnection to reconnection.
58    ///
59    /// Default or zero value is `1000`.
60    pub reconnect_millis: u64,
61    /// Client identifier. Use `None` to generate a random client identifier.
62    pub client_id: Option<String>,
63    /// Clean session flag.
64    ///
65    /// **Note**: this is not stable.
66    pub clean_session: bool,
67}
68
69/// Packet handler definitions.
70pub(super) trait PacketHandler: Send + Sync {
71    /// For **Publish** packets.
72    fn on_publish(&self, packet: Publish);
73}
74
75/// The validated options for management.
76#[derive(Clone)]
77struct InnerOptions {
78    /// The formatted URI resource.
79    uri: MQTTUri,
80    /// Connection timeout in milliseconds.
81    connect_timeout_millis: u64,
82    /// Time in milliseconds from disconnection to reconnection.
83    reconnect_millis: u64,
84    /// Client identifier.
85    client_id: String,
86    /// Clean session flag.
87    clean_session: bool,
88}
89
90/// Default connect timeout in milliseconds.
91const DEF_CONN_TIMEOUT_MS: u64 = 3000;
92/// Default reconnect time in milliseconds.
93const DEF_RECONN_TIME_MS: u64 = 1000;
94/// The accepted pattern of the client identifier.
95const CLIENT_ID_PATTERN: &'static str = "^[0-9A-Za-z-]{1,23}$";
96
97impl MqttConnection {
98    /// Create a connection instance.
99    pub fn new(opts: MqttConnectionOptions) -> Result<MqttConnection, String> {
100        let uri = MQTTUri::from_str(opts.uri.as_str())?;
101
102        Ok(MqttConnection {
103            opts: InnerOptions {
104                uri,
105                connect_timeout_millis: match opts.connect_timeout_millis {
106                    0 => DEF_CONN_TIMEOUT_MS,
107                    _ => opts.connect_timeout_millis,
108                },
109                reconnect_millis: match opts.reconnect_millis {
110                    0 => DEF_RECONN_TIME_MS,
111                    _ => opts.reconnect_millis,
112                },
113                client_id: match opts.client_id {
114                    None => format!("general-mq-{}", randomstring(12)),
115                    Some(client_id) => {
116                        let re = Regex::new(CLIENT_ID_PATTERN).unwrap();
117                        if !re.is_match(client_id.as_str()) {
118                            return Err(format!("client_id is not match {}", CLIENT_ID_PATTERN));
119                        }
120                        client_id
121                    }
122                },
123                clean_session: opts.clean_session,
124            },
125            status: Arc::new(Mutex::new(Status::Closed)),
126            conn: Arc::new(Mutex::new(None)),
127            handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn EventHandler>>::new())),
128            packet_handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn PacketHandler>>::new())),
129            ev_loop: Arc::new(Mutex::new(None)),
130        })
131    }
132
133    /// To add a packet handler for [`crate::MqttQueue`]. The `name` is **the queue name**.
134    pub(super) fn add_packet_handler(&mut self, name: &str, handler: Arc<dyn PacketHandler>) {
135        self.packet_handlers
136            .lock()
137            .unwrap()
138            .insert(name.to_string(), handler);
139    }
140
141    /// To remove a packet handler. The `name` is **the queue name**.
142    pub(super) fn remove_packet_handler(&mut self, name: &str) {
143        self.packet_handlers.lock().unwrap().remove(name);
144    }
145
146    /// To get the raw MQTT connection instance for topic operations such as subscribe or publish.
147    pub(super) fn get_raw_connection(&self) -> Option<RumqttConnection> {
148        match self.conn.lock().unwrap().as_ref() {
149            None => None,
150            Some(conn) => Some(conn.clone()),
151        }
152    }
153}
154
155#[async_trait]
156impl GmqConnection for MqttConnection {
157    fn status(&self) -> Status {
158        *self.status.lock().unwrap()
159    }
160
161    fn add_handler(&mut self, handler: Arc<dyn EventHandler>) -> String {
162        let id = randomstring(ID_SIZE);
163        self.handlers.lock().unwrap().insert(id.clone(), handler);
164        id
165    }
166
167    fn remove_handler(&mut self, id: &str) {
168        self.handlers.lock().unwrap().remove(id);
169    }
170
171    fn connect(&mut self) -> Result<(), Box<dyn StdError>> {
172        {
173            let mut task_handle_mutex = self.ev_loop.lock().unwrap();
174            if (*task_handle_mutex).is_some() {
175                return Ok(());
176            }
177            *self.status.lock().unwrap() = Status::Connecting;
178            *task_handle_mutex = Some(create_event_loop(self));
179        }
180        Ok(())
181    }
182
183    async fn close(&mut self) -> Result<(), Box<dyn StdError + Send + Sync>> {
184        match { self.ev_loop.lock().unwrap().take() } {
185            None => return Ok(()),
186            Some(handle) => handle.abort(),
187        }
188        {
189            *self.status.lock().unwrap() = Status::Closing;
190        }
191
192        let conn = { self.conn.lock().unwrap().take() };
193        let mut result: Result<(), ClientError> = Ok(());
194        if let Some(conn) = conn {
195            result = conn.disconnect().await;
196        }
197
198        {
199            *self.status.lock().unwrap() = Status::Closed;
200        }
201        let handlers = { (*self.handlers.lock().unwrap()).clone() };
202        for (id, handler) in handlers {
203            let conn = Arc::new(self.clone());
204            task::spawn(async move {
205                handler.on_status(id.clone(), conn, Status::Closed).await;
206            });
207        }
208
209        result?;
210        Ok(())
211    }
212}
213
214impl Default for MqttConnectionOptions {
215    fn default() -> Self {
216        MqttConnectionOptions {
217            uri: "mqtt://localhost".to_string(),
218            connect_timeout_millis: DEF_CONN_TIMEOUT_MS,
219            reconnect_millis: DEF_RECONN_TIME_MS,
220            client_id: None,
221            clean_session: true,
222        }
223    }
224}
225
226/// To create an event loop runtime task.
227fn create_event_loop(conn: &MqttConnection) -> JoinHandle<()> {
228    let this = Arc::new(conn.clone());
229    task::spawn(async move {
230        loop {
231            match this.status() {
232                Status::Closing | Status::Closed => task::yield_now().await,
233                Status::Connecting | Status::Connected => {
234                    let mut opts = RumqttOption::new(
235                        this.opts.client_id.as_str(),
236                        this.opts.uri.host.as_str(),
237                        this.opts.uri.port,
238                    );
239                    opts.set_clean_session(this.opts.clean_session)
240                        .set_credentials(
241                            this.opts.uri.username.as_str(),
242                            this.opts.uri.password.as_str(),
243                        );
244                    if this.opts.uri.scheme == MQTTScheme::MQTTS {
245                        opts.set_transport(Transport::Tls(TlsConfiguration::default()));
246                    }
247
248                    let mut to_disconnected = false;
249                    let (client, mut event_loop) = RumqttConnection::new(opts, 10);
250                    let mut net_opts = NetworkOptions::new();
251                    net_opts.set_connection_timeout(this.opts.connect_timeout_millis);
252                    event_loop.set_network_options(net_opts);
253                    loop {
254                        match event_loop.poll().await {
255                            Err(_) => {
256                                if this.status() == Status::Connected {
257                                    to_disconnected = true;
258                                }
259                                break;
260                            }
261                            Ok(event) => {
262                                let packet = match event {
263                                    RumqttEvent::Incoming(packet) => packet,
264                                    _ => continue,
265                                };
266                                match packet {
267                                    Packet::Publish(packet) => {
268                                        if this.status() != Status::Connected {
269                                            continue;
270                                        }
271                                        let handler = {
272                                            let topic = packet.topic.as_str();
273                                            match this.packet_handlers.lock().unwrap().get(topic) {
274                                                None => continue,
275                                                Some(handler) => handler.clone(),
276                                            }
277                                        };
278                                        handler.on_publish(packet);
279                                    }
280                                    Packet::ConnAck(_) => {
281                                        let mut to_connected = false;
282                                        {
283                                            let mut status_mutex = this.status.lock().unwrap();
284                                            let status = *status_mutex;
285                                            if status == Status::Closing || status == Status::Closed
286                                            {
287                                                break;
288                                            } else if status != Status::Connected {
289                                                *this.conn.lock().unwrap() = Some(client.clone());
290                                                *status_mutex = Status::Connected;
291                                                to_connected = true;
292                                            }
293                                        }
294
295                                        if to_connected {
296                                            let handlers =
297                                                { (*this.handlers.lock().unwrap()).clone() };
298                                            for (id, handler) in handlers {
299                                                let conn = this.clone();
300                                                task::spawn(async move {
301                                                    handler
302                                                        .on_status(
303                                                            id.clone(),
304                                                            conn,
305                                                            Status::Connected,
306                                                        )
307                                                        .await;
308                                                });
309                                            }
310                                        }
311                                    }
312                                    _ => {}
313                                }
314                            }
315                        }
316                    }
317
318                    {
319                        let mut status_mutex = this.status.lock().unwrap();
320                        if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
321                            continue;
322                        }
323                        let _ = this.conn.lock().unwrap().take();
324                        *status_mutex = Status::Disconnected;
325                    }
326
327                    if to_disconnected {
328                        let handlers = { (*this.handlers.lock().unwrap()).clone() };
329                        for (id, handler) in handlers {
330                            let conn = this.clone();
331                            task::spawn(async move {
332                                handler
333                                    .on_status(id.clone(), conn, Status::Disconnected)
334                                    .await;
335                            });
336                        }
337                    }
338                    time::sleep(Duration::from_millis(this.opts.reconnect_millis)).await;
339                    {
340                        let mut status_mutex = this.status.lock().unwrap();
341                        if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
342                            continue;
343                        }
344                        *status_mutex = Status::Connecting;
345                    }
346                    if to_disconnected {
347                        let handlers = { (*this.handlers.lock().unwrap()).clone() };
348                        for (id, handler) in handlers {
349                            let conn = this.clone();
350                            task::spawn(async move {
351                                handler
352                                    .on_status(id.clone(), conn, Status::Connecting)
353                                    .await;
354                            });
355                        }
356                    }
357                }
358                Status::Disconnected => {
359                    *this.status.lock().unwrap() = Status::Connecting;
360                }
361            }
362        }
363    })
364}