use crate::{output::Output, Error, MessageBatch};
use async_trait::async_trait;
use rumqttc::{AsyncClient, MqttOptions, QoS};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MqttOutputConfig {
pub host: String,
pub port: u16,
pub client_id: String,
pub username: Option<String>,
pub password: Option<String>,
pub topic: String,
pub qos: Option<u8>,
pub clean_session: Option<bool>,
pub keep_alive: Option<u64>,
pub retain: Option<bool>,
}
pub struct MqttOutput {
config: MqttOutputConfig,
client: Arc<Mutex<Option<AsyncClient>>>,
connected: AtomicBool,
eventloop_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}
impl MqttOutput {
pub fn new(config: &MqttOutputConfig) -> Result<Self, Error> {
Ok(Self {
config: config.clone(),
client: Arc::new(Mutex::new(None)),
connected: AtomicBool::new(false),
eventloop_handle: Arc::new(Mutex::new(None)),
})
}
}
#[async_trait]
impl Output for MqttOutput {
async fn connect(&self) -> Result<(), Error> {
let mut mqtt_options =
MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
mqtt_options.set_credentials(username, password);
}
if let Some(keep_alive) = self.config.keep_alive {
mqtt_options.set_keep_alive(std::time::Duration::from_secs(keep_alive));
}
if let Some(clean_session) = self.config.clean_session {
mqtt_options.set_clean_session(clean_session);
}
let (client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
let client_arc = self.client.clone();
let mut client_guard = client_arc.lock().await;
*client_guard = Some(client);
let eventloop_handle = tokio::spawn(async move {
while let Ok(_) = eventloop.poll().await {
}
});
let eventloop_handle_arc = self.eventloop_handle.clone();
let mut eventloop_handle_guard = eventloop_handle_arc.lock().await;
*eventloop_handle_guard = Some(eventloop_handle);
self.connected.store(true, Ordering::SeqCst);
Ok(())
}
async fn write(&self, msg: &MessageBatch) -> Result<(), Error> {
if !self.connected.load(Ordering::SeqCst) {
return Err(Error::Connection("输出未连接".to_string()));
}
let client_arc = self.client.clone();
let client_guard = client_arc.lock().await;
let client = client_guard
.as_ref()
.ok_or_else(|| Error::Connection("MQTT客户端未初始化".to_string()))?;
let payloads = match msg.as_string() {
Ok(v) => v.to_vec(),
Err(e) => {
return Err(e);
}
};
for payload in payloads {
info!(
"Send message: {}",
&String::from_utf8_lossy((&payload).as_ref())
);
let qos_level = match self.config.qos {
Some(0) => QoS::AtMostOnce,
Some(1) => QoS::AtLeastOnce,
Some(2) => QoS::ExactlyOnce,
_ => QoS::AtLeastOnce, };
let retain = self.config.retain.unwrap_or(false);
client
.publish(&self.config.topic, qos_level, retain, payload)
.await
.map_err(|e| Error::Processing(format!("MQTT发布失败: {}", e)))?;
}
Ok(())
}
async fn close(&self) -> Result<(), Error> {
let mut eventloop_handle_guard = self.eventloop_handle.lock().await;
if let Some(handle) = eventloop_handle_guard.take() {
handle.abort();
}
let client_arc = self.client.clone();
let client_guard = client_arc.lock().await;
if let Some(client) = &*client_guard {
let _ = client.disconnect().await;
}
self.connected.store(false, Ordering::SeqCst);
Ok(())
}
}