1use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder};
6use arkflow_core::{Error, MessageBatch};
7
8use async_trait::async_trait;
9use flume::{Receiver, Sender};
10use rumqttc::{AsyncClient, Event, MqttOptions, Packet, Publish, QoS};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tokio::sync::{broadcast, Mutex};
14use tracing::error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MqttInputConfig {
19 pub host: String,
21 pub port: u16,
23 pub client_id: String,
25 pub username: Option<String>,
27 pub password: Option<String>,
29 pub topics: Vec<String>,
31 pub qos: Option<u8>,
33 pub clean_session: Option<bool>,
35 pub keep_alive: Option<u64>,
37}
38
39pub struct MqttInput {
41 config: MqttInputConfig,
42 client: Arc<Mutex<Option<AsyncClient>>>,
43 sender: Arc<Sender<MqttMsg>>,
44 receiver: Arc<Receiver<MqttMsg>>,
45 close_tx: broadcast::Sender<()>,
46}
47
48enum MqttMsg {
49 Publish(Publish),
50 Err(Error),
51}
52
53impl MqttInput {
54 pub fn new(config: MqttInputConfig) -> Result<Self, Error> {
56 let (sender, receiver) = flume::bounded::<MqttMsg>(1000);
57 let (close_tx, _) = broadcast::channel(1);
58 Ok(Self {
59 config: config.clone(),
60 client: Arc::new(Mutex::new(None)),
61 sender: Arc::new(sender),
62 receiver: Arc::new(receiver),
63 close_tx,
64 })
65 }
66}
67
68#[async_trait]
69impl Input for MqttInput {
70 async fn connect(&self) -> Result<(), Error> {
71 let mut mqtt_options =
73 MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
74 mqtt_options.set_manual_acks(true);
75 if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
77 mqtt_options.set_credentials(username, password);
78 }
79
80 if let Some(keep_alive) = self.config.keep_alive {
82 mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
83 }
84
85 if let Some(clean_session) = self.config.clean_session {
87 mqtt_options.set_clean_session(clean_session);
88 }
89
90 let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
92 let qos_level = match self.config.qos {
94 Some(0) => QoS::AtMostOnce,
95 Some(1) => QoS::AtLeastOnce,
96 Some(2) => QoS::ExactlyOnce,
97 _ => QoS::AtLeastOnce, };
99
100 for topic in &self.config.topics {
101 client.subscribe(topic, qos_level).await.map_err(|e| {
102 Error::Connection(format!(
103 "Unable to subscribe to MQTT topics {}: {}",
104 topic, e
105 ))
106 })?;
107 }
108
109 let client_arc = self.client.clone();
110 let mut client_guard = client_arc.lock().await;
111 *client_guard = Some(client);
112
113 let sender_arc = self.sender.clone();
114 let mut rx = self.close_tx.subscribe();
115 tokio::spawn(async move {
116 loop {
117 tokio::select! {
118 result = eventloop.poll() => {
119 match result {
120 Ok(event) => {
121 if let Event::Incoming(Packet::Publish(publish)) = event {
122 match sender_arc.send_async(MqttMsg::Publish(publish)).await {
124 Ok(_) => {}
125 Err(e) => {
126 error!("{}",e)
127 }
128 };
129 }
130 }
131 Err(e) => {
132 error!("MQTT event loop error: {}", e);
134 match sender_arc.send_async(MqttMsg::Err(Error::Disconnection)).await {
135 Ok(_) => {}
136 Err(e) => {
137 error!("{}",e)
138 }
139 };
140 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
141 }
142 }
143 }
144 _ = rx.recv() => {
145 break;
146 }
147 }
148 }
149 });
150
151 Ok(())
152 }
153
154 async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
155 {
156 let client_arc = self.client.clone();
157 if client_arc.lock().await.is_none() {
158 return Err(Error::Disconnection);
159 }
160 }
161
162 let mut close_rx = self.close_tx.subscribe();
163 tokio::select! {
164 result = self.receiver.recv_async() =>{
165 match result {
166 Ok(msg) => {
167 match msg{
168 MqttMsg::Publish(publish) => {
169 let payload = publish.payload.to_vec();
170 let msg = MessageBatch::new_binary(vec![payload]);
171 Ok((msg, Arc::new(MqttAck {
172 client: self.client.clone(),
173 publish,
174 })))
175 },
176 MqttMsg::Err(e) => {
177 Err(e)
178 }
179 }
180 }
181 Err(_) => {
182 Err(Error::EOF)
183 }
184 }
185 },
186 _ = close_rx.recv()=>{
187 Err(Error::EOF)
188 }
189 }
190 }
191
192 async fn close(&self) -> Result<(), Error> {
193 let _ = self.close_tx.send(());
195
196 let client_arc = self.client.clone();
198 let client_guard = client_arc.lock().await;
199 if let Some(client) = &*client_guard {
200 let _ = client.disconnect().await;
202 }
203
204 Ok(())
205 }
206}
207
208struct MqttAck {
209 client: Arc<Mutex<Option<AsyncClient>>>,
210 publish: Publish,
211}
212#[async_trait]
213impl Ack for MqttAck {
214 async fn ack(&self) {
215 let mutex_guard = self.client.lock().await;
216 if let Some(client) = &*mutex_guard {
217 if let Err(e) = client.ack(&self.publish).await {
218 error!("{}", e);
219 }
220 }
221 }
222}
223
224pub(crate) struct MqttInputBuilder;
225impl InputBuilder for MqttInputBuilder {
226 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
227 if config.is_none() {
228 return Err(Error::Config(
229 "MQTT input configuration is missing".to_string(),
230 ));
231 }
232
233 let config: MqttInputConfig = serde_json::from_value(config.clone().unwrap())?;
234 Ok(Arc::new(MqttInput::new(config)?))
235 }
236}
237
238pub fn init() {
239 register_input_builder("mqtt", Arc::new(MqttInputBuilder));
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[tokio::test]
247 async fn test_mqtt_input_new() {
248 let config = MqttInputConfig {
249 host: "localhost".to_string(),
250 port: 1883,
251 client_id: "test-client".to_string(),
252 username: Some("user".to_string()),
253 password: Some("pass".to_string()),
254 topics: vec!["test/topic".to_string()],
255 qos: Some(1),
256 clean_session: Some(true),
257 keep_alive: Some(60),
258 };
259
260 let input = MqttInput::new(config);
261 assert!(input.is_ok());
262 let input = input.unwrap();
263 assert_eq!(input.config.host, "localhost");
264 assert_eq!(input.config.port, 1883);
265 assert_eq!(input.config.client_id, "test-client");
266 assert_eq!(input.config.username, Some("user".to_string()));
267 assert_eq!(input.config.password, Some("pass".to_string()));
268 assert_eq!(input.config.topics, vec!["test/topic".to_string()]);
269 assert_eq!(input.config.qos, Some(1));
270 assert_eq!(input.config.clean_session, Some(true));
271 assert_eq!(input.config.keep_alive, Some(60));
272 }
273
274 #[tokio::test]
275 async fn test_mqtt_input_read_not_connected() {
276 let config = MqttInputConfig {
277 host: "localhost".to_string(),
278 port: 1883,
279 client_id: "test-client".to_string(),
280 username: None,
281 password: None,
282 topics: vec!["test/topic".to_string()],
283 qos: None,
284 clean_session: None,
285 keep_alive: None,
286 };
287
288 let input = MqttInput::new(config).unwrap();
289 let result = input.read().await;
291 assert!(result.is_err());
292 match result {
293 Err(Error::Disconnection) => {}
294 _ => panic!("Expected Disconnection error"),
295 }
296 }
297
298 #[tokio::test]
299 async fn test_mqtt_input_close() {
300 let config = MqttInputConfig {
301 host: "localhost".to_string(),
302 port: 1883,
303 client_id: "test-client".to_string(),
304 username: None,
305 password: None,
306 topics: vec!["test/topic".to_string()],
307 qos: None,
308 clean_session: None,
309 keep_alive: None,
310 };
311
312 let input = MqttInput::new(config).unwrap();
313 let result = input.close().await;
315 assert!(result.is_ok());
316 }
317
318 #[tokio::test]
319 async fn test_mqtt_input_message_processing() {
320 let config = MqttInputConfig {
321 host: "localhost".to_string(),
322 port: 1883,
323 client_id: "test-client".to_string(),
324 username: None,
325 password: None,
326 topics: vec!["test/topic".to_string()],
327 qos: None,
328 clean_session: None,
329 keep_alive: None,
330 };
331
332 let input = MqttInput::new(config).unwrap();
333
334 let test_payload = "test message".as_bytes().to_vec();
336 let publish = Publish {
337 dup: false,
338 qos: QoS::AtLeastOnce,
339 retain: false,
340 topic: "test/topic".to_string(),
341 pkid: 1,
342 payload: test_payload.into(),
343 };
344
345 input
347 .sender
348 .send_async(MqttMsg::Publish(publish))
349 .await
350 .unwrap();
351
352 let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
354 input.client.lock().await.replace(client);
355
356 let result = input.read().await;
358 assert!(result.is_ok());
359 let (msg, ack) = result.unwrap();
360
361 let content = msg.as_string().unwrap();
363 assert_eq!(content, vec!["test message"]);
364
365 ack.ack().await;
367
368 assert!(input.close().await.is_ok());
370 }
371
372 #[tokio::test]
373 async fn test_mqtt_input_error_handling() {
374 let config = MqttInputConfig {
375 host: "localhost".to_string(),
376 port: 1883,
377 client_id: "test-client".to_string(),
378 username: None,
379 password: None,
380 topics: vec!["test/topic".to_string()],
381 qos: None,
382 clean_session: None,
383 keep_alive: None,
384 };
385
386 let input = MqttInput::new(config).unwrap();
387
388 let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
390 input.client.lock().await.replace(client);
391
392 input
394 .sender
395 .send_async(MqttMsg::Err(Error::Disconnection))
396 .await
397 .unwrap();
398
399 let result = input.read().await;
401 assert!(result.is_err());
402 match result {
403 Err(Error::Disconnection) => {}
404 _ => panic!("Expected Disconnection error"),
405 }
406
407 assert!(input.close().await.is_ok());
409 }
410}