use std::collections::HashMap;
use std::sync::{Arc, Weak};
use parking_lot::RwLock;
use crate::state::StateChange;
use crate::subscription::CallbackRegistry;
use crate::telemetry::{SensorData, TelemetryState};
use crate::types::PowerState;
#[derive(Debug, Default)]
pub struct TopicRouter {
subscribers: RwLock<HashMap<String, Weak<CallbackRegistry>>>,
}
impl TopicRouter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, device_topic: impl Into<String>, callbacks: &Arc<CallbackRegistry>) {
let topic = device_topic.into();
tracing::debug!(topic = %topic, "Registering device for routing");
self.subscribers
.write()
.insert(topic, Arc::downgrade(callbacks));
}
pub fn unregister(&self, device_topic: &str) -> bool {
tracing::debug!(topic = %device_topic, "Unregistering device from routing");
self.subscribers.write().remove(device_topic).is_some()
}
pub fn route(&self, topic: &str, payload: &str) -> bool {
let Some(parsed) = ParsedTopic::parse(topic) else {
tracing::trace!(topic = %topic, "Ignoring unparseable topic");
return false;
};
let callbacks = {
let subscribers = self.subscribers.read();
subscribers.get(parsed.device_topic).and_then(Weak::upgrade)
};
let Some(callbacks) = callbacks else {
tracing::trace!(
topic = %topic,
device = %parsed.device_topic,
"No registered device for topic"
);
return false;
};
dispatch_message(&callbacks, &parsed, payload);
true
}
pub fn cleanup(&self) {
self.subscribers.write().retain(|topic, weak| {
let alive = weak.strong_count() > 0;
if !alive {
tracing::debug!(topic = %topic, "Cleaning up dropped device");
}
alive
});
}
#[must_use]
pub fn device_count(&self) -> usize {
self.subscribers.read().len()
}
#[must_use]
pub fn active_device_count(&self) -> usize {
self.subscribers
.read()
.values()
.filter(|weak| weak.strong_count() > 0)
.count()
}
pub fn dispatch_reconnected_all(&self) {
let subscribers = self.subscribers.read();
for (topic, weak) in subscribers.iter() {
if let Some(callbacks) = weak.upgrade() {
tracing::debug!(device = %topic, "Dispatching reconnected event");
callbacks.dispatch_reconnected();
}
}
}
pub fn dispatch_disconnected_all(&self) {
let subscribers = self.subscribers.read();
for (topic, weak) in subscribers.iter() {
if let Some(callbacks) = weak.upgrade() {
tracing::debug!(device = %topic, "Dispatching disconnected event");
callbacks.dispatch_disconnected();
}
}
}
}
fn dispatch_message(callbacks: &CallbackRegistry, parsed: &ParsedTopic<'_>, payload: &str) {
match (parsed.prefix, parsed.subtopic) {
("stat", subtopic) if subtopic.starts_with("POWER") => {
if let Some(change) = parse_power_topic(subtopic, payload) {
tracing::debug!(
device = %parsed.device_topic,
subtopic = %subtopic,
payload = %payload,
"Dispatching power change"
);
callbacks.dispatch(&change);
}
}
("stat", "RESULT") => {
if let Some(changes) = parse_result_payload(payload) {
tracing::debug!(
device = %parsed.device_topic,
payload = %payload,
"Dispatching result changes"
);
for change in changes {
callbacks.dispatch(&change);
}
}
}
("tele", "STATE") => {
if let Ok(state) = serde_json::from_str::<TelemetryState>(payload) {
let changes = state.to_state_changes();
tracing::debug!(
device = %parsed.device_topic,
change_count = changes.len(),
"Dispatching telemetry state changes"
);
for change in changes {
callbacks.dispatch(&change);
}
}
}
("tele", "SENSOR") => {
if let Ok(sensor) = serde_json::from_str::<SensorData>(payload) {
let changes = sensor.to_state_changes();
if !changes.is_empty() {
tracing::debug!(
device = %parsed.device_topic,
change_count = changes.len(),
"Dispatching sensor state changes"
);
for change in changes {
callbacks.dispatch(&change);
}
}
}
}
("tele", "LWT") => {
match payload {
"Online" => {
tracing::debug!(device = %parsed.device_topic, "Device came online");
}
"Offline" => {
tracing::debug!(device = %parsed.device_topic, "Device went offline");
callbacks.dispatch_disconnected();
}
_ => {}
}
}
_ => {
tracing::trace!(
device = %parsed.device_topic,
prefix = %parsed.prefix,
subtopic = %parsed.subtopic,
"Ignoring unhandled topic type"
);
}
}
}
#[derive(Debug)]
struct ParsedTopic<'a> {
prefix: &'a str,
device_topic: &'a str,
subtopic: &'a str,
}
impl<'a> ParsedTopic<'a> {
fn parse(topic: &'a str) -> Option<Self> {
let parts: Vec<&str> = topic.split('/').collect();
if parts.len() >= 3 {
Some(Self {
prefix: parts[0],
device_topic: parts[1],
subtopic: parts[2],
})
} else {
None
}
}
}
fn parse_power_topic(subtopic: &str, payload: &str) -> Option<StateChange> {
let index = if subtopic == "POWER" {
1
} else {
subtopic.strip_prefix("POWER")?.parse().ok()?
};
let state = payload.parse::<PowerState>().ok()?;
Some(StateChange::Power { index, state })
}
fn parse_result_payload(payload: &str) -> Option<Vec<StateChange>> {
let state: TelemetryState = serde_json::from_str(payload).ok()?;
let changes = state.to_state_changes();
if changes.is_empty() {
None
} else {
Some(changes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn parse_topic_valid() {
let parsed = ParsedTopic::parse("stat/tasmota_bedroom/POWER").unwrap();
assert_eq!(parsed.prefix, "stat");
assert_eq!(parsed.device_topic, "tasmota_bedroom");
assert_eq!(parsed.subtopic, "POWER");
}
#[test]
fn parse_topic_tele() {
let parsed = ParsedTopic::parse("tele/living_room/STATE").unwrap();
assert_eq!(parsed.prefix, "tele");
assert_eq!(parsed.device_topic, "living_room");
assert_eq!(parsed.subtopic, "STATE");
}
#[test]
fn parse_topic_invalid() {
assert!(ParsedTopic::parse("invalid").is_none());
assert!(ParsedTopic::parse("only/two").is_none());
}
#[test]
fn parse_power_topic_simple() {
let change = parse_power_topic("POWER", "ON").unwrap();
assert!(matches!(
change,
StateChange::Power {
index: 1,
state: PowerState::On
}
));
}
#[test]
fn parse_power_topic_indexed() {
let change = parse_power_topic("POWER3", "OFF").unwrap();
assert!(matches!(
change,
StateChange::Power {
index: 3,
state: PowerState::Off
}
));
}
#[test]
fn parse_power_topic_invalid() {
assert!(parse_power_topic("POWER", "INVALID").is_none());
assert!(parse_power_topic("INVALID", "ON").is_none());
}
#[test]
fn router_register_and_route() {
let router = TopicRouter::new();
let callbacks = Arc::new(CallbackRegistry::new());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
callbacks.on_power_changed(move |_idx, _state| {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
router.register("bedroom", &callbacks);
assert_eq!(router.device_count(), 1);
let routed = router.route("stat/bedroom/POWER", "ON");
assert!(routed);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn router_unregistered_device() {
let router = TopicRouter::new();
let routed = router.route("stat/unknown/POWER", "ON");
assert!(!routed);
}
#[test]
fn router_unregister() {
let router = TopicRouter::new();
let callbacks = Arc::new(CallbackRegistry::new());
router.register("bedroom", &callbacks);
assert_eq!(router.device_count(), 1);
let removed = router.unregister("bedroom");
assert!(removed);
assert_eq!(router.device_count(), 0);
let routed = router.route("stat/bedroom/POWER", "ON");
assert!(!routed);
}
#[test]
fn router_cleanup_dropped_device() {
let router = TopicRouter::new();
{
let callbacks = Arc::new(CallbackRegistry::new());
router.register("temporary", &callbacks);
assert_eq!(router.active_device_count(), 1);
}
assert_eq!(router.device_count(), 1);
assert_eq!(router.active_device_count(), 0);
router.cleanup();
assert_eq!(router.device_count(), 0);
}
#[test]
fn router_route_telemetry_state() {
let router = TopicRouter::new();
let callbacks = Arc::new(CallbackRegistry::new());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
callbacks.on_state_changed(move |_change| {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
router.register("living_room", &callbacks);
let payload = r#"{"POWER":"ON","Dimmer":75}"#;
let routed = router.route("tele/living_room/STATE", payload);
assert!(routed);
assert!(counter.load(Ordering::SeqCst) >= 1);
}
#[test]
fn router_route_lwt_offline() {
let router = TopicRouter::new();
let callbacks = Arc::new(CallbackRegistry::new());
let disconnected = Arc::new(AtomicU32::new(0));
let disconnected_clone = disconnected.clone();
callbacks.on_disconnected(move || {
disconnected_clone.fetch_add(1, Ordering::SeqCst);
});
router.register("device", &callbacks);
let routed = router.route("tele/device/LWT", "Offline");
assert!(routed);
assert_eq!(disconnected.load(Ordering::SeqCst), 1);
}
#[test]
fn router_multiple_devices() {
let router = TopicRouter::new();
let callbacks1 = Arc::new(CallbackRegistry::new());
let counter1 = Arc::new(AtomicU32::new(0));
let c1 = counter1.clone();
callbacks1.on_power_changed(move |_, _| {
c1.fetch_add(1, Ordering::SeqCst);
});
let callbacks2 = Arc::new(CallbackRegistry::new());
let counter2 = Arc::new(AtomicU32::new(0));
let c2 = counter2.clone();
callbacks2.on_power_changed(move |_, _| {
c2.fetch_add(1, Ordering::SeqCst);
});
router.register("device1", &callbacks1);
router.register("device2", &callbacks2);
router.route("stat/device1/POWER", "ON");
assert_eq!(counter1.load(Ordering::SeqCst), 1);
assert_eq!(counter2.load(Ordering::SeqCst), 0);
router.route("stat/device2/POWER", "OFF");
assert_eq!(counter1.load(Ordering::SeqCst), 1);
assert_eq!(counter2.load(Ordering::SeqCst), 1);
}
#[test]
fn router_replace_registration() {
let router = TopicRouter::new();
let callbacks1 = Arc::new(CallbackRegistry::new());
let counter1 = Arc::new(AtomicU32::new(0));
let c1 = counter1.clone();
callbacks1.on_power_changed(move |_, _| {
c1.fetch_add(1, Ordering::SeqCst);
});
let callbacks2 = Arc::new(CallbackRegistry::new());
let counter2 = Arc::new(AtomicU32::new(0));
let c2 = counter2.clone();
callbacks2.on_power_changed(move |_, _| {
c2.fetch_add(1, Ordering::SeqCst);
});
router.register("device", &callbacks1);
router.register("device", &callbacks2);
router.route("stat/device/POWER", "ON");
assert_eq!(counter1.load(Ordering::SeqCst), 0);
assert_eq!(counter2.load(Ordering::SeqCst), 1);
}
}