rmqttc/
lib.rs

1#![allow(dead_code)]
2mod client;
3mod conn;
4mod manager;
5use anyhow::{Result, anyhow};
6use bytes::Bytes;
7pub use client::*;
8pub(crate) use conn::*;
9pub use manager::*;
10pub use rumqttc::v5::mqttbytes::QoS;
11pub use rumqttc::v5::mqttbytes::v5::{ConnectProperties, Publish as IncomeMessage};
12pub use rumqttc::v5::{AsyncClient, MqttOptions as Config};
13use serde::Serializer;
14use serde::de::Deserializer;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::fmt::Display;
19use std::time::Duration;
20use tokio::sync::{mpsc, watch};
21use tokio::time;
22//初始化路由
23pub struct InitTopics(HashMap<String, QoS>);
24
25impl InitTopics {
26    pub fn new() -> Self {
27        InitTopics(HashMap::new())
28    }
29    pub fn add<T: AsRef<str> + Sync + Send>(&mut self, topic: T, qos: QoS) -> Result<()> {
30        let topic_ref = topic.as_ref();
31        if !self.0.contains_key(topic_ref) {
32            self.0.insert(topic_ref.to_string(), qos);
33        }
34        Ok(())
35    }
36    pub fn get_topics(&self) -> HashMap<String, QoS> {
37        self.0.clone()
38    }
39    pub fn remove_topic<T: AsRef<str>>(&mut self, topic: T) {
40        let topic_ref = topic.as_ref();
41        self.0.remove(topic_ref);
42    }
43}
44
45//事件回调
46pub type OnEventCallback = Box<dyn Fn(MqttEvent) + Send + Sync + 'static>;
47//消息回调
48pub type OnMessageCallback = Box<dyn Fn(IncomeMessage) + Send + Sync + 'static>;
49
50//MQTT 状态
51#[derive(Debug, Eq, PartialEq, Clone)]
52pub enum State {
53    Pending,
54    Connected,
55    Disconnected,
56    Closed,
57    Error(String),
58}
59
60impl Display for State {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            State::Pending => write!(f, "pending"),
64            State::Connected => write!(f, "connected"),
65            State::Disconnected => write!(f, "disconnected"),
66            State::Closed => write!(f, "closed"),
67            State::Error(s) => write!(f, "Error: {}", s),
68        }
69    }
70}
71
72//MQTT事件
73#[derive(Debug, Eq, PartialEq, Clone)]
74pub enum MqttEvent {
75    Connected,
76    Disconnected,
77    Closed,
78    Error(String),
79}
80impl MqttEvent {
81    pub fn to_string(&self) -> String {
82        match self {
83            MqttEvent::Connected => format!("connected"),
84            MqttEvent::Disconnected => format!("disconnected"),
85            MqttEvent::Closed => format!("closed"),
86            MqttEvent::Error(s) => format!("Error: {}", s),
87        }
88    }
89}
90
91enum MqttEventData {
92    Error(String),
93    Connected,
94    Disconnected,
95    IncomeMsg(IncomeMessage),
96}
97
98fn qos_to_u8(qos: &QoS) -> u8 {
99    match qos {
100        QoS::AtMostOnce => 0,
101        QoS::AtLeastOnce => 1,
102        QoS::ExactlyOnce => 2,
103    }
104}
105
106#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
107pub struct MqttPubCmd {
108    pub topic: String,
109    #[serde(serialize_with = "serialize_qos", deserialize_with = "deserialize_qos")]
110    pub qos: QoS,
111    pub retain: bool,
112    pub last_will: Option<bool>,
113    pub data: Value,
114}
115
116impl Default for MqttPubCmd {
117    fn default() -> Self {
118        MqttPubCmd {
119            topic: "".to_string(),
120            qos: QoS::AtMostOnce,
121            retain: false,
122            last_will: None,
123            data: Value::Null,
124        }
125    }
126}
127
128#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
129pub struct MqttSubCmd {
130    pub topic: String,
131    #[serde(serialize_with = "serialize_qos", deserialize_with = "deserialize_qos")]
132    pub qos: QoS,
133}
134
135fn serialize_qos<S>(qos: &QoS, serializer: S) -> Result<S::Ok, S::Error>
136where
137    S: Serializer,
138{
139    let num = qos_to_u8(qos);
140    serializer.serialize_u8(num)
141}
142
143fn deserialize_qos<'de, D>(deserializer: D) -> Result<QoS, D::Error>
144where
145    D: Deserializer<'de>,
146{
147    let v = u8::deserialize(deserializer)?;
148    let q = rumqttc::v5::mqttbytes::qos(v).unwrap_or(QoS::AtMostOnce);
149    Ok(q)
150}
151
152#[derive(Debug, PartialEq)]
153pub enum MqttMsg {
154    Pub(MqttPubCmd),
155    Sub(MqttSubCmd),
156    UnSub(String),
157    Closed,
158}
159
160pub async fn start_with_cfg(
161    cfg: Config,
162    on_msg: OnMessageCallback,
163    on_event: OnEventCallback,
164    topics: InitTopics,
165    timeout: Duration,
166) -> Result<Client> {
167    //init
168    let (state_tx, state_rx) = watch::channel(State::Pending);
169    let (cmd_sender, cmd_receiver) = mpsc::channel::<MqttMsg>(128);
170    let client = Client::new(state_rx, cmd_sender);
171
172    //mqtt
173    let topics = topics.get_topics();
174    let mut man = Manager::new(cfg, state_tx, on_msg, on_event, topics);
175    tokio::spawn(async move { man.run(cmd_receiver).await });
176
177    let mut timeout = timeout.as_secs();
178    if timeout <= 0 {
179        timeout = 10;
180    }
181    let mut re_count = 0;
182    while !client.connected() {
183        if re_count > timeout {
184            log::error!("connect timeout {} s", timeout);
185            client.close().await?;
186            time::sleep(Duration::from_millis(100)).await;
187            return Err(anyhow!("连接超时,请检查网络是否正常.."));
188        }
189        if let Some(s) = client.state_is_error() {
190            return Err(anyhow!("连接失败: {}", s));
191        }
192        log::debug!("wait connect..");
193        time::sleep(Duration::from_secs(1)).await;
194        re_count += 1;
195    }
196
197    Ok(client)
198}
199
200pub fn to_topic(topic: &str, skuid: &str, uuid: &str) -> String {
201    topic.replace("{skuid}", skuid).replace("{uuid}", uuid)
202}
203
204// /aam/sub/request/2928/10002 是否符合 /aam/sub/request/+/+
205pub fn topic_match_one(topic: &str, topic_filter: &str) -> bool {
206    if topic.to_lowercase() == topic_filter.to_lowercase() {
207        return true;
208    }
209
210    let topic_parts: Vec<&str> = topic.split('/').collect();
211    let pattern_parts: Vec<&str> = topic_filter.split('/').collect();
212    if topic_parts.len() != pattern_parts.len() {
213        return false;
214    }
215
216    for (i, pattern_part) in pattern_parts.iter().enumerate() {
217        if *pattern_part != "+" && *pattern_part != topic_parts[i] {
218            return false;
219        }
220    }
221    true
222}
223
224// /aam/sub/request/2928/10002 是否符合 /aam/sub/request/+/+
225// 提取 + + 里的值
226pub fn topic_get_match_one(topic: &str, topic_filter: &str) -> Option<Vec<String>> {
227    let topic_parts: Vec<&str> = topic.split('/').collect();
228    let pattern_parts: Vec<&str> = topic_filter.split('/').collect();
229    if topic_parts.len() != pattern_parts.len() {
230        return None;
231    }
232
233    let mut values = Vec::new();
234    for (i, pattern_part) in pattern_parts.iter().enumerate() {
235        if *pattern_part == "+" {
236            values.push(topic_parts[i].to_string());
237        } else if *pattern_part != topic_parts[i] {
238            return None; // 静态部分不匹配
239        }
240    }
241    Some(values)
242}
243
244//test/topic/1/21/2232  能配符配置 test/topic/#
245pub fn topic_match_all(topic: &str, filter: &str) -> bool {
246    if topic.to_lowercase() == filter.to_lowercase() {
247        return true;
248    }
249
250    let topic_parts: Vec<&str> = topic.split('/').collect();
251    let filter_parts: Vec<&str> = filter.split('/').collect();
252
253    // 检查是否存在 # 通配符
254    if let Some(pos) = filter_parts.iter().position(|&x| x == "#") {
255        // 检查静态前缀是否匹配
256        if topic_parts.len() < pos {
257            return false;
258        }
259
260        for i in 0..pos {
261            if filter_parts[i] != "+" && filter_parts[i] != topic_parts[i] {
262                return false;
263            }
264        }
265
266        true
267    } else {
268        // 没有 # 通配符时需要完全匹配
269        if topic_parts.len() != filter_parts.len() {
270            return false;
271        }
272
273        for (i, &filter_part) in filter_parts.iter().enumerate() {
274            if filter_part != "+" && filter_part != topic_parts[i] {
275                return false;
276            }
277        }
278
279        true
280    }
281}
282
283// bytes to string
284pub fn bytes_to_string(b: &Bytes) -> Option<String> {
285    std::str::from_utf8(b).ok().map(|s| s.to_string())
286}
287
288#[cfg(test)]
289mod test {
290    use super::*;
291    use serde_json::json;
292    #[test]
293    fn test_topic_is_match() {
294        let topic = "test/topic/1/21";
295        let topic_filter = "test/topic/+/+";
296        let res = topic_match_one(topic, topic_filter);
297        println!("res:===> {}", res);
298    }
299
300    #[test]
301    fn test_topic_get_match() {
302        let topic = "test/topic/1/21";
303        let topic_filter = "test/topic/+/+";
304        let res = topic_get_match_one(topic, topic_filter);
305        println!("res:===> {:?}", res);
306    }
307
308    #[test]
309    fn test_topic_match_all() {
310        let topic = "test/topic/1/21/2232";
311        let topic_filter = "test/topic/11/#";
312        let res = topic_match_all(topic, topic_filter);
313        println!("res:===> {}", res);
314    }
315
316    #[test]
317    fn test_serialize() {
318        let d = MqttPubCmd {
319            topic: "test".to_string(),
320            qos: QoS::AtMostOnce,
321            retain: false,
322            last_will: None,
323            data: json!("test"),
324        };
325
326        let s = serde_json::to_string(&d).unwrap();
327        println!("{}", s)
328    }
329
330    #[test]
331    fn test_deserialize() {
332        let json_str = r#"{"topic":"test","qos":1,"retain":false,"last_will":null,"data":"test"}"#;
333        let s: MqttPubCmd = serde_json::from_str(json_str).unwrap();
334        println!("{:?}", s)
335    }
336}