arkflow_plugin/input/
mqtt.rs

1//! MQTT input component
2//!
3//! Receive data from the MQTT broker
4
5use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder};
6use arkflow_core::{Error, MessageBatch};
7
8use async_trait::async_trait;
9use flume::{Receiver, Sender};
10use rumqttc::{AsyncClient, Event, MqttOptions, Packet, Publish, QoS};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tokio::sync::{broadcast, Mutex};
14use tracing::error;
15
16/// MQTT input configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MqttInputConfig {
19    /// MQTT broker address
20    pub host: String,
21    /// MQTT broker port
22    pub port: u16,
23    /// Client ID
24    pub client_id: String,
25    /// Username (optional)
26    pub username: Option<String>,
27    /// Password (optional)
28    pub password: Option<String>,
29    /// List of topics to subscribe to
30    pub topics: Vec<String>,
31    /// Quality of Service (0, 1, 2)
32    pub qos: Option<u8>,
33    /// Whether to use clean session
34    pub clean_session: Option<bool>,
35    /// Keep alive interval (in seconds)
36    pub keep_alive: Option<u64>,
37}
38
39/// MQTT input component
40pub struct MqttInput {
41    config: MqttInputConfig,
42    client: Arc<Mutex<Option<AsyncClient>>>,
43    sender: Arc<Sender<MqttMsg>>,
44    receiver: Arc<Receiver<MqttMsg>>,
45    close_tx: broadcast::Sender<()>,
46}
47
48enum MqttMsg {
49    Publish(Publish),
50    Err(Error),
51}
52
53impl MqttInput {
54    /// Create a new MQTT input component
55    pub fn new(config: MqttInputConfig) -> Result<Self, Error> {
56        let (sender, receiver) = flume::bounded::<MqttMsg>(1000);
57        let (close_tx, _) = broadcast::channel(1);
58        Ok(Self {
59            config: config.clone(),
60            client: Arc::new(Mutex::new(None)),
61            sender: Arc::new(sender),
62            receiver: Arc::new(receiver),
63            close_tx,
64        })
65    }
66}
67
68#[async_trait]
69impl Input for MqttInput {
70    async fn connect(&self) -> Result<(), Error> {
71        // Create MQTT options
72        let mut mqtt_options =
73            MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
74        mqtt_options.set_manual_acks(true);
75        // Set the authentication information
76        if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
77            mqtt_options.set_credentials(username, password);
78        }
79
80        // Set the keep-alive time
81        if let Some(keep_alive) = self.config.keep_alive {
82            mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
83        }
84
85        // Set up a clean session
86        if let Some(clean_session) = self.config.clean_session {
87            mqtt_options.set_clean_session(clean_session);
88        }
89
90        // Create an MQTT client
91        let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
92        // Subscribe to topics
93        let qos_level = match self.config.qos {
94            Some(0) => QoS::AtMostOnce,
95            Some(1) => QoS::AtLeastOnce,
96            Some(2) => QoS::ExactlyOnce,
97            _ => QoS::AtLeastOnce, // Default is QoS 1
98        };
99
100        for topic in &self.config.topics {
101            client.subscribe(topic, qos_level).await.map_err(|e| {
102                Error::Connection(format!(
103                    "Unable to subscribe to MQTT topics {}: {}",
104                    topic, e
105                ))
106            })?;
107        }
108
109        let client_arc = self.client.clone();
110        let mut client_guard = client_arc.lock().await;
111        *client_guard = Some(client);
112
113        let sender_arc = self.sender.clone();
114        let mut rx = self.close_tx.subscribe();
115        tokio::spawn(async move {
116            loop {
117                tokio::select! {
118                    result = eventloop.poll() => {
119                        match result {
120                            Ok(event) => {
121                                if let Event::Incoming(Packet::Publish(publish)) = event {
122                                    // Add messages to the queue
123                                    match sender_arc.send_async(MqttMsg::Publish(publish)).await {
124                                        Ok(_) => {}
125                                        Err(e) => {
126                                            error!("{}",e)
127                                        }
128                                    };
129                                }
130                            }
131                            Err(e) => {
132                               // Log the error and wait a short time before continuing
133                                error!("MQTT event loop error: {}", e);
134                                match sender_arc.send_async(MqttMsg::Err(Error::Disconnection)).await {
135                                        Ok(_) => {}
136                                        Err(e) => {
137                                            error!("{}",e)
138                                        }
139                                };
140                                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
141                            }
142                        }
143                    }
144                    _ = rx.recv() => {
145                        break;
146                    }
147                }
148            }
149        });
150
151        Ok(())
152    }
153
154    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
155        {
156            let client_arc = self.client.clone();
157            if client_arc.lock().await.is_none() {
158                return Err(Error::Disconnection);
159            }
160        }
161
162        let mut close_rx = self.close_tx.subscribe();
163        tokio::select! {
164            result = self.receiver.recv_async() =>{
165                match result {
166                    Ok(msg) => {
167                        match msg{
168                            MqttMsg::Publish(publish) => {
169                                 let payload = publish.payload.to_vec();
170                            let msg = MessageBatch::new_binary(vec![payload]);
171                            Ok((msg, Arc::new(MqttAck {
172                                client: self.client.clone(),
173                                publish,
174                            })))
175                            },
176                            MqttMsg::Err(e) => {
177                                  Err(e)
178                            }
179                        }
180                    }
181                    Err(_) => {
182                        Err(Error::EOF)
183                    }
184                }
185            },
186            _ = close_rx.recv()=>{
187                Err(Error::EOF)
188            }
189        }
190    }
191
192    async fn close(&self) -> Result<(), Error> {
193        // Send a shutdown signal
194        let _ = self.close_tx.send(());
195
196        // Disconnect the MQTT connection
197        let client_arc = self.client.clone();
198        let client_guard = client_arc.lock().await;
199        if let Some(client) = &*client_guard {
200            // Try to disconnect, but don't wait for the result
201            let _ = client.disconnect().await;
202        }
203
204        Ok(())
205    }
206}
207
208struct MqttAck {
209    client: Arc<Mutex<Option<AsyncClient>>>,
210    publish: Publish,
211}
212#[async_trait]
213impl Ack for MqttAck {
214    async fn ack(&self) {
215        let mutex_guard = self.client.lock().await;
216        if let Some(client) = &*mutex_guard {
217            if let Err(e) = client.ack(&self.publish).await {
218                error!("{}", e);
219            }
220        }
221    }
222}
223
224pub(crate) struct MqttInputBuilder;
225impl InputBuilder for MqttInputBuilder {
226    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
227        if config.is_none() {
228            return Err(Error::Config(
229                "MQTT input configuration is missing".to_string(),
230            ));
231        }
232
233        let config: MqttInputConfig = serde_json::from_value(config.clone().unwrap())?;
234        Ok(Arc::new(MqttInput::new(config)?))
235    }
236}
237
238pub fn init() {
239    register_input_builder("mqtt", Arc::new(MqttInputBuilder));
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[tokio::test]
247    async fn test_mqtt_input_new() {
248        let config = MqttInputConfig {
249            host: "localhost".to_string(),
250            port: 1883,
251            client_id: "test-client".to_string(),
252            username: Some("user".to_string()),
253            password: Some("pass".to_string()),
254            topics: vec!["test/topic".to_string()],
255            qos: Some(1),
256            clean_session: Some(true),
257            keep_alive: Some(60),
258        };
259
260        let input = MqttInput::new(config);
261        assert!(input.is_ok());
262        let input = input.unwrap();
263        assert_eq!(input.config.host, "localhost");
264        assert_eq!(input.config.port, 1883);
265        assert_eq!(input.config.client_id, "test-client");
266        assert_eq!(input.config.username, Some("user".to_string()));
267        assert_eq!(input.config.password, Some("pass".to_string()));
268        assert_eq!(input.config.topics, vec!["test/topic".to_string()]);
269        assert_eq!(input.config.qos, Some(1));
270        assert_eq!(input.config.clean_session, Some(true));
271        assert_eq!(input.config.keep_alive, Some(60));
272    }
273
274    #[tokio::test]
275    async fn test_mqtt_input_read_not_connected() {
276        let config = MqttInputConfig {
277            host: "localhost".to_string(),
278            port: 1883,
279            client_id: "test-client".to_string(),
280            username: None,
281            password: None,
282            topics: vec!["test/topic".to_string()],
283            qos: None,
284            clean_session: None,
285            keep_alive: None,
286        };
287
288        let input = MqttInput::new(config).unwrap();
289        // Try to read a message without connection, should return an error
290        let result = input.read().await;
291        assert!(result.is_err());
292        match result {
293            Err(Error::Disconnection) => {}
294            _ => panic!("Expected Disconnection error"),
295        }
296    }
297
298    #[tokio::test]
299    async fn test_mqtt_input_close() {
300        let config = MqttInputConfig {
301            host: "localhost".to_string(),
302            port: 1883,
303            client_id: "test-client".to_string(),
304            username: None,
305            password: None,
306            topics: vec!["test/topic".to_string()],
307            qos: None,
308            clean_session: None,
309            keep_alive: None,
310        };
311
312        let input = MqttInput::new(config).unwrap();
313        // Test closing operation, should succeed even if not connected
314        let result = input.close().await;
315        assert!(result.is_ok());
316    }
317
318    #[tokio::test]
319    async fn test_mqtt_input_message_processing() {
320        let config = MqttInputConfig {
321            host: "localhost".to_string(),
322            port: 1883,
323            client_id: "test-client".to_string(),
324            username: None,
325            password: None,
326            topics: vec!["test/topic".to_string()],
327            qos: None,
328            clean_session: None,
329            keep_alive: None,
330        };
331
332        let input = MqttInput::new(config).unwrap();
333
334        // Manually send a message to the receive queue
335        let test_payload = "test message".as_bytes().to_vec();
336        let publish = Publish {
337            dup: false,
338            qos: QoS::AtLeastOnce,
339            retain: false,
340            topic: "test/topic".to_string(),
341            pkid: 1,
342            payload: test_payload.into(),
343        };
344
345        // Send message to queue
346        input
347            .sender
348            .send_async(MqttMsg::Publish(publish))
349            .await
350            .unwrap();
351
352        // Simulate connection status
353        let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
354        input.client.lock().await.replace(client);
355
356        // Read message and verify
357        let result = input.read().await;
358        assert!(result.is_ok());
359        let (msg, ack) = result.unwrap();
360
361        // Verify message content
362        let content = msg.as_string().unwrap();
363        assert_eq!(content, vec!["test message"]);
364
365        // Test message acknowledgment
366        ack.ack().await;
367
368        // Close connection
369        assert!(input.close().await.is_ok());
370    }
371
372    #[tokio::test]
373    async fn test_mqtt_input_error_handling() {
374        let config = MqttInputConfig {
375            host: "localhost".to_string(),
376            port: 1883,
377            client_id: "test-client".to_string(),
378            username: None,
379            password: None,
380            topics: vec!["test/topic".to_string()],
381            qos: None,
382            clean_session: None,
383            keep_alive: None,
384        };
385
386        let input = MqttInput::new(config).unwrap();
387
388        // Simulate connection status
389        let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
390        input.client.lock().await.replace(client);
391
392        // Send error message to queue
393        input
394            .sender
395            .send_async(MqttMsg::Err(Error::Disconnection))
396            .await
397            .unwrap();
398
399        // Read message and verify error handling
400        let result = input.read().await;
401        assert!(result.is_err());
402        match result {
403            Err(Error::Disconnection) => {}
404            _ => panic!("Expected Disconnection error"),
405        }
406
407        // Close connection
408        assert!(input.close().await.is_ok());
409    }
410}