mqtt-rs 0.16.2

MQTT driver for epics-rs — publish/subscribe MQTT topics as EPICS records
Documentation
use std::collections::HashMap;

use asyn_rs::error::{AsynError, AsynResult};
use asyn_rs::param::ParamType;
use asyn_rs::port::{PortDriver, PortDriverBase, PortFlags};
use asyn_rs::user::AsynUser;
use tokio::sync::mpsc;

use crate::address::TopicAddress;
use crate::config::{MqttConfig, QoS};
use crate::error::MqttError;
use crate::payload::{DecodedValue, encode_payload};

/// Request to publish a message to the MQTT broker.
#[derive(Debug, Clone)]
pub struct PublishRequest {
    pub topic: String,
    pub payload: String,
    pub qos: QoS,
    pub retained: bool,
}

/// MQTT PortDriver implementation.
///
/// Maps MQTT topics to asyn parameters. Incoming MQTT messages update the param
/// cache and fire I/O Intr callbacks. EPICS writes are published to the broker
/// via an async channel.
/// Parameter index for the MQTT connection status.
pub const PARAM_CONNECTED: &str = "_MQTT_CONNECTED";

pub struct MqttDriver {
    base: PortDriverBase,
    /// drvInfo string -> (param index, address)
    registry: HashMap<String, (usize, TopicAddress)>,
    /// MQTT topic -> list of (param index, address)
    topic_map: HashMap<String, Vec<(usize, TopicAddress)>>,
    /// param index -> topic address (for O(1) lookup on writes)
    reason_to_addr: Vec<Option<TopicAddress>>,
    /// Channel to send publish requests to the event loop
    publish_tx: mpsc::UnboundedSender<PublishRequest>,
    /// Default QoS for publishing
    default_qos: QoS,
    /// Param index for connection status (0=disconnected, 1=connected)
    pub connected_param: usize,
}

impl MqttDriver {
    /// Create a new MQTT driver with pre-declared topic addresses.
    ///
    /// All topics must be declared upfront because `drv_user_create(&self)`
    /// cannot mutate the driver to create new parameters at runtime.
    pub fn new(
        port_name: &str,
        config: &MqttConfig,
        topics: Vec<TopicAddress>,
        publish_tx: mpsc::UnboundedSender<PublishRequest>,
    ) -> Self {
        let flags = PortFlags {
            can_block: true,
            ..PortFlags::default()
        };
        let mut base = PortDriverBase::new(port_name, 1, flags);
        let mut registry = HashMap::new();
        let mut topic_map: HashMap<String, Vec<(usize, TopicAddress)>> = HashMap::new();
        let mut reason_to_addr = Vec::new();

        // Create connection status param (0=disconnected, 1=connected)
        let connected_param = base
            .create_param(PARAM_CONNECTED, ParamType::Int32)
            .expect("failed to create connected param");
        base.set_int32_param(connected_param, 0, 0).unwrap();

        for addr in topics {
            let drv_info = addr.to_drv_info();
            let param_type = addr.param_type();
            let idx = base
                .create_param(&drv_info, param_type)
                .expect("failed to create param");

            // Grow reason_to_addr to accommodate this index
            if reason_to_addr.len() <= idx {
                reason_to_addr.resize_with(idx + 1, || None);
            }
            reason_to_addr[idx] = Some(addr.clone());

            topic_map
                .entry(addr.topic.clone())
                .or_default()
                .push((idx, addr.clone()));
            registry.insert(drv_info, (idx, addr));
        }

        Self {
            base,
            registry,
            topic_map,
            reason_to_addr,
            publish_tx,
            default_qos: config.qos,
            connected_param,
        }
    }

    /// Get the set of MQTT topics this driver subscribes to.
    pub fn subscribed_topics(&self) -> Vec<String> {
        self.topic_map.keys().cloned().collect()
    }

    /// Get a clone of the topic map for the event loop.
    pub fn topic_map(&self) -> &HashMap<String, Vec<(usize, TopicAddress)>> {
        &self.topic_map
    }

