use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use paho_mqtt::{AsyncClient, QoS};
use tokio::sync::{RwLock, mpsc};
use crate::error::ProtocolError;
use crate::protocol::TopicRouter;
use crate::protocol::response_collector::MqttMessage;
use super::builder::MqttBrokerBuilder;
use super::config::MqttBrokerConfig;
pub(crate) struct DeviceSubscription {
pub response_tx: mpsc::Sender<MqttMessage>,
pub router: Arc<TopicRouter>,
}
pub(super) struct MqttBrokerInner {
pub(super) client: AsyncClient,
pub(super) subscriptions: RwLock<HashMap<String, DeviceSubscription>>,
pub(super) config: MqttBrokerConfig,
pub(super) connected: AtomicBool,
pub(super) discovery_tx: RwLock<Option<mpsc::Sender<String>>>,
}
#[derive(Clone)]
pub struct MqttBroker {
pub(super) inner: Arc<MqttBrokerInner>,
}
impl MqttBroker {
#[must_use]
pub fn builder() -> MqttBrokerBuilder {
MqttBrokerBuilder::default()
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.inner.connected.load(Ordering::Acquire)
}
#[must_use]
pub fn host(&self) -> &str {
&self.inner.config.host
}
#[must_use]
pub fn port(&self) -> u16 {
self.inner.config.port
}
#[must_use]
pub fn has_credentials(&self) -> bool {
self.inner.config.credentials.is_some()
}
#[must_use]
pub fn command_timeout(&self) -> Duration {
self.inner.config.command_timeout
}
pub(crate) fn client(&self) -> &AsyncClient {
&self.inner.client
}
#[must_use]
pub fn device(&self, topic: impl Into<String>) -> crate::device::BrokerDeviceBuilder<'_> {
crate::device::BrokerDeviceBuilder::new(self, topic)
}
pub(crate) async fn add_device_subscription(
&self,
device_topic: String,
) -> Result<(mpsc::Receiver<MqttMessage>, Arc<TopicRouter>), ProtocolError> {
let stat_topic = format!("stat/{device_topic}/+");
self.inner
.client
.subscribe(&stat_topic, QoS::AtLeastOnce)
.await
.map_err(ProtocolError::from)?;
let tele_topic = format!("tele/{device_topic}/+");
self.inner
.client
.subscribe(&tele_topic, QoS::AtLeastOnce)
.await
.map_err(ProtocolError::from)?;
tracing::debug!(
stat = %stat_topic,
tele = %tele_topic,
"Subscribed to device topics"
);
let (response_tx, response_rx) = mpsc::channel::<MqttMessage>(20);
let router = Arc::new(TopicRouter::new());
let subscription = DeviceSubscription {
response_tx,
router: Arc::clone(&router),
};
self.inner
.subscriptions
.write()
.await
.insert(device_topic, subscription);
Ok((response_rx, router))
}
pub(crate) async fn remove_device_subscription(&self, device_topic: &str) {
self.inner.subscriptions.write().await.remove(device_topic);
let stat_topic = format!("stat/{device_topic}/+");
let tele_topic = format!("tele/{device_topic}/+");
if let Err(e) = self.inner.client.unsubscribe(&stat_topic).await {
tracing::warn!(topic = %stat_topic, error = %e, "Failed to unsubscribe from stat topic");
}
if let Err(e) = self.inner.client.unsubscribe(&tele_topic).await {
tracing::warn!(topic = %tele_topic, error = %e, "Failed to unsubscribe from tele topic");
}
tracing::debug!(
stat = %stat_topic,
tele = %tele_topic,
"Unsubscribed from device topics"
);
}
pub(super) async fn route_message(&self, topic: &str, payload: String) {
let parts: Vec<&str> = topic.split('/').collect();
if parts.len() < 3 {
return;
}
let prefix = parts[0];
let device_topic = parts[1];
let suffix = parts[2];
if prefix != "stat" && prefix != "tele" {
return;
}
let is_discovery_topic = (prefix == "tele" && (suffix == "LWT" || suffix == "STATE"))
|| (prefix == "stat" && suffix == "STATUS");
if is_discovery_topic
&& let Some(discovery_tx) = self.inner.discovery_tx.read().await.as_ref()
{
tracing::debug!(
topic = %topic,
device = %device_topic,
"Discovered device topic"
);
let _ = discovery_tx.send(device_topic.to_string()).await;
}
let subscriptions = self.inner.subscriptions.read().await;
let Some(sub) = subscriptions.get(device_topic) else {
return;
};
sub.router.route(topic, &payload);
if prefix == "stat" {
let is_json_response = suffix == "RESULT" || suffix.starts_with("STATUS");
if is_json_response {
tracing::debug!(
topic = %topic,
device = %device_topic,
suffix = %suffix,
"Routing response to device"
);
let msg = MqttMessage::new(suffix.to_string(), payload);
let _ = sub.response_tx.send(msg).await;
}
}
}
pub(super) async fn handle_reconnection(&self) {
let subscriptions = self.inner.subscriptions.read().await;
for (device_topic, subscription) in subscriptions.iter() {
let stat_topic = format!("stat/{device_topic}/+");
let tele_topic = format!("tele/{device_topic}/+");
if let Err(e) = self
.inner
.client
.subscribe(&stat_topic, QoS::AtLeastOnce)
.await
{
tracing::error!(topic = %stat_topic, error = %e, "Failed to resubscribe to stat topic");
}
if let Err(e) = self
.inner
.client
.subscribe(&tele_topic, QoS::AtLeastOnce)
.await
{
tracing::error!(topic = %tele_topic, error = %e, "Failed to resubscribe to tele topic");
}
tracing::debug!(device = %device_topic, "Resubscribed to device topics");
subscription.router.dispatch_reconnected_all();
}
tracing::info!(
device_count = subscriptions.len(),
"Reconnection complete, all devices notified"
);
}
pub(super) async fn dispatch_disconnected_all(&self) {
let subscriptions = self.inner.subscriptions.read().await;
for (device_topic, subscription) in subscriptions.iter() {
tracing::debug!(device = %device_topic, "Notifying device of disconnection");
subscription.router.dispatch_disconnected_all();
}
}
pub async fn disconnect(&self) -> Result<(), ProtocolError> {
tracing::info!(
host = %self.inner.config.host,
port = %self.inner.config.port,
"Disconnecting from MQTT broker"
);
self.inner.subscriptions.write().await.clear();
self.inner
.client
.disconnect(None)
.await
.map_err(ProtocolError::from)?;
self.inner.connected.store(false, Ordering::Release);
Ok(())
}
#[must_use]
pub async fn subscription_count(&self) -> usize {
self.inner.subscriptions.read().await.len()
}
pub(crate) async fn start_discovery(&self) -> mpsc::Receiver<String> {
let (tx, rx) = mpsc::channel::<String>(100);
*self.inner.discovery_tx.write().await = Some(tx);
rx
}
pub(crate) async fn stop_discovery(&self) {
*self.inner.discovery_tx.write().await = None;
}
pub(super) fn set_connected(&self, val: bool) {
self.inner.connected.store(val, Ordering::Release);
}
pub(super) fn swap_connected(&self, val: bool) -> bool {
self.inner.connected.swap(val, Ordering::AcqRel)
}
}
impl std::fmt::Debug for MqttBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MqttBroker")
.field("host", &self.inner.config.host)
.field("port", &self.inner.config.port)
.field("connected", &self.is_connected())
.finish()
}
}