use rumqttc::{AsyncClient, EventLoop, Packet};
pub use rumqttc::{LastWill, MqttOptions, QoS};
use xdevs::{
simulation::rt::{input::InputSender, output::OutputReceiver, Handler},
Event,
};
#[derive(Clone)]
pub struct MqttHandler {
root_topic: String,
pub mqtt_options: MqttOptions,
pub client_cap: usize,
pub input_qos: QoS,
pub output_qos: QoS,
pub output_retain: bool,
}
impl MqttHandler {
pub fn new<R: Into<String>, S: Into<String>, T: Into<String>>(
root_topic: R,
id: S,
host: T,
port: u16,
) -> Self {
Self {
root_topic: root_topic.into(),
mqtt_options: MqttOptions::new(id, host, port),
client_cap: 10,
input_qos: QoS::AtMostOnce,
output_qos: QoS::AtLeastOnce,
output_retain: true,
}
}
}
impl Handler for MqttHandler {
unsafe fn spawn(
self,
input_tx: Option<InputSender>,
output_rx: Option<OutputReceiver>,
) -> Vec<tokio::task::JoinHandle<()>> {
let mut handles = Vec::new();
if input_tx.is_none() && output_rx.is_none() {
tracing::warn!("no input or output handler provided. Exiting.");
} else {
let root_topic = self.root_topic;
let (client, eventloop) = AsyncClient::new(self.mqtt_options, self.client_cap);
let client_config = ClientConfig {
input: if input_tx.is_some() {
Some(self.input_qos)
} else {
None
},
output: output_rx.map(|rx| (self.output_qos, self.output_retain, rx)),
};
handles.push(tokio::task::spawn(client_thread(
client,
root_topic.clone(),
client_config,
)));
handles.push(tokio::task::spawn(eventloop_thread(eventloop, input_tx)));
};
handles
}
}
struct ClientConfig {
input: Option<QoS>,
output: Option<(QoS, bool, OutputReceiver)>,
}
async fn client_thread(client: AsyncClient, root_topic: String, config: ClientConfig) {
let input_topic = format!("{root_topic}/input");
let output_topic = format!("{root_topic}/output");
if let Some(input_qos) = config.input {
tracing::info!("subscribing to MQTT topic {input_topic}/+");
if let Err(e) = client
.subscribe(format!("{input_topic}/+"), input_qos)
.await
{
tracing::error!("failed to subscribe to MQTT topic {input_topic}/+: {e:?}");
return;
}
}
if let Some((output_qos, output_retain, mut output_rx)) = config.output {
loop {
match output_rx.recv().await {
Ok(event) => {
let port = event.port();
let value = event.value();
tracing::info!("publishing value {value} to MQTT topic {output_topic}/{port}");
if let Err(e) = client
.publish(
format!("{output_topic}/{port}"),
output_qos,
output_retain,
value,
)
.await
{
tracing::warn!("failed to publish to MQTT topic: {e:?}.");
}
}
Err(e) => {
tracing::error!("output handler dropped: {e:?}. Disconnecting MQTT client.");
client.disconnect().await.ok();
break;
}
}
}
}
}
async fn eventloop_thread(mut eventloop: EventLoop, input_tx: Option<InputSender>) {
loop {
match eventloop.poll().await {
Ok(notif) => {
tracing::debug!("MQTT event notification: {notif:?}");
if let Some(input_tx) = &input_tx {
if let rumqttc::Event::Incoming(Packet::Publish(packet)) = notif {
let port = packet.topic.split('/').next_back().unwrap().to_string();
let value = match String::from_utf8(packet.payload.to_vec()) {
Ok(string) => string,
Err(e) => {
tracing::warn!("Failed to convert payload to UTF8 String: {e}.");
continue;
}
};
match input_tx.send(Event::new(port, value)).await {
Ok(_) => {}
Err(e) => {
tracing::error!(
"input handler dropped: {e:?}. Disconnecting MQTT client."
);
break;
}
}
}
}
}
Err(e) => {
tracing::error!("MQTT eventloop error: {:?}", e);
break;
}
}
}
}
#[cfg(feature = "dmt")]
pub use dmt::DmtPropagator;
#[cfg(feature = "dmt")]
mod dmt {
use super::MqttHandler;
use crate::dmt::{Component, DevsModelTree};
use rumqttc::{AsyncClient, EventLoop, Packet, Publish, QoS};
use std::{
collections::{HashMap, HashSet},
ops::{Deref, DerefMut},
};
use tokio::sync::mpsc::{channel, Receiver, Sender};
fn topic_map<S: Into<String>>(root: S, component: &Component) -> MqttTopicMap {
let mut map = MqttTopicMap::new();
let root: String = root.into();
for (component_id, component) in component.components().iter() {
let submap = topic_map(format!("{root}/components/{component_id}"), component);
map.extend(submap);
}
for (node_from, nodes_to) in component.coupling_map().iter() {
let port_from = node_from.port();
let topic_from = match node_from.component() {
Some(component) => format!("{root}/components/{component}/output/{port_from}"),
None => format!("{root}/input/{port_from}"),
};
for node_to in nodes_to.iter() {
let port_to = node_to.port();
let topic_to = match node_to.component() {
Some(component) => format!("{root}/components/{component}/input/{port_to}"),
None => format!("{root}/output/{port_to}"),
};
map.insert(topic_from.clone(), topic_to);
}
}
map
}
fn avoid_multihop(map: &mut MqttTopicMap) {
let reverse = map.reverse();
for (port_to, ports_from) in reverse.iter() {
if let Some(ports_to) = map.0.remove(port_to) {
for port_from in ports_from.iter() {
map.0.get_mut(port_from).unwrap().extend(ports_to.clone());
}
}
}
}
pub struct DmtPropagator {
config: MqttHandler,
group_id: Option<String>,
model: DevsModelTree,
}
impl DmtPropagator {
pub fn new<R: Into<String>, S: Into<String>, T: Into<String>>(
root_topic: R,
client_id: S,
host: T,
port: u16,
model: DevsModelTree,
) -> Self {
let config = MqttHandler::new(root_topic, client_id, host, port);
Self {
config,
group_id: None,
model,
}
}
pub fn set_group_id<S: Into<String>>(&mut self, group_id: S) {
let group_id: String = group_id.into();
assert!(
!group_id.contains('/'),
"group_id cannot contain '/' character."
);
self.group_id = Some(group_id);
}
pub fn spawn(self) -> Vec<tokio::task::JoinHandle<()>> {
let mut handles = Vec::new();
let mut topic_map = topic_map(&self.root_topic, &self.model);
avoid_multihop(&mut topic_map);
let (config, _) = (self.config, self.model);
let (sender, receiver) = channel(10);
let (client, eventloop) = AsyncClient::new(config.mqtt_options, config.client_cap);
handles.push(tokio::task::spawn(client_thread(
client,
config.input_qos,
config.output_qos,
config.output_retain,
receiver,
topic_map,
self.group_id,
)));
handles.push(tokio::task::spawn(eventloop_thread(eventloop, sender)));
handles
}
}
impl Deref for DmtPropagator {
type Target = MqttHandler;
fn deref(&self) -> &Self::Target {
&self.config
}
}
impl DerefMut for DmtPropagator {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.config
}
}
async fn client_thread(
client: AsyncClient,
input_qos: QoS,
output_qos: QoS,
output_retain: bool,
mut input_rx: Receiver<Publish>,
topic_map: MqttTopicMap,
group_id: Option<String>,
) {
let mut topics = topic_map
.iter()
.map(|(topic, _)| topic.clone())
.collect::<Vec<_>>();
topics.sort_unstable();
for (n, topic) in topics.iter().enumerate() {
let mut topic = topic.clone();
if let Some(group_id) = &group_id {
topic = format!("$share/{group_id}{n}/{topic}");
}
tracing::info!("subscribing to MQTT topic {topic}.");
if let Err(e) = client.subscribe(topic, input_qos).await {
tracing::error!("failed to subscribe to MQTT topic: {e:?}.");
return;
}
}
loop {
match input_rx.recv().await {
Some(publish) => {
let topic = publish.topic;
let payload = publish.payload;
tracing::info!("received value from MQTT topic {topic}.");
match topic_map.get(&topic) {
Some(topics) => {
for topic in topics.iter() {
tracing::info!("publishing value to MQTT topic {topic}");
if let Err(e) = client
.publish(topic, output_qos, output_retain, payload.clone())
.await
{
tracing::warn!("failed to publish to MQTT topic: {e:?}.");
}
}
}
None => {
tracing::error!("no mapping found for MQTT topic {topic}.");
}
}
}
None => {
tracing::error!("input handler dropped. Disconnecting MQTT client.");
client.disconnect().await.ok();
break;
}
}
}
}
async fn eventloop_thread(mut eventloop: EventLoop, sender: Sender<Publish>) {
loop {
match eventloop.poll().await {
Ok(notif) => {
tracing::debug!("MQTT event notification: {notif:?}");
if let rumqttc::Event::Incoming(Packet::Publish(packet)) = notif {
match sender.send(packet).await {
Ok(_) => {}
Err(e) => {
tracing::error!(
"output handler dropped: {e:?}. Disconnecting MQTT client."
);
break;
}
}
}
}
Err(e) => {
tracing::error!("MQTT eventloop error: {:?}", e);
break;
}
}
}
}
struct MqttTopicMap(HashMap<String, HashSet<String>>);
impl MqttTopicMap {
fn new() -> Self {
Self(HashMap::new())
}
fn get(&self, topic: &str) -> Option<&HashSet<String>> {
self.0.get(topic)
}
fn insert(&mut self, topic: String, port: String) {
self.0.entry(topic).or_default().insert(port);
}
fn extend(&mut self, other: MqttTopicMap) {
self.0.extend(other.0);
}
fn iter(&self) -> impl Iterator<Item = (&String, &HashSet<String>)> {
self.0.iter()
}
fn reverse(&self) -> MqttTopicMap {
let mut map = MqttTopicMap::new();
for (topic, ports) in self.0.iter() {
for port in ports.iter() {
map.insert(port.clone(), topic.clone());
}
}
map
}
}
}