use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use rumqttc::QoS;
use tokio::sync::RwLock;
use crate::device::Device;
use crate::error::{Error, ProtocolError};
use crate::protocol::{MqttBroker, SharedMqttClient};
use crate::state::DeviceState;
const DEFAULT_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Default)]
pub struct DiscoveryOptions {
timeout: Option<Duration>,
credentials: Option<(String, String)>,
port: Option<u16>,
}
impl DiscoveryOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.credentials = Some((username.into(), password.into()));
self
}
#[must_use]
pub fn with_port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
#[must_use]
pub fn timeout(&self) -> Duration {
self.timeout.unwrap_or(DEFAULT_DISCOVERY_TIMEOUT)
}
#[must_use]
pub fn credentials(&self) -> Option<(&str, &str)> {
self.credentials
.as_ref()
.map(|(u, p)| (u.as_str(), p.as_str()))
}
#[must_use]
pub fn port(&self) -> u16 {
self.port.unwrap_or(1883)
}
}
impl MqttBroker {
pub async fn discover_devices(
&self,
timeout: Duration,
) -> Result<Vec<(Device<SharedMqttClient>, DeviceState)>, Error> {
tracing::info!(
host = %self.host(),
port = %self.port(),
timeout_secs = timeout.as_secs(),
"Starting MQTT device discovery"
);
self.client()
.subscribe("tele/+/LWT", QoS::AtMostOnce)
.await
.map_err(ProtocolError::Mqtt)?;
self.client()
.subscribe("tele/+/STATE", QoS::AtMostOnce)
.await
.map_err(ProtocolError::Mqtt)?;
self.client()
.subscribe("stat/+/STATUS", QoS::AtMostOnce)
.await
.map_err(ProtocolError::Mqtt)?;
tracing::debug!("Subscribed to discovery topics");
self.client()
.publish("cmnd/tasmotas/Status", QoS::AtMostOnce, false, "0")
.await
.map_err(ProtocolError::Mqtt)?;
tracing::debug!("Sent broadcast Status command to trigger device responses");
let topics = self.collect_device_topics(timeout).await;
let _ = self.client().unsubscribe("tele/+/LWT").await;
let _ = self.client().unsubscribe("tele/+/STATE").await;
let _ = self.client().unsubscribe("stat/+/STATUS").await;
tracing::info!(count = topics.len(), "Discovered device topics");
if topics.is_empty() {
return Ok(Vec::new());
}
let mut devices = Vec::with_capacity(topics.len());
for topic in topics {
tracing::debug!(topic = %topic, "Creating device for discovered topic");
match self.create_device_for_topic(&topic).await {
Ok(device_and_state) => {
tracing::info!(topic = %topic, "Successfully created device");
devices.push(device_and_state);
}
Err(e) => {
tracing::warn!(topic = %topic, error = %e, "Failed to create device, skipping");
}
}
}
tracing::info!(
discovered = devices.len(),
"MQTT device discovery completed"
);
Ok(devices)
}
async fn collect_device_topics(&self, timeout: Duration) -> HashSet<String> {
let mut discovery_rx = self.start_discovery().await;
let discovered_topics: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
let topics_clone = discovered_topics.clone();
let collector = tokio::spawn(async move {
while let Some(topic) = discovery_rx.recv().await {
topics_clone.write().await.insert(topic);
}
});
tokio::time::sleep(timeout).await;
self.stop_discovery().await;
collector.abort();
discovered_topics.read().await.clone()
}
async fn create_device_for_topic(
&self,
topic: &str,
) -> Result<(Device<SharedMqttClient>, DeviceState), Error> {
self.device(topic).build().await
}
}
pub async fn discover_devices(
host: &str,
options: Option<DiscoveryOptions>,
) -> Result<(MqttBroker, Vec<(Device<SharedMqttClient>, DeviceState)>), Error> {
let options = options.unwrap_or_default();
let mut builder = MqttBroker::builder().host(host).port(options.port());
if let Some((username, password)) = options.credentials() {
builder = builder.credentials(username, password);
}
let broker = builder.build().await?;
let devices = broker.discover_devices(options.timeout()).await?;
Ok((broker, devices))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn discovery_options_default() {
let options = DiscoveryOptions::default();
assert_eq!(options.timeout(), Duration::from_secs(5));
assert!(options.credentials().is_none());
assert_eq!(options.port(), 1883);
}
#[test]
fn discovery_options_new() {
let options = DiscoveryOptions::new();
assert_eq!(options.timeout(), Duration::from_secs(5));
}
#[test]
fn discovery_options_with_timeout() {
let options = DiscoveryOptions::new().with_timeout(Duration::from_secs(10));
assert_eq!(options.timeout(), Duration::from_secs(10));
}
#[test]
fn discovery_options_with_credentials() {
let options = DiscoveryOptions::new().with_credentials("user", "pass");
assert_eq!(options.credentials(), Some(("user", "pass")));
}
#[test]
fn discovery_options_with_port() {
let options = DiscoveryOptions::new().with_port(8883);
assert_eq!(options.port(), 8883);
}
#[test]
fn discovery_options_chained() {
let options = DiscoveryOptions::new()
.with_timeout(Duration::from_secs(15))
.with_credentials("mqtt_user", "mqtt_pass")
.with_port(1884);
assert_eq!(options.timeout(), Duration::from_secs(15));
assert_eq!(options.credentials(), Some(("mqtt_user", "mqtt_pass")));
assert_eq!(options.port(), 1884);
}
}