use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use rumqttc::{AsyncClient, EventLoop, MqttOptions, QoS};
use tokio::sync::{RwLock, mpsc, oneshot};
use crate::error::ProtocolError;
use crate::protocol::TopicRouter;
use crate::protocol::response_collector::MqttMessage;
static BROKER_CLIENT_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
pub const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
pub struct MqttBrokerConfig {
host: String,
port: u16,
credentials: Option<(String, String)>,
keep_alive: Duration,
connection_timeout: Duration,
command_timeout: Duration,
}
impl Default for MqttBrokerConfig {
fn default() -> Self {
Self {
host: String::new(),
port: 1883,
credentials: None,
keep_alive: Duration::from_secs(30),
connection_timeout: Duration::from_secs(10),
command_timeout: DEFAULT_COMMAND_TIMEOUT,
}
}
}
pub(crate) struct DeviceSubscription {
pub response_tx: mpsc::Sender<MqttMessage>,
pub router: Arc<TopicRouter>,
}
#[derive(Clone)]
pub struct MqttBroker {
inner: Arc<MqttBrokerInner>,
}
struct MqttBrokerInner {
client: AsyncClient,
subscriptions: RwLock<HashMap<String, DeviceSubscription>>,
config: MqttBrokerConfig,
connected: AtomicBool,
initial_connection_done: AtomicBool,
discovery_tx: RwLock<Option<mpsc::Sender<String>>>,
}
impl MqttBroker {
#[must_use]
pub fn builder() -> MqttBrokerBuilder {
MqttBrokerBuilder::default()
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.inner.connected.load(Ordering::Acquire)
}
#[must_use]
pub fn host(&self) -> &str {
&self.inner.config.host
}
#[must_use]
pub fn port(&self) -> u16 {
self.inner.config.port
}
#[must_use]
pub fn has_credentials(&self) -> bool {
self.inner.config.credentials.is_some()
}
#[must_use]
pub fn command_timeout(&self) -> Duration {
self.inner.config.command_timeout
}
pub(crate) fn client(&self) -> &AsyncClient {
&self.inner.client
}
#[must_use]
pub fn device(&self, topic: impl Into<String>) -> crate::device::BrokerDeviceBuilder<'_> {
crate::device::BrokerDeviceBuilder::new(self, topic)
}
pub(crate) async fn add_device_subscription(
&self,
device_topic: String,
) -> Result<(mpsc::Receiver<MqttMessage>, Arc<TopicRouter>), ProtocolError> {
let stat_topic = format!("stat/{device_topic}/+");
self.inner
.client
.subscribe(&stat_topic, QoS::AtLeastOnce)
.await
.map_err(ProtocolError::Mqtt)?;
let tele_topic = format!("tele/{device_topic}/+");
self.inner
.client
.subscribe(&tele_topic, QoS::AtLeastOnce)
.await
.map_err(ProtocolError::Mqtt)?;
tracing::debug!(
stat = %stat_topic,
tele = %tele_topic,
"Subscribed to device topics"
);
let (response_tx, response_rx) = mpsc::channel::<MqttMessage>(20);
let router = Arc::new(TopicRouter::new());
let subscription = DeviceSubscription {
response_tx,
router: Arc::clone(&router),
};
self.inner
.subscriptions
.write()
.await
.insert(device_topic, subscription);
Ok((response_rx, router))
}
pub(crate) async fn remove_device_subscription(&self, device_topic: &str) {
self.inner.subscriptions.write().await.remove(device_topic);
let stat_topic = format!("stat/{device_topic}/+");
let tele_topic = format!("tele/{device_topic}/+");
if let Err(e) = self.inner.client.unsubscribe(&stat_topic).await {
tracing::warn!(topic = %stat_topic, error = %e, "Failed to unsubscribe from stat topic");
}
if let Err(e) = self.inner.client.unsubscribe(&tele_topic).await {
tracing::warn!(topic = %tele_topic, error = %e, "Failed to unsubscribe from tele topic");
}
tracing::debug!(
stat = %stat_topic,
tele = %tele_topic,
"Unsubscribed from device topics"
);
}
async fn route_message(&self, topic: &str, payload: String) {
let parts: Vec<&str> = topic.split('/').collect();
if parts.len() < 3 {
return;
}
let prefix = parts[0];
let device_topic = parts[1];
let suffix = parts[2];
if prefix != "stat" && prefix != "tele" {
return;
}
let is_discovery_topic = (prefix == "tele" && (suffix == "LWT" || suffix == "STATE"))
|| (prefix == "stat" && suffix == "STATUS");
if is_discovery_topic
&& let Some(discovery_tx) = self.inner.discovery_tx.read().await.as_ref()
{
tracing::debug!(
topic = %topic,
device = %device_topic,
"Discovered device topic"
);
let _ = discovery_tx.send(device_topic.to_string()).await;
}
let subscriptions = self.inner.subscriptions.read().await;
let Some(sub) = subscriptions.get(device_topic) else {
return;
};
sub.router.route(topic, &payload);
if prefix == "stat" {
let is_json_response = suffix == "RESULT" || suffix.starts_with("STATUS");
if is_json_response {
tracing::debug!(
topic = %topic,
device = %device_topic,
suffix = %suffix,
"Routing response to device"
);
let msg = MqttMessage::new(suffix.to_string(), payload);
let _ = sub.response_tx.send(msg).await;
}
}
}
async fn handle_reconnection(&self) {
let subscriptions = self.inner.subscriptions.read().await;
for (device_topic, subscription) in subscriptions.iter() {
let stat_topic = format!("stat/{device_topic}/+");
let tele_topic = format!("tele/{device_topic}/+");
if let Err(e) = self
.inner
.client
.subscribe(&stat_topic, QoS::AtLeastOnce)
.await
{
tracing::error!(topic = %stat_topic, error = %e, "Failed to resubscribe to stat topic");
}
if let Err(e) = self
.inner
.client
.subscribe(&tele_topic, QoS::AtLeastOnce)
.await
{
tracing::error!(topic = %tele_topic, error = %e, "Failed to resubscribe to tele topic");
}
tracing::debug!(
device = %device_topic,
"Resubscribed to device topics"
);
subscription.router.dispatch_reconnected_all();
}
tracing::info!(
device_count = subscriptions.len(),
"Reconnection complete, all devices notified"
);
}
async fn dispatch_disconnected_all(&self) {
let subscriptions = self.inner.subscriptions.read().await;
for (device_topic, subscription) in subscriptions.iter() {
tracing::debug!(device = %device_topic, "Notifying device of disconnection");
subscription.router.dispatch_disconnected_all();
}
}
pub async fn disconnect(&self) -> Result<(), ProtocolError> {
tracing::info!(
host = %self.inner.config.host,
port = %self.inner.config.port,
"Disconnecting from MQTT broker"
);
self.inner.subscriptions.write().await.clear();
self.inner
.client
.disconnect()
.await
.map_err(ProtocolError::Mqtt)?;
self.inner.connected.store(false, Ordering::Release);
Ok(())
}
pub async fn subscription_count(&self) -> usize {
self.inner.subscriptions.read().await.len()
}
pub(crate) async fn start_discovery(&self) -> mpsc::Receiver<String> {
let (tx, rx) = mpsc::channel::<String>(100);
*self.inner.discovery_tx.write().await = Some(tx);
rx
}
pub(crate) async fn stop_discovery(&self) {
*self.inner.discovery_tx.write().await = None;
}
}
impl std::fmt::Debug for MqttBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MqttBroker")
.field("host", &self.inner.config.host)
.field("port", &self.inner.config.port)
.field("connected", &self.is_connected())
.finish()
}
}
#[derive(Debug, Default)]
pub struct MqttBrokerBuilder {
config: MqttBrokerConfig,
}
impl MqttBrokerBuilder {
#[must_use]
pub fn host(mut self, host: impl Into<String>) -> Self {
self.config.host = host.into();
self
}
#[must_use]
pub fn port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
#[must_use]
pub fn credentials(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.config.credentials = Some((username.into(), password.into()));
self
}
#[must_use]
pub fn keep_alive(mut self, duration: Duration) -> Self {
self.config.keep_alive = duration;
self
}
#[must_use]
pub fn connection_timeout(mut self, duration: Duration) -> Self {
self.config.connection_timeout = duration;
self
}
#[must_use]
pub fn command_timeout(mut self, duration: Duration) -> Self {
self.config.command_timeout = duration;
self
}
pub async fn build(self) -> Result<MqttBroker, ProtocolError> {
if self.config.host.is_empty() {
return Err(ProtocolError::InvalidAddress(
"MQTT broker host is required".to_string(),
));
}
let counter = BROKER_CLIENT_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
let client_id = format!("tasmor_{}_{}", std::process::id(), counter);
let mut mqtt_options = MqttOptions::new(&client_id, &self.config.host, self.config.port);
mqtt_options.set_keep_alive(self.config.keep_alive);
mqtt_options.set_clean_session(true);
if let Some((ref username, ref password)) = self.config.credentials {
mqtt_options.set_credentials(username, password);
}
let (client, event_loop) = AsyncClient::new(mqtt_options, 10);
let inner = MqttBrokerInner {
client,
subscriptions: RwLock::new(HashMap::new()),
config: self.config.clone(),
connected: AtomicBool::new(false),
initial_connection_done: AtomicBool::new(false),
discovery_tx: RwLock::new(None),
};
let broker = MqttBroker {
inner: Arc::new(inner),
};
let broker_clone = broker.clone();
let (connack_tx, connack_rx) = oneshot::channel();
tokio::spawn(async move {
handle_broker_events(event_loop, broker_clone, Some(connack_tx)).await;
});
let timeout = self.config.connection_timeout;
match tokio::time::timeout(timeout, connack_rx).await {
Ok(Ok(())) => {
broker.inner.connected.store(true, Ordering::Release);
tracing::info!(
host = %self.config.host,
port = %self.config.port,
"Connected to MQTT broker"
);
}
Ok(Err(_)) => {
return Err(ProtocolError::ConnectionFailed(
"MQTT event loop terminated unexpectedly".to_string(),
));
}
Err(_) => {
return Err(ProtocolError::ConnectionFailed(format!(
"MQTT connection timeout after {}s",
timeout.as_secs()
)));
}
}
Ok(broker)
}
}
async fn handle_broker_events(
mut event_loop: EventLoop,
broker: MqttBroker,
connack_tx: Option<oneshot::Sender<()>>,
) {
use rumqttc::{Event, Packet};
let mut connack_tx = connack_tx;
loop {
match event_loop.poll().await {
Ok(Event::Incoming(Packet::ConnAck(connack))) => {
tracing::debug!(?connack, "MQTT broker connected");
broker.inner.connected.store(true, Ordering::Release);
if let Some(tx) = connack_tx.take() {
let _ = tx.send(());
}
if broker.inner.initial_connection_done.load(Ordering::Acquire) {
tracing::info!("MQTT broker reconnected, restoring subscriptions");
broker.handle_reconnection().await;
} else {
broker
.inner
.initial_connection_done
.store(true, Ordering::Release);
}
}
Ok(Event::Incoming(Packet::SubAck(suback))) => {
tracing::debug!(?suback, "MQTT subscription acknowledged");
}
Ok(Event::Incoming(Packet::Publish(publish))) => {
if let Ok(payload) = String::from_utf8(publish.payload.to_vec()) {
tracing::debug!(
topic = %publish.topic,
payload = %payload,
"MQTT message received"
);
broker.route_message(&publish.topic, payload).await;
}
}
Ok(Event::Incoming(Packet::Disconnect)) => {
tracing::info!("MQTT broker disconnected by server");
broker.inner.connected.store(false, Ordering::Release);
broker.dispatch_disconnected_all().await;
}
Ok(_) => {}
Err(e) => {
let was_connected = broker.inner.connected.swap(false, Ordering::AcqRel);
if was_connected {
tracing::warn!(error = %e, "MQTT connection lost, waiting for reconnection");
broker.dispatch_disconnected_all().await;
} else {
tracing::debug!(error = %e, "MQTT connection error during reconnection attempt");
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_default_values() {
let builder = MqttBrokerBuilder::default();
assert_eq!(builder.config.port, 1883);
assert!(builder.config.host.is_empty());
assert!(builder.config.credentials.is_none());
assert_eq!(builder.config.keep_alive, Duration::from_secs(30));
assert_eq!(builder.config.connection_timeout, Duration::from_secs(10));
}
#[test]
fn builder_with_host() {
let builder = MqttBrokerBuilder::default().host("192.168.1.50");
assert_eq!(builder.config.host, "192.168.1.50");
}
#[test]
fn builder_with_port() {
let builder = MqttBrokerBuilder::default().port(8883);
assert_eq!(builder.config.port, 8883);
}
#[test]
fn builder_with_credentials() {
let builder = MqttBrokerBuilder::default().credentials("user", "pass");
let creds = builder.config.credentials.unwrap();
assert_eq!(creds.0, "user");
assert_eq!(creds.1, "pass");
}
#[test]
fn builder_with_keep_alive() {
let builder = MqttBrokerBuilder::default().keep_alive(Duration::from_secs(60));
assert_eq!(builder.config.keep_alive, Duration::from_secs(60));
}
#[test]
fn builder_with_connection_timeout() {
let builder = MqttBrokerBuilder::default().connection_timeout(Duration::from_secs(5));
assert_eq!(builder.config.connection_timeout, Duration::from_secs(5));
}
#[test]
fn builder_with_command_timeout() {
let builder = MqttBrokerBuilder::default().command_timeout(Duration::from_secs(15));
assert_eq!(builder.config.command_timeout, Duration::from_secs(15));
}
#[test]
fn builder_default_command_timeout() {
let builder = MqttBrokerBuilder::default();
assert_eq!(builder.config.command_timeout, Duration::from_secs(5));
}
#[test]
fn builder_chain() {
let builder = MqttBrokerBuilder::default()
.host("192.168.1.50")
.port(8883)
.credentials("admin", "secret")
.keep_alive(Duration::from_secs(45))
.connection_timeout(Duration::from_secs(15))
.command_timeout(Duration::from_secs(10));
assert_eq!(builder.config.host, "192.168.1.50");
assert_eq!(builder.config.port, 8883);
assert!(builder.config.credentials.is_some());
assert_eq!(builder.config.keep_alive, Duration::from_secs(45));
assert_eq!(builder.config.connection_timeout, Duration::from_secs(15));
assert_eq!(builder.config.command_timeout, Duration::from_secs(10));
}
#[tokio::test]
async fn builder_missing_host_fails() {
let result = MqttBrokerBuilder::default().build().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ProtocolError::InvalidAddress(_)));
}
#[test]
fn config_default() {
let config = MqttBrokerConfig::default();
assert!(config.host.is_empty());
assert_eq!(config.port, 1883);
assert!(config.credentials.is_none());
}
}