use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Arc};
use azure_iot_operations_mqtt::{
aio::cloud_event as aio_cloud_event,
control_packet::{Publish, QoS, TopicFilter},
session::{SessionManagedClient, SessionPubReceiver},
token::AckToken,
};
use tokio_util::sync::CancellationToken;
use crate::{
ProtocolVersion,
application::{ApplicationContext, ApplicationHybridLogicalClock},
common::{
aio_protocol_error::AIOProtocolError,
hybrid_logical_clock::HybridLogicalClock,
payload_serialize::{FormatIndicator, PayloadSerialize},
topic_processor::TopicPattern,
user_properties::UserProperty,
},
telemetry::DEFAULT_TELEMETRY_PROTOCOL_VERSION,
};
const SUPPORTED_PROTOCOL_VERSIONS: &[u16] = &[1];
pub type CloudEvent = aio_cloud_event::CloudEvent;
pub type CloudEventParseError = aio_cloud_event::CloudEventParseError;
pub fn cloud_event_from_telemetry<T: PayloadSerialize>(
telemetry: &Message<T>,
) -> Result<CloudEvent, CloudEventParseError> {
CloudEvent::try_from((
&telemetry.custom_user_data,
telemetry.content_type.as_deref(),
))
}
#[derive(Debug)]
pub struct Message<T: PayloadSerialize> {
pub payload: T,
pub content_type: Option<String>,
pub format_indicator: FormatIndicator,
pub custom_user_data: Vec<(String, String)>,
pub sender_id: Option<String>,
pub timestamp: Option<HybridLogicalClock>,
pub topic_tokens: HashMap<String, String>,
pub topic: String,
pub duplicate: Option<bool>,
}
impl<T> TryFrom<Publish> for Message<T>
where
T: PayloadSerialize,
{
type Error = String;
fn try_from(value: Publish) -> Result<Message<T>, Self::Error> {
let publish_properties = value.properties;
let expected_aio_properties = [
UserProperty::Timestamp,
UserProperty::ProtocolVersion,
UserProperty::SourceId,
];
let mut telemetry_custom_user_data = vec![];
let mut telemetry_aio_data = HashMap::new();
for (key, value) in publish_properties.user_properties {
match UserProperty::from_str(&key) {
Ok(p) if expected_aio_properties.contains(&p) => {
telemetry_aio_data.insert(p, value);
}
Ok(_) => {
log::warn!(
"Telemetry should not contain MQTT user property '{key}'. Value is '{value}'"
);
telemetry_custom_user_data.push((key, value));
}
Err(()) => {
telemetry_custom_user_data.push((key, value));
}
}
}
let protocol_version = {
match telemetry_aio_data.get(&UserProperty::ProtocolVersion) {
Some(protocol_version) => {
if let Some(version) = ProtocolVersion::parse_protocol_version(protocol_version)
{
version
} else {
return Err(format!(
"Received a telemetry with an unparsable protocol version number: {protocol_version}"
));
}
}
None => DEFAULT_TELEMETRY_PROTOCOL_VERSION,
}
};
if !protocol_version.is_supported(SUPPORTED_PROTOCOL_VERSIONS) {
return Err(format!(
"Unsupported protocol version '{protocol_version}'. Only major protocol versions '{SUPPORTED_PROTOCOL_VERSIONS:?}' are supported"
));
}
let timestamp = telemetry_aio_data
.get(&UserProperty::Timestamp)
.map(|s| HybridLogicalClock::from_str(s))
.transpose()
.map_err(|e| e.to_string())?;
let format_indicator = publish_properties.payload_format_indicator.into();
let content_type = publish_properties.content_type;
let payload = T::deserialize(&value.payload, content_type.as_ref(), &format_indicator)
.map_err(|e| format!("{e:?}"))?;
let duplicate = match value.qos {
azure_iot_operations_mqtt::control_packet::DeliveryQoS::AtMostOnce => None,
azure_iot_operations_mqtt::control_packet::DeliveryQoS::AtLeastOnce(delivery_info) => {
Some(delivery_info.dup)
}
azure_iot_operations_mqtt::control_packet::DeliveryQoS::ExactlyOnce(_) => {
unreachable!()
}
};
let telemetry_message = Message {
payload,
content_type,
format_indicator,
custom_user_data: telemetry_custom_user_data,
sender_id: telemetry_aio_data.remove(&UserProperty::SourceId),
timestamp,
topic_tokens: HashMap::default(),
topic: value.topic_name.as_str().to_string(),
duplicate,
};
Ok(telemetry_message)
}
}
#[derive(Builder, Clone)]
#[builder(setter(into, strip_option))]
pub struct Options {
topic_pattern: String,
#[builder(default = "None")]
topic_namespace: Option<String>,
#[builder(default)]
topic_token_map: HashMap<String, String>,
#[builder(default = "true")]
auto_ack: bool,
#[allow(unused)]
#[builder(default = "None")]
service_group_id: Option<String>,
}
pub struct Receiver<T>
where
T: PayloadSerialize + Send + Sync + 'static,
{
application_hlc: Arc<ApplicationHybridLogicalClock>,
mqtt_client: SessionManagedClient,
#[allow(clippy::struct_field_names)]
mqtt_receiver: SessionPubReceiver,
telemetry_topic: TopicFilter,
topic_pattern: TopicPattern,
message_payload_type: PhantomData<T>,
state: State,
cancellation_token: CancellationToken,
auto_ack: bool,
}
#[derive(PartialEq)]
enum State {
New,
Subscribed,
ShutdownSuccessful,
}
impl<T> Receiver<T>
where
T: PayloadSerialize + Send + Sync + 'static,
{
#[allow(clippy::needless_pass_by_value)]
pub fn new(
application_context: ApplicationContext,
client: SessionManagedClient,
receiver_options: Options,
) -> Result<Self, AIOProtocolError> {
let topic_pattern = TopicPattern::new(
&receiver_options.topic_pattern,
None,
receiver_options.topic_namespace.as_deref(),
&receiver_options.topic_token_map,
)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"receiver_options.topic_pattern",
)
})?;
let telemetry_topic = topic_pattern.as_subscribe_topic().map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"receiver_options.topic_pattern",
)
})?;
let mqtt_receiver = client.create_filtered_pub_receiver(telemetry_topic.clone());
Ok(Self {
application_hlc: application_context.application_hlc,
mqtt_client: client,
mqtt_receiver,
telemetry_topic,
topic_pattern,
message_payload_type: PhantomData,
state: State::New,
cancellation_token: CancellationToken::new(),
auto_ack: receiver_options.auto_ack,
})
}
pub async fn shutdown(&mut self) -> Result<(), AIOProtocolError> {
self.mqtt_receiver.close();
match self.state {
State::New | State::ShutdownSuccessful => {
self.state = State::ShutdownSuccessful;
}
State::Subscribed => {
let unsubscribe_result = self
.mqtt_client
.unsubscribe(
self.telemetry_topic.clone(),
azure_iot_operations_mqtt::control_packet::UnsubscribeProperties::default(),
)
.await;
match unsubscribe_result {
Ok(unsub_ct) => match unsub_ct.await {
Ok(unsuback) => match unsuback.as_result() {
Ok(()) => {
self.state = State::ShutdownSuccessful;
}
Err(e) => {
log::error!("Telemetry Receiver Unsuback error: {unsuback:?}");
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on telemetry receiver unsuback".to_string()),
Box::new(e),
None,
));
}
},
Err(e) => {
log::error!("Telemetry Receiver Unsubscribe completion error: {e}");
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on telemetry receiver unsubscribe".to_string()),
Box::new(e),
None,
));
}
},
Err(e) => {
log::error!("Client error while unsubscribing in Telemetry Receiver: {e}");
return Err(AIOProtocolError::new_mqtt_error(
Some("Client error on telemetry receiver unsubscribe".to_string()),
Box::new(e),
None,
));
}
}
}
}
log::info!("Telemetry receiver shutdown");
Ok(())
}
async fn try_subscribe(&mut self) -> Result<(), AIOProtocolError> {
let subscribe_result = self
.mqtt_client
.subscribe(
self.telemetry_topic.clone(),
QoS::AtLeastOnce,
false,
azure_iot_operations_mqtt::control_packet::RetainOptions::default(),
azure_iot_operations_mqtt::control_packet::SubscribeProperties::default(),
)
.await;
match subscribe_result {
Ok(sub_ct) => match sub_ct.await {
Ok(suback) => {
suback.as_result().map_err(|e| {
log::error!("Telemetry Receiver Suback error: {suback:?}");
AIOProtocolError::new_mqtt_error(
Some("MQTT error on telemetry receiver suback".to_string()),
Box::new(e),
None,
)
})?;
}
Err(e) => {
log::error!("Telemetry Receiver Subscribe completion error: {e}");
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on telemetry receiver subscribe".to_string()),
Box::new(e),
None,
));
}
},
Err(e) => {
log::error!("Client error while subscribing in Telemetry Receiver: {e}");
return Err(AIOProtocolError::new_mqtt_error(
Some("Client error on telemetry receiver subscribe".to_string()),
Box::new(e),
None,
));
}
}
Ok(())
}
pub async fn recv(
&mut self,
) -> Option<Result<(Message<T>, Option<AckToken>), AIOProtocolError>> {
if self.state == State::New {
if let Err(e) = self.try_subscribe().await {
return Some(Err(e));
}
self.state = State::Subscribed;
}
loop {
match self.mqtt_receiver.recv_manual_ack().await {
Some((m, mut ack_token)) => {
if self.auto_ack {
ack_token.take();
}
let pkid = match m.qos {
azure_iot_operations_mqtt::control_packet::DeliveryQoS::AtMostOnce => {
0
}
azure_iot_operations_mqtt::control_packet::DeliveryQoS::AtLeastOnce(
delivery_info,
) => delivery_info.packet_identifier.get(),
azure_iot_operations_mqtt::control_packet::DeliveryQoS::ExactlyOnce(_) => {
log::warn!("Received QoS 2 telemetry message");
continue;
}
};
log::debug!("[pkid: {pkid}] Received message");
match TryInto::<Message<T>>::try_into(m) {
Ok(mut message) => {
message
.topic_tokens
.extend(self.topic_pattern.parse_tokens(&message.topic));
if let Some(hlc) = &message.timestamp
&& let Err(e) = self.application_hlc.update(hlc)
{
log::warn!(
"[pkid: {pkid}]: Failure updating application HLC against received telemetry HLC {hlc}: {e}"
);
}
return Some(Ok((message, ack_token)));
}
Err(e_string) => {
log::warn!("[pkid: {pkid}] {e_string}");
if let Some(ack_token) = ack_token {
tokio::spawn({
let receiver_cancellation_token_clone =
self.cancellation_token.clone();
async move {
tokio::select! {
() = receiver_cancellation_token_clone.cancelled() => { },
ack_res = ack_token.ack() => {
match ack_res {
Ok(_) => { }
Err(e) => {
log::warn!("[pkid: {pkid}] Telemetry Receiver Ack error {e}");
}
}
}
}
}
});
}
}
}
}
_ => {
return None;
}
}
}
}
}
impl<T> Drop for Receiver<T>
where
T: PayloadSerialize + Send + Sync + 'static,
{
fn drop(&mut self) {
self.cancellation_token.cancel();
self.mqtt_receiver.close();
if State::Subscribed == self.state {
tokio::spawn({
let telemetry_topic = self.telemetry_topic.clone();
let mqtt_client = self.mqtt_client.clone();
async move {
match mqtt_client
.unsubscribe(
telemetry_topic.clone(),
azure_iot_operations_mqtt::control_packet::UnsubscribeProperties::default(),
)
.await
{
Ok(_) => {
log::debug!(
"Telemetry Receiver Unsubscribe sent on topic {telemetry_topic}. Unsuback may still be pending."
);
}
Err(e) => {
log::warn!("Telemetry Receiver Unsubscribe error on topic {telemetry_topic}: {e}");
}
}
}
});
}
log::info!("Telemetry receiver dropped");
}
}
#[cfg(test)]
mod tests {
use test_case::test_case;
use super::*;
use crate::{
application::ApplicationContextBuilder,
common::{
aio_protocol_error::{AIOProtocolErrorKind, Value},
payload_serialize::MockPayload,
},
telemetry::receiver::{OptionsBuilder, Receiver},
};
use azure_iot_operations_mqtt::{
aio::connection_settings::MqttConnectionSettingsBuilder,
session::{Session, SessionOptionsBuilder},
};
fn get_session() -> Session {
let connection_settings = MqttConnectionSettingsBuilder::default()
.hostname("localhost")
.client_id("test_server")
.build()
.unwrap();
let session_options = SessionOptionsBuilder::default()
.connection_settings(connection_settings)
.build()
.unwrap();
Session::new(session_options).unwrap()
}
fn create_topic_tokens() -> HashMap<String, String> {
HashMap::from([("telemetryName".to_string(), "test_telemetry".to_string())])
}
#[test]
fn test_new_defaults() {
let session = get_session();
let receiver_options = OptionsBuilder::default()
.topic_pattern("test/receiver")
.build()
.unwrap();
Receiver::<MockPayload>::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
receiver_options,
)
.unwrap();
}
#[test]
fn test_new_override_defaults() {
let session = get_session();
let receiver_options = OptionsBuilder::default()
.topic_pattern("test/{telemetryName}/receiver")
.topic_namespace("test_namespace")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
Receiver::<MockPayload>::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
receiver_options,
)
.unwrap();
}
#[test_case(""; "new_empty_topic_pattern")]
#[test_case(" "; "new_whitespace_topic_pattern")]
fn test_new_empty_topic_pattern(topic_pattern: &str) {
let session = get_session();
let receiver_options = OptionsBuilder::default()
.topic_pattern(topic_pattern)
.build()
.unwrap();
let result: Result<Receiver<MockPayload>, _> = Receiver::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
receiver_options,
);
match result {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(
e.property_name,
Some("receiver_options.topic_pattern".to_string())
);
assert_eq!(
e.property_value,
Some(Value::String(topic_pattern.to_string()))
);
}
}
}
#[tokio::test]
async fn test_shutdown_without_subscribe() {
let session = get_session();
let receiver_options = OptionsBuilder::default()
.topic_pattern("test/receiver")
.build()
.unwrap();
let mut receiver: Receiver<MockPayload> = Receiver::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
receiver_options,
)
.unwrap();
assert!(receiver.shutdown().await.is_ok());
}
}