use std::collections::HashMap;
use asyn_rs::error::{AsynError, AsynResult};
use asyn_rs::param::ParamType;
use asyn_rs::port::{PortDriver, PortDriverBase, PortFlags};
use asyn_rs::user::AsynUser;
use tokio::sync::mpsc;
use crate::address::TopicAddress;
use crate::config::{MqttConfig, QoS};
use crate::error::MqttError;
use crate::payload::{DecodedValue, encode_payload};
#[derive(Debug, Clone)]
pub struct PublishRequest {
pub topic: String,
pub payload: String,
pub qos: QoS,
pub retained: bool,
}
pub const PARAM_CONNECTED: &str = "_MQTT_CONNECTED";
pub struct MqttDriver {
base: PortDriverBase,
registry: HashMap<String, (usize, TopicAddress)>,
topic_map: HashMap<String, Vec<(usize, TopicAddress)>>,
reason_to_addr: Vec<Option<TopicAddress>>,
publish_tx: mpsc::UnboundedSender<PublishRequest>,
default_qos: QoS,
pub connected_param: usize,
}
impl MqttDriver {
pub fn new(
port_name: &str,
config: &MqttConfig,
topics: Vec<TopicAddress>,
publish_tx: mpsc::UnboundedSender<PublishRequest>,
) -> Self {
let flags = PortFlags {
can_block: true,
..PortFlags::default()
};
let mut base = PortDriverBase::new(port_name, 1, flags);
let mut registry = HashMap::new();
let mut topic_map: HashMap<String, Vec<(usize, TopicAddress)>> = HashMap::new();
let mut reason_to_addr = Vec::new();
let connected_param = base
.create_param(PARAM_CONNECTED, ParamType::Int32)
.expect("failed to create connected param");
base.set_int32_param(connected_param, 0, 0).unwrap();
for addr in topics {
let drv_info = addr.to_drv_info();
let param_type = addr.param_type();
let idx = base
.create_param(&drv_info, param_type)
.expect("failed to create param");
if reason_to_addr.len() <= idx {
reason_to_addr.resize_with(idx + 1, || None);
}
reason_to_addr[idx] = Some(addr.clone());
topic_map
.entry(addr.topic.clone())
.or_default()
.push((idx, addr.clone()));
registry.insert(drv_info, (idx, addr));
}
Self {
base,
registry,
topic_map,
reason_to_addr,
publish_tx,
default_qos: config.qos,
connected_param,
}
}
pub fn subscribed_topics(&self) -> Vec<String> {
self.topic_map.keys().cloned().collect()
}
pub fn topic_map(&self) -> &HashMap<String, Vec<(usize, TopicAddress)>> {
&self.topic_map
}
fn publish_value(&self, reason: usize, value: &DecodedValue) -> AsynResult<()> {
let addr = self
.reason_to_addr
.get(reason)
.and_then(|a| a.as_ref())
.ok_or_else(|| AsynError::ParamNotFound(format!("reason {reason}")))?;
let payload = encode_payload(value, addr);
self.publish_tx
.send(PublishRequest {
topic: addr.topic.clone(),
payload,
qos: self.default_qos,
retained: false,
})
.map_err(|_| MqttError::PublishChannelClosed)?;
Ok(())
}
}
impl PortDriver for MqttDriver {
fn base(&self) -> &PortDriverBase {
&self.base
}
fn base_mut(&mut self) -> &mut PortDriverBase {
&mut self.base
}
fn drv_user_create(&self, drv_info: &str) -> AsynResult<usize> {
if let Some((idx, _)) = self.registry.get(drv_info) {
return Ok(*idx);
}
self.base()
.params
.find_param(drv_info)
.ok_or_else(|| AsynError::ParamNotFound(drv_info.to_string()))
}
fn write_int32(&mut self, user: &mut AsynUser, value: i32) -> AsynResult<()> {
self.publish_value(user.reason, &DecodedValue::Int32(value))?;
self.base.params.set_int32(user.reason, user.addr, value)?;
self.base.call_param_callbacks(user.addr)
}
fn write_float64(&mut self, user: &mut AsynUser, value: f64) -> AsynResult<()> {
self.publish_value(user.reason, &DecodedValue::Float64(value))?;
self.base
.params
.set_float64(user.reason, user.addr, value)?;
self.base.call_param_callbacks(user.addr)
}
fn write_octet(&mut self, user: &mut AsynUser, data: &[u8]) -> AsynResult<()> {
let s = String::from_utf8_lossy(data).into_owned();
self.publish_value(user.reason, &DecodedValue::String(s.clone()))?;
self.base.params.set_string(user.reason, user.addr, s)?;
self.base.call_param_callbacks(user.addr)
}
fn write_uint32_digital(
&mut self,
user: &mut AsynUser,
value: u32,
mask: u32,
) -> AsynResult<()> {
if mask != 0xFFFF_FFFF {
self.base.params.get_uint32_strict(user.reason, user.addr)?;
}
self.base
.params
.set_uint32(user.reason, user.addr, value, mask, 0)?;
let full_val = self
.base
.params
.get_uint32(user.reason, user.addr)
.unwrap_or(value & mask);
self.publish_value(user.reason, &DecodedValue::UInt32(full_val))?;
self.base.call_param_callbacks(user.addr)
}
fn write_int32_array(&mut self, user: &AsynUser, data: &[i32]) -> AsynResult<()> {
self.publish_value(user.reason, &DecodedValue::Int32Array(data.to_vec()))?;
self.base
.params
.set_int32_array(user.reason, user.addr, data.to_vec())?;
self.base.call_param_callbacks(user.addr)
}
fn read_int32_array(&mut self, user: &AsynUser, buf: &mut [i32]) -> AsynResult<usize> {
let data = self.base.params.get_int32_array(user.reason, user.addr)?;
let n = data.len().min(buf.len());
buf[..n].copy_from_slice(&data[..n]);
Ok(n)
}
fn write_float64_array(&mut self, user: &AsynUser, data: &[f64]) -> AsynResult<()> {
self.publish_value(user.reason, &DecodedValue::Float64Array(data.to_vec()))?;
self.base
.params
.set_float64_array(user.reason, user.addr, data.to_vec())?;
self.base.call_param_callbacks(user.addr)
}
fn read_float64_array(&mut self, user: &AsynUser, buf: &mut [f64]) -> AsynResult<usize> {
let data = self.base.params.get_float64_array(user.reason, user.addr)?;
let n = data.len().min(buf.len());
buf[..n].copy_from_slice(&data[..n]);
Ok(n)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_driver(topics: &[&str]) -> (MqttDriver, mpsc::UnboundedReceiver<PublishRequest>) {
let (tx, rx) = mpsc::unbounded_channel();
let config = MqttConfig::default();
let addrs: Vec<TopicAddress> = topics
.iter()
.map(|s| TopicAddress::parse(s).unwrap())
.collect();
let driver = MqttDriver::new("TEST", &config, addrs, tx);
(driver, rx)
}
#[test]
fn drv_user_create_finds_registered_topics() {
let (driver, _rx) = make_driver(&[
"FLAT:INT test/int_topic",
"FLAT:FLOAT test/float_topic",
"JSON:FLOAT sensors/data humidity",
]);
assert!(driver.drv_user_create("FLAT:INT test/int_topic").is_ok());
assert!(
driver
.drv_user_create("FLAT:FLOAT test/float_topic")
.is_ok()
);
assert!(
driver
.drv_user_create("JSON:FLOAT sensors/data humidity")
.is_ok()
);
}
#[test]
fn drv_user_create_rejects_unknown() {
let (driver, _rx) = make_driver(&["FLAT:INT test/topic"]);
assert!(driver.drv_user_create("FLAT:FLOAT other/topic").is_err());
}
#[test]
fn subscribed_topics_returns_unique_mqtt_topics() {
let (driver, _rx) = make_driver(&[
"FLAT:INT test/topic",
"FLAT:FLOAT test/topic",
"FLAT:STRING other/topic",
]);
let mut topics = driver.subscribed_topics();
topics.sort();
assert_eq!(topics, vec!["other/topic", "test/topic"]);
}
#[test]
fn write_int32_sends_publish_request() {
let (mut driver, mut rx) = make_driver(&["FLAT:INT test/int_topic"]);
let reason = driver.drv_user_create("FLAT:INT test/int_topic").unwrap();
let mut user = AsynUser::new(reason);
driver.write_int32(&mut user, 42).unwrap();
let req = rx.try_recv().unwrap();
assert_eq!(req.topic, "test/int_topic");
assert_eq!(req.payload, "42");
}
#[test]
fn write_float64_sends_publish_request() {
let (mut driver, mut rx) = make_driver(&["FLAT:FLOAT test/float_topic"]);
let reason = driver
.drv_user_create("FLAT:FLOAT test/float_topic")
.unwrap();
let mut user = AsynUser::new(reason);
driver.write_float64(&mut user, 3.15).unwrap();
let req = rx.try_recv().unwrap();
assert_eq!(req.topic, "test/float_topic");
assert_eq!(req.payload, "3.150000");
}
#[test]
fn write_octet_sends_publish_request() {
let (mut driver, mut rx) = make_driver(&["FLAT:STRING test/str_topic"]);
let reason = driver
.drv_user_create("FLAT:STRING test/str_topic")
.unwrap();
let mut user = AsynUser::new(reason);
driver.write_octet(&mut user, b"hello").unwrap();
let req = rx.try_recv().unwrap();
assert_eq!(req.topic, "test/str_topic");
assert_eq!(req.payload, "hello");
}
#[test]
fn masked_digital_write_rejected_when_value_uninitialized() {
let (mut driver, mut rx) = make_driver(&["FLAT:DIGITAL test/bits"]);
let reason = driver.drv_user_create("FLAT:DIGITAL test/bits").unwrap();
let mut user = AsynUser::new(reason);
let r = driver.write_uint32_digital(&mut user, 0x0005, 0x000f);
assert!(
matches!(r, Err(AsynError::ParamUndefined(_))),
"masked write on uninitialized value must be rejected, got {r:?}"
);
assert!(
rx.try_recv().is_err(),
"a rejected masked write must not publish"
);
}
#[test]
fn masked_digital_write_merges_after_current_value_known() {
let (mut driver, mut rx) = make_driver(&["FLAT:DIGITAL test/bits"]);
let reason = driver.drv_user_create("FLAT:DIGITAL test/bits").unwrap();
let mut user = AsynUser::new(reason);
driver
.write_uint32_digital(&mut user, 0x00f0, 0xffff_ffff)
.unwrap();
let req = rx.try_recv().unwrap();
assert_eq!(req.topic, "test/bits");
assert_eq!(req.payload, "240");
driver
.write_uint32_digital(&mut user, 0x0005, 0x000f)
.unwrap();
let req = rx.try_recv().unwrap();
assert_eq!(req.payload, "245");
}
#[test]
fn topic_map_groups_by_mqtt_topic() {
let (driver, _rx) = make_driver(&[
"FLAT:INT test/shared",
"FLAT:FLOAT test/shared",
"FLAT:STRING test/other",
]);
assert_eq!(driver.topic_map()["test/shared"].len(), 2);
assert_eq!(driver.topic_map()["test/other"].len(), 1);
}
}