use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::{collections::HashMap, marker::PhantomData, time::Duration};
use azure_iot_operations_mqtt::{
aio::cloud_event as aio_cloud_event,
session::{SessionManagedClient, SessionPubReceiver},
};
use azure_iot_operations_mqtt::{
control_packet::{PublishProperties, QoS, TopicFilter, TopicName},
token::AckToken,
};
use bytes::Bytes;
use chrono::{DateTime, Utc};
use tokio::sync::oneshot;
use tokio::time::{Instant, timeout};
use tokio_util::sync::{CancellationToken, DropGuard};
use crate::{
ProtocolVersion,
application::{ApplicationContext, ApplicationHybridLogicalClock},
common::{
aio_protocol_error::{AIOProtocolError, Value},
cloud_event as protocol_cloud_event,
hybrid_logical_clock::{HLCErrorKind, HybridLogicalClock},
is_invalid_utf8,
payload_serialize::{
DeserializationError, FormatIndicator, PayloadSerialize, SerializedPayload,
},
topic_processor::{TopicPattern, contains_invalid_char, is_valid_replacement},
user_properties::{PARTITION_KEY, UserProperty, validate_user_properties},
},
rpc_command::{
DEFAULT_RPC_COMMAND_PROTOCOL_VERSION, DEFAULT_RPC_RESPONSE_CLOUD_EVENT_EVENT_TYPE,
RPC_COMMAND_PROTOCOL_VERSION, StatusCode,
},
supported_protocol_major_versions_to_string,
};
const DEFAULT_MESSAGE_EXPIRY_INTERVAL_SECONDS: u32 = 10;
const CACHE_EXPIRY_BUFFER_SECONDS: u64 = 60;
const INTERNAL_LOGIC_EXPIRATION_ERROR: &str =
"Internal logic error, unable to calculate command expiration time";
const SUPPORTED_PROTOCOL_VERSIONS: &[u16] = &[1];
struct ResponseArguments {
command_name: String,
response_topic: TopicName,
correlation_data: Option<Bytes>,
status_code: StatusCode,
status_message: Option<String>,
is_application_error: bool,
invalid_property_name: Option<String>,
invalid_property_value: Option<String>,
command_expiration_time: Option<Instant>,
message_expiry_interval: Option<u32>,
supported_protocol_major_versions: Option<Vec<u16>>,
request_protocol_version: Option<String>,
cached_key: Option<CacheKey>,
cache_lookup_result: CacheLookupResult,
}
pub struct Request<TReq, TResp>
where
TReq: PayloadSerialize,
TResp: PayloadSerialize,
{
pub payload: TReq,
pub content_type: Option<String>,
pub format_indicator: FormatIndicator,
pub custom_user_data: Vec<(String, String)>,
pub timestamp: Option<HybridLogicalClock>,
pub invoker_id: Option<String>,
pub topic_tokens: HashMap<String, String>,
command_name: String,
response_tx: oneshot::Sender<Response<TResp>>,
publish_completion_rx: oneshot::Receiver<Result<(), AIOProtocolError>>,
}
impl<TReq, TResp> Request<TReq, TResp>
where
TReq: PayloadSerialize,
TResp: PayloadSerialize,
{
pub async fn complete(self, response: Response<TResp>) -> Result<(), AIOProtocolError> {
let _ = self.response_tx.send(response);
self.publish_completion_rx
.await
.map_err(|_| Self::create_cancellation_error(self.command_name))?
}
fn create_cancellation_error(command_name: String) -> AIOProtocolError {
AIOProtocolError::new_cancellation_error(
false,
None,
Some(
"Command Executor has been shutdown and can no longer respond to commands"
.to_string(),
),
Some(command_name),
)
}
pub fn is_cancelled(&self) -> bool {
self.response_tx.is_closed()
}
}
pub type RequestCloudEvent = aio_cloud_event::CloudEvent;
pub type RequestCloudEventParseError = aio_cloud_event::CloudEventParseError;
pub fn cloud_event_from_request<TReq: PayloadSerialize, TResp: PayloadSerialize>(
request: &Request<TReq, TResp>,
) -> Result<RequestCloudEvent, RequestCloudEventParseError> {
RequestCloudEvent::try_from((&request.custom_user_data, request.content_type.as_deref()))
}
#[derive(Builder, Clone, Debug)]
#[builder(setter(into), build_fn(validate = "Self::validate"))]
pub struct Response<TResp>
where
TResp: PayloadSerialize,
{
#[builder(setter(custom))]
serialized_payload: SerializedPayload,
#[builder(private)]
payload_type: PhantomData<TResp>,
#[builder(default)]
custom_user_data: Vec<(String, String)>,
#[builder(default = "None")]
cloud_event: Option<ResponseCloudEvent>,
}
#[derive(Clone, Debug)]
pub struct ResponseCloudEvent(protocol_cloud_event::CloudEvent);
#[derive(Clone)]
pub struct ResponseCloudEventBuilder(protocol_cloud_event::CloudEventBuilder);
#[derive(Debug)]
#[non_exhaustive]
pub enum ResponseCloudEventBuilderError {
UninitializedField(&'static str),
ValidationError(String),
}
impl std::error::Error for ResponseCloudEventBuilderError {}
impl std::fmt::Display for ResponseCloudEventBuilderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResponseCloudEventBuilderError::UninitializedField(field_name) => {
write!(f, "Uninitialized field: {field_name}")
}
ResponseCloudEventBuilderError::ValidationError(err_msg) => {
write!(f, "Validation error: {err_msg}")
}
}
}
}
impl From<protocol_cloud_event::CloudEventBuilderError> for ResponseCloudEventBuilderError {
fn from(value: protocol_cloud_event::CloudEventBuilderError) -> Self {
match value {
protocol_cloud_event::CloudEventBuilderError::UninitializedField(field_name) => {
ResponseCloudEventBuilderError::UninitializedField(field_name)
}
protocol_cloud_event::CloudEventBuilderError::ValidationError(err_msg) => {
ResponseCloudEventBuilderError::ValidationError(err_msg)
}
}
}
}
impl Default for ResponseCloudEventBuilder {
fn default() -> Self {
Self(protocol_cloud_event::CloudEventBuilder::new(
DEFAULT_RPC_RESPONSE_CLOUD_EVENT_EVENT_TYPE.to_string(),
))
}
}
impl ResponseCloudEventBuilder {
pub fn build(&self) -> Result<ResponseCloudEvent, ResponseCloudEventBuilderError> {
Ok(ResponseCloudEvent(
protocol_cloud_event::CloudEventBuilder::build(&self.0)?,
))
}
pub fn source<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.source(value);
self
}
pub fn spec_version<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.spec_version(value);
self
}
pub fn event_type<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.event_type(value);
self
}
pub fn data_schema<VALUE: Into<Option<String>>>(&mut self, value: VALUE) -> &mut Self {
self.0.data_schema(value);
self
}
pub fn id<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.id(value);
self
}
pub fn time<VALUE: Into<Option<DateTime<Utc>>>>(&mut self, value: VALUE) -> &mut Self {
self.0.time(value);
self
}
pub fn subject<VALUE: Into<protocol_cloud_event::CloudEventSubject>>(
&mut self,
value: VALUE,
) -> &mut Self {
self.0.subject(value);
self
}
}
impl<TResp: PayloadSerialize> ResponseBuilder<TResp> {
pub fn payload(&mut self, payload: TResp) -> Result<&mut Self, AIOProtocolError> {
match payload.serialize() {
Err(e) => Err(AIOProtocolError::new_payload_invalid_error(
true,
false,
Some(e.into()),
Some("Payload serialization error".to_string()),
None,
)),
Ok(serialized_payload) => {
if is_invalid_utf8(&serialized_payload.content_type) {
return Err(AIOProtocolError::new_configuration_invalid_error(
None,
"content_type",
Value::String(serialized_payload.content_type.clone()),
Some(format!(
"Content type '{}' of command response is not valid UTF-8",
serialized_payload.content_type
)),
None,
));
}
self.serialized_payload = Some(serialized_payload);
self.payload_type = Some(PhantomData);
Ok(self)
}
}
}
fn validate(&self) -> Result<(), String> {
if let Some(custom_user_data) = &self.custom_user_data {
for (key, _) in custom_user_data {
if aio_cloud_event::CloudEventFields::from_str(key).is_ok() {
return Err(format!(
"Invalid user data property '{key}' is a reserved Cloud Event key"
));
}
}
validate_user_properties(custom_user_data)?;
}
if let Some(Some(cloud_event)) = &self.cloud_event
&& let Some(serialized_payload) = &self.serialized_payload
{
aio_cloud_event::CloudEventFields::DataContentType.validate(
&serialized_payload.content_type,
&cloud_event.0.spec_version,
)?;
}
Ok(())
}
}
pub fn application_error_headers(
custom_user_data: &mut Vec<(String, String)>,
application_error_code: String,
application_error_payload: String,
) -> Result<(), String> {
const APPLICATION_ERROR_CODE_HEADER: &str = "AppErrCode";
const APPLICATION_ERROR_PAYLOAD_HEADER: &str = "AppErrPayload";
if application_error_code.trim().is_empty() {
return Err("application_error_code cannot be empty".into());
}
custom_user_data.push((APPLICATION_ERROR_CODE_HEADER.into(), application_error_code));
if !application_error_payload.trim().is_empty() {
custom_user_data.push((
APPLICATION_ERROR_PAYLOAD_HEADER.into(),
application_error_payload,
));
}
Ok(())
}
#[derive(Eq, Hash, PartialEq, Clone)]
struct CacheKey {
response_topic: TopicName,
correlation_data: Bytes,
}
#[derive(Clone, Debug)]
#[allow(clippy::large_enum_variant)]
enum CacheEntry {
Cached {
serialized_payload: SerializedPayload,
properties: PublishProperties,
expiration_time: Instant,
},
InProgress {
processing_cancellation_token: CancellationToken,
},
}
#[derive(Debug)]
enum CacheLookupResult {
Cached {
serialized_payload: SerializedPayload,
properties: PublishProperties,
response_message_expiry_interval: u32,
},
InProgress(CancellationToken),
NotFound,
}
#[derive(Clone)]
struct Cache(Arc<Mutex<HashMap<CacheKey, CacheEntry>>>);
impl Cache {
fn get(&self, key: &CacheKey) -> CacheLookupResult {
let cache = self.0.lock().unwrap();
match cache.get(key) {
Some(entry) => {
match entry {
CacheEntry::Cached {
serialized_payload,
properties,
expiration_time,
} => {
let response_message_expiry_interval =
get_response_message_expiry_interval(*expiration_time);
if let Some(response_message_expiry_interval) =
response_message_expiry_interval
{
CacheLookupResult::Cached {
serialized_payload: serialized_payload.clone(),
properties: properties.clone(),
response_message_expiry_interval,
}
} else {
CacheLookupResult::NotFound
}
}
CacheEntry::InProgress {
processing_cancellation_token,
} => {
CacheLookupResult::InProgress(processing_cancellation_token.clone())
}
}
}
None => CacheLookupResult::NotFound,
}
}
fn set(&self, key: CacheKey, entry: CacheEntry) {
let mut cache = self.0.lock().unwrap();
cache.retain(|_, entry| {
match entry {
CacheEntry::Cached {
expiration_time, ..
} => {
expiration_time.elapsed().is_zero()
}
CacheEntry::InProgress {
processing_cancellation_token,
} => {
!processing_cancellation_token.is_cancelled()
}
}
});
cache.insert(key, entry);
}
}
#[allow(unused)]
#[derive(Builder, Clone)]
#[builder(setter(into, strip_option))]
pub struct Options {
request_topic_pattern: String,
command_name: String,
#[builder(default = "None")]
topic_namespace: Option<String>,
#[builder(default)]
topic_token_map: HashMap<String, String>,
#[builder(default = "false")]
is_idempotent: bool,
#[builder(default = "None")]
service_group_id: Option<String>,
}
#[allow(unused)]
pub struct Executor<TReq, TResp>
where
TReq: PayloadSerialize + Send + 'static,
TResp: PayloadSerialize + Send + 'static,
{
application_hlc: Arc<ApplicationHybridLogicalClock>,
mqtt_client: SessionManagedClient,
mqtt_receiver: SessionPubReceiver,
is_idempotent: bool,
request_topic_pattern: TopicPattern,
request_topic_filter: TopicFilter,
command_name: String,
request_payload_type: PhantomData<TReq>,
response_payload_type: PhantomData<TResp>,
cache: Cache,
state: State,
cancellation_token: CancellationToken,
}
#[derive(PartialEq)]
enum State {
New,
Subscribed,
ShutdownSuccessful,
}
impl<TReq, TResp> Executor<TReq, TResp>
where
TReq: PayloadSerialize + Send + 'static,
TResp: PayloadSerialize + Send + 'static,
{
pub fn new(
application_context: ApplicationContext,
client: SessionManagedClient,
executor_options: Options,
) -> Result<Self, AIOProtocolError> {
if executor_options.command_name.is_empty()
|| contains_invalid_char(&executor_options.command_name)
{
return Err(AIOProtocolError::new_configuration_invalid_error(
None,
"command_name",
Value::String(executor_options.command_name.clone()),
None,
Some(executor_options.command_name),
));
}
let request_topic_pattern = TopicPattern::new(
&executor_options.request_topic_pattern,
executor_options.service_group_id,
executor_options.topic_namespace.as_deref(),
&executor_options.topic_token_map,
)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"executor_options.request_topic_pattern",
)
})?;
let request_topic_filter = request_topic_pattern.as_subscribe_topic().map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"executor_options.request_topic_pattern",
)
})?;
let mqtt_receiver = client.create_filtered_pub_receiver(request_topic_filter.clone());
Ok(Executor {
application_hlc: application_context.application_hlc,
mqtt_client: client,
mqtt_receiver,
is_idempotent: executor_options.is_idempotent,
request_topic_pattern,
request_topic_filter,
command_name: executor_options.command_name,
request_payload_type: PhantomData,
response_payload_type: PhantomData,
cache: Cache(Arc::new(Mutex::new(HashMap::new()))),
state: State::New,
cancellation_token: CancellationToken::new(),
})
}
pub async fn shutdown(&mut self) -> Result<(), AIOProtocolError> {
self.mqtt_receiver.close();
match self.state {
State::New | State::ShutdownSuccessful => {
self.state = State::ShutdownSuccessful;
}
State::Subscribed => {
let unsubscribe_result = self
.mqtt_client
.unsubscribe(
self.request_topic_filter.clone(),
azure_iot_operations_mqtt::control_packet::UnsubscribeProperties::default(),
)
.await;
match unsubscribe_result {
Ok(unsub_ct) => match unsub_ct.await {
Ok(unsuback) => match unsuback.as_result() {
Ok(()) => {
self.state = State::ShutdownSuccessful;
}
Err(e) => {
log::error!(
"[{}] Executor nsuback error: {unsuback:?}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on command executor unsuback".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
},
Err(e) => {
log::error!(
"[{}] Executor unsubscribe completion error: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on command executor unsubscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
},
Err(e) => {
log::error!(
"[{}] Client error while unsubscribing in Executor: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("Client error on command executor unsubscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
}
}
}
log::info!("[{}] Executor Shutdown", self.command_name);
Ok(())
}
async fn try_subscribe(&mut self) -> Result<(), AIOProtocolError> {
let subscribe_result = self
.mqtt_client
.subscribe(
self.request_topic_filter.clone(),
QoS::AtLeastOnce,
false,
azure_iot_operations_mqtt::control_packet::RetainOptions::default(),
azure_iot_operations_mqtt::control_packet::SubscribeProperties::default(),
)
.await;
match subscribe_result {
Ok(sub_ct) => match sub_ct.await {
Ok(suback) => {
suback.as_result().map_err(|e| {
log::error!("[{}] Executor suback error: {suback:?}", self.command_name);
AIOProtocolError::new_mqtt_error(
Some("MQTT error on command executor suback".to_string()),
Box::new(e),
Some(self.command_name.clone()),
)
})?;
}
Err(e) => {
log::error!(
"[{}] Executor subscribe completion error: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on command executor subscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
},
Err(e) => {
log::error!(
"[{}] Client error while subscribing in Executor: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("Client error on command executor subscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
}
Ok(())
}
pub async fn recv(&mut self) -> Option<Result<Request<TReq, TResp>, AIOProtocolError>> {
if State::New == self.state {
if let Err(e) = self.try_subscribe().await {
return Some(Err(e));
}
self.state = State::Subscribed;
}
loop {
match self.mqtt_receiver.recv_manual_ack().await {
Some((m, ack_token)) => {
let Some(ack_token) = ack_token else {
log::warn!(
"[{}] Received command request without ack token",
self.command_name
);
continue;
};
let pkid = match m.qos {
azure_iot_operations_mqtt::control_packet::DeliveryQoS::AtMostOnce
| azure_iot_operations_mqtt::control_packet::DeliveryQoS::ExactlyOnce(_) => {
log::warn!(
"[{}] Received non QoS 1 command request",
self.command_name
);
continue;
}
azure_iot_operations_mqtt::control_packet::DeliveryQoS::AtLeastOnce(
delivery_info,
) => delivery_info.packet_identifier.get(),
};
log::debug!("[{}][pkid: {}] Received request", self.command_name, pkid);
let message_received_time = Instant::now();
let processing_cancellation_token = CancellationToken::new();
let processing_cancellation_token_clone = processing_cancellation_token.clone();
let processing_drop_guard = processing_cancellation_token.drop_guard();
let properties = m.properties;
let response_topic = if let Some(rt) = properties.response_topic {
if !is_valid_replacement(rt.as_str()) {
log::warn!(
"[{}][pkid: {}] Response topic invalid, command response will not be published",
self.command_name,
pkid
);
tokio::task::spawn({
let executor_cancellation_token_clone =
self.cancellation_token.clone();
async move {
handle_ack(ack_token, executor_cancellation_token_clone, pkid)
.await;
}
});
continue;
}
rt
} else {
log::warn!(
"[{}][pkid: {}] Response topic missing, command response will not be published",
self.command_name,
pkid
);
tokio::task::spawn({
let executor_cancellation_token_clone = self.cancellation_token.clone();
async move {
handle_ack(ack_token, executor_cancellation_token_clone, pkid)
.await;
}
});
continue;
};
let mut command_expiration_time_calculated = false;
let mut response_arguments = ResponseArguments {
command_name: self.command_name.clone(),
response_topic,
correlation_data: None,
status_code: StatusCode::Ok,
status_message: None,
is_application_error: false,
invalid_property_name: None,
invalid_property_value: None,
message_expiry_interval: None,
command_expiration_time: None,
supported_protocol_major_versions: None,
request_protocol_version: None,
cached_key: None,
cache_lookup_result: CacheLookupResult::NotFound,
};
let command_expiration_time = match properties.message_expiry_interval {
Some(ct) => {
response_arguments.message_expiry_interval = Some(ct);
message_received_time.checked_add(Duration::from_secs(ct.into()))
}
_ => message_received_time.checked_add(Duration::from_secs(u64::from(
DEFAULT_MESSAGE_EXPIRY_INTERVAL_SECONDS,
))),
};
if let Some(command_expiration_time) = command_expiration_time {
response_arguments.command_expiration_time = Some(command_expiration_time);
command_expiration_time_calculated = true;
}
if let Some(correlation_data) = properties.correlation_data {
if correlation_data.len() == 16 {
response_arguments.correlation_data = Some(correlation_data.clone());
response_arguments.cached_key = Some(CacheKey {
response_topic: response_arguments.response_topic.clone(),
correlation_data,
});
} else {
response_arguments.status_code = StatusCode::BadRequest;
response_arguments.status_message =
Some("Correlation data bytes do not conform to a GUID".to_string());
response_arguments.invalid_property_name =
Some("Correlation Data".to_string());
if let Ok(correlation_data_str) =
String::from_utf8(correlation_data.to_vec())
{
response_arguments.invalid_property_value =
Some(correlation_data_str);
} else {
}
response_arguments.correlation_data = Some(correlation_data);
}
} else {
response_arguments.status_code = StatusCode::BadRequest;
response_arguments.status_message =
Some("Correlation data missing".to_string());
response_arguments.invalid_property_name =
Some("Correlation Data".to_string());
}
'process_request: {
let Some(cache_key) = &response_arguments.cached_key else {
break 'process_request;
};
let Some(command_expiration_time) = command_expiration_time else {
response_arguments.status_code = StatusCode::InternalServerError;
response_arguments.status_message =
Some(INTERNAL_LOGIC_EXPIRATION_ERROR.to_string());
break 'process_request;
};
if properties.message_expiry_interval.is_none() {
response_arguments.status_code = StatusCode::BadRequest;
response_arguments.status_message =
Some("Message expiry interval missing".to_string());
response_arguments.invalid_property_name =
Some("Message Expiry".to_string());
break 'process_request;
}
response_arguments.cache_lookup_result = self.cache.get(cache_key);
if !matches!(
response_arguments.cache_lookup_result,
CacheLookupResult::NotFound
) {
break 'process_request;
}
self.cache.set(
cache_key.clone(),
CacheEntry::InProgress {
processing_cancellation_token: processing_cancellation_token_clone,
},
);
let mut request_protocol_version = DEFAULT_RPC_COMMAND_PROTOCOL_VERSION; if let Some((_, protocol_version)) =
properties.user_properties.iter().find(|(key, _)| {
UserProperty::from_str(key) == Ok(UserProperty::ProtocolVersion)
})
{
if let Some(request_version) =
ProtocolVersion::parse_protocol_version(protocol_version)
{
request_protocol_version = request_version;
} else {
response_arguments.status_code = StatusCode::VersionNotSupported;
response_arguments.status_message = Some(format!(
"Unparsable protocol version value provided: {protocol_version}."
));
response_arguments.supported_protocol_major_versions =
Some(SUPPORTED_PROTOCOL_VERSIONS.to_vec());
response_arguments.request_protocol_version =
Some(protocol_version.clone());
break 'process_request;
}
}
if !request_protocol_version.is_supported(SUPPORTED_PROTOCOL_VERSIONS) {
response_arguments.status_code = StatusCode::VersionNotSupported;
response_arguments.status_message = Some(format!(
"The command executor that received the request only supports major protocol versions '{SUPPORTED_PROTOCOL_VERSIONS:?}', but '{request_protocol_version}' was sent on the request."
));
response_arguments.supported_protocol_major_versions =
Some(SUPPORTED_PROTOCOL_VERSIONS.to_vec());
response_arguments.request_protocol_version =
Some(request_protocol_version.to_string());
break 'process_request;
}
let mut user_data = Vec::new();
let mut timestamp = None;
let mut invoker_id = None;
for (key, value) in properties.user_properties {
match UserProperty::from_str(&key) {
Ok(UserProperty::Timestamp) => {
match HybridLogicalClock::from_str(&value) {
Ok(ts) => {
if let Err(e) = self.application_hlc.update(&ts) {
response_arguments.status_message = Some(format!(
"Failure updating application HLC against {value}: {e}"
));
response_arguments.invalid_property_name =
Some(UserProperty::Timestamp.to_string());
response_arguments.invalid_property_value =
Some(value);
match e.kind() {
HLCErrorKind::ClockDrift => {
response_arguments.status_code =
StatusCode::ServiceUnavailable;
}
HLCErrorKind::OverflowWarning => {
response_arguments.status_code =
StatusCode::InternalServerError;
}
}
break 'process_request;
}
timestamp = Some(ts);
}
Err(e) => {
response_arguments.status_code = StatusCode::BadRequest;
response_arguments.status_message =
Some(format!("Timestamp invalid: {e}"));
response_arguments.invalid_property_name =
Some(UserProperty::Timestamp.to_string());
response_arguments.invalid_property_value = Some(value);
break 'process_request;
}
}
}
Ok(UserProperty::SourceId) => {
invoker_id = Some(value);
}
Ok(UserProperty::ProtocolVersion) => {
}
Err(()) => {
if key == PARTITION_KEY {
continue;
}
user_data.push((key, value));
}
_ => {
log::warn!(
"[{}] Command request should not contain MQTT user property {key}. Value is {value}",
self.command_name
);
user_data.push((key, value));
}
}
}
let topic_tokens = self
.request_topic_pattern
.parse_tokens(m.topic_name.as_str());
let format_indicator = properties.payload_format_indicator.into();
let payload = match TReq::deserialize(
&m.payload,
properties.content_type.as_ref(),
&format_indicator,
) {
Ok(payload) => payload,
Err(e) => match e {
DeserializationError::InvalidPayload(deserialization_e) => {
response_arguments.status_code = StatusCode::BadRequest;
response_arguments.status_message = Some(format!(
"Error deserializing payload: {deserialization_e:?}"
));
break 'process_request;
}
DeserializationError::UnsupportedContentType(message) => {
response_arguments.status_code =
StatusCode::UnsupportedMediaType;
response_arguments.status_message = Some(message);
response_arguments.invalid_property_name =
Some("Content Type".to_string());
response_arguments.invalid_property_value =
Some(properties.content_type.unwrap_or("None".to_string()));
break 'process_request;
}
},
};
let (response_tx, response_rx) = oneshot::channel();
let (publish_completion_tx, publish_completion_rx) = oneshot::channel();
let command_request = Request {
payload,
content_type: properties.content_type,
format_indicator,
custom_user_data: user_data,
timestamp,
invoker_id,
topic_tokens,
command_name: self.command_name.clone(),
response_tx,
publish_completion_rx,
};
if command_expiration_time.elapsed().is_zero() {
tokio::task::spawn({
let app_hlc_clone = self.application_hlc.clone();
let client_clone = self.mqtt_client.clone();
let cache_clone = self.cache.clone();
let executor_cancellation_token_clone =
self.cancellation_token.clone();
async move {
tokio::select! {
() = executor_cancellation_token_clone.cancelled() => { },
() = Self::process_command(
app_hlc_clone,
client_clone,
pkid,
response_arguments,
(Some(response_rx), Some(publish_completion_tx)),
cache_clone,
processing_drop_guard,
) => {
handle_ack(ack_token, executor_cancellation_token_clone, pkid).await;
},
}
}
});
return Some(Ok(command_request));
}
}
let Some(command_expiration_time) = command_expiration_time else {
continue;
};
match response_arguments.cache_lookup_result {
CacheLookupResult::Cached {
serialized_payload,
properties,
response_message_expiry_interval,
} => {
tokio::task::spawn({
let client_clone = self.mqtt_client.clone();
let executor_cancellation_token_clone =
self.cancellation_token.clone();
async move {
tokio::select! {
() = executor_cancellation_token_clone.cancelled() => { },
() = Self::process_duplicate_command(
client_clone,
response_arguments.response_topic,
serialized_payload,
properties,
response_message_expiry_interval,
response_arguments.command_name,
pkid,
) => {
handle_ack(ack_token, executor_cancellation_token_clone, pkid).await;
},
}
}
});
}
CacheLookupResult::InProgress(cancellation_token) => {
tokio::task::spawn(handle_in_progress_duplicate_ack(
ack_token,
cancellation_token.clone(),
self.cancellation_token.clone(),
pkid,
));
}
CacheLookupResult::NotFound => {
if command_expiration_time.elapsed().is_zero() {
tokio::task::spawn({
let app_hlc_clone = self.application_hlc.clone();
let client_clone = self.mqtt_client.clone();
let cache_clone = self.cache.clone();
let executor_cancellation_token_clone =
self.cancellation_token.clone();
async move {
tokio::select! {
() = executor_cancellation_token_clone.cancelled() => { },
() = Self::process_command(
app_hlc_clone,
client_clone,
pkid,
response_arguments,
(None, None),
cache_clone,
processing_drop_guard,
) => {
handle_ack(ack_token, executor_cancellation_token_clone, pkid).await;
},
}
}
});
}
}
}
if !command_expiration_time_calculated {
return Some(Err(AIOProtocolError::new_internal_logic_error(
true,
false,
None,
"command_expiration_time",
None,
Some(INTERNAL_LOGIC_EXPIRATION_ERROR.to_string()),
Some(self.command_name.clone()),
)));
}
}
_ => {
return None;
}
}
}
}
async fn process_duplicate_command(
client: SessionManagedClient,
response_topic: TopicName,
serialized_payload: SerializedPayload,
mut publish_properties: PublishProperties,
response_message_expiry_interval: u32,
command_name: String,
pkid: u16,
) {
log::debug!(
"[{command_name}][pkid: {pkid}] Duplicate request, responding with cached response"
);
publish_properties.message_expiry_interval = Some(response_message_expiry_interval);
match client
.publish_qos1(
response_topic,
false,
serialized_payload.payload,
publish_properties,
)
.await
{
Ok(publish_completion_token) => {
match publish_completion_token.await {
Ok(puback) => {
if !puback.is_success() {
log::warn!(
"[{command_name}][pkid: {pkid}] Puback reported failure for cached command response: {puback:?}"
);
}
}
Err(e) => {
log::warn!(
"[{command_name}][pkid: {pkid}] Publish completion error for cached command response: {e}"
);
}
}
}
Err(e) => {
log::warn!(
"[{command_name}][pkid: {pkid}] Client error on cached command response publish: {e}"
);
}
}
}
#[allow(clippy::type_complexity)]
async fn process_command(
application_hlc: Arc<ApplicationHybridLogicalClock>,
client: SessionManagedClient,
pkid: u16,
mut response_arguments: ResponseArguments,
application_channels: (
Option<oneshot::Receiver<Response<TResp>>>,
Option<oneshot::Sender<Result<(), AIOProtocolError>>>,
), cache: Cache,
_processing_drop_guard: DropGuard,
) {
let (response_rx, completion_tx) = application_channels;
let mut serialized_payload = SerializedPayload::default();
let mut publish_properties = PublishProperties::default();
let mut user_properties: Vec<(String, String)> = Vec::new();
'process_response: {
let Some(command_expiration_time) = response_arguments.command_expiration_time else {
break 'process_response;
};
if let Some(response_rx) = response_rx {
let response = if let Ok(response_timer) = timeout(
command_expiration_time.duration_since(Instant::now()),
response_rx,
)
.await
{
if let Ok(response_app) = response_timer {
response_app
} else {
response_arguments.status_code = StatusCode::InternalServerError;
response_arguments.status_message =
Some("Request has been dropped by the application".to_string());
response_arguments.is_application_error = true;
break 'process_response;
}
} else {
log::warn!(
"[{}][pkid: {}] Command request timed out",
response_arguments.command_name,
pkid
);
if let Some(completion_tx) = completion_tx {
let _ = completion_tx.send(Err(AIOProtocolError::new_timeout_error(
false,
None,
&response_arguments.command_name,
Duration::from_secs(
response_arguments
.message_expiry_interval
.unwrap_or_default()
.into(),
),
None,
Some(response_arguments.command_name.clone()),
)));
}
return;
};
user_properties = response.custom_user_data;
if let Some(cloud_event) = response.cloud_event {
let cloud_event_headers = cloud_event
.0
.into_headers(response_arguments.response_topic.as_str());
for (key, value) in cloud_event_headers {
user_properties.push((key, value));
}
}
serialized_payload = response.serialized_payload;
if serialized_payload.payload.is_empty() {
response_arguments.status_code = StatusCode::NoContent;
}
} else {
}
}
if response_arguments.status_code != StatusCode::Ok
|| response_arguments.status_code != StatusCode::NoContent
{
user_properties.push((
UserProperty::IsApplicationError.to_string(),
response_arguments.is_application_error.to_string(),
));
}
user_properties.push((
UserProperty::Status.to_string(),
(response_arguments.status_code as u16).to_string(),
));
user_properties.push((
UserProperty::ProtocolVersion.to_string(),
RPC_COMMAND_PROTOCOL_VERSION.to_string(),
));
user_properties.push((
UserProperty::SourceId.to_string(),
client.client_id().to_string(),
));
if let Ok(timestamp_str) = application_hlc.update_now() {
user_properties.push((UserProperty::Timestamp.to_string(), timestamp_str));
}
if let Some(status_message) = response_arguments.status_message {
log::warn!(
"[{}][pkid: {}] sending error reponse to invoker: {}",
response_arguments.command_name,
pkid,
status_message
);
user_properties.push((UserProperty::StatusMessage.to_string(), status_message));
}
if let Some(name) = response_arguments.invalid_property_name {
user_properties.push((UserProperty::InvalidPropertyName.to_string(), name));
}
if let Some(value) = response_arguments.invalid_property_value {
user_properties.push((UserProperty::InvalidPropertyValue.to_string(), value));
}
if let Some(supported_protocol_major_versions) =
response_arguments.supported_protocol_major_versions
{
user_properties.push((
UserProperty::SupportedMajorVersions.to_string(),
supported_protocol_major_versions_to_string(&supported_protocol_major_versions),
));
}
if let Some(request_protocol_version) = response_arguments.request_protocol_version {
user_properties.push((
UserProperty::RequestProtocolVersion.to_string(),
request_protocol_version,
));
}
publish_properties.payload_format_indicator = serialized_payload.format_indicator.into();
publish_properties.topic_alias = None;
publish_properties.response_topic = None;
publish_properties.correlation_data = response_arguments.correlation_data;
publish_properties.user_properties = user_properties;
publish_properties.subscription_identifiers = Vec::new();
publish_properties.content_type = Some(serialized_payload.content_type.clone());
match response_arguments.command_expiration_time {
Some(command_expiration_time) => {
let response_message_expiry_interval =
get_response_message_expiry_interval(command_expiration_time);
if let Some(response_message_expiry_interval) = response_message_expiry_interval {
publish_properties.message_expiry_interval =
Some(response_message_expiry_interval);
} else {
log::warn!(
"[{}][pkid: {}] Command request timed out",
response_arguments.command_name,
pkid
);
if let Some(completion_tx) = completion_tx {
let _ = completion_tx.send(Err(AIOProtocolError::new_timeout_error(
false,
None,
&response_arguments.command_name,
Duration::from_secs(
response_arguments
.message_expiry_interval
.unwrap_or_default()
.into(),
),
None,
Some(response_arguments.command_name.clone()),
)));
}
return;
}
if let Some(cached_key) = response_arguments.cached_key {
let cache_entry = CacheEntry::Cached {
serialized_payload: serialized_payload.clone(),
properties: publish_properties.clone(),
expiration_time: command_expiration_time
+ Duration::from_secs(CACHE_EXPIRY_BUFFER_SECONDS),
};
log::debug!(
"[{}][pkid: {}] Caching response",
response_arguments.command_name,
pkid
);
cache.set(cached_key, cache_entry);
}
}
_ => {
publish_properties.message_expiry_interval =
Some(DEFAULT_MESSAGE_EXPIRY_INTERVAL_SECONDS);
}
}
match client
.publish_qos1(
response_arguments.response_topic,
false,
serialized_payload.payload,
publish_properties,
)
.await
{
Ok(publish_completion_token) => {
match publish_completion_token.await {
Ok(puback) => {
match puback.as_result() {
Ok(()) => {
if let Some(completion_tx) = completion_tx {
let _ = completion_tx.send(Ok(()));
}
}
Err(e) => {
log::error!(
"[{}][pkid: {}] Command response Puback error: {puback:?}",
response_arguments.command_name,
pkid
);
if let Some(completion_tx) = completion_tx {
let _ =
completion_tx.send(Err(AIOProtocolError::new_mqtt_error(
Some(
"MQTT error on command executor response puback"
.to_string(),
),
Box::new(e),
Some(response_arguments.command_name.clone()),
)));
}
}
}
}
Err(e) => {
log::error!(
"[{}][pkid: {}] Command response Publish completion error: {e}",
response_arguments.command_name,
pkid
);
if let Some(completion_tx) = completion_tx {
let _ = completion_tx.send(Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on command executor response publish".to_string()),
Box::new(e),
Some(response_arguments.command_name.clone()),
)));
}
}
}
}
Err(e) => {
log::error!(
"[{}][pkid: {}] Client error on command executor response publish: {e}",
response_arguments.command_name,
pkid
);
if let Some(completion_tx) = completion_tx {
let _ = completion_tx.send(Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on command executor response publish".to_string()),
Box::new(e),
Some(response_arguments.command_name.clone()),
)));
}
}
}
}
}
impl<TReq, TResp> Drop for Executor<TReq, TResp>
where
TReq: PayloadSerialize + Send + 'static,
TResp: PayloadSerialize + Send + 'static,
{
fn drop(&mut self) {
self.cancellation_token.cancel();
self.mqtt_receiver.close();
if State::Subscribed == self.state {
tokio::spawn({
let request_topic = self.request_topic_filter.clone();
let mqtt_client = self.mqtt_client.clone();
async move {
match mqtt_client
.unsubscribe(
request_topic.clone(),
azure_iot_operations_mqtt::control_packet::UnsubscribeProperties::default(),
)
.await
{
Ok(_) => {
log::debug!(
"Executor Unsubscribe sent on topic {request_topic}. Unsuback may still be pending."
);
}
Err(e) => {
log::warn!("Executor Unsubscribe error on topic {request_topic}: {e}");
}
}
}
});
}
log::info!("[{}] Command Executor has been dropped", self.command_name);
}
}
fn get_response_message_expiry_interval(command_expiration_time: Instant) -> Option<u32> {
let response_message_expiry_interval =
command_expiration_time.saturating_duration_since(Instant::now());
if response_message_expiry_interval.is_zero() {
None
} else {
let response_message_expiry_interval =
if response_message_expiry_interval.subsec_nanos() != 0 {
response_message_expiry_interval.as_secs().saturating_add(1)
} else {
response_message_expiry_interval.as_secs()
};
match response_message_expiry_interval.try_into() {
Ok(interval) => Some(interval),
Err(_) => unreachable!(), }
}
}
async fn handle_ack(
ack_token: AckToken,
executor_cancellation_token: CancellationToken,
pkid: u16,
) {
tokio::select! {
() = executor_cancellation_token.cancelled() => { },
ack_res = ack_token.ack() => {
match ack_res {
Ok(ack_ct) => {
match ack_ct.await {
Ok(()) => log::debug!("[pkid: {pkid}] Command Request Acknowledged"),
Err(e) => {
match e {
azure_iot_operations_mqtt::error::CompletionError::Detached => {
log::warn!("[pkid: {pkid}] Command Request Ack error: {e}");
},
azure_iot_operations_mqtt::error::CompletionError::Canceled(_) => {
log::warn!("[pkid: {pkid}] Command Request ack cancelled due to disconnect, request will be redelivered");
},
}
}
}
},
Err(e) => {
log::warn!("[pkid: {pkid}] Command Request Ack error: {e}");
}
}
}
}
}
async fn handle_in_progress_duplicate_ack(
ack_token: AckToken,
in_progress_cancellation_token: CancellationToken,
executor_cancellation_token: CancellationToken,
pkid: u16,
) {
tokio::select! {
() = executor_cancellation_token.cancelled() => { },
() = in_progress_cancellation_token.cancelled() => {
handle_ack(ack_token, executor_cancellation_token, pkid).await;
}
}
}
#[cfg(test)]
mod tests {
use azure_iot_operations_mqtt::session::{Session, SessionOptionsBuilder};
use test_case::test_case;
use azure_iot_operations_mqtt::aio::connection_settings::MqttConnectionSettingsBuilder;
use super::*;
use crate::application::ApplicationContextBuilder;
use crate::common::{aio_protocol_error::AIOProtocolErrorKind, payload_serialize::MockPayload};
fn create_session() -> Session {
let connection_settings = MqttConnectionSettingsBuilder::default()
.hostname("localhost")
.client_id("test_server")
.build()
.unwrap();
let session_options = SessionOptionsBuilder::default()
.connection_settings(connection_settings)
.build()
.unwrap();
Session::new(session_options).unwrap()
}
fn create_topic_tokens() -> HashMap<String, String> {
HashMap::from([
("executorId".to_string(), "test_executor_id".to_string()),
("commandName".to_string(), "test_command_name".to_string()),
])
}
#[tokio::test]
async fn test_new_defaults() {
let session = create_session();
let managed_client = session.create_managed_client();
let executor_options = OptionsBuilder::default()
.request_topic_pattern("test/{commandName}/{executorId}/request")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let executor: Executor<MockPayload, MockPayload> = Executor::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
executor_options,
)
.unwrap();
assert_eq!(
executor
.request_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
"test/test_command_name/test_executor_id/request"
);
assert!(!executor.is_idempotent);
}
#[tokio::test]
async fn test_new_override_defaults() {
let session = create_session();
let managed_client = session.create_managed_client();
let executor_options = OptionsBuilder::default()
.request_topic_pattern("test/{commandName}/{executorId}/request")
.command_name("test_command_name")
.topic_namespace("test_namespace")
.topic_token_map(create_topic_tokens())
.is_idempotent(true)
.build()
.unwrap();
let executor: Executor<MockPayload, MockPayload> = Executor::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
executor_options,
)
.unwrap();
assert_eq!(
executor
.request_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
"test_namespace/test/test_command_name/test_executor_id/request"
);
assert!(executor.is_idempotent);
}
#[test_case(""; "empty command name")]
#[test_case(" "; "whitespace command name")]
#[tokio::test]
async fn test_new_empty_and_whitespace_command_name(command_name: &str) {
let session = create_session();
let managed_client = session.create_managed_client();
let executor_options = OptionsBuilder::default()
.request_topic_pattern("test/{commandName}/request")
.command_name(command_name.to_string())
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let executor: Result<Executor<MockPayload, MockPayload>, AIOProtocolError> = Executor::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
executor_options,
);
match executor {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("command_name".to_string()));
assert!(e.property_value == Some(Value::String(command_name.to_string())));
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test_case(""; "empty request topic pattern")]
#[test_case(" "; "whitespace request topic pattern")]
#[test_case("test/{commandName}/\u{0}/request"; "invalid request topic pattern")]
#[tokio::test]
async fn test_invalid_request_topic_string(request_topic: &str) {
let session = create_session();
let managed_client = session.create_managed_client();
let executor_options = OptionsBuilder::default()
.request_topic_pattern(request_topic.to_string())
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let executor: Result<Executor<MockPayload, MockPayload>, AIOProtocolError> = Executor::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
executor_options,
);
match executor {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(
e.property_name,
Some("executor_options.request_topic_pattern".to_string())
);
assert!(e.property_value == Some(Value::String(request_topic.to_string())));
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test_case(""; "empty topic namespace")]
#[test_case(" "; "whitespace topic namespace")]
#[test_case("test/\u{0}"; "invalid topic namespace")]
#[tokio::test]
async fn test_invalid_topic_namespace(topic_namespace: &str) {
let session = create_session();
let managed_client = session.create_managed_client();
let executor_options = OptionsBuilder::default()
.request_topic_pattern("test/{commandName}/request")
.command_name("test_command_name")
.topic_namespace(topic_namespace.to_string())
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let executor: Result<Executor<MockPayload, MockPayload>, AIOProtocolError> = Executor::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
executor_options,
);
match executor {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("topic_namespace".to_string()));
assert!(e.property_value == Some(Value::String(topic_namespace.to_string())));
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[tokio::test]
async fn test_shutdown_without_subscribe() {
let session = create_session();
let executor_options = OptionsBuilder::default()
.request_topic_pattern("test/request")
.command_name("test_command_name")
.build()
.unwrap();
let mut executor: Executor<MockPayload, MockPayload> = Executor::new(
ApplicationContextBuilder::default().build().unwrap(),
session.create_managed_client(),
executor_options,
)
.unwrap();
assert!(executor.shutdown().await.is_ok());
}
#[test]
fn test_response_serialization_error() {
let mut mock_response_payload = MockPayload::new();
mock_response_payload
.expect_serialize()
.returning(|| Err("dummy error".to_string()))
.times(1);
let mut binding = ResponseBuilder::default();
let resp_builder = binding.payload(mock_response_payload);
match resp_builder {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::PayloadInvalid);
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test]
fn test_response_serialization_bad_content_type_error() {
let mut mock_response_payload = MockPayload::new();
mock_response_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json\u{0000}".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let mut binding = ResponseBuilder::default();
let resp_builder = binding.payload(mock_response_payload);
match resp_builder {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("content_type".to_string()));
assert!(
e.property_value == Some(Value::String("application/json\u{0000}".to_string()))
);
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test]
fn test_response_invalid_custom_user_data_cloud_event_header() {
let mut mock_response_payload = MockPayload::new();
mock_response_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let response_builder_result = ResponseBuilder::default()
.payload(mock_response_payload)
.unwrap()
.custom_user_data(vec![("source".to_string(), "test".to_string())])
.build();
assert!(response_builder_result.is_err());
}
#[test]
fn test_response_defaults() {
let mut mock_response_payload = MockPayload::new();
mock_response_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let response_builder_result = ResponseBuilder::default()
.payload(mock_response_payload)
.unwrap()
.build();
let r = response_builder_result.unwrap();
assert!(r.custom_user_data.is_empty());
assert!(r.cloud_event.is_none());
assert!(r.serialized_payload.payload.is_empty());
}
#[tokio::test]
async fn test_cache_not_found() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let status = cache.get(&key);
assert!(matches!(status, CacheLookupResult::NotFound));
}
#[test]
fn test_cache_found_complete() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let entered_serialized_payload = SerializedPayload {
payload: Bytes::from("test_payload").to_vec(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
};
let entry = CacheEntry::Cached {
serialized_payload: entered_serialized_payload.clone(),
properties: PublishProperties::default(),
expiration_time: Instant::now() + Duration::from_secs(60),
};
cache.set(key.clone(), entry.clone());
let status = cache.get(&key);
match status {
CacheLookupResult::Cached {
serialized_payload,
properties,
response_message_expiry_interval,
} => {
assert_eq!(serialized_payload, entered_serialized_payload);
assert_eq!(properties, PublishProperties::default());
let range = 1..=60;
assert!(range.contains(&response_message_expiry_interval));
}
_ => {
panic!("Expected cached entry");
}
}
}
#[test]
fn test_cache_found_in_progress() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let entry = CacheEntry::InProgress {
processing_cancellation_token: CancellationToken::new(),
};
cache.set(key.clone(), entry.clone());
match cache.get(&key) {
CacheLookupResult::InProgress(_) => { }
_ => {
panic!("Expected in progress entry");
}
}
}
#[test]
fn test_cache_expired_entry_not_found() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let entry = CacheEntry::Cached {
serialized_payload: SerializedPayload {
payload: Bytes::from("test_payload").to_vec(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
},
properties: PublishProperties::default(),
expiration_time: Instant::now() - Duration::from_secs(60),
};
cache.set(key.clone(), entry);
let status = cache.get(&key);
assert!(matches!(status, CacheLookupResult::NotFound));
let new_serialized_payload = SerializedPayload {
payload: Bytes::from("new_test_payload").to_vec(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
};
let new_entry = CacheEntry::Cached {
serialized_payload: new_serialized_payload.clone(),
properties: PublishProperties::default(),
expiration_time: Instant::now() + Duration::from_secs(60),
};
cache.set(key.clone(), new_entry.clone());
let new_status = cache.get(&key);
match new_status {
CacheLookupResult::Cached {
serialized_payload,
properties,
response_message_expiry_interval,
} => {
assert_eq!(serialized_payload, new_serialized_payload);
assert_eq!(properties, PublishProperties::default());
let range = 1..=60;
assert!(range.contains(&response_message_expiry_interval));
}
_ => {
panic!("Expected cached entry");
}
}
}
#[test]
fn test_cache_expired_entry_not_found_with_different_key_set() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let old_key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let old_entry = CacheEntry::Cached {
serialized_payload: SerializedPayload {
payload: Bytes::from("test_payload").to_vec(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
},
properties: PublishProperties::default(),
expiration_time: Instant::now() - Duration::from_secs(60),
};
cache.set(old_key.clone(), old_entry);
let status = cache.get(&old_key);
assert!(matches!(status, CacheLookupResult::NotFound));
let new_key = CacheKey {
response_topic: TopicName::new("new_test_response_topic").unwrap(),
correlation_data: Bytes::from("new_test_correlation_data"),
};
let new_serialized_payload = SerializedPayload {
payload: Bytes::from("new_test_payload").to_vec(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
};
let new_entry = CacheEntry::Cached {
serialized_payload: new_serialized_payload.clone(),
properties: PublishProperties::default(),
expiration_time: Instant::now() + Duration::from_secs(60),
};
cache.set(new_key.clone(), new_entry.clone());
let old_status = cache.get(&old_key);
assert!(matches!(old_status, CacheLookupResult::NotFound));
let new_status = cache.get(&new_key);
match new_status {
CacheLookupResult::Cached {
serialized_payload,
properties,
response_message_expiry_interval,
} => {
assert_eq!(serialized_payload, new_serialized_payload);
assert_eq!(properties, PublishProperties::default());
let range = 1..=60;
assert!(range.contains(&response_message_expiry_interval));
}
_ => {
panic!("Expected cached entry");
}
}
}
#[test]
fn test_cache_in_progress_found_with_different_key_set() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let old_key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let old_entry = CacheEntry::InProgress {
processing_cancellation_token: CancellationToken::new(),
};
cache.set(old_key.clone(), old_entry);
let status = cache.get(&old_key);
assert!(matches!(status, CacheLookupResult::InProgress(..)));
let new_key = CacheKey {
response_topic: TopicName::new("new_test_response_topic").unwrap(),
correlation_data: Bytes::from("new_test_correlation_data"),
};
let new_serialized_payload = SerializedPayload {
payload: Bytes::from("new_test_payload").to_vec(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
};
let new_entry = CacheEntry::Cached {
serialized_payload: new_serialized_payload.clone(),
properties: PublishProperties::default(),
expiration_time: Instant::now() + Duration::from_secs(60),
};
cache.set(new_key.clone(), new_entry.clone());
let old_status = cache.get(&old_key);
assert!(matches!(old_status, CacheLookupResult::InProgress(..)));
let new_status = cache.get(&new_key);
match new_status {
CacheLookupResult::Cached {
serialized_payload,
properties,
response_message_expiry_interval,
} => {
assert_eq!(serialized_payload, new_serialized_payload);
assert_eq!(properties, PublishProperties::default());
let range = 1..=60;
assert!(range.contains(&response_message_expiry_interval));
}
_ => {
panic!("Expected cached entry");
}
}
}
#[test]
fn test_cache_in_progress_notified_completion() {
let cache = Cache(Arc::new(Mutex::new(HashMap::new())));
let processing_cancellation_token = CancellationToken::new();
let key = CacheKey {
response_topic: TopicName::new("test_response_topic").unwrap(),
correlation_data: Bytes::from("test_correlation_data"),
};
let entry = CacheEntry::InProgress {
processing_cancellation_token: processing_cancellation_token.clone(),
};
cache.set(key.clone(), entry.clone());
{
let _processing_drop_guard = processing_cancellation_token.drop_guard();
match cache.get(&key) {
CacheLookupResult::InProgress(_) => { }
_ => {
panic!("Expected in progress entry");
}
}
}
match cache.get(&key) {
CacheLookupResult::InProgress(cancellation_token) => {
assert!(cancellation_token.is_cancelled());
}
_ => {
panic!("Expected in progress entry");
}
}
}
#[test]
fn test_response_add_empty_error_payload_success() {
let mut mock_response_payload = MockPayload::new();
mock_response_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let mut custom_user_data = Vec::new();
assert!(
application_error_headers(&mut custom_user_data, "500".into(), " ".into()).is_ok()
);
let response = ResponseBuilder::default()
.custom_user_data(custom_user_data)
.payload(mock_response_payload)
.unwrap()
.build()
.unwrap();
assert_eq!(response.custom_user_data.len(), 1);
let mut app_error_code_header_found = false;
let mut app_error_payload_header_found = false;
for (key, value) in response.custom_user_data {
if key == "AppErrCode" {
app_error_code_header_found = true;
assert_eq!(value, "500");
}
if key == "AppErrPayload" {
app_error_payload_header_found = true;
}
}
assert!(app_error_code_header_found);
assert!(!app_error_payload_header_found);
}
#[test]
fn test_response_add_empty_error_code_error() {
let mut custom_user_data = Vec::new();
assert!(
application_error_headers(&mut custom_user_data, " ".into(), "Some error".into())
.is_err()
);
assert_eq!(custom_user_data.len(), 0);
}
#[test]
fn test_get_response_message_expiry_interval_not_expired() {
let response_message_expiry_interval =
get_response_message_expiry_interval(Instant::now() + Duration::from_secs(10));
let range = 1..=10;
assert!(range.contains(&response_message_expiry_interval.unwrap()));
}
#[test]
fn test_get_response_message_expiry_inteval_expired() {
let response_message_expiry_interval =
get_response_message_expiry_interval(Instant::now() - Duration::from_secs(10));
assert!(response_message_expiry_interval.is_none());
}
#[test]
fn test_get_response_message_expiry_interval_at_limit() {
let response_message_expiry_interval = get_response_message_expiry_interval(
Instant::now() + Duration::from_secs(u64::from(u32::MAX)),
);
let range = 1..=u32::MAX; assert!(range.contains(&response_message_expiry_interval.unwrap()));
}
}