1use crate::input::Ack;
6use crate::{input::Input, Error, MessageBatch};
7use async_trait::async_trait;
8use flume::{Receiver, Sender};
9use rumqttc::{AsyncClient, Event, MqttOptions, Packet, Publish, QoS};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use tokio::sync::{broadcast, Mutex};
13use tracing::error;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MqttInputConfig {
18 pub host: String,
20 pub port: u16,
22 pub client_id: String,
24 pub username: Option<String>,
26 pub password: Option<String>,
28 pub topics: Vec<String>,
30 pub qos: Option<u8>,
32 pub clean_session: Option<bool>,
34 pub keep_alive: Option<u64>,
36}
37
38pub struct MqttInput {
40 config: MqttInputConfig,
41 client: Arc<Mutex<Option<AsyncClient>>>,
42 sender: Arc<Sender<MqttMsg>>,
43 receiver: Arc<Receiver<MqttMsg>>,
44 close_tx: broadcast::Sender<()>,
45}
46enum MqttMsg {
47 Publish(Publish),
48 Err(Error),
49}
50impl MqttInput {
51 pub fn new(config: &MqttInputConfig) -> Result<Self, Error> {
53 let (sender, receiver) = flume::bounded::<MqttMsg>(1000);
54 let (close_tx, _) = broadcast::channel(1);
55 Ok(Self {
56 config: config.clone(),
57 client: Arc::new(Mutex::new(None)),
58 sender: Arc::new(sender),
59 receiver: Arc::new(receiver),
60 close_tx,
61 })
62 }
63}
64
65#[async_trait]
66impl Input for MqttInput {
67 async fn connect(&self) -> Result<(), Error> {
68 let mut mqtt_options =
70 MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
71 mqtt_options.set_manual_acks(true);
72 if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
74 mqtt_options.set_credentials(username, password);
75 }
76
77 if let Some(keep_alive) = self.config.keep_alive {
79 mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
80 }
81
82 if let Some(clean_session) = self.config.clean_session {
84 mqtt_options.set_clean_session(clean_session);
85 }
86
87 let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
89 let qos_level = match self.config.qos {
91 Some(0) => QoS::AtMostOnce,
92 Some(1) => QoS::AtLeastOnce,
93 Some(2) => QoS::ExactlyOnce,
94 _ => QoS::AtLeastOnce, };
96
97 for topic in &self.config.topics {
98 client
99 .subscribe(topic, qos_level)
100 .await
101 .map_err(|e| Error::Connection(format!("无法订阅MQTT主题 {}: {}", topic, e)))?;
102 }
103
104 let client_arc = self.client.clone();
106 let mut client_guard = client_arc.lock().await;
107 *client_guard = Some(client);
108
109 let sender_arc = self.sender.clone();
111 let mut rx = self.close_tx.subscribe();
112 tokio::spawn(async move {
113 loop {
114 tokio::select! {
115 result = eventloop.poll() => {
116 match result {
117 Ok(event) => {
118 if let Event::Incoming(Packet::Publish(publish)) = event {
119 match sender_arc.send_async(MqttMsg::Publish(publish)).await {
121 Ok(_) => {}
122 Err(e) => {
123 error!("{}",e)
124 }
125 };
126 }
127 }
128 Err(e) => {
129 error!("MQTT事件循环错误: {}", e);
131 match sender_arc.send_async(MqttMsg::Err(Error::Disconnection)).await {
132 Ok(_) => {}
133 Err(e) => {
134 error!("{}",e)
135 }
136 };
137 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
138 }
139 }
140 }
141 _ = rx.recv() => {
142 break;
143 }
144 }
145 }
146 });
147
148 Ok(())
149 }
150
151 async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
152 {
153 let client_arc = self.client.clone();
154 if client_arc.lock().await.is_none() {
155 return Err(Error::Disconnection);
156 }
157 }
158
159 let mut close_rx = self.close_tx.subscribe();
160 tokio::select! {
161 result = self.receiver.recv_async() =>{
162 match result {
163 Ok(msg) => {
164 match msg{
165 MqttMsg::Publish(publish) => {
166 let payload = publish.payload.to_vec();
167 let msg = MessageBatch::new_binary(vec![payload]);
168 Ok((msg, Arc::new(MqttAck {
169 client: self.client.clone(),
170 publish,
171 })))
172 },
173 MqttMsg::Err(e) => {
174 Err(e)
175 }
176 }
177 }
178 Err(_) => {
179 Err(Error::Done)
180 }
181 }
182 },
183 _ = close_rx.recv()=>{
184 Err(Error::Done)
185 }
186 }
187 }
188
189 async fn close(&self) -> Result<(), Error> {
190 let _ = self.close_tx.send(());
192
193 let client_arc = self.client.clone();
195 let client_guard = client_arc.lock().await;
196 if let Some(client) = &*client_guard {
197 let _ = client.disconnect().await;
199 }
200
201 Ok(())
202 }
203}
204
205struct MqttAck {
206 client: Arc<Mutex<Option<AsyncClient>>>,
207 publish: Publish,
208}
209#[async_trait]
210impl Ack for MqttAck {
211 async fn ack(&self) {
212 let mutex_guard = self.client.lock().await;
213 if let Some(client) = &*mutex_guard {
214 if let Err(e) = client.ack(&self.publish).await {
215 error!("{}", e);
216 }
217 }
218 }
219}