1#![allow(dead_code)]
2mod client;
3mod conn;
4mod manager;
5use anyhow::{Result, anyhow};
6use bytes::Bytes;
7pub use client::*;
8pub(crate) use conn::*;
9pub use manager::*;
10pub use rumqttc::v5::mqttbytes::QoS;
11pub use rumqttc::v5::mqttbytes::v5::{ConnectProperties, Publish as IncomeMessage};
12pub use rumqttc::v5::{AsyncClient, MqttOptions as Config};
13use serde::Serializer;
14use serde::de::Deserializer;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::fmt::Display;
19use std::time::Duration;
20use tokio::sync::{mpsc, watch};
21use tokio::time;
22pub struct InitTopics(HashMap<String, QoS>);
24
25impl InitTopics {
26 pub fn new() -> Self {
27 InitTopics(HashMap::new())
28 }
29 pub fn add<T: AsRef<str> + Sync + Send>(&mut self, topic: T, qos: QoS) -> Result<()> {
30 let topic_ref = topic.as_ref();
31 if !self.0.contains_key(topic_ref) {
32 self.0.insert(topic_ref.to_string(), qos);
33 }
34 Ok(())
35 }
36 pub fn get_topics(&self) -> HashMap<String, QoS> {
37 self.0.clone()
38 }
39 pub fn remove_topic<T: AsRef<str>>(&mut self, topic: T) {
40 let topic_ref = topic.as_ref();
41 self.0.remove(topic_ref);
42 }
43}
44
45pub type OnEventCallback = Box<dyn Fn(MqttEvent) + Send + Sync + 'static>;
47pub type OnMessageCallback = Box<dyn Fn(IncomeMessage) + Send + Sync + 'static>;
49
50#[derive(Debug, Eq, PartialEq, Clone)]
52pub enum State {
53 Pending,
54 Connected,
55 Disconnected,
56 Closed,
57 Error(String),
58}
59
60impl Display for State {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 State::Pending => write!(f, "pending"),
64 State::Connected => write!(f, "connected"),
65 State::Disconnected => write!(f, "disconnected"),
66 State::Closed => write!(f, "closed"),
67 State::Error(s) => write!(f, "Error: {}", s),
68 }
69 }
70}
71
72#[derive(Debug, Eq, PartialEq, Clone)]
74pub enum MqttEvent {
75 Connected,
76 Disconnected,
77 Closed,
78 Error(String),
79}
80impl MqttEvent {
81 pub fn to_string(&self) -> String {
82 match self {
83 MqttEvent::Connected => format!("connected"),
84 MqttEvent::Disconnected => format!("disconnected"),
85 MqttEvent::Closed => format!("closed"),
86 MqttEvent::Error(s) => format!("Error: {}", s),
87 }
88 }
89}
90
91enum MqttEventData {
92 Error(String),
93 Connected,
94 Disconnected,
95 IncomeMsg(IncomeMessage),
96}
97
98fn qos_to_u8(qos: &QoS) -> u8 {
99 match qos {
100 QoS::AtMostOnce => 0,
101 QoS::AtLeastOnce => 1,
102 QoS::ExactlyOnce => 2,
103 }
104}
105
106#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
107pub struct MqttPubCmd {
108 pub topic: String,
109 #[serde(serialize_with = "serialize_qos", deserialize_with = "deserialize_qos")]
110 pub qos: QoS,
111 pub retain: bool,
112 pub last_will: Option<bool>,
113 pub data: Value,
114}
115
116impl Default for MqttPubCmd {
117 fn default() -> Self {
118 MqttPubCmd {
119 topic: "".to_string(),
120 qos: QoS::AtMostOnce,
121 retain: false,
122 last_will: None,
123 data: Value::Null,
124 }
125 }
126}
127
128#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
129pub struct MqttSubCmd {
130 pub topic: String,
131 #[serde(serialize_with = "serialize_qos", deserialize_with = "deserialize_qos")]
132 pub qos: QoS,
133}
134
135fn serialize_qos<S>(qos: &QoS, serializer: S) -> Result<S::Ok, S::Error>
136where
137 S: Serializer,
138{
139 let num = qos_to_u8(qos);
140 serializer.serialize_u8(num)
141}
142
143fn deserialize_qos<'de, D>(deserializer: D) -> Result<QoS, D::Error>
144where
145 D: Deserializer<'de>,
146{
147 let v = u8::deserialize(deserializer)?;
148 let q = rumqttc::v5::mqttbytes::qos(v).unwrap_or(QoS::AtMostOnce);
149 Ok(q)
150}
151
152#[derive(Debug, PartialEq)]
153pub enum MqttMsg {
154 Pub(MqttPubCmd),
155 Sub(MqttSubCmd),
156 UnSub(String),
157 Closed,
158}
159
160pub async fn start_with_cfg(
161 cfg: Config,
162 on_msg: OnMessageCallback,
163 on_event: OnEventCallback,
164 topics: InitTopics,
165 timeout: Duration,
166) -> Result<Client> {
167 let (state_tx, state_rx) = watch::channel(State::Pending);
169 let (cmd_sender, cmd_receiver) = mpsc::channel::<MqttMsg>(128);
170 let client = Client::new(state_rx, cmd_sender);
171
172 let topics = topics.get_topics();
174 let mut man = Manager::new(cfg, state_tx, on_msg, on_event, topics);
175 tokio::spawn(async move { man.run(cmd_receiver).await });
176
177 let mut timeout = timeout.as_secs();
178 if timeout <= 0 {
179 timeout = 10;
180 }
181 let mut re_count = 0;
182 while !client.connected() {
183 if re_count > timeout {
184 log::error!("connect timeout {} s", timeout);
185 client.close().await?;
186 time::sleep(Duration::from_millis(100)).await;
187 return Err(anyhow!("连接超时,请检查网络是否正常.."));
188 }
189 if let Some(s) = client.state_is_error() {
190 return Err(anyhow!("连接失败: {}", s));
191 }
192 log::debug!("wait connect..");
193 time::sleep(Duration::from_secs(1)).await;
194 re_count += 1;
195 }
196
197 Ok(client)
198}
199
200pub fn to_topic(topic: &str, skuid: &str, uuid: &str) -> String {
201 topic.replace("{skuid}", skuid).replace("{uuid}", uuid)
202}
203
204pub fn topic_match_one(topic: &str, topic_filter: &str) -> bool {
206 if topic.to_lowercase() == topic_filter.to_lowercase() {
207 return true;
208 }
209
210 let topic_parts: Vec<&str> = topic.split('/').collect();
211 let pattern_parts: Vec<&str> = topic_filter.split('/').collect();
212 if topic_parts.len() != pattern_parts.len() {
213 return false;
214 }
215
216 for (i, pattern_part) in pattern_parts.iter().enumerate() {
217 if *pattern_part != "+" && *pattern_part != topic_parts[i] {
218 return false;
219 }
220 }
221 true
222}
223
224pub fn topic_get_match_one(topic: &str, topic_filter: &str) -> Option<Vec<String>> {
227 let topic_parts: Vec<&str> = topic.split('/').collect();
228 let pattern_parts: Vec<&str> = topic_filter.split('/').collect();
229 if topic_parts.len() != pattern_parts.len() {
230 return None;
231 }
232
233 let mut values = Vec::new();
234 for (i, pattern_part) in pattern_parts.iter().enumerate() {
235 if *pattern_part == "+" {
236 values.push(topic_parts[i].to_string());
237 } else if *pattern_part != topic_parts[i] {
238 return None; }
240 }
241 Some(values)
242}
243
244pub fn topic_match_all(topic: &str, filter: &str) -> bool {
246 if topic.to_lowercase() == filter.to_lowercase() {
247 return true;
248 }
249
250 let topic_parts: Vec<&str> = topic.split('/').collect();
251 let filter_parts: Vec<&str> = filter.split('/').collect();
252
253 if let Some(pos) = filter_parts.iter().position(|&x| x == "#") {
255 if topic_parts.len() < pos {
257 return false;
258 }
259
260 for i in 0..pos {
261 if filter_parts[i] != "+" && filter_parts[i] != topic_parts[i] {
262 return false;
263 }
264 }
265
266 true
267 } else {
268 if topic_parts.len() != filter_parts.len() {
270 return false;
271 }
272
273 for (i, &filter_part) in filter_parts.iter().enumerate() {
274 if filter_part != "+" && filter_part != topic_parts[i] {
275 return false;
276 }
277 }
278
279 true
280 }
281}
282
283pub fn bytes_to_string(b: &Bytes) -> Option<String> {
285 std::str::from_utf8(b).ok().map(|s| s.to_string())
286}
287
288#[cfg(test)]
289mod test {
290 use super::*;
291 use serde_json::json;
292 #[test]
293 fn test_topic_is_match() {
294 let topic = "test/topic/1/21";
295 let topic_filter = "test/topic/+/+";
296 let res = topic_match_one(topic, topic_filter);
297 println!("res:===> {}", res);
298 }
299
300 #[test]
301 fn test_topic_get_match() {
302 let topic = "test/topic/1/21";
303 let topic_filter = "test/topic/+/+";
304 let res = topic_get_match_one(topic, topic_filter);
305 println!("res:===> {:?}", res);
306 }
307
308 #[test]
309 fn test_topic_match_all() {
310 let topic = "test/topic/1/21/2232";
311 let topic_filter = "test/topic/11/#";
312 let res = topic_match_all(topic, topic_filter);
313 println!("res:===> {}", res);
314 }
315
316 #[test]
317 fn test_serialize() {
318 let d = MqttPubCmd {
319 topic: "test".to_string(),
320 qos: QoS::AtMostOnce,
321 retain: false,
322 last_will: None,
323 data: json!("test"),
324 };
325
326 let s = serde_json::to_string(&d).unwrap();
327 println!("{}", s)
328 }
329
330 #[test]
331 fn test_deserialize() {
332 let json_str = r#"{"topic":"test","qos":1,"retain":false,"last_will":null,"data":"test"}"#;
333 let s: MqttPubCmd = serde_json::from_str(json_str).unwrap();
334 println!("{:?}", s)
335 }
336}