use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder};
use arkflow_core::{Error, MessageBatch};
use async_trait::async_trait;
use flume::{Receiver, Sender};
use rumqttc::{AsyncClient, Event, MqttOptions, Packet, Publish, QoS};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, Mutex};
use tracing::error;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MqttInputConfig {
pub host: String,
pub port: u16,
pub client_id: String,
pub username: Option<String>,
pub password: Option<String>,
pub topics: Vec<String>,
pub qos: Option<u8>,
pub clean_session: Option<bool>,
pub keep_alive: Option<u64>,
}
pub struct MqttInput {
config: MqttInputConfig,
client: Arc<Mutex<Option<AsyncClient>>>,
sender: Arc<Sender<MqttMsg>>,
receiver: Arc<Receiver<MqttMsg>>,
close_tx: broadcast::Sender<()>,
}
enum MqttMsg {
Publish(Publish),
Err(Error),
}
impl MqttInput {
pub fn new(config: MqttInputConfig) -> Result<Self, Error> {
let (sender, receiver) = flume::bounded::<MqttMsg>(1000);
let (close_tx, _) = broadcast::channel(1);
Ok(Self {
config: config.clone(),
client: Arc::new(Mutex::new(None)),
sender: Arc::new(sender),
receiver: Arc::new(receiver),
close_tx,
})
}
}
#[async_trait]
impl Input for MqttInput {
async fn connect(&self) -> Result<(), Error> {
let mut mqtt_options =
MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
mqtt_options.set_manual_acks(true);
if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
mqtt_options.set_credentials(username, password);
}
if let Some(keep_alive) = self.config.keep_alive {
mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
}
if let Some(clean_session) = self.config.clean_session {
mqtt_options.set_clean_session(clean_session);
}
let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
let qos_level = match self.config.qos {
Some(0) => QoS::AtMostOnce,
Some(1) => QoS::AtLeastOnce,
Some(2) => QoS::ExactlyOnce,
_ => QoS::AtLeastOnce, };
for topic in &self.config.topics {
client.subscribe(topic, qos_level).await.map_err(|e| {
Error::Connection(format!(
"Unable to subscribe to MQTT topics {}: {}",
topic, e
))
})?;
}
let client_arc = self.client.clone();
let mut client_guard = client_arc.lock().await;
*client_guard = Some(client);
let sender_arc = self.sender.clone();
let mut rx = self.close_tx.subscribe();
tokio::spawn(async move {
loop {
tokio::select! {
result = eventloop.poll() => {
match result {
Ok(event) => {
if let Event::Incoming(Packet::Publish(publish)) = event {
match sender_arc.send_async(MqttMsg::Publish(publish)).await {
Ok(_) => {}
Err(e) => {
error!("{}",e)
}
};
}
}
Err(e) => {
error!("MQTT event loop error: {}", e);
match sender_arc.send_async(MqttMsg::Err(Error::Disconnection)).await {
Ok(_) => {}
Err(e) => {
error!("{}",e)
}
};
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
_ = rx.recv() => {
break;
}
}
}
});
Ok(())
}
async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
{
let client_arc = self.client.clone();
if client_arc.lock().await.is_none() {
return Err(Error::Disconnection);
}
}
let mut close_rx = self.close_tx.subscribe();
tokio::select! {
result = self.receiver.recv_async() =>{
match result {
Ok(msg) => {
match msg{
MqttMsg::Publish(publish) => {
let payload = publish.payload.to_vec();
let msg = MessageBatch::new_binary(vec![payload]);
Ok((msg, Arc::new(MqttAck {
client: self.client.clone(),
publish,
})))
},
MqttMsg::Err(e) => {
Err(e)
}
}
}
Err(_) => {
Err(Error::EOF)
}
}
},
_ = close_rx.recv()=>{
Err(Error::EOF)
}
}
}
async fn close(&self) -> Result<(), Error> {
let _ = self.close_tx.send(());
let client_arc = self.client.clone();
let client_guard = client_arc.lock().await;
if let Some(client) = &*client_guard {
let _ = client.disconnect().await;
}
Ok(())
}
}
struct MqttAck {
client: Arc<Mutex<Option<AsyncClient>>>,
publish: Publish,
}
#[async_trait]
impl Ack for MqttAck {
async fn ack(&self) {
let mutex_guard = self.client.lock().await;
if let Some(client) = &*mutex_guard {
if let Err(e) = client.ack(&self.publish).await {
error!("{}", e);
}
}
}
}
pub(crate) struct MqttInputBuilder;
impl InputBuilder for MqttInputBuilder {
fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
if config.is_none() {
return Err(Error::Config(
"MQTT input configuration is missing".to_string(),
));
}
let config: MqttInputConfig = serde_json::from_value(config.clone().unwrap())?;
Ok(Arc::new(MqttInput::new(config)?))
}
}
pub fn init() {
register_input_builder("mqtt", Arc::new(MqttInputBuilder));
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mqtt_input_new() {
let config = MqttInputConfig {
host: "localhost".to_string(),
port: 1883,
client_id: "test-client".to_string(),
username: Some("user".to_string()),
password: Some("pass".to_string()),
topics: vec!["test/topic".to_string()],
qos: Some(1),
clean_session: Some(true),
keep_alive: Some(60),
};
let input = MqttInput::new(config);
assert!(input.is_ok());
let input = input.unwrap();
assert_eq!(input.config.host, "localhost");
assert_eq!(input.config.port, 1883);
assert_eq!(input.config.client_id, "test-client");
assert_eq!(input.config.username, Some("user".to_string()));
assert_eq!(input.config.password, Some("pass".to_string()));
assert_eq!(input.config.topics, vec!["test/topic".to_string()]);
assert_eq!(input.config.qos, Some(1));
assert_eq!(input.config.clean_session, Some(true));
assert_eq!(input.config.keep_alive, Some(60));
}
#[tokio::test]
async fn test_mqtt_input_read_not_connected() {
let config = MqttInputConfig {
host: "localhost".to_string(),
port: 1883,
client_id: "test-client".to_string(),
username: None,
password: None,
topics: vec!["test/topic".to_string()],
qos: None,
clean_session: None,
keep_alive: None,
};
let input = MqttInput::new(config).unwrap();
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::Disconnection) => {}
_ => panic!("Expected Disconnection error"),
}
}
#[tokio::test]
async fn test_mqtt_input_close() {
let config = MqttInputConfig {
host: "localhost".to_string(),
port: 1883,
client_id: "test-client".to_string(),
username: None,
password: None,
topics: vec!["test/topic".to_string()],
qos: None,
clean_session: None,
keep_alive: None,
};
let input = MqttInput::new(config).unwrap();
let result = input.close().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_mqtt_input_message_processing() {
let config = MqttInputConfig {
host: "localhost".to_string(),
port: 1883,
client_id: "test-client".to_string(),
username: None,
password: None,
topics: vec!["test/topic".to_string()],
qos: None,
clean_session: None,
keep_alive: None,
};
let input = MqttInput::new(config).unwrap();
let test_payload = "test message".as_bytes().to_vec();
let publish = Publish {
dup: false,
qos: QoS::AtLeastOnce,
retain: false,
topic: "test/topic".to_string(),
pkid: 1,
payload: test_payload.into(),
};
input
.sender
.send_async(MqttMsg::Publish(publish))
.await
.unwrap();
let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
input.client.lock().await.replace(client);
let result = input.read().await;
assert!(result.is_ok());
let (msg, ack) = result.unwrap();
let content = msg.as_string().unwrap();
assert_eq!(content, vec!["test message"]);
ack.ack().await;
assert!(input.close().await.is_ok());
}
#[tokio::test]
async fn test_mqtt_input_error_handling() {
let config = MqttInputConfig {
host: "localhost".to_string(),
port: 1883,
client_id: "test-client".to_string(),
username: None,
password: None,
topics: vec!["test/topic".to_string()],
qos: None,
clean_session: None,
keep_alive: None,
};
let input = MqttInput::new(config).unwrap();
let client = AsyncClient::new(MqttOptions::new("test-client", "localhost", 1883), 10).0;
input.client.lock().await.replace(client);
input
.sender
.send_async(MqttMsg::Err(Error::Disconnection))
.await
.unwrap();
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::Disconnection) => {}
_ => panic!("Expected Disconnection error"),
}
assert!(input.close().await.is_ok());
}
}