use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use rumqttc::QoS;
use tokio::sync::{Mutex, mpsc};
use crate::command::Command;
use crate::error::ProtocolError;
use crate::protocol::response_collector::{MqttMessage, ResponseSpec, collect_responses};
use crate::protocol::{CommandResponse, Protocol};
use crate::subscription::CallbackRegistry;
use super::mqtt_broker::MqttBroker;
use super::topic_router::TopicRouter;
pub struct SharedMqttClient {
client: rumqttc::AsyncClient,
topic: String,
response_rx: Arc<Mutex<mpsc::Receiver<MqttMessage>>>,
router: Arc<TopicRouter>,
broker: MqttBroker,
disconnected: AtomicBool,
command_timeout: Duration,
}
impl SharedMqttClient {
pub(crate) fn new(
client: rumqttc::AsyncClient,
topic: String,
response_rx: mpsc::Receiver<MqttMessage>,
router: Arc<TopicRouter>,
broker: MqttBroker,
command_timeout: Duration,
) -> Self {
Self {
client,
topic,
response_rx: Arc::new(Mutex::new(response_rx)),
router,
broker,
disconnected: AtomicBool::new(false),
command_timeout,
}
}
#[must_use]
pub fn topic(&self) -> &str {
&self.topic
}
pub async fn disconnect(&self) {
if self.disconnected.swap(true, Ordering::SeqCst) {
return; }
self.broker.remove_device_subscription(&self.topic).await;
tracing::debug!(topic = %self.topic, "Device disconnected");
}
#[must_use]
pub fn is_disconnected(&self) -> bool {
self.disconnected.load(Ordering::SeqCst)
}
pub fn register_callbacks(&self, callbacks: &Arc<CallbackRegistry>) {
self.router.register(&self.topic, callbacks);
}
async fn publish_command(&self, command: &str, payload: &str) -> Result<(), ProtocolError> {
let topic = format!("cmnd/{}/{command}", self.topic);
tracing::debug!(topic = %topic, payload = %payload, "Publishing shared MQTT command");
self.client
.publish(&topic, QoS::AtLeastOnce, false, payload)
.await
.map_err(ProtocolError::Mqtt)
}
async fn drain_stale_responses(&self) {
let mut rx = self.response_rx.lock().await;
let mut count = 0;
while rx.try_recv().is_ok() {
count += 1;
}
if count > 0 {
tracing::debug!(count, "Drained stale MQTT responses");
}
}
async fn collect_command_responses(
&self,
spec: &ResponseSpec,
) -> Result<String, ProtocolError> {
let mut rx = self.response_rx.lock().await;
collect_responses(&mut rx, spec, self.command_timeout).await
}
}
impl Protocol for SharedMqttClient {
async fn send_command<C: Command + Sync>(
&self,
command: &C,
) -> Result<CommandResponse, ProtocolError> {
let cmd_name = command.mqtt_topic_suffix();
let payload = command.mqtt_payload();
let response_spec = command.response_spec();
self.drain_stale_responses().await;
self.publish_command(&cmd_name, &payload).await?;
let body = self.collect_command_responses(&response_spec).await?;
Ok(CommandResponse::new(body))
}
async fn send_raw(&self, command: &str) -> Result<CommandResponse, ProtocolError> {
let parts: Vec<&str> = command.splitn(2, ' ').collect();
let (cmd_name, payload) = match parts.as_slice() {
[name] => (*name, ""),
[name, payload] => (*name, *payload),
_ => {
return Err(ProtocolError::InvalidAddress(
"Invalid command format".to_string(),
));
}
};
self.drain_stale_responses().await;
self.publish_command(cmd_name, payload).await?;
let body = self
.collect_command_responses(&ResponseSpec::Single)
.await?;
Ok(CommandResponse::new(body))
}
}
impl Drop for SharedMqttClient {
fn drop(&mut self) {
if self.disconnected.load(Ordering::SeqCst) {
return; }
let topic = self.topic.clone();
let broker = self.broker.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
broker.remove_device_subscription(&topic).await;
tracing::debug!(topic = %topic, "Device cleanup via Drop");
});
} else {
tracing::warn!(
topic = %self.topic,
"No tokio runtime available for async cleanup in Drop"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shared_client_implements_protocol() {
fn assert_protocol<T: Protocol>() {}
assert_protocol::<SharedMqttClient>();
}
}