use std::str::FromStr;
use std::sync::Arc;
use std::{collections::HashMap, marker::PhantomData, time::Duration};
use azure_iot_operations_mqtt::aio::cloud_event as aio_cloud_event;
use azure_iot_operations_mqtt::control_packet::{PublishProperties, QoS};
use azure_iot_operations_mqtt::session::SessionManagedClient;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::{
application::{ApplicationContext, ApplicationHybridLogicalClock},
common::{
aio_protocol_error::{AIOProtocolError, Value},
cloud_event as protocol_cloud_event, is_invalid_utf8,
payload_serialize::{PayloadSerialize, SerializedPayload},
topic_processor::TopicPattern,
user_properties::{PERSIST_KEY, UserProperty, validate_user_properties},
},
telemetry::{DEFAULT_TELEMETRY_CLOUD_EVENT_EVENT_TYPE, TELEMETRY_PROTOCOL_VERSION},
};
#[derive(Builder, Clone, Debug)]
#[builder(setter(into), build_fn(validate = "Self::validate"))]
pub struct Message<T: PayloadSerialize> {
#[builder(setter(custom))]
serialized_payload: SerializedPayload,
#[builder(private)]
payload_type: PhantomData<T>,
#[builder(default = "QoS::AtLeastOnce")]
qos: QoS,
#[builder(default)]
custom_user_data: Vec<(String, String)>,
#[builder(default)]
topic_tokens: HashMap<String, String>,
#[builder(default = "Duration::from_secs(10)")]
#[builder(setter(custom))]
#[allow(clippy::struct_field_names)]
message_expiry: Duration,
#[builder(default = "None")]
cloud_event: Option<CloudEvent>,
#[builder(default = "self.persist == Some(true)")]
retain: bool,
#[builder(default = "false")]
persist: bool,
}
#[derive(Clone, Debug)]
pub struct CloudEvent(protocol_cloud_event::CloudEvent);
#[derive(Clone)]
pub struct CloudEventBuilder(protocol_cloud_event::CloudEventBuilder);
#[derive(Debug)]
#[non_exhaustive]
pub enum CloudEventBuilderError {
UninitializedField(&'static str),
ValidationError(String),
}
impl std::error::Error for CloudEventBuilderError {}
impl std::fmt::Display for CloudEventBuilderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CloudEventBuilderError::UninitializedField(field_name) => {
write!(f, "Uninitialized field: {field_name}")
}
CloudEventBuilderError::ValidationError(err_msg) => {
write!(f, "Validation error: {err_msg}")
}
}
}
}
impl From<protocol_cloud_event::CloudEventBuilderError> for CloudEventBuilderError {
fn from(value: protocol_cloud_event::CloudEventBuilderError) -> Self {
match value {
protocol_cloud_event::CloudEventBuilderError::UninitializedField(field_name) => {
CloudEventBuilderError::UninitializedField(field_name)
}
protocol_cloud_event::CloudEventBuilderError::ValidationError(err_msg) => {
CloudEventBuilderError::ValidationError(err_msg)
}
}
}
}
impl Default for CloudEventBuilder {
fn default() -> Self {
Self(protocol_cloud_event::CloudEventBuilder::new(
DEFAULT_TELEMETRY_CLOUD_EVENT_EVENT_TYPE.to_string(),
))
}
}
impl CloudEventBuilder {
pub fn build(&self) -> Result<CloudEvent, CloudEventBuilderError> {
Ok(CloudEvent(protocol_cloud_event::CloudEventBuilder::build(
&self.0,
)?))
}
pub fn source<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.source(value);
self
}
pub fn spec_version<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.spec_version(value);
self
}
pub fn event_type<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.event_type(value);
self
}
pub fn data_schema<VALUE: Into<Option<String>>>(&mut self, value: VALUE) -> &mut Self {
self.0.data_schema(value);
self
}
pub fn id<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.id(value);
self
}
pub fn time<VALUE: Into<Option<DateTime<Utc>>>>(&mut self, value: VALUE) -> &mut Self {
self.0.time(value);
self
}
pub fn subject<VALUE: Into<protocol_cloud_event::CloudEventSubject>>(
&mut self,
value: VALUE,
) -> &mut Self {
self.0.subject(value);
self
}
}
impl<T: PayloadSerialize> MessageBuilder<T> {
pub fn payload(&mut self, payload: T) -> Result<&mut Self, AIOProtocolError> {
match payload.serialize() {
Err(e) => Err(AIOProtocolError::new_payload_invalid_error(
true,
false,
Some(e.into()),
Some("Payload serialization error".to_string()),
None,
)),
Ok(serialized_payload) => {
if is_invalid_utf8(&serialized_payload.content_type) {
return Err(AIOProtocolError::new_configuration_invalid_error(
None,
"content_type",
Value::String(serialized_payload.content_type.clone()),
Some(format!(
"Content type '{}' of telemetry message type is not valid UTF-8",
serialized_payload.content_type
)),
None,
));
}
self.serialized_payload = Some(serialized_payload);
self.payload_type = Some(PhantomData);
Ok(self)
}
}
}
pub fn message_expiry(&mut self, message_expiry: Duration) -> &mut Self {
self.message_expiry = Some(if message_expiry.subsec_nanos() != 0 {
Duration::from_secs(message_expiry.as_secs().saturating_add(1))
} else {
message_expiry
});
self
}
fn validate(&self) -> Result<(), String> {
if let Some(custom_user_data) = &self.custom_user_data {
for (key, _) in custom_user_data {
if aio_cloud_event::CloudEventFields::from_str(key).is_ok() {
return Err(format!(
"Invalid user data property '{key}' is a reserved Cloud Event key"
));
}
}
validate_user_properties(custom_user_data)?;
}
if let Some(timeout) = &self.message_expiry {
match <u64 as TryInto<u32>>::try_into(timeout.as_secs()) {
Ok(_) => {}
Err(_) => {
return Err("Timeout in seconds must be less than or equal to u32::max to be used as message_expiry_interval".to_string());
}
}
}
if let Some(qos) = &self.qos
&& *qos != QoS::AtMostOnce
&& *qos != QoS::AtLeastOnce
{
return Err("QoS must be AtMostOnce or AtLeastOnce".to_string());
}
if let Some(Some(cloud_event)) = &self.cloud_event
&& let Some(serialized_payload) = &self.serialized_payload
{
aio_cloud_event::CloudEventFields::DataContentType.validate(
&serialized_payload.content_type,
&cloud_event.0.spec_version,
)?;
}
if self.persist == Some(true) && self.retain == Some(false) {
return Err("Persist cannot be used without retain".to_string());
}
Ok(())
}
}
#[derive(Builder, Clone)]
#[builder(setter(into, strip_option))]
#[allow(clippy::struct_field_names)]
pub struct Options {
topic_pattern: String,
#[builder(default = "None")]
topic_namespace: Option<String>,
#[builder(default)]
topic_token_map: HashMap<String, String>,
}
pub struct Sender<T>
where
T: PayloadSerialize,
{
application_hlc: Arc<ApplicationHybridLogicalClock>,
mqtt_client: SessionManagedClient,
message_payload_type: PhantomData<T>,
topic_pattern: TopicPattern,
}
impl<T> Sender<T>
where
T: PayloadSerialize,
{
#[allow(clippy::needless_pass_by_value)]
pub fn new(
application_context: ApplicationContext,
client: SessionManagedClient,
sender_options: Options,
) -> Result<Self, AIOProtocolError> {
let topic_pattern = TopicPattern::new(
&sender_options.topic_pattern,
None,
sender_options.topic_namespace.as_deref(),
&sender_options.topic_token_map,
)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"sender_options.topic_pattern",
)
})?;
Ok(Self {
application_hlc: application_context.application_hlc,
mqtt_client: client,
message_payload_type: PhantomData,
topic_pattern,
})
}
pub async fn send(&self, mut message: Message<T>) -> Result<(), AIOProtocolError> {
let message_expiry_interval: u32 = match message.message_expiry.as_secs().try_into() {
Ok(val) => val,
Err(_) => {
unreachable!();
}
};
let message_topic = self
.topic_pattern
.as_publish_topic(&message.topic_tokens)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(e, "message_topic")
})?;
let timestamp_str = self.application_hlc.update_now()?;
let correlation_id = Uuid::new_v4();
let correlation_data = Bytes::from(correlation_id.as_bytes().to_vec());
if let Some(cloud_event) = message.cloud_event {
let cloud_event_headers = cloud_event.0.into_headers(message_topic.as_str());
for (key, value) in cloud_event_headers {
message.custom_user_data.push((key, value));
}
}
if message.persist {
message
.custom_user_data
.push((PERSIST_KEY.to_string(), true.to_string()));
}
message
.custom_user_data
.push((UserProperty::Timestamp.to_string(), timestamp_str));
message.custom_user_data.push((
UserProperty::ProtocolVersion.to_string(),
TELEMETRY_PROTOCOL_VERSION.to_string(),
));
message.custom_user_data.push((
UserProperty::SourceId.to_string(),
self.mqtt_client.client_id().to_string(),
));
let publish_properties = PublishProperties {
correlation_data: Some(correlation_data),
response_topic: None,
payload_format_indicator: message.serialized_payload.format_indicator.into(),
content_type: Some(message.serialized_payload.content_type.clone()),
message_expiry_interval: Some(message_expiry_interval),
user_properties: message.custom_user_data,
topic_alias: None,
subscription_identifiers: Vec::new(),
};
match message.qos {
azure_iot_operations_mqtt::control_packet::QoS::AtMostOnce => {
let publish_result = self
.mqtt_client
.publish_qos0(
message_topic,
message.retain,
message.serialized_payload.payload,
publish_properties,
)
.await;
match publish_result {
Ok(publish_completion_token) => publish_completion_token.await.map_err(|e| {
log::error!("Telemetry Publish completion error: {e}");
AIOProtocolError::new_mqtt_error(
Some("MQTT Error on telemetry send publish".to_string()),
Box::new(e),
None,
)
}),
Err(e) => {
log::error!("Telemetry Publish error: {e}");
Err(AIOProtocolError::new_mqtt_error(
Some("MQTT Error on telemetry send publish".to_string()),
Box::new(e),
None,
))
}
}
}
azure_iot_operations_mqtt::control_packet::QoS::AtLeastOnce => {
let publish_result = self
.mqtt_client
.publish_qos1(
message_topic,
message.retain,
message.serialized_payload.payload,
publish_properties,
)
.await;
match publish_result {
Ok(publish_completion_token) => {
match publish_completion_token.await {
Ok(puback) => puback.as_result().map_err(|e| {
AIOProtocolError::new_mqtt_error(
Some("MQTT Puback indicated failure".to_string()),
Box::new(e),
None,
)
}),
Err(e) => {
log::error!("Telemetry Publish completion error: {e}");
Err(AIOProtocolError::new_mqtt_error(
Some("MQTT Error on telemetry send publish".to_string()),
Box::new(e),
None,
))
}
}
}
Err(e) => {
log::error!("Telemetry Publish error: {e}");
Err(AIOProtocolError::new_mqtt_error(
Some("MQTT Error on telemetry send publish".to_string()),
Box::new(e),
None,
))
}
}
}
azure_iot_operations_mqtt::control_packet::QoS::ExactlyOnce => unreachable!(
"QoS::ExactlyOnce is not supported for telemetry sending and isn't possible to set on Message"
),
}
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, time::Duration};
use test_case::test_case;
use crate::{
application::ApplicationContextBuilder,
common::{
aio_protocol_error::{AIOProtocolErrorKind, Value},
payload_serialize::{FormatIndicator, MockPayload, SerializedPayload},
},
telemetry::sender::{OptionsBuilder, Sender},
};
use azure_iot_operations_mqtt::{
aio::connection_settings::MqttConnectionSettingsBuilder,
session::{Session, SessionOptionsBuilder},
};
use super::MessageBuilder;
fn get_session() -> Session {
let connection_settings = MqttConnectionSettingsBuilder::default()
.hostname("localhost")
.client_id("test_client")
.build()
.unwrap();
let session_options = SessionOptionsBuilder::default()
.connection_settings(connection_settings)
.build()
.unwrap();
Session::new(session_options).unwrap()
}
#[test]
fn test_new_defaults() {
let session = get_session();
let sender_options = OptionsBuilder::default()
.topic_pattern("test/test_telemetry")
.build()
.unwrap();
Sender::<MockPayload>::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
sender_options,
)
.unwrap();
}
#[test]
fn test_new_override_defaults() {
let session = get_session();
let sender_options = OptionsBuilder::default()
.topic_pattern("test/{telemetryName}")
.topic_namespace("test_namespace")
.topic_token_map(HashMap::from([(
"telemetryName".to_string(),
"test_telemetry".to_string(),
)]))
.build()
.unwrap();
Sender::<MockPayload>::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
sender_options,
)
.unwrap();
}
#[test_case(""; "new_empty_topic_pattern")]
#[test_case(" "; "new_whitespace_topic_pattern")]
fn test_new_empty_topic_pattern(property_value: &str) {
let session = get_session();
let sender_options = OptionsBuilder::default()
.topic_pattern(property_value)
.build()
.unwrap();
let sender: Result<Sender<MockPayload>, _> = Sender::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
sender_options,
);
match sender {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(
e.property_name,
Some("sender_options.topic_pattern".to_string())
);
assert!(e.property_value == Some(Value::String(property_value.to_string())));
}
}
}
#[test]
fn test_message_serialization_error() {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| Err("dummy error".to_string()))
.times(1);
let mut binding = MessageBuilder::default();
let message_builder = binding.payload(mock_telemetry_payload);
match message_builder {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::PayloadInvalid);
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test]
fn test_response_serialization_bad_content_type_error() {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json\u{0000}".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let mut binding = MessageBuilder::default();
let message_builder = binding.payload(mock_telemetry_payload);
match message_builder {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("content_type".to_string()));
assert!(
e.property_value == Some(Value::String("application/json\u{0000}".to_string()))
);
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test_case(Duration::from_secs(u64::from(u32::MAX) + 1); "send_timeout_u32_max")]
fn test_send_timeout_invalid_value(timeout: Duration) {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: String::new().into(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let message_builder_result = MessageBuilder::default()
.payload(mock_telemetry_payload)
.unwrap()
.message_expiry(timeout)
.build();
assert!(message_builder_result.is_err());
}
#[test]
fn test_send_qos_invalid_value() {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: String::new().into(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let message_builder_result = MessageBuilder::default()
.payload(mock_telemetry_payload)
.unwrap()
.qos(azure_iot_operations_mqtt::control_packet::QoS::ExactlyOnce)
.build();
assert!(message_builder_result.is_err());
}
#[test]
fn test_send_invalid_custom_user_data_cloud_event_header() {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: String::new().into(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let message_builder_result = MessageBuilder::default()
.payload(mock_telemetry_payload)
.unwrap()
.custom_user_data(vec![("source".to_string(), "test".to_string())])
.build();
assert!(message_builder_result.is_err());
}
#[test]
fn test_invalid_persist_retain() {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: String::new().into(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let message_builder_result = MessageBuilder::default()
.payload(mock_telemetry_payload)
.unwrap()
.persist(true)
.retain(false)
.build();
assert!(message_builder_result.is_err());
}
#[test]
fn test_message_defaults() {
let mut mock_telemetry_payload = MockPayload::new();
mock_telemetry_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: String::new().into(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let message_builder_result = MessageBuilder::default()
.payload(mock_telemetry_payload)
.unwrap()
.build();
assert!(message_builder_result.is_ok());
let m = message_builder_result.unwrap();
assert!(!m.persist);
assert!(!m.retain);
assert_eq!(
m.qos,
azure_iot_operations_mqtt::control_packet::QoS::AtLeastOnce
);
assert_eq!(m.message_expiry, Duration::from_secs(10));
assert!(m.custom_user_data.is_empty());
assert!(m.topic_tokens.is_empty());
assert!(m.cloud_event.is_none());
assert!(m.serialized_payload.payload.is_empty());
}
}