use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Arc, time::Duration};
use azure_iot_operations_mqtt::{
aio::cloud_event as aio_cloud_event,
control_packet::{Publish, PublishProperties, QoS, TopicFilter},
session::{SessionManagedClient, SessionPubReceiver},
};
use bytes::Bytes;
use chrono::{DateTime, Utc};
use iso8601_duration;
use tokio::{
sync::{
Mutex, Notify,
broadcast::{Sender, error::RecvError},
},
task::{self, JoinHandle},
time,
};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::common::{
cloud_event as protocol_cloud_event,
user_properties::{PARTITION_KEY, validate_invoker_user_properties},
};
use crate::{
ProtocolVersion,
application::{ApplicationContext, ApplicationHybridLogicalClock},
common::{
aio_protocol_error::{AIOProtocolError, AIOProtocolErrorKind, Value},
hybrid_logical_clock::HybridLogicalClock,
is_invalid_utf8,
payload_serialize::{
DeserializationError, FormatIndicator, PayloadSerialize, SerializedPayload,
},
topic_processor::{TopicPattern, contains_invalid_char},
user_properties::UserProperty,
},
parse_supported_protocol_major_versions,
rpc_command::{
DEFAULT_RPC_COMMAND_PROTOCOL_VERSION, DEFAULT_RPC_REQUEST_CLOUD_EVENT_EVENT_TYPE,
RPC_COMMAND_PROTOCOL_VERSION, StatusCode, StatusCodeParseError,
},
};
const SUPPORTED_PROTOCOL_VERSIONS: &[u16] = &[1];
#[derive(Builder, Clone, Debug)]
#[builder(setter(into), build_fn(validate = "Self::validate"))]
pub struct Request<TReq>
where
TReq: PayloadSerialize,
{
#[builder(setter(custom))]
serialized_payload: SerializedPayload,
#[builder(private)]
payload_type: PhantomData<TReq>,
#[builder(default)]
custom_user_data: Vec<(String, String)>,
#[builder(default)]
topic_tokens: HashMap<String, String>,
#[builder(setter(custom))]
timeout: Duration,
#[builder(default = "None")]
cloud_event: Option<RequestCloudEvent>,
}
#[derive(Clone, Debug)]
pub struct RequestCloudEvent(protocol_cloud_event::CloudEvent);
#[derive(Clone)]
pub struct RequestCloudEventBuilder(protocol_cloud_event::CloudEventBuilder);
#[derive(Debug)]
#[non_exhaustive]
pub enum RequestCloudEventBuilderError {
UninitializedField(&'static str),
ValidationError(String),
}
impl std::error::Error for RequestCloudEventBuilderError {}
impl std::fmt::Display for RequestCloudEventBuilderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RequestCloudEventBuilderError::UninitializedField(field_name) => {
write!(f, "Uninitialized field: {field_name}")
}
RequestCloudEventBuilderError::ValidationError(err_msg) => {
write!(f, "Validation error: {err_msg}")
}
}
}
}
impl From<protocol_cloud_event::CloudEventBuilderError> for RequestCloudEventBuilderError {
fn from(value: protocol_cloud_event::CloudEventBuilderError) -> Self {
match value {
protocol_cloud_event::CloudEventBuilderError::UninitializedField(field_name) => {
RequestCloudEventBuilderError::UninitializedField(field_name)
}
protocol_cloud_event::CloudEventBuilderError::ValidationError(err_msg) => {
RequestCloudEventBuilderError::ValidationError(err_msg)
}
}
}
}
impl Default for RequestCloudEventBuilder {
fn default() -> Self {
Self(protocol_cloud_event::CloudEventBuilder::new(
DEFAULT_RPC_REQUEST_CLOUD_EVENT_EVENT_TYPE.to_string(),
))
}
}
impl RequestCloudEventBuilder {
pub fn build(&self) -> Result<RequestCloudEvent, RequestCloudEventBuilderError> {
Ok(RequestCloudEvent(
protocol_cloud_event::CloudEventBuilder::build(&self.0)?,
))
}
pub fn source<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.source(value);
self
}
pub fn spec_version<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.spec_version(value);
self
}
pub fn event_type<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.event_type(value);
self
}
pub fn data_schema<VALUE: Into<Option<String>>>(&mut self, value: VALUE) -> &mut Self {
self.0.data_schema(value);
self
}
pub fn id<VALUE: Into<String>>(&mut self, value: VALUE) -> &mut Self {
self.0.id(value);
self
}
pub fn time<VALUE: Into<Option<DateTime<Utc>>>>(&mut self, value: VALUE) -> &mut Self {
self.0.time(value);
self
}
pub fn subject<VALUE: Into<protocol_cloud_event::CloudEventSubject>>(
&mut self,
value: VALUE,
) -> &mut Self {
self.0.subject(value);
self
}
}
impl<TReq: PayloadSerialize> RequestBuilder<TReq> {
pub fn payload(&mut self, payload: TReq) -> Result<&mut Self, AIOProtocolError> {
match payload.serialize() {
Err(e) => Err(AIOProtocolError::new_payload_invalid_error(
true,
false,
Some(e.into()),
Some("Payload serialization error".to_string()),
None,
)),
Ok(serialized_payload) => {
if is_invalid_utf8(&serialized_payload.content_type) {
return Err(AIOProtocolError::new_configuration_invalid_error(
None,
"content_type",
Value::String(serialized_payload.content_type.clone()),
Some(format!(
"Content type '{}' of command request is not valid UTF-8",
serialized_payload.content_type
)),
None,
));
}
self.serialized_payload = Some(serialized_payload);
self.payload_type = Some(PhantomData);
Ok(self)
}
}
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(if timeout.subsec_nanos() != 0 {
Duration::from_secs(timeout.as_secs().saturating_add(1))
} else {
timeout
});
self
}
fn validate(&self) -> Result<(), String> {
if let Some(custom_user_data) = &self.custom_user_data {
for (key, _) in custom_user_data {
if aio_cloud_event::CloudEventFields::from_str(key).is_ok() {
return Err(format!(
"Invalid user data property '{key}' is a reserved Cloud Event key"
));
}
}
validate_invoker_user_properties(custom_user_data)?;
}
if let Some(timeout) = &self.timeout {
if timeout.as_secs() == 0 {
return Err("Timeout must not be 0".to_string());
}
match <u64 as TryInto<u32>>::try_into(timeout.as_secs()) {
Ok(_) => {}
Err(_) => {
return Err("Timeout in seconds must be less than or equal to u32::max to be used as message_expiry_interval".to_string());
}
}
}
if let Some(Some(cloud_event)) = &self.cloud_event
&& let Some(serialized_payload) = &self.serialized_payload
{
aio_cloud_event::CloudEventFields::DataContentType.validate(
&serialized_payload.content_type,
&cloud_event.0.spec_version,
)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct Response<TResp>
where
TResp: PayloadSerialize,
{
pub payload: TResp,
pub content_type: Option<String>,
pub format_indicator: FormatIndicator,
pub custom_user_data: Vec<(String, String)>,
pub timestamp: Option<HybridLogicalClock>,
pub executor_id: Option<String>,
}
pub type ResponseCloudEvent = aio_cloud_event::CloudEvent;
pub type ResponseCloudEventParseError = aio_cloud_event::CloudEventParseError;
pub fn cloud_event_from_response<TResp: PayloadSerialize>(
response: &Response<TResp>,
) -> Result<ResponseCloudEvent, ResponseCloudEventParseError> {
ResponseCloudEvent::try_from((&response.custom_user_data, response.content_type.as_deref()))
}
#[must_use]
pub fn application_error_headers(
custom_user_data: &Vec<(String, String)>,
) -> (Option<String>, Option<String>) {
const APPLICATION_ERROR_CODE_HEADER: &str = "AppErrCode";
const APPLICATION_ERROR_PAYLOAD_HEADER: &str = "AppErrPayload";
let mut app_error_code: Option<String> = None;
let mut app_error_payload: Option<String> = None;
for (key, value) in custom_user_data {
if key == APPLICATION_ERROR_CODE_HEADER {
app_error_code = Some(value.clone());
}
if key == APPLICATION_ERROR_PAYLOAD_HEADER {
app_error_payload = Some(value.clone());
}
}
(app_error_code, app_error_payload)
}
#[derive(thiserror::Error, Debug, Clone)]
#[error("Remote Error status code: {status_code:?}")]
struct RemoteError {
status_code: StatusCode,
protocol_version: ProtocolVersion,
status_message: Option<String>,
is_application_error: bool,
invalid_property_name: Option<String>,
invalid_property_value: Option<String>,
supported_protocol_major_versions: Option<Vec<u16>>,
timestamp: Option<HybridLogicalClock>,
}
impl From<RemoteError> for AIOProtocolError {
fn from(value: RemoteError) -> Self {
let remote_error_clone = value.clone();
let mut aio_error = AIOProtocolError {
kind: AIOProtocolErrorKind::UnknownError,
message: value.status_message,
is_shallow: false, is_remote: true, nested_error: Some(Box::new(remote_error_clone)),
header_name: None,
header_value: None,
timeout_name: None,
timeout_value: None,
property_name: None,
property_value: None,
command_name: None, protocol_version: Some(value.protocol_version.to_string()),
supported_protocol_major_versions: value.supported_protocol_major_versions,
};
match value.status_code {
StatusCode::Ok | StatusCode::NoContent => {
unreachable!("Invalid status code for RemoteError")
}
StatusCode::BadRequest => {
if value.invalid_property_name.is_some() && value.invalid_property_value.is_some() {
aio_error.kind = AIOProtocolErrorKind::HeaderInvalid;
aio_error.header_name = value.invalid_property_name;
aio_error.header_value = value.invalid_property_value;
} else if value.invalid_property_name.is_some() {
aio_error.kind = AIOProtocolErrorKind::HeaderMissing;
aio_error.header_name = value.invalid_property_name;
} else {
aio_error.kind = AIOProtocolErrorKind::PayloadInvalid;
}
}
StatusCode::RequestTimeout => {
aio_error.kind = AIOProtocolErrorKind::Timeout;
aio_error.timeout_name = value.invalid_property_name;
aio_error.timeout_value = value.invalid_property_value.and_then(|timeout_s| {
match timeout_s.parse::<iso8601_duration::Duration>() {
Ok(d) => d.to_std(),
Err(_) => None,
}
});
}
StatusCode::UnsupportedMediaType => {
aio_error.kind = AIOProtocolErrorKind::HeaderInvalid;
aio_error.header_name = value.invalid_property_name;
aio_error.header_value = value.invalid_property_value;
}
StatusCode::InternalServerError => {
if value.is_application_error {
aio_error.kind = AIOProtocolErrorKind::ExecutionException;
aio_error.property_name = value.invalid_property_name;
aio_error.property_value = value.invalid_property_value.map(Value::String);
} else if value.invalid_property_name.is_some() {
aio_error.kind = AIOProtocolErrorKind::InternalLogicError;
aio_error.property_name = value.invalid_property_name;
aio_error.property_value = value.invalid_property_value.map(Value::String);
} else {
aio_error.kind = AIOProtocolErrorKind::UnknownError;
}
}
StatusCode::ServiceUnavailable => {
aio_error.kind = AIOProtocolErrorKind::StateInvalid;
aio_error.property_name = value.invalid_property_name;
aio_error.property_value = value.invalid_property_value.map(Value::String);
}
StatusCode::VersionNotSupported => {
aio_error.kind = AIOProtocolErrorKind::UnsupportedVersion;
}
}
aio_error
}
}
enum CommandResult<TResp>
where
TResp: PayloadSerialize,
{
Ok(Response<TResp>),
Err(RemoteError),
}
impl<TResp> TryFrom<Publish> for CommandResult<TResp>
where
TResp: PayloadSerialize,
{
type Error = AIOProtocolError;
fn try_from(value: Publish) -> Result<CommandResult<TResp>, Self::Error> {
let publish_properties = value.properties;
let expected_aio_properties = [
UserProperty::Timestamp,
UserProperty::Status,
UserProperty::StatusMessage,
UserProperty::SourceId,
UserProperty::IsApplicationError,
UserProperty::InvalidPropertyName,
UserProperty::InvalidPropertyValue,
UserProperty::ProtocolVersion,
UserProperty::SupportedMajorVersions,
UserProperty::RequestProtocolVersion,
];
let mut response_custom_user_data = vec![];
let mut response_aio_data = HashMap::new();
for (key, value) in publish_properties.user_properties {
match UserProperty::from_str(&key) {
Ok(p) if expected_aio_properties.contains(&p) => {
response_aio_data.insert(p, value);
}
Ok(_) => {
log::warn!(
"Response should not contain MQTT user property '{key}'. Value is '{value}'"
);
response_custom_user_data.push((key, value));
}
Err(()) => {
response_custom_user_data.push((key, value));
}
}
}
let protocol_version = {
match response_aio_data.get(&UserProperty::ProtocolVersion) {
Some(protocol_version) => {
if let Some(version) = ProtocolVersion::parse_protocol_version(protocol_version)
{
version
} else {
return Err(AIOProtocolError::new_unsupported_version_error(
Some(format!(
"Received a response with an unparsable protocol version number: {protocol_version}"
)),
protocol_version.clone(),
SUPPORTED_PROTOCOL_VERSIONS.to_vec(),
None,
false,
false,
));
}
}
None => DEFAULT_RPC_COMMAND_PROTOCOL_VERSION,
}
};
if !protocol_version.is_supported(SUPPORTED_PROTOCOL_VERSIONS) {
return Err(AIOProtocolError::new_unsupported_version_error(
None,
protocol_version.to_string(),
SUPPORTED_PROTOCOL_VERSIONS.to_vec(),
None,
false,
false,
));
}
let status_code = {
match response_aio_data.get(&UserProperty::Status) {
Some(s) => match StatusCode::from_str(s) {
Ok(code) => code,
Err(StatusCodeParseError::UnparsableStatusCode(s)) => {
return Err(AIOProtocolError::new_header_invalid_error(
&UserProperty::Status.to_string(),
&s,
false,
Some(format!(
"Could not parse status in response '{s}' as an integer"
)),
None,
));
}
Err(StatusCodeParseError::UnknownStatusCode(_)) => {
let status_message = response_aio_data
.remove(&UserProperty::StatusMessage)
.unwrap_or(String::from("Unknown"));
let mut unknown_err = AIOProtocolError::new_unknown_error(
true,
false,
None,
Some(status_message),
None,
);
unknown_err.property_name =
response_aio_data.remove(&UserProperty::InvalidPropertyName);
unknown_err.property_value = response_aio_data
.remove(&UserProperty::InvalidPropertyValue)
.map(Value::String);
return Err(unknown_err);
}
},
None => {
return Err(AIOProtocolError::new_header_missing_error(
&UserProperty::Status.to_string(),
false,
Some(format!(
"Response missing MQTT user property '{}'",
UserProperty::Status
)),
None,
));
}
}
};
let timestamp = response_aio_data
.get(&UserProperty::Timestamp)
.map(|s| HybridLogicalClock::from_str(s))
.transpose()?;
let command_result = match status_code {
StatusCode::Ok | StatusCode::NoContent => {
let content_type = publish_properties.content_type;
let format_indicator = publish_properties.payload_format_indicator.into();
if matches!(status_code, StatusCode::NoContent) && !value.payload.is_empty() {
return Err(AIOProtocolError::new_payload_invalid_error(
false,
false,
None,
Some("Status code 204 (No Content) should not have a payload".to_string()),
None,
));
}
let payload = match TResp::deserialize(
&value.payload,
content_type.as_ref(),
&format_indicator,
) {
Ok(payload) => payload,
Err(DeserializationError::InvalidPayload(e)) => {
return Err(AIOProtocolError::new_payload_invalid_error(
false,
false,
Some(e.into()),
None,
None,
));
}
Err(DeserializationError::UnsupportedContentType(message)) => {
return Err(AIOProtocolError::new_header_invalid_error(
"Content Type",
&content_type.unwrap_or("None".to_string()),
false,
Some(message),
None,
));
}
};
Self::Ok(Response {
payload,
content_type,
format_indicator,
custom_user_data: response_custom_user_data,
timestamp,
executor_id: response_aio_data.remove(&UserProperty::SourceId),
})
}
_ => Self::Err(RemoteError {
status_code,
protocol_version,
status_message: response_aio_data.remove(&UserProperty::StatusMessage),
is_application_error: response_aio_data
.get(&UserProperty::IsApplicationError)
.is_some_and(|v| v == "true"),
invalid_property_name: response_aio_data.remove(&UserProperty::InvalidPropertyName),
invalid_property_value: response_aio_data
.remove(&UserProperty::InvalidPropertyValue),
timestamp,
supported_protocol_major_versions: response_aio_data
.get(&UserProperty::SupportedMajorVersions)
.map(|s| parse_supported_protocol_major_versions(s)),
}),
};
Ok(command_result)
}
}
#[derive(Builder, Clone)]
#[builder(setter(into))]
pub struct Options {
request_topic_pattern: String,
#[builder(default = "None")]
response_topic_pattern: Option<String>,
command_name: String,
#[builder(default = "None")]
topic_namespace: Option<String>,
#[builder(default)]
topic_token_map: HashMap<String, String>,
#[builder(default = "None")]
response_topic_prefix: Option<String>,
#[builder(default = "None")]
response_topic_suffix: Option<String>,
}
pub struct Invoker<TReq, TResp>
where
TReq: PayloadSerialize + 'static,
TResp: PayloadSerialize + 'static,
{
application_hlc: Arc<ApplicationHybridLogicalClock>,
mqtt_client: SessionManagedClient,
command_name: String,
request_topic_pattern: TopicPattern,
response_topic_pattern: TopicPattern,
response_topic_filter: TopicFilter,
request_payload_type: PhantomData<TReq>,
response_payload_type: PhantomData<TResp>,
state_mutex: Arc<Mutex<State>>,
shutdown_notifier: Arc<Notify>,
response_tx: Sender<Option<Publish>>,
}
enum State {
New,
Subscribed,
ShutdownInitiated,
ShutdownSuccessful,
}
impl<TReq, TResp> Invoker<TReq, TResp>
where
TReq: PayloadSerialize + 'static,
TResp: PayloadSerialize + 'static,
{
pub fn new(
application_context: ApplicationContext,
client: SessionManagedClient,
invoker_options: Options,
) -> Result<Self, AIOProtocolError> {
if invoker_options.command_name.is_empty()
|| contains_invalid_char(&invoker_options.command_name)
{
return Err(AIOProtocolError::new_configuration_invalid_error(
None,
"command_name",
Value::String(invoker_options.command_name.clone()),
None,
Some(invoker_options.command_name),
));
}
let mut response_topic_pattern;
if let Some(pattern) = invoker_options.response_topic_pattern {
response_topic_pattern = pattern;
} else {
response_topic_pattern = invoker_options.request_topic_pattern.clone();
if invoker_options.response_topic_prefix.is_none()
&& invoker_options.response_topic_suffix.is_none()
{
response_topic_pattern =
"clients/".to_owned() + client.client_id() + "/" + &response_topic_pattern;
} else {
if let Some(prefix) = invoker_options.response_topic_prefix {
response_topic_pattern = prefix + "/" + &response_topic_pattern;
}
if let Some(suffix) = invoker_options.response_topic_suffix {
response_topic_pattern = response_topic_pattern + "/" + &suffix;
}
}
}
let request_topic_pattern = TopicPattern::new(
&invoker_options.request_topic_pattern,
None,
invoker_options.topic_namespace.as_deref(),
&invoker_options.topic_token_map,
)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"invoker_options.request_topic_pattern",
)
})?;
let response_topic_pattern = TopicPattern::new(
&response_topic_pattern,
None,
invoker_options.topic_namespace.as_deref(),
&invoker_options.topic_token_map,
)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(e, "response_topic_pattern")
})?;
let invoker_state_mutex = Arc::new(Mutex::new(State::New));
let response_topic_filter = response_topic_pattern.as_subscribe_topic().map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(e, "response_topic_pattern")
})?;
let mqtt_receiver = client.create_filtered_pub_receiver(response_topic_filter.clone());
let response_tx = Sender::new(5);
let shutdown_notifier = Arc::new(Notify::new());
task::spawn({
let response_tx_clone = response_tx.clone();
let shutdown_notifier_clone = shutdown_notifier.clone();
let command_name_clone = invoker_options.command_name.clone();
async move {
Self::receive_response_loop(
mqtt_receiver,
response_tx_clone,
shutdown_notifier_clone,
command_name_clone,
)
.await;
}
});
Ok(Self {
application_hlc: application_context.application_hlc,
mqtt_client: client,
command_name: invoker_options.command_name,
request_topic_pattern,
response_topic_pattern,
response_topic_filter,
request_payload_type: PhantomData,
response_payload_type: PhantomData,
state_mutex: invoker_state_mutex,
shutdown_notifier,
response_tx,
})
}
pub async fn invoke(
&self,
request: Request<TReq>,
) -> Result<Response<TResp>, AIOProtocolError> {
let command_timeout = request.timeout;
let invoke_result = time::timeout(request.timeout, self.invoke_internal(request)).await;
match invoke_result {
Ok(result) => match result {
Ok(response) => Ok(response),
Err(e) => Err(e),
},
Err(e) => {
log::error!(
"[{command_name}] Command invoke timed out after {command_timeout:?}",
command_name = self.command_name,
);
Err(AIOProtocolError::new_timeout_error(
false,
Some(Box::new(e)),
&self.command_name,
command_timeout,
None,
Some(self.command_name.clone()),
))
}
}
}
async fn subscribe_to_response_filter(&self) -> Result<(), AIOProtocolError> {
let subscribe_result = self
.mqtt_client
.subscribe(
self.response_topic_filter.clone(),
QoS::AtLeastOnce,
false,
azure_iot_operations_mqtt::control_packet::RetainOptions::default(),
azure_iot_operations_mqtt::control_packet::SubscribeProperties::default(),
)
.await;
match subscribe_result {
Ok(sub_ct) => {
match sub_ct.await {
Ok(suback) => {
suback.as_result().map_err(|e| {
log::error!("[{}] Invoker suback error: {suback:?}", self.command_name);
AIOProtocolError::new_mqtt_error(
Some("MQTT Error on command invoker suback".to_string()),
Box::new(e),
Some(self.command_name.clone()),
)
})?;
}
Err(e) => {
log::error!(
"[{}] Invoker subscribe completion error: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT Error on command invoker subscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
}
}
Err(e) => {
log::error!(
"[{}] Client error while subscribing in Invoker: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("Client error on command invoker subscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
}
Ok(())
}
async fn invoke_internal(
&self,
mut request: Request<TReq>,
) -> Result<Response<TResp>, AIOProtocolError> {
let cancellation_token = CancellationToken::new();
let _drop_guard = cancellation_token.clone().drop_guard();
let message_expiry_interval: u32 = match request.timeout.as_secs().try_into() {
Ok(val) => val,
Err(_) => {
unreachable!();
}
};
let request_topic = self
.request_topic_pattern
.as_publish_topic(&request.topic_tokens)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"request_topic_pattern",
)
})?;
let response_topic = self
.response_topic_pattern
.as_publish_topic(&request.topic_tokens)
.map_err(|e| {
AIOProtocolError::config_invalid_from_topic_pattern_error(
e,
"response_topic_pattern",
)
})?;
let correlation_id = Uuid::new_v4();
let correlation_data = Bytes::from(correlation_id.as_bytes().to_vec());
let timestamp_str = self.application_hlc.update_now()?;
request.custom_user_data.push((
UserProperty::SourceId.to_string(),
self.mqtt_client.client_id().to_string(),
));
request
.custom_user_data
.push((UserProperty::Timestamp.to_string(), timestamp_str));
request.custom_user_data.push((
UserProperty::ProtocolVersion.to_string(),
RPC_COMMAND_PROTOCOL_VERSION.to_string(),
));
request.custom_user_data.push((
PARTITION_KEY.to_string(),
self.mqtt_client.client_id().to_string(),
));
if let Some(cloud_event) = request.cloud_event {
let cloud_event_headers = cloud_event.0.into_headers(request_topic.as_str());
for (key, value) in cloud_event_headers {
request.custom_user_data.push((key, value));
}
}
let publish_properties = PublishProperties {
correlation_data: Some(correlation_data.clone()),
response_topic: Some(response_topic),
payload_format_indicator: request.serialized_payload.format_indicator.into(),
content_type: Some(request.serialized_payload.content_type.clone()),
message_expiry_interval: Some(message_expiry_interval),
user_properties: request.custom_user_data,
topic_alias: None,
subscription_identifiers: Vec::new(),
};
{
let mut invoker_state = self.state_mutex.lock().await;
match *invoker_state {
State::New => {
self.subscribe_to_response_filter().await?;
*invoker_state = State::Subscribed;
}
State::Subscribed => { }
State::ShutdownInitiated | State::ShutdownSuccessful => {
return Err(AIOProtocolError::new_cancellation_error(
false,
None,
Some(
"Command Invoker has been shutdown and can no longer invoke commands"
.to_string(),
),
Some(self.command_name.clone()),
));
}
}
}
let mut response_rx = self.response_tx.subscribe();
let publish_result = self
.mqtt_client
.publish_qos1(
request_topic,
false,
request.serialized_payload.payload,
publish_properties,
)
.await;
let pub_task = tokio::task::spawn({
let command_name = self.command_name.clone();
let ct = cancellation_token.clone();
async move {
match publish_result {
Ok(publish_completion_token) => {
tokio::select! {
() = ct.cancelled() => {
Err(AIOProtocolError::new_timeout_error(
false,
None,
&command_name,
request.timeout,
None,
Some(command_name.clone()),
))
},
publish_completion_token_result = publish_completion_token => {
match publish_completion_token_result {
Ok(puback) => {
puback.as_result().map_err(|e| {
AIOProtocolError::new_mqtt_error(
Some("MQTT Puback indicated failure".to_string()),
Box::new(e),
Some(command_name),
)
})
},
Err(e) => {
log::error!("[{command_name}] Command Request publish completion error: {e}");
Err(AIOProtocolError::new_mqtt_error(
Some("MQTT Error on command invoke publish".to_string()),
Box::new(e),
Some(command_name),
))
}
}
}
}
}
Err(e) => {
log::error!(
"[{command_name}] Client error while publishing Invoker Command Request: {e}"
);
Err(AIOProtocolError::new_mqtt_error(
Some("Client error on command invoker request publish".to_string()),
Box::new(e),
Some(command_name),
))
}
}
}
});
let response_task = tokio::task::spawn({
let command_name = self.command_name.clone();
let ct = cancellation_token.clone();
async move {
loop {
tokio::select! {
() = ct.cancelled() => {
return Err(AIOProtocolError::new_timeout_error(
false,
None,
&command_name,
request.timeout,
None,
Some(command_name.clone()),
));
},
res = response_rx.recv() => {
match res {
Ok(rsp_pub) => {
if let Some(rsp_pub) = rsp_pub {
if let Some(ref response_correlation_data) =
rsp_pub.properties.correlation_data
&& *response_correlation_data == correlation_data {
return Ok(rsp_pub);
}
} else {
log::error!(
"[{command_name}] Command Invoker has been shutdown and will no longer receive a response"
);
return Err(AIOProtocolError::new_cancellation_error(
false,
None,
Some(
"Command Invoker has been shutdown and will no longer receive a response"
.to_string(),
),
Some(command_name),
));
}
}
Err(RecvError::Lagged(e)) => {
log::warn!(
"[{command_name}] Invoker response receiver lagged. Response may not be received. Number of skipped messages: {e}"
);
}
Err(RecvError::Closed) => {
log::error!(
"[{command_name}] Invoker MQTT Receiver has been cleaned up and will no longer send a response"
);
return Err(AIOProtocolError::new_cancellation_error(
false,
None,
Some(
"MQTT Receiver has been cleaned up and will no longer send a response"
.to_string(),
),
Some(command_name),
));
}
}
}
}
}
}
});
let rsp_pub = match tokio::try_join!(flatten(pub_task), flatten(response_task)) {
Ok(((), rsp_pub)) => rsp_pub,
Err(e) => {
return Err(e);
}
};
let command_result: CommandResult<TResp> =
rsp_pub.try_into().map_err(|mut e: AIOProtocolError| {
e.command_name = Some(self.command_name.clone());
e
})?;
match command_result {
CommandResult::Ok(response) => {
if let Some(hlc) = &response.timestamp {
self.application_hlc.update(hlc).map_err(|e| {
let mut aio_error: AIOProtocolError = e.into();
aio_error.command_name = Some(self.command_name.clone());
aio_error
})?;
}
Ok(response)
}
CommandResult::Err(remote_e) => {
if let Some(hlc) = &remote_e.timestamp {
self.application_hlc.update(hlc).map_err(|e| {
let mut aio_error: AIOProtocolError = e.into();
aio_error.command_name = Some(self.command_name.clone());
aio_error
})?;
}
let mut aio_e: AIOProtocolError = remote_e.into();
aio_e.command_name = Some(self.command_name.clone());
Err(aio_e)
}
}
}
async fn receive_response_loop(
mut mqtt_receiver: SessionPubReceiver,
response_tx: Sender<Option<Publish>>,
shutdown_notifier: Arc<Notify>,
command_name: String,
) {
loop {
tokio::select! {
() = shutdown_notifier.notified() => {
mqtt_receiver.close();
log::info!("[{command_name}] Invoker MQTT Receiver closed");
},
recv_result = mqtt_receiver.recv_manual_ack() => {
if let Some((m, ack_token)) = recv_result {
match response_tx.send(Some(m)) {
Ok(_) => { },
Err(e) => {
log::debug!("[{command_name}] Command Response ignored, no pending commands: {e}");
}
}
if let Some(ack_token) = ack_token {
tokio::task::spawn({
let command_name_clone = command_name.clone();
async move {
match ack_token.ack().await {
Ok(ack_ct) => {
match ack_ct.await {
Ok(()) => { },
Err(e) => log::warn!("[{command_name_clone}] Error acking command response: {e}"),
}
},
Err(e) => {
log::warn!("[{command_name_clone}] Error acking command response: {e}");
}
}
}
});
}
} else {
_ = response_tx.send(None);
log::info!("[{command_name}] No more command responses will be received.");
break;
}
}
}
}
}
pub async fn shutdown(&self) -> Result<(), AIOProtocolError> {
self.shutdown_notifier.notify_one();
let mut invoker_state_mutex_guard = self.state_mutex.lock().await;
match *invoker_state_mutex_guard {
State::New | State::ShutdownSuccessful => {
}
State::ShutdownInitiated | State::Subscribed => {
*invoker_state_mutex_guard = State::ShutdownInitiated;
let unsubscribe_result = self
.mqtt_client
.unsubscribe(
self.response_topic_filter.clone(),
azure_iot_operations_mqtt::control_packet::UnsubscribeProperties::default(),
)
.await;
match unsubscribe_result {
Ok(unsub_completion_token) => match unsub_completion_token.await {
Ok(unsuback) => {
unsuback.as_result().map_err(|e| {
log::error!(
"[{}] Invoker Unsuback error: {unsuback:?}",
self.command_name
);
AIOProtocolError::new_mqtt_error(
Some("MQTT error on command invoker unsuback".to_string()),
Box::new(e),
Some(self.command_name.clone()),
)
})?;
}
Err(e) => {
log::error!(
"[{}] Invoker Unsubscribe completion error: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("MQTT error on command invoker unsubscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
},
Err(e) => {
log::error!(
"[{}] Client error while unsubscribing in Invoker: {e}",
self.command_name
);
return Err(AIOProtocolError::new_mqtt_error(
Some("Client error on command invoker unsubscribe".to_string()),
Box::new(e),
Some(self.command_name.clone()),
));
}
}
}
}
log::info!("[{}] Command Invoker Shutdown", self.command_name);
*invoker_state_mutex_guard = State::ShutdownSuccessful;
Ok(())
}
}
impl<TReq, TResp> Drop for Invoker<TReq, TResp>
where
TReq: PayloadSerialize + 'static,
TResp: PayloadSerialize + 'static,
{
fn drop(&mut self) {
tokio::spawn({
let invoker_state_mutex = self.state_mutex.clone();
let unsubscribe_filter = self.response_topic_filter.clone();
let mqtt_client = self.mqtt_client.clone();
async move { drop_unsubscribe(mqtt_client, invoker_state_mutex, unsubscribe_filter).await }
});
self.shutdown_notifier.notify_one();
log::info!("[{}] Command Invoker has been dropped", self.command_name);
}
}
async fn drop_unsubscribe(
mqtt_client: SessionManagedClient,
invoker_state_mutex: Arc<Mutex<State>>,
unsubscribe_filter: TopicFilter,
) {
let mut invoker_state_mutex_guard = invoker_state_mutex.lock().await;
match *invoker_state_mutex_guard {
State::New | State::ShutdownSuccessful => {
}
State::ShutdownInitiated | State::Subscribed => {
*invoker_state_mutex_guard = State::ShutdownInitiated;
match mqtt_client
.unsubscribe(
unsubscribe_filter.clone(),
azure_iot_operations_mqtt::control_packet::UnsubscribeProperties::default(),
)
.await
{
Ok(_) => {
log::debug!(
"Invoker Unsubscribe sent on topic {unsubscribe_filter}. Unsuback may still be pending."
);
}
Err(e) => {
log::warn!("Invoker Unsubscribe error on topic {unsubscribe_filter}: {e}");
}
}
}
}
*invoker_state_mutex_guard = State::ShutdownSuccessful;
}
async fn flatten<T>(
handle: JoinHandle<Result<T, AIOProtocolError>>,
) -> Result<T, AIOProtocolError> {
match handle.await {
Ok(Ok(result)) => Ok(result),
Ok(Err(e)) => Err(e),
Err(e) => {
unreachable!("Invoker Join Error: {e}. Tasks should not be able to panic")
}
}
}
#[cfg(test)]
mod tests {
use test_case::test_case;
use azure_iot_operations_mqtt::aio::connection_settings::MqttConnectionSettingsBuilder;
use azure_iot_operations_mqtt::session::{Session, SessionOptionsBuilder};
use super::*;
use crate::application::ApplicationContextBuilder;
use crate::common::{
aio_protocol_error::AIOProtocolErrorKind,
payload_serialize::{DESERIALIZE_MTX, FormatIndicator, MockPayload},
};
fn create_session() -> Session {
let connection_settings = MqttConnectionSettingsBuilder::default()
.hostname("localhost")
.client_id("test_client")
.build()
.unwrap();
let session_options = SessionOptionsBuilder::default()
.connection_settings(connection_settings)
.build()
.unwrap();
Session::new(session_options).unwrap()
}
fn create_topic_tokens() -> HashMap<String, String> {
HashMap::from([
("commandName".to_string(), "test_command_name".to_string()),
("invokerClientId".to_string(), "test_client".to_string()),
])
}
#[tokio::test]
async fn test_new_defaults() {
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/{commandName}/{executorId}/request")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
assert_eq!(
invoker
.response_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
"clients/test_client/test/test_command_name/+/request"
);
}
#[tokio::test]
async fn test_new_override_defaults() {
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/{commandName}/{executorId}/request")
.response_topic_pattern("test/{commandName}/{executorId}/response".to_string())
.command_name("test_command_name")
.topic_namespace("test_namespace".to_string())
.topic_token_map(create_topic_tokens())
.response_topic_prefix("custom/{invokerClientId}".to_string())
.response_topic_suffix("custom/response".to_string())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
assert_eq!(
invoker
.response_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
"test_namespace/test/test_command_name/+/response"
);
}
#[test_case("command_name", ""; "new_empty_command_name")]
#[test_case("command_name", " "; "new_whitespace_command_name")]
#[test_case("request_topic_pattern", ""; "new_empty_request_topic_pattern")]
#[test_case("request_topic_pattern", " "; "new_whitespace_request_topic_pattern")]
#[test_case("response_topic_pattern", ""; "new_empty_response_topic_pattern")]
#[test_case("response_topic_pattern", " "; "new_whitespace_response_topic_pattern")]
#[test_case("response_topic_prefix", ""; "new_empty_response_topic_prefix")]
#[test_case("response_topic_prefix", " "; "new_whitespace_response_topic_prefix")]
#[test_case("response_topic_suffix", ""; "new_empty_response_topic_suffix")]
#[test_case("response_topic_suffix", " "; "new_whitespace_response_topic_suffix")]
#[tokio::test]
async fn test_new_empty_args(property_name: &str, property_value: &str) {
let session = create_session();
let managed_client = session.create_managed_client();
let mut command_name = "test_command_name".to_string();
let mut request_topic_pattern = "test/req/topic".to_string();
let mut response_topic_pattern = None;
let mut response_topic_prefix = "custom/prefix".to_string();
let mut response_topic_suffix = "custom/suffix".to_string();
let error_property_name;
let mut error_property_value = property_value.to_string();
match property_name {
"command_name" => {
command_name = property_value.to_string();
error_property_name = "command_name";
}
"request_topic_pattern" => {
request_topic_pattern = property_value.to_string();
error_property_name = "invoker_options.request_topic_pattern";
}
"response_topic_pattern" => {
response_topic_pattern = Some(property_value.to_string());
error_property_name = "response_topic_pattern";
}
"response_topic_prefix" => {
response_topic_prefix = property_value.to_string();
error_property_name = "response_topic_pattern";
error_property_value.push_str("/test/req/topic/custom/suffix");
}
"response_topic_suffix" => {
response_topic_suffix = property_value.to_string();
error_property_name = "response_topic_pattern";
error_property_value = "custom/prefix/test/req/topic/".to_string();
error_property_value.push_str(&response_topic_suffix);
}
_ => panic!("Invalid property_name"),
}
let invoker_options = OptionsBuilder::default()
.request_topic_pattern(request_topic_pattern)
.response_topic_pattern(response_topic_pattern)
.response_topic_prefix(response_topic_prefix)
.response_topic_suffix(response_topic_suffix)
.command_name(command_name)
.build()
.unwrap();
let invoker: Result<Invoker<MockPayload, MockPayload>, AIOProtocolError> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
);
match invoker {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some(error_property_name.to_string()));
assert!(e.property_value == Some(Value::String(error_property_value.clone())));
}
}
}
#[test_case(Some("custom/prefix".to_string()), Some("custom/suffix".to_string()), "custom/prefix/test/req/topic/custom/suffix"; "new_response_topic_prefix_and_suffix")]
#[test_case(None, Some("custom/suffix".to_string()), "test/req/topic/custom/suffix"; "new_none_response_topic_prefix")]
#[test_case(Some("custom/prefix".to_string()), None, "custom/prefix/test/req/topic"; "new_none_response_topic_suffix")]
#[test_case(None, None, "clients/test_client/test/req/topic"; "new_none_response_topic_prefix_and_suffix")]
#[tokio::test]
async fn test_new_response_pattern_prefix_suffix_args(
response_topic_prefix: Option<String>,
response_topic_suffix: Option<String>,
expected_response_topic_subscribe_pattern: &str,
) {
let session = create_session();
let managed_client = session.create_managed_client();
let command_name = "test_command_name".to_string();
let request_topic_pattern = "test/req/topic".to_string();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern(request_topic_pattern)
.command_name(command_name)
.response_topic_prefix(response_topic_prefix)
.response_topic_suffix(response_topic_suffix)
.build()
.unwrap();
let invoker: Result<Invoker<MockPayload, MockPayload>, AIOProtocolError> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
);
assert!(invoker.is_ok());
assert_eq!(
invoker
.unwrap()
.response_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
expected_response_topic_subscribe_pattern
);
}
#[tokio::test]
async fn test_new_response_pattern_default_prefix() {
let session = create_session();
let managed_client = session.create_managed_client();
let command_name = "test_command_name";
let request_topic_pattern = "test/req/topic";
let invoker_options = OptionsBuilder::default()
.request_topic_pattern(request_topic_pattern)
.command_name(command_name)
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Result<Invoker<MockPayload, MockPayload>, AIOProtocolError> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
);
assert!(invoker.is_ok());
assert_eq!(
invoker
.unwrap()
.response_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
"clients/test_client/test/req/topic"
);
}
#[tokio::test]
async fn test_new_response_pattern_only_suffix() {
let session = create_session();
let managed_client = session.create_managed_client();
let command_name = "test_command_name";
let request_topic_pattern = "test/req/topic";
let response_topic_suffix = "custom/suffix";
let invoker_options = OptionsBuilder::default()
.request_topic_pattern(request_topic_pattern)
.command_name(command_name)
.topic_token_map(create_topic_tokens())
.response_topic_suffix(response_topic_suffix.to_string())
.build()
.unwrap();
let invoker: Result<Invoker<MockPayload, MockPayload>, AIOProtocolError> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
);
assert!(invoker.is_ok());
assert_eq!(
invoker
.unwrap()
.response_topic_pattern
.as_subscribe_topic()
.unwrap()
.as_str(),
"test/req/topic/custom/suffix"
);
}
#[tokio::test]
#[ignore = "test ignored because waiting for the suback hangs forever. Leaving the test for now until we have a full testing framework"]
async fn test_invoke_timeout_parameter() {
let _deserialize_mutex = DESERIALIZE_MTX.lock();
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/req/topic")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let mock_payload_deserialize_ctx = MockPayload::deserialize_context();
mock_payload_deserialize_ctx
.expect()
.returning(|_, _, _| {
let mut mock_response_payload = MockPayload::default();
mock_response_payload
.expect_clone()
.returning(MockPayload::default)
.times(1);
Ok(mock_response_payload)
})
.once();
let mut invoker_state = invoker.state_mutex.lock().await;
*invoker_state = State::Subscribed;
drop(invoker_state);
let response = invoker
.invoke(
RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_secs(5))
.build()
.unwrap(),
)
.await;
assert!(response.is_ok());
}
#[tokio::test]
async fn test_invoke_times_out() {
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/req/topic")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let response = invoker
.invoke(
RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_secs(1))
.build()
.unwrap(),
)
.await;
match response {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::Timeout);
assert!(!e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.timeout_name, Some("test_command_name".to_string()));
assert!(e.timeout_value == Some(Duration::from_secs(1)));
}
}
}
#[tokio::test]
async fn test_invoke_times_out_timeout_rounded() {
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/req/topic")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let response = invoker
.invoke(
RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_nanos(1))
.build()
.unwrap(),
)
.await;
match response {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::Timeout);
assert!(!e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.timeout_name, Some("test_command_name".to_string()));
assert!(e.timeout_value == Some(Duration::from_secs(1)));
}
}
}
#[tokio::test]
#[ignore = "test ignored because waiting for the suback hangs forever. Leaving the test for now until we have a full testing framework"]
async fn test_invoke_deserialize_error() {
let _deserialize_mutex = DESERIALIZE_MTX.lock();
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/req/topic")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let mock_payload_deserialize_ctx = MockPayload::deserialize_context();
mock_payload_deserialize_ctx
.expect()
.returning(|_, _, _| {
Err(DeserializationError::InvalidPayload(
"dummy error".to_string(),
))
})
.once();
let mut invoker_state = invoker.state_mutex.lock().await;
*invoker_state = State::Subscribed;
drop(invoker_state);
let response = invoker
.invoke(
RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_millis(2))
.build()
.unwrap(),
)
.await;
match response {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::PayloadInvalid);
assert!(!e.is_shallow);
assert!(!e.is_remote);
assert!(e.nested_error.is_some());
}
}
}
#[tokio::test]
async fn test_invoke_executor_id_invalid_value() {
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/req/{executorId}/topic")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let response = invoker
.invoke(
RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_secs(2))
.topic_tokens(HashMap::from([(
"executorId".to_string(),
"+++".to_string(),
)]))
.build()
.unwrap(),
)
.await;
match response {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("executorId".to_string()));
assert!(e.property_value == Some(Value::String("+++".to_string())));
}
}
}
#[tokio::test]
async fn test_invoke_missing_token() {
let session = create_session();
let managed_client = session.create_managed_client();
let invoker_options = OptionsBuilder::default()
.request_topic_pattern("test/req/{executorId}/topic")
.command_name("test_command_name")
.topic_token_map(create_topic_tokens())
.build()
.unwrap();
let invoker: Invoker<MockPayload, MockPayload> = Invoker::new(
ApplicationContextBuilder::default().build().unwrap(),
managed_client,
invoker_options,
)
.unwrap();
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let response = invoker
.invoke(
RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_secs(2))
.topic_tokens(HashMap::new())
.build()
.unwrap(),
)
.await;
match response {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("executorId".to_string()));
assert_eq!(e.property_value, Some(Value::String(String::new())));
}
}
}
#[test]
fn test_request_serialization_error() {
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| Err("dummy error".to_string()))
.times(1);
let mut binding = RequestBuilder::default();
let req_builder = binding.payload(mock_request_payload);
match req_builder {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::PayloadInvalid);
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test]
fn test_request_serialization_bad_content_type_error() {
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json\u{0000}".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let mut binding = RequestBuilder::default();
let req_builder = binding.payload(mock_request_payload);
match req_builder {
Err(e) => {
assert_eq!(e.kind, AIOProtocolErrorKind::ConfigurationInvalid);
assert!(e.is_shallow);
assert!(!e.is_remote);
assert_eq!(e.property_name, Some("content_type".to_string()));
assert!(
e.property_value == Some(Value::String("application/json\u{0000}".to_string()))
);
}
Ok(_) => {
panic!("Expected error");
}
}
}
#[test_case(Duration::from_secs(0); "invoke_timeout_0")]
#[test_case(Duration::from_secs(u64::from(u32::MAX) + 1); "invoke_timeout_u32_max")]
fn test_request_timeout_invalid_value(timeout: Duration) {
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let request_builder_result = RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(timeout)
.build();
assert!(request_builder_result.is_err());
}
#[test]
fn test_request_invalid_custom_user_data_cloud_event_header() {
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let request_builder_result = RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_secs(2))
.custom_user_data(vec![("source".to_string(), "test".to_string())])
.build();
assert!(request_builder_result.is_err());
}
#[test]
fn test_request_defaults() {
let mut mock_request_payload = MockPayload::new();
mock_request_payload
.expect_serialize()
.returning(|| {
Ok(SerializedPayload {
payload: Vec::new(),
content_type: "application/json".to_string(),
format_indicator: FormatIndicator::Utf8EncodedCharacterData,
})
})
.times(1);
let request_builder_result = RequestBuilder::default()
.payload(mock_request_payload)
.unwrap()
.timeout(Duration::from_secs(2))
.build();
let r = request_builder_result.unwrap();
assert_eq!(r.timeout, Duration::from_secs(2));
assert!(r.custom_user_data.is_empty());
assert!(r.topic_tokens.is_empty());
assert!(r.cloud_event.is_none());
assert!(r.serialized_payload.payload.is_empty());
}
#[tokio::test]
async fn test_no_app_error_code_and_payload() {
let user_data: Vec<(String, String)> = Vec::new();
let (application_error_code, application_error_payload) =
application_error_headers(&user_data);
assert!(application_error_code.is_none());
assert!(application_error_payload.is_none());
}
#[tokio::test]
async fn test_response_with_app_error_code_and_payload() {
let error_code_content = "5888";
let error_payload_content = "5888 is a fictitious error code";
let custom_user_data = vec![
("AppErrCode".into(), error_code_content.into()),
("AppErrPayload".into(), error_payload_content.into()),
];
assert_eq!(custom_user_data.len(), 2);
let (application_error_code, application_error_payload) =
application_error_headers(&custom_user_data);
assert_eq!(application_error_code, Some(error_code_content.into()));
assert_eq!(
application_error_payload,
Some(error_payload_content.into())
);
}
#[tokio::test]
async fn test_response_with_app_error_code_but_no_payload() {
let error_code_content = "5888";
let custom_user_data = vec![("AppErrCode".into(), error_code_content.into())];
assert_eq!(custom_user_data.len(), 1);
let (application_error_code, application_error_payload) =
application_error_headers(&custom_user_data);
assert_eq!(application_error_code, Some(error_code_content.into()));
assert!(application_error_payload.is_none());
}
}