    /// Encode and publish a value for the given parameter reason.
    /// Uses FLAT or JSON encoding depending on the topic address format.
    fn publish_value(&self, reason: usize, value: &DecodedValue) -> AsynResult<()> {
        let addr = self
            .reason_to_addr
            .get(reason)
            .and_then(|a| a.as_ref())
            .ok_or_else(|| AsynError::ParamNotFound(format!("reason {reason}")))?;

        let payload = encode_payload(value, addr);

        self.publish_tx
            .send(PublishRequest {
                topic: addr.topic.clone(),
                payload,
                qos: self.default_qos,
                retained: false,
            })
            .map_err(|_| MqttError::PublishChannelClosed)?;

        Ok(())
    }
}

impl PortDriver for MqttDriver {
    fn base(&self) -> &PortDriverBase {
        &self.base
    }

    fn base_mut(&mut self) -> &mut PortDriverBase {
        &mut self.base
    }

    fn drv_user_create(&self, drv_info: &str) -> AsynResult<usize> {
        // Check topic registry first, then fall back to param name lookup
        // (for internal params like _MQTT_CONNECTED)
        if let Some((idx, _)) = self.registry.get(drv_info) {
            return Ok(*idx);
        }
        self.base()
            .params
            .find_param(drv_info)
            .ok_or_else(|| AsynError::ParamNotFound(drv_info.to_string()))
    }

    fn write_int32(&mut self, user: &mut AsynUser, value: i32) -> AsynResult<()> {
        self.publish_value(user.reason, &DecodedValue::Int32(value))?;
        self.base.params.set_int32(user.reason, user.addr, value)?;
        self.base.call_param_callbacks(user.addr)
    }

    fn write_float64(&mut self, user: &mut AsynUser, value: f64) -> AsynResult<()> {
        self.publish_value(user.reason, &DecodedValue::Float64(value))?;
        self.base
            .params
            .set_float64(user.reason, user.addr, value)?;
        self.base.call_param_callbacks(user.addr)
    }

    fn write_octet(&mut self, user: &mut AsynUser, data: &[u8]) -> AsynResult<()> {
        let s = String::from_utf8_lossy(data).into_owned();
        self.publish_value(user.reason, &DecodedValue::String(s.clone()))?;
        self.base.params.set_string(user.reason, user.addr, s)?;
        self.base.call_param_callbacks(user.addr)
    }

    fn write_uint32_digital(
        &mut self,
        user: &mut AsynUser,
        value: u32,
        mask: u32,
    ) -> AsynResult<()> {
        self.base
            .params
            .set_uint32(user.reason, user.addr, value, mask)?;
        let full_val = self
            .base
            .params
            .get_uint32(user.reason, user.addr)
            .unwrap_or(value & mask);
        self.publish_value(user.reason, &DecodedValue::UInt32(full_val))?;
        self.base.call_param_callbacks(user.addr)
    }

    fn write_int32_array(&mut self, user: &AsynUser, data: &[i32]) -> AsynResult<()> {
        self.publish_value(user.reason, &DecodedValue::Int32Array(data.to_vec()))?;
        self.base
            .params
            .set_int32_array(user.reason, user.addr, data.to_vec())?;
        self.base.call_param_callbacks(user.addr)
    }

    fn read_int32_array(&mut self, user: &AsynUser, buf: &mut [i32]) -> AsynResult<usize> {
        let data = self.base.params.get_int32_array(user.reason, user.addr)?;
        let n = data.len().min(buf.len());
        buf[..n].copy_from_slice(&data[..n]);
        Ok(n)
    }

    fn write_float64_array(&mut self, user: &AsynUser, data: &[f64]) -> AsynResult<()> {
        self.publish_value(user.reason, &DecodedValue::Float64Array(data.to_vec()))?;
        self.base
            .params
            .set_float64_array(user.reason, user.addr, data.to_vec())?;
        self.base.call_param_callbacks(user.addr)
    }

    fn read_float64_array(&mut self, user: &AsynUser, buf: &mut [f64]) -> AsynResult<usize> {
        let data = self.base.params.get_float64_array(user.reason, user.addr)?;
        let n = data.len().min(buf.len());
        buf[..n].copy_from_slice(&data[..n]);
        Ok(n)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_driver(topics: &[&str]) -> (MqttDriver, mpsc::UnboundedReceiver<PublishRequest>) {
        let (tx, rx) = mpsc::unbounded_channel();
        let config = MqttConfig::default();
        let addrs: Vec<TopicAddress> = topics
            .iter()
            .map(|s| TopicAddress::parse(s).unwrap())
            .collect();
        let driver = MqttDriver::new("TEST", &config, addrs, tx);
        (driver, rx)
    }

    #[test]
    fn drv_user_create_finds_registered_topics() {
        let (driver, _rx) = make_driver(&[
            "FLAT:INT test/int_topic",
            "FLAT:FLOAT test/float_topic",
            "JSON:FLOAT sensors/data humidity",
        ]);

        assert!(driver.drv_user_create("FLAT:INT test/int_topic").is_ok());
        assert!(
            driver
                .drv_user_create("FLAT:FLOAT test/float_topic")
                .is_ok()
        );
        assert!(
            driver
                .drv_user_create("JSON:FLOAT sensors/data humidity")
                .is_ok()
        );
    }

    #[test]
    fn drv_user_create_rejects_unknown() {
        let (driver, _rx) = make_driver(&["FLAT:INT test/topic"]);
        assert!(driver.drv_user_create("FLAT:FLOAT other/topic").is_err());
    }

    #[test]
    fn subscribed_topics_returns_unique_mqtt_topics() {
        let (driver, _rx) = make_driver(&[
            "FLAT:INT test/topic",
            "FLAT:FLOAT test/topic",
            "FLAT:STRING other/topic",
        ]);

        let mut topics = driver.subscribed_topics();
        topics.sort();
        assert_eq!(topics, vec!["other/topic", "test/topic"]);
    }

    #[test]
    fn write_int32_sends_publish_request() {
        let (mut driver, mut rx) = make_driver(&["FLAT:INT test/int_topic"]);
        let reason = driver.drv_user_create("FLAT:INT test/int_topic").unwrap();
        let mut user = AsynUser::new(reason);

        driver.write_int32(&mut user, 42).unwrap();

        let req = rx.try_recv().unwrap();
        assert_eq!(req.topic, "test/int_topic");
        assert_eq!(req.payload, "42");
    }

    #[test]
    fn write_float64_sends_publish_request() {
        let (mut driver, mut rx) = make_driver(&["FLAT:FLOAT test/float_topic"]);
        let reason = driver
            .drv_user_create("FLAT:FLOAT test/float_topic")
            .unwrap();
        let mut user = AsynUser::new(reason);

        driver.write_float64(&mut user, 3.15).unwrap();

        let req = rx.try_recv().unwrap();
        assert_eq!(req.topic, "test/float_topic");
        assert_eq!(req.payload, "3.15");
    }

    #[test]
    fn write_octet_sends_publish_request() {
        let (mut driver, mut rx) = make_driver(&["FLAT:STRING test/str_topic"]);
        let reason = driver
            .drv_user_create("FLAT:STRING test/str_topic")
            .unwrap();
        let mut user = AsynUser::new(reason);

        driver.write_octet(&mut user, b"hello").unwrap();

        let req = rx.try_recv().unwrap();
        assert_eq!(req.topic, "test/str_topic");
        assert_eq!(req.payload, "hello");
    }

    #[test]
    fn topic_map_groups_by_mqtt_topic() {
        let (driver, _rx) = make_driver(&[
            "FLAT:INT test/shared",
            "FLAT:FLOAT test/shared",
            "FLAT:STRING test/other",
        ]);

        assert_eq!(driver.topic_map()["test/shared"].len(), 2);
        assert_eq!(driver.topic_map()["test/other"].len(), 1);
    }
}