arkflow_plugin/output/
mqtt.rs

1//! MQTT output component
2//!
3//! Send the processed data to the MQTT broker
4
5use arkflow_core::output::{register_output_builder, Output, OutputBuilder};
6use arkflow_core::{Error, MessageBatch};
7use async_trait::async_trait;
8use rumqttc::{AsyncClient, ClientError, MqttOptions, QoS};
9use serde::{Deserialize, Serialize};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::Mutex;
13use tracing::info;
14
15/// MQTT output configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MqttOutputConfig {
18    /// MQTT broker address
19    pub host: String,
20    /// MQTT broker port
21    pub port: u16,
22    /// Client ID
23    pub client_id: String,
24    /// Username (optional)
25    pub username: Option<String>,
26    /// Password (optional)
27    pub password: Option<String>,
28    /// Published topics
29    pub topic: String,
30    /// Quality of Service (0, 1, 2)
31    pub qos: Option<u8>,
32    /// Whether to use clean session
33    pub clean_session: Option<bool>,
34    /// Keep alive interval (seconds)
35    pub keep_alive: Option<u64>,
36    /// Whether to retain the message
37    pub retain: Option<bool>,
38}
39
40/// MQTT output component
41struct MqttOutput<T: MqttClient> {
42    config: MqttOutputConfig,
43    client: Arc<Mutex<Option<T>>>,
44    connected: AtomicBool,
45    eventloop_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
46}
47
48impl<T: MqttClient> MqttOutput<T> {
49    /// Create a new MQTT output component
50    pub fn new(config: MqttOutputConfig) -> Result<Self, Error> {
51        Ok(Self {
52            config,
53            client: Arc::new(Mutex::new(None)),
54            connected: AtomicBool::new(false),
55            eventloop_handle: Arc::new(Mutex::new(None)),
56        })
57    }
58}
59
60#[async_trait]
61impl<T: MqttClient> Output for MqttOutput<T> {
62    async fn connect(&self) -> Result<(), Error> {
63        // Create MQTT options
64        let mut mqtt_options =
65            MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
66
67        // Set the authentication information
68        if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
69            mqtt_options.set_credentials(username, password);
70        }
71
72        // Set the keep-alive time
73        if let Some(keep_alive) = self.config.keep_alive {
74            mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
75        }
76
77        // Set up a purge session
78        if let Some(clean_session) = self.config.clean_session {
79            mqtt_options.set_clean_session(clean_session);
80        }
81
82        // Create an MQTT client
83        let (client, mut eventloop) = T::create(mqtt_options, 10).await?;
84        // Save the client
85        let client_arc = self.client.clone();
86        let mut client_guard = client_arc.lock().await;
87        *client_guard = Some(client);
88
89        // Start an event loop processing thread (keep the connection active)
90        let eventloop_handle = tokio::spawn(async move {
91            while let Ok(_) = eventloop.poll().await {
92                // Just keep the event loop running and don't need to process the event
93            }
94        });
95
96        // Holds the event loop processing thread handle
97        let eventloop_handle_arc = self.eventloop_handle.clone();
98        let mut eventloop_handle_guard = eventloop_handle_arc.lock().await;
99        *eventloop_handle_guard = Some(eventloop_handle);
100
101        self.connected.store(true, Ordering::SeqCst);
102        Ok(())
103    }
104
105    async fn write(&self, msg: &MessageBatch) -> Result<(), Error> {
106        if !self.connected.load(Ordering::SeqCst) {
107            return Err(Error::Connection("The output is not connected".to_string()));
108        }
109
110        let client_arc = self.client.clone();
111        let client_guard = client_arc.lock().await;
112        let client = client_guard
113            .as_ref()
114            .ok_or_else(|| Error::Connection("The MQTT client is not initialized".to_string()))?;
115
116        // Get the message content
117        let payloads = match msg.as_string() {
118            Ok(v) => v.to_vec(),
119            Err(e) => {
120                return Err(e);
121            }
122        };
123
124        for payload in payloads {
125            info!(
126                "Send message: {}",
127                &String::from_utf8_lossy((&payload).as_ref())
128            );
129
130            // Determine the QoS level
131            let qos_level = match self.config.qos {
132                Some(0) => QoS::AtMostOnce,
133                Some(1) => QoS::AtLeastOnce,
134                Some(2) => QoS::ExactlyOnce,
135                _ => QoS::AtLeastOnce, // The default is QoS 1
136            };
137
138            // Decide whether to keep the message
139            let retain = self.config.retain.unwrap_or(false);
140
141            // Post a message
142            client
143                .publish(&self.config.topic, qos_level, retain, payload)
144                .await
145                .map_err(|e| Error::Process(format!("MQTT publishing failed: {}", e)))?;
146        }
147
148        Ok(())
149    }
150
151    async fn close(&self) -> Result<(), Error> {
152        // Stop the event loop processing thread
153        let mut eventloop_handle_guard = self.eventloop_handle.lock().await;
154        if let Some(handle) = eventloop_handle_guard.take() {
155            handle.abort();
156        }
157
158        // Disconnect the MQTT connection
159        let client_arc = self.client.clone();
160        let client_guard = client_arc.lock().await;
161        if let Some(client) = &*client_guard {
162            // Try to disconnect, but don't wait for the result
163            let _ = client.disconnect().await;
164        }
165
166        self.connected.store(false, Ordering::SeqCst);
167        Ok(())
168    }
169}
170
171pub(crate) struct MqttOutputBuilder;
172impl OutputBuilder for MqttOutputBuilder {
173    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Output>, Error> {
174        if config.is_none() {
175            return Err(Error::Config(
176                "HTTP output configuration is missing".to_string(),
177            ));
178        }
179        let config: MqttOutputConfig = serde_json::from_value(config.clone().unwrap())?;
180        Ok(Arc::new(MqttOutput::<AsyncClient>::new(config)?))
181    }
182}
183
184pub fn init() {
185    register_output_builder("mqtt", Arc::new(MqttOutputBuilder));
186}
187
188#[async_trait]
189trait MqttClient: Send + Sync {
190    async fn create(
191        mqtt_options: MqttOptions,
192        cap: usize,
193    ) -> Result<(Self, rumqttc::EventLoop), Error>
194    where
195        Self: Sized;
196
197    async fn publish<S, V>(
198        &self,
199        topic: S,
200        qos: QoS,
201        retain: bool,
202        payload: V,
203    ) -> Result<(), ClientError>
204    where
205        S: Into<String> + Send,
206        V: Into<Vec<u8>> + Send;
207
208    // Add the disconnect method to the trait
209    async fn disconnect(&self) -> Result<(), ClientError>;
210}
211
212#[async_trait]
213impl MqttClient for AsyncClient {
214    async fn create(
215        mqtt_options: MqttOptions,
216        cap: usize,
217    ) -> Result<(Self, rumqttc::EventLoop), Error>
218    where
219        Self: Sized,
220    {
221        let (client, eventloop) = AsyncClient::new(mqtt_options, cap);
222        Ok((client, eventloop))
223    }
224
225    async fn publish<S, V>(
226        &self,
227        topic: S,
228        qos: QoS,
229        retain: bool,
230        payload: V,
231    ) -> Result<(), ClientError>
232    where
233        S: Into<String> + Send,
234        V: Into<Vec<u8>> + Send,
235    {
236        AsyncClient::publish(self, topic, qos, retain, payload).await
237    }
238
239    async fn disconnect(&self) -> Result<(), ClientError> {
240        AsyncClient::disconnect(self).await
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::sync::Arc;
248    use tokio::sync::Mutex;
249
250    // Mock MQTT client for testing
251    struct MockMqttClient {
252        connected: Arc<AtomicBool>,
253        published_messages: Arc<Mutex<Vec<(String, Vec<u8>)>>>,
254    }
255
256    impl MockMqttClient {
257        fn new() -> Self {
258            Self {
259                connected: Arc::new(AtomicBool::new(true)),
260                published_messages: Arc::new(Mutex::new(Vec::new())),
261            }
262        }
263    }
264
265    #[async_trait]
266    impl MqttClient for MockMqttClient {
267        async fn create(
268            _mqtt_options: MqttOptions,
269            _cap: usize,
270        ) -> Result<(Self, rumqttc::EventLoop), Error> {
271            // Create a new EventLoop directly without using new() method
272            let (_, eventloop) = AsyncClient::new(MqttOptions::new("", "", 0), 10);
273            Ok((Self::new(), eventloop))
274        }
275
276        async fn publish<S, V>(
277            &self,
278            topic: S,
279            _qos: QoS,
280            _retain: bool,
281            payload: V,
282        ) -> Result<(), ClientError>
283        where
284            S: Into<String> + Send,
285            V: Into<Vec<u8>> + Send,
286        {
287            let mut messages = self.published_messages.lock().await;
288            messages.push((topic.into(), payload.into()));
289            Ok(())
290        }
291
292        async fn disconnect(&self) -> Result<(), ClientError> {
293            self.connected.store(false, Ordering::SeqCst);
294            Ok(())
295        }
296    }
297
298    /// Test creating a new MQTT output component
299    #[tokio::test]
300    async fn test_mqtt_output_new() {
301        let config = MqttOutputConfig {
302            host: "localhost".to_string(),
303            port: 1883,
304            client_id: "test_client".to_string(),
305            username: Some("user".to_string()),
306            password: Some("pass".to_string()),
307            topic: "test/topic".to_string(),
308            qos: Some(1),
309            clean_session: Some(true),
310            keep_alive: Some(60),
311            retain: Some(false),
312        };
313
314        let output = MqttOutput::<MockMqttClient>::new(config);
315        assert!(output.is_ok());
316    }
317
318    /// Test MQTT output connection
319    #[tokio::test]
320    async fn test_mqtt_output_connect() {
321        let config = MqttOutputConfig {
322            host: "localhost".to_string(),
323            port: 1883,
324            client_id: "test_client".to_string(),
325            username: None,
326            password: None,
327            topic: "test/topic".to_string(),
328            qos: None,
329            clean_session: None,
330            keep_alive: None,
331            retain: None,
332        };
333
334        let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
335        assert!(output.connect().await.is_ok());
336    }
337
338    /// Test MQTT message publishing
339    #[tokio::test]
340    async fn test_mqtt_output_write() {
341        let config = MqttOutputConfig {
342            host: "localhost".to_string(),
343            port: 1883,
344            client_id: "test_client".to_string(),
345            username: None,
346            password: None,
347            topic: "test/topic".to_string(),
348            qos: None,
349            clean_session: None,
350            keep_alive: None,
351            retain: None,
352        };
353
354        let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
355        output.connect().await.unwrap();
356
357        let msg = MessageBatch::from_string("test message");
358        assert!(output.write(&msg).await.is_ok());
359
360        // Verify the message was published
361        let client = output.client.lock().await;
362        let mock_client = client.as_ref().unwrap();
363        let messages = mock_client.published_messages.lock().await;
364        assert_eq!(messages.len(), 1);
365        assert_eq!(messages[0].0, "test/topic");
366        assert_eq!(messages[0].1, b"test message");
367    }
368
369    /// Test MQTT output disconnection
370    #[tokio::test]
371    async fn test_mqtt_output_close() {
372        let config = MqttOutputConfig {
373            host: "localhost".to_string(),
374            port: 1883,
375            client_id: "test_client".to_string(),
376            username: None,
377            password: None,
378            topic: "test/topic".to_string(),
379            qos: None,
380            clean_session: None,
381            keep_alive: None,
382            retain: None,
383        };
384
385        let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
386        output.connect().await.unwrap();
387        assert!(output.close().await.is_ok());
388
389        // Verify the client is disconnected
390        let client = output.client.lock().await;
391        let mock_client = client.as_ref().unwrap();
392        assert!(!mock_client.connected.load(Ordering::SeqCst));
393    }
394
395    /// Test error handling when writing to disconnected client
396    #[tokio::test]
397    async fn test_mqtt_output_write_disconnected() {
398        let config = MqttOutputConfig {
399            host: "localhost".to_string(),
400            port: 1883,
401            client_id: "test_client".to_string(),
402            username: None,
403            password: None,
404            topic: "test/topic".to_string(),
405            qos: None,
406            clean_session: None,
407            keep_alive: None,
408            retain: None,
409        };
410
411        let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
412        output.connect().await.unwrap();
413        output.close().await.unwrap();
414
415        let msg = MessageBatch::from_string("test message");
416        assert!(output.write(&msg).await.is_err());
417    }
418}