arkflow-plugin 0.1.0

High-performance Rust flow processing engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
//! MQTT input component
//!
//! Receive data from the MQTT broker

use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder};
use arkflow_core::{Error, MessageBatch};

use async_trait::async_trait;
use flume::{Receiver, Sender};
use rumqttc::{AsyncClient, Event, MqttOptions, Packet, Publish, QoS};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, Mutex};
use tracing::error;

/// MQTT input configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MqttInputConfig {
    /// MQTT broker address
    pub host: String,
    /// MQTT broker port
    pub port: u16,
    /// Client ID
    pub client_id: String,
    /// Username (optional)
    pub username: Option<String>,
    /// Password (optional)
    pub password: Option<String>,
    /// List of topics to subscribe to
    pub topics: Vec<String>,
    /// Quality of Service (0, 1, 2)
    pub qos: Option<u8>,
    /// Whether to use clean session
    pub clean_session: Option<bool>,
    /// Keep alive interval (in seconds)
    pub keep_alive: Option<u64>,
}

/// MQTT input component
pub struct MqttInput {
    config: MqttInputConfig,
    client: Arc<Mutex<Option<AsyncClient>>>,
    sender: Arc<Sender<MqttMsg>>,
    receiver: Arc<Receiver<MqttMsg>>,
    close_tx: broadcast::Sender<()>,
}

enum MqttMsg {
    Publish(Publish),
    Err(Error),
}

impl MqttInput {
    /// Create a new MQTT input component
    pub fn new(config: MqttInputConfig) -> Result<Self, Error> {
        let (sender, receiver) = flume::bounded::<MqttMsg>(1000);
        let (close_tx, _) = broadcast::channel(1);
        Ok(Self {
            config: config.clone(),
            client: Arc::new(Mutex::new(None)),
            sender: Arc::new(sender),
            receiver: Arc::new(receiver),
            close_tx,
        })
    }
}

#[async_trait]
impl Input for MqttInput {
    async fn connect(&self) -> Result<(), Error> {
        // Create MQTT options
        let mut mqtt_options =
            MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
        mqtt_options.set_manual_acks(true);
        // Set the authentication information
        if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
            mqtt_options.set_credentials(username, password);
        }

        // Set the keep-alive time
        if let Some(keep_alive) = self.config.keep_alive {
            mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
        }

        // Set up a clean session
        if let Some(clean_session) = self.config.clean_session {
            mqtt_options.set_clean_session(clean_session);
        }

        // Create an MQTT client
        let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
        // Subscribe to topics
        let qos_level = match self.config.qos {
            Some(0) => QoS::AtMostOnce,
            Some(1) => QoS::AtLeastOnce,
            Some(2) => QoS::ExactlyOnce,
            _ => QoS::AtLeastOnce, // Default is QoS 1
        };

        for topic in &self.config.topics {
            client.subscribe(topic, qos_level).await.map_err(|e| {
                Error::Connection(format!(
                    "Unable to subscribe to MQTT topics {}: {}",
                    topic, e
                ))
            })?;
        }

        let client_arc = self.client.clone();
        let mut client_guard = client_arc.lock().await;
        *client_guard = Some(client);

        let sender_arc = self.sender.clone();
        let mut rx = self.close_tx.subscribe();
        tokio::spawn(async move {
            loop {
                tokio::select! {
                    result = eventloop.poll() => {
                        match result {
                            Ok(event) => {
                                if let Event::Incoming(Packet::Publish(publish)) = event {
                                    // Add messages to the queue
                                    match sender_arc.send_async(MqttMsg::Publish(publish)).await {
                                        Ok(_) => {}
                                        Err(e) => {
                                            error!("{}",e)
                                        }
                                    };
                                }
                            }
                            Err(e) => {
                               // Log the error and wait a short time before continuing
                                error!("MQTT event loop error: {}", e);
                                match sender_arc.send_async(MqttMsg::Err(Error::Disconnection)).await {
                                        Ok(_) => {}
                                        Err(e) => {
                                            error!("{}",e)
                                        }
                                };
                                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
                            }
                        }
                    }
                    _ = rx.recv() => {
                        break;
                    }
                }
            }
        });

        Ok(())
    }

    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
        {
            let client_arc = self.client.clone();
            if client_arc.lock().await.is_none() {
                return Err(Error::Disconnection);
            }
        }

        let mut close_rx = self.close_tx.subscribe();
        tokio::select! {
            result = self.receiver.recv_async() =>{
                match result {
                    Ok(msg) => {
                        match msg{
                            MqttMsg::Publish(publish) => {
                                 let payload = publish.payload.to_vec();
                            let msg = MessageBatch::new_binary(vec![payload]);
                            Ok((msg, Arc::new(MqttAck {
                                client: self.client.clone(),
                                publish,
                            })))
                            },
                            MqttMsg::Err(e) => {
                                  Err(e)
                            }
                        }
                    }
                    Err(_) => {
                        Err(Error::EOF)
                    }
                }
            },
            _ = close_rx.recv()=>{
                Err(Error::EOF)
            }
        }
    }

    async fn close(&self) -> Result<(), Error> {
        // Send a shutdown signal
        let _ = self.close_tx.send(());

        // Disconnect the MQTT connection
        let client_arc = self.client.clone();
        let client_guard = client_arc.lock().await;
        if let Some(client) = &*client_guard {
            // Try to disconnect, but don't wait for the result
            let _ = client.disconnect().await;
        }

        Ok(())
    }
}

struct MqttAck {
    client: Arc<Mutex<Option<AsyncClient>>>,
    publish: Publish,
}
#[async_trait]
impl Ack for MqttAck {
    async fn ack(&self) {
        let mutex_guard = self.client.lock().await;
        if let Some(client) = &*mutex_guard {
            if let Err(e) = client.ack(&self.publish).await {
                error!("{}", e);
            }
        }
    }
}

pub(crate) struct MqttInputBuilder;
impl InputBuilder for MqttInputBuilder {
    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
        if config.is_none() {
            return Err(Error::Config(
                "MQTT input configuration is missing".to_string(),
            ));
        }

        let config: MqttInputConfig = serde_json::from_value(config.clone().unwrap())?;
        Ok(Arc::new(MqttInput::new(config)?))
    }
}

pub fn init() {
    register_input_builder("mqtt", Arc::new(MqttInputBuilder));
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_mqtt_input_new() {
        let config = MqttInputConfig {
            host: "localhost".to_string(),
            port: 1883,
            client_id: "test-client".to_string(),
            username: Some("user".to_string()),
            password: Some("pass".to_string()),
            topics: vec!["test/topic".to_string()],
            qos: Some(1),
            clean_session: Some(true),
            keep_alive: Some(60),
        };

        let input = MqttInput::new(config);
        assert!(input.is_ok());
        let input = input.unwrap();
        assert_eq!(input.config.host, "localhost");
        assert_eq!(input.config.port, 1883);
        assert_eq!(input.config.client_id, "test-client");
        assert_eq!(input.config.username, Some("user".to_string()));
        assert_eq!(input.config.password, Some("pass".to_string()));
        assert_eq!(input.config.topics, vec!["test/topic".to_string()]);
        assert_eq!(input.config.qos, Some(1));
        assert_eq!(input.config.clean_session, Some(true));
        assert_eq!(input.config.keep_alive, Some(60));
    }

    #[tokio::test]
    async fn test_mqtt_input_read_not_connected() {
        let config = MqttInputConfig {
            host: "localhost".to_string(),
            port: 1883,
            client_id: "test-client".to_string(),
            username: None,
            password: None,
            topics: vec!["test/topic".to_string()],
            qos: None,
            clean_session: None,
            keep_alive: None,
        };

        let input = MqttInput::new(config).unwrap();
        // Try to read a message without connection, should return an error
        let result = input.read().await;
        assert!(result.is_err());
        match result {
            Err(Error::Disconnection) => {}
            _ => panic!("Expected Disconnection error"),
        }
    }

    #[tokio::test]
    async fn test_mqtt_input_close() {
        let config = MqttInputConfig {
            host: "localhost".to_string(),
            port: 1883,
            client_id: "test-client".to_string(),
            username: None,
            password: None,
            topics: vec!["test/topic".to_string()],
            qos: None,
            clean_session: None,
            keep_alive: None,
        };

        let input = MqttInput::new(config).unwrap();
        // Test closing operation, should succeed even if not connected
        let result = input.close().await;
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn test_mqtt_input_message_processing() {
        let config = MqttInputConfig {
            host: "localhost".to_string(),
            port: 1883,
            client_id: "test-client".to_string(),
            username: None,
            password: None,
            topics: vec!["test/topic".to_string()],
            qos: None,
            clean_session: None,
            keep_alive: None,
        };

        let input = MqttInput::new(config).unwrap();

        // Manually send a message to the receive queue
        let test_payload = "test message".as_bytes().to_vec();
        let publish = Publish {
            dup: false,
            qos: QoS::AtLeastOnce,
            retain: false,
            topic: "test/topic".to_string(),
            pkid: 1,
            payload: test_payload.into(),
        };

        // Send message to queue
        input
            .sender
            .send_async(MqttMsg::Publish(publish))
            .await
            .unwrap();

        // Simulate connection status
        let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
        input.client.lock().await.replace(client);

        // Read message and verify
        let result = input.read().await;
        assert!(result.is_ok());
        let (msg, ack) = result.unwrap();

        // Verify message content
        let content = msg.as_string().unwrap();
        assert_eq!(content, vec!["test message"]);

        // Test message acknowledgment
        ack.ack().await;

        // Close connection
        assert!(input.close().await.is_ok());
    }

    #[tokio::test]
    async fn test_mqtt_input_error_handling() {
        let config = MqttInputConfig {
            host: "localhost".to_string(),
            port: 1883,
            client_id: "test-client".to_string(),
            username: None,
            password: None,
            topics: vec!["test/topic".to_string()],
            qos: None,
            clean_session: None,
            keep_alive: None,
        };

        let input = MqttInput::new(config).unwrap();

        // Simulate connection status
        let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
        input.client.lock().await.replace(client);

        // Send error message to queue
        input
            .sender
            .send_async(MqttMsg::Err(Error::Disconnection))
            .await
            .unwrap();

        // Read message and verify error handling
        let result = input.read().await;
        assert!(result.is_err());
        match result {
            Err(Error::Disconnection) => {}
            _ => panic!("Expected Disconnection error"),
        }

        // Close connection
        assert!(input.close().await.is_ok());
    }
}