arkflow/input/
mqtt.rs

1//! MQTT输入组件
2//!
3//! 从MQTT代理接收数据
4
5use crate::input::Ack;
6use crate::{input::Input, Error, MessageBatch};
7use async_trait::async_trait;
8use flume::{Receiver, Sender};
9use rumqttc::{AsyncClient, Event, MqttOptions, Packet, Publish, QoS};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use tokio::sync::{broadcast, Mutex};
13use tracing::error;
14
15/// MQTT输入配置
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MqttInputConfig {
18    /// MQTT代理地址
19    pub host: String,
20    /// MQTT代理端口
21    pub port: u16,
22    /// 客户端ID
23    pub client_id: String,
24    /// 用户名(可选)
25    pub username: Option<String>,
26    /// 密码(可选)
27    pub password: Option<String>,
28    /// 订阅的主题列表
29    pub topics: Vec<String>,
30    /// 服务质量等级 (0, 1, 2)
31    pub qos: Option<u8>,
32    /// 是否清除会话
33    pub clean_session: Option<bool>,
34    /// 保持连接的时间间隔(秒)
35    pub keep_alive: Option<u64>,
36}
37
38/// MQTT输入组件
39pub struct MqttInput {
40    config: MqttInputConfig,
41    client: Arc<Mutex<Option<AsyncClient>>>,
42    sender: Arc<Sender<MqttMsg>>,
43    receiver: Arc<Receiver<MqttMsg>>,
44    close_tx: broadcast::Sender<()>,
45}
46enum MqttMsg {
47    Publish(Publish),
48    Err(Error),
49}
50impl MqttInput {
51    /// 创建一个新的MQTT输入组件
52    pub fn new(config: &MqttInputConfig) -> Result<Self, Error> {
53        let (sender, receiver) = flume::bounded::<MqttMsg>(1000);
54        let (close_tx, _) = broadcast::channel(1);
55        Ok(Self {
56            config: config.clone(),
57            client: Arc::new(Mutex::new(None)),
58            sender: Arc::new(sender),
59            receiver: Arc::new(receiver),
60            close_tx,
61        })
62    }
63}
64
65#[async_trait]
66impl Input for MqttInput {
67    async fn connect(&self) -> Result<(), Error> {
68        // 创建MQTT选项
69        let mut mqtt_options =
70            MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
71        mqtt_options.set_manual_acks(true);
72        // 设置认证信息
73        if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
74            mqtt_options.set_credentials(username, password);
75        }
76
77        // 设置保持连接时间
78        if let Some(keep_alive) = self.config.keep_alive {
79            mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
80        }
81
82        // 设置清除会话
83        if let Some(clean_session) = self.config.clean_session {
84            mqtt_options.set_clean_session(clean_session);
85        }
86
87        // 创建MQTT客户端
88        let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
89        // 订阅主题
90        let qos_level = match self.config.qos {
91            Some(0) => QoS::AtMostOnce,
92            Some(1) => QoS::AtLeastOnce,
93            Some(2) => QoS::ExactlyOnce,
94            _ => QoS::AtLeastOnce, // 默认为QoS 1
95        };
96
97        for topic in &self.config.topics {
98            client
99                .subscribe(topic, qos_level)
100                .await
101                .map_err(|e| Error::Connection(format!("无法订阅MQTT主题 {}: {}", topic, e)))?;
102        }
103
104        // 保存客户端
105        let client_arc = self.client.clone();
106        let mut client_guard = client_arc.lock().await;
107        *client_guard = Some(client);
108
109        // 启动事件循环处理线程
110        let sender_arc = self.sender.clone();
111        let mut rx = self.close_tx.subscribe();
112        tokio::spawn(async move {
113            loop {
114                tokio::select! {
115                    result = eventloop.poll() => {
116                        match result {
117                            Ok(event) => {
118                                if let Event::Incoming(Packet::Publish(publish)) = event {
119                                    // 将消息添加到队列
120                                    match sender_arc.send_async(MqttMsg::Publish(publish)).await {
121                                        Ok(_) => {}
122                                        Err(e) => {
123                                            error!("{}",e)
124                                        }
125                                    };
126                                }
127                            }
128                            Err(e) => {
129                               // 记录错误并尝试短暂等待后继续
130                                error!("MQTT事件循环错误: {}", e);
131                                match sender_arc.send_async(MqttMsg::Err(Error::Disconnection)).await {
132                                        Ok(_) => {}
133                                        Err(e) => {
134                                            error!("{}",e)
135                                        }
136                                };
137                                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
138                            }
139                        }
140                    }
141                    _ = rx.recv() => {
142                        break;
143                    }
144                }
145            }
146        });
147
148        Ok(())
149    }
150
151    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
152        {
153            let client_arc = self.client.clone();
154            if client_arc.lock().await.is_none() {
155                return Err(Error::Disconnection);
156            }
157        }
158
159        let mut close_rx = self.close_tx.subscribe();
160        tokio::select! {
161            result = self.receiver.recv_async() =>{
162                match result {
163                    Ok(msg) => {
164                        match msg{
165                            MqttMsg::Publish(publish) => {
166                                 let payload = publish.payload.to_vec();
167                            let msg = MessageBatch::new_binary(vec![payload]);
168                            Ok((msg, Arc::new(MqttAck {
169                                client: self.client.clone(),
170                                publish,
171                            })))
172                            },
173                            MqttMsg::Err(e) => {
174                                  Err(e)
175                            }
176                        }
177                    }
178                    Err(_) => {
179                        Err(Error::Done)
180                    }
181                }
182            },
183            _ = close_rx.recv()=>{
184                Err(Error::Done)
185            }
186        }
187    }
188
189    async fn close(&self) -> Result<(), Error> {
190        // 发送关闭信号
191        let _ = self.close_tx.send(());
192
193        // 断开MQTT连接
194        let client_arc = self.client.clone();
195        let client_guard = client_arc.lock().await;
196        if let Some(client) = &*client_guard {
197            // 尝试断开连接,但不等待结果
198            let _ = client.disconnect().await;
199        }
200
201        Ok(())
202    }
203}
204
205struct MqttAck {
206    client: Arc<Mutex<Option<AsyncClient>>>,
207    publish: Publish,
208}
209#[async_trait]
210impl Ack for MqttAck {
211    async fn ack(&self) {
212        let mutex_guard = self.client.lock().await;
213        if let Some(client) = &*mutex_guard {
214            if let Err(e) = client.ack(&self.publish).await {
215                error!("{}", e);
216            }
217        }
218    }
219}