1use arkflow_core::output::{register_output_builder, Output, OutputBuilder};
6use arkflow_core::{Error, MessageBatch};
7use async_trait::async_trait;
8use rumqttc::{AsyncClient, ClientError, MqttOptions, QoS};
9use serde::{Deserialize, Serialize};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::Mutex;
13use tracing::info;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MqttOutputConfig {
18 pub host: String,
20 pub port: u16,
22 pub client_id: String,
24 pub username: Option<String>,
26 pub password: Option<String>,
28 pub topic: String,
30 pub qos: Option<u8>,
32 pub clean_session: Option<bool>,
34 pub keep_alive: Option<u64>,
36 pub retain: Option<bool>,
38}
39
40struct MqttOutput<T: MqttClient> {
42 config: MqttOutputConfig,
43 client: Arc<Mutex<Option<T>>>,
44 connected: AtomicBool,
45 eventloop_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
46}
47
48impl<T: MqttClient> MqttOutput<T> {
49 pub fn new(config: MqttOutputConfig) -> Result<Self, Error> {
51 Ok(Self {
52 config,
53 client: Arc::new(Mutex::new(None)),
54 connected: AtomicBool::new(false),
55 eventloop_handle: Arc::new(Mutex::new(None)),
56 })
57 }
58}
59
60#[async_trait]
61impl<T: MqttClient> Output for MqttOutput<T> {
62 async fn connect(&self) -> Result<(), Error> {
63 let mut mqtt_options =
65 MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
66
67 if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
69 mqtt_options.set_credentials(username, password);
70 }
71
72 if let Some(keep_alive) = self.config.keep_alive {
74 mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
75 }
76
77 if let Some(clean_session) = self.config.clean_session {
79 mqtt_options.set_clean_session(clean_session);
80 }
81
82 let (client, mut eventloop) = T::create(mqtt_options, 10).await?;
84 let client_arc = self.client.clone();
86 let mut client_guard = client_arc.lock().await;
87 *client_guard = Some(client);
88
89 let eventloop_handle = tokio::spawn(async move {
91 while let Ok(_) = eventloop.poll().await {
92 }
94 });
95
96 let eventloop_handle_arc = self.eventloop_handle.clone();
98 let mut eventloop_handle_guard = eventloop_handle_arc.lock().await;
99 *eventloop_handle_guard = Some(eventloop_handle);
100
101 self.connected.store(true, Ordering::SeqCst);
102 Ok(())
103 }
104
105 async fn write(&self, msg: &MessageBatch) -> Result<(), Error> {
106 if !self.connected.load(Ordering::SeqCst) {
107 return Err(Error::Connection("The output is not connected".to_string()));
108 }
109
110 let client_arc = self.client.clone();
111 let client_guard = client_arc.lock().await;
112 let client = client_guard
113 .as_ref()
114 .ok_or_else(|| Error::Connection("The MQTT client is not initialized".to_string()))?;
115
116 let payloads = match msg.as_string() {
118 Ok(v) => v.to_vec(),
119 Err(e) => {
120 return Err(e);
121 }
122 };
123
124 for payload in payloads {
125 info!(
126 "Send message: {}",
127 &String::from_utf8_lossy((&payload).as_ref())
128 );
129
130 let qos_level = match self.config.qos {
132 Some(0) => QoS::AtMostOnce,
133 Some(1) => QoS::AtLeastOnce,
134 Some(2) => QoS::ExactlyOnce,
135 _ => QoS::AtLeastOnce, };
137
138 let retain = self.config.retain.unwrap_or(false);
140
141 client
143 .publish(&self.config.topic, qos_level, retain, payload)
144 .await
145 .map_err(|e| Error::Process(format!("MQTT publishing failed: {}", e)))?;
146 }
147
148 Ok(())
149 }
150
151 async fn close(&self) -> Result<(), Error> {
152 let mut eventloop_handle_guard = self.eventloop_handle.lock().await;
154 if let Some(handle) = eventloop_handle_guard.take() {
155 handle.abort();
156 }
157
158 let client_arc = self.client.clone();
160 let client_guard = client_arc.lock().await;
161 if let Some(client) = &*client_guard {
162 let _ = client.disconnect().await;
164 }
165
166 self.connected.store(false, Ordering::SeqCst);
167 Ok(())
168 }
169}
170
171pub(crate) struct MqttOutputBuilder;
172impl OutputBuilder for MqttOutputBuilder {
173 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Output>, Error> {
174 if config.is_none() {
175 return Err(Error::Config(
176 "HTTP output configuration is missing".to_string(),
177 ));
178 }
179 let config: MqttOutputConfig = serde_json::from_value(config.clone().unwrap())?;
180 Ok(Arc::new(MqttOutput::<AsyncClient>::new(config)?))
181 }
182}
183
184pub fn init() {
185 register_output_builder("mqtt", Arc::new(MqttOutputBuilder));
186}
187
188#[async_trait]
189trait MqttClient: Send + Sync {
190 async fn create(
191 mqtt_options: MqttOptions,
192 cap: usize,
193 ) -> Result<(Self, rumqttc::EventLoop), Error>
194 where
195 Self: Sized;
196
197 async fn publish<S, V>(
198 &self,
199 topic: S,
200 qos: QoS,
201 retain: bool,
202 payload: V,
203 ) -> Result<(), ClientError>
204 where
205 S: Into<String> + Send,
206 V: Into<Vec<u8>> + Send;
207
208 async fn disconnect(&self) -> Result<(), ClientError>;
210}
211
212#[async_trait]
213impl MqttClient for AsyncClient {
214 async fn create(
215 mqtt_options: MqttOptions,
216 cap: usize,
217 ) -> Result<(Self, rumqttc::EventLoop), Error>
218 where
219 Self: Sized,
220 {
221 let (client, eventloop) = AsyncClient::new(mqtt_options, cap);
222 Ok((client, eventloop))
223 }
224
225 async fn publish<S, V>(
226 &self,
227 topic: S,
228 qos: QoS,
229 retain: bool,
230 payload: V,
231 ) -> Result<(), ClientError>
232 where
233 S: Into<String> + Send,
234 V: Into<Vec<u8>> + Send,
235 {
236 AsyncClient::publish(self, topic, qos, retain, payload).await
237 }
238
239 async fn disconnect(&self) -> Result<(), ClientError> {
240 AsyncClient::disconnect(self).await
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::sync::Arc;
248 use tokio::sync::Mutex;
249
250 struct MockMqttClient {
252 connected: Arc<AtomicBool>,
253 published_messages: Arc<Mutex<Vec<(String, Vec<u8>)>>>,
254 }
255
256 impl MockMqttClient {
257 fn new() -> Self {
258 Self {
259 connected: Arc::new(AtomicBool::new(true)),
260 published_messages: Arc::new(Mutex::new(Vec::new())),
261 }
262 }
263 }
264
265 #[async_trait]
266 impl MqttClient for MockMqttClient {
267 async fn create(
268 _mqtt_options: MqttOptions,
269 _cap: usize,
270 ) -> Result<(Self, rumqttc::EventLoop), Error> {
271 let (_, eventloop) = AsyncClient::new(MqttOptions::new("", "", 0), 10);
273 Ok((Self::new(), eventloop))
274 }
275
276 async fn publish<S, V>(
277 &self,
278 topic: S,
279 _qos: QoS,
280 _retain: bool,
281 payload: V,
282 ) -> Result<(), ClientError>
283 where
284 S: Into<String> + Send,
285 V: Into<Vec<u8>> + Send,
286 {
287 let mut messages = self.published_messages.lock().await;
288 messages.push((topic.into(), payload.into()));
289 Ok(())
290 }
291
292 async fn disconnect(&self) -> Result<(), ClientError> {
293 self.connected.store(false, Ordering::SeqCst);
294 Ok(())
295 }
296 }
297
298 #[tokio::test]
300 async fn test_mqtt_output_new() {
301 let config = MqttOutputConfig {
302 host: "localhost".to_string(),
303 port: 1883,
304 client_id: "test_client".to_string(),
305 username: Some("user".to_string()),
306 password: Some("pass".to_string()),
307 topic: "test/topic".to_string(),
308 qos: Some(1),
309 clean_session: Some(true),
310 keep_alive: Some(60),
311 retain: Some(false),
312 };
313
314 let output = MqttOutput::<MockMqttClient>::new(config);
315 assert!(output.is_ok());
316 }
317
318 #[tokio::test]
320 async fn test_mqtt_output_connect() {
321 let config = MqttOutputConfig {
322 host: "localhost".to_string(),
323 port: 1883,
324 client_id: "test_client".to_string(),
325 username: None,
326 password: None,
327 topic: "test/topic".to_string(),
328 qos: None,
329 clean_session: None,
330 keep_alive: None,
331 retain: None,
332 };
333
334 let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
335 assert!(output.connect().await.is_ok());
336 }
337
338 #[tokio::test]
340 async fn test_mqtt_output_write() {
341 let config = MqttOutputConfig {
342 host: "localhost".to_string(),
343 port: 1883,
344 client_id: "test_client".to_string(),
345 username: None,
346 password: None,
347 topic: "test/topic".to_string(),
348 qos: None,
349 clean_session: None,
350 keep_alive: None,
351 retain: None,
352 };
353
354 let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
355 output.connect().await.unwrap();
356
357 let msg = MessageBatch::from_string("test message");
358 assert!(output.write(&msg).await.is_ok());
359
360 let client = output.client.lock().await;
362 let mock_client = client.as_ref().unwrap();
363 let messages = mock_client.published_messages.lock().await;
364 assert_eq!(messages.len(), 1);
365 assert_eq!(messages[0].0, "test/topic");
366 assert_eq!(messages[0].1, b"test message");
367 }
368
369 #[tokio::test]
371 async fn test_mqtt_output_close() {
372 let config = MqttOutputConfig {
373 host: "localhost".to_string(),
374 port: 1883,
375 client_id: "test_client".to_string(),
376 username: None,
377 password: None,
378 topic: "test/topic".to_string(),
379 qos: None,
380 clean_session: None,
381 keep_alive: None,
382 retain: None,
383 };
384
385 let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
386 output.connect().await.unwrap();
387 assert!(output.close().await.is_ok());
388
389 let client = output.client.lock().await;
391 let mock_client = client.as_ref().unwrap();
392 assert!(!mock_client.connected.load(Ordering::SeqCst));
393 }
394
395 #[tokio::test]
397 async fn test_mqtt_output_write_disconnected() {
398 let config = MqttOutputConfig {
399 host: "localhost".to_string(),
400 port: 1883,
401 client_id: "test_client".to_string(),
402 username: None,
403 password: None,
404 topic: "test/topic".to_string(),
405 qos: None,
406 clean_session: None,
407 keep_alive: None,
408 retain: None,
409 };
410
411 let output = MqttOutput::<MockMqttClient>::new(config).unwrap();
412 output.connect().await.unwrap();
413 output.close().await.unwrap();
414
415 let msg = MessageBatch::from_string("test message");
416 assert!(output.write(&msg).await.is_err());
417 }
418}