use std::fmt;
use std::str::FromStr;
use async_channel::Sender;
use log::{debug, error, info};
use rumqttc::{
ConnAck, Connect, Event, MqttOptions, Packet, PubAck, PubComp, PubRec, PubRel, Publish,
Request, SubAck, Subscribe, UnsubAck, Unsubscribe,
};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::UnboundedReceiver;
use super::*;
use crate::{AccountId, Addressable, AgentId, Authenticable, Error, SharedGroup};
#[cfg(feature = "queue-counter")]
use crate::queue_counter::QueueCounterHandle;
const DEFAULT_MQTT_REQUESTS_CHAN_SIZE: Option<usize> = Some(10_000);
#[derive(Debug, Clone, Deserialize)]
pub struct AgentConfig {
uri: String,
clean_session: Option<bool>,
keep_alive_interval: Option<u64>,
reconnect_interval: Option<u64>,
outgoing_message_queue_size: Option<usize>,
incoming_message_queue_size: Option<usize>,
password: Option<String>,
max_message_size: Option<usize>,
#[serde(default = "default_mqtt_requests_chan_size")]
requests_channel_size: Option<usize>,
}
fn default_mqtt_requests_chan_size() -> Option<usize> {
DEFAULT_MQTT_REQUESTS_CHAN_SIZE
}
impl AgentConfig {
pub fn set_password(&mut self, value: &str) -> &mut Self {
self.password = Some(value.to_owned());
self
}
}
#[derive(Debug)]
pub struct AgentBuilder {
connection: Connection,
api_version: String,
}
impl AgentBuilder {
pub fn new(agent_id: AgentId, api_version: &str) -> Self {
Self {
connection: Connection::new(agent_id),
api_version: api_version.to_owned(),
}
}
pub fn connection_version(self, version: &str) -> Self {
let mut connection = self.connection;
connection.set_version(version);
Self { connection, ..self }
}
pub fn connection_mode(self, mode: ConnectionMode) -> Self {
let mut connection = self.connection;
connection.set_mode(mode);
Self { connection, ..self }
}
pub fn start(
self,
config: &AgentConfig,
) -> Result<(Agent, UnboundedReceiver<AgentNotification>), Error> {
{
let options = Self::mqtt_options(&self.connection, config)?;
let channel_size = config
.requests_channel_size
.expect("requests_channel_size is not specified");
let mut eventloop = rumqttc::EventLoop::new(options, channel_size);
let mqtt_tx = eventloop.handle();
let reconnect_interval = config.reconnect_interval.to_owned();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<AgentNotification>();
#[cfg(feature = "queue-counter")]
let queue_counter = QueueCounterHandle::start();
#[cfg(feature = "queue-counter")]
let queue_counter_ = queue_counter.clone();
tokio::spawn(async move {
let mut recovering_connection = false;
loop {
match eventloop.poll().await {
Ok(packet) => {
if recovering_connection {
recovering_connection = false;
if let Err(e) = tx.send(AgentNotification::Reconnection) {
error!("Failed to notify about reconnection: {}", e);
}
}
match packet {
Event::Outgoing(content) => {
info!("Outgoing message = '{:?}'", content);
}
Event::Incoming(message) => {
debug!("Incoming item = {:?}", message);
let mut msg: AgentNotification = message.into();
if let AgentNotification::Message(Ok(ref mut content), _) = msg
{
if let IncomingMessage::Request(req) = content {
let method = req.properties().method().to_owned();
req.properties_mut().set_method(&method);
}
#[cfg(feature = "queue-counter")]
queue_counter_.add_incoming_message(content);
}
if let Err(e) = tx.send(msg) {
error!("Failed to transmit message, reason = {}", e);
};
}
}
}
Err(err) => {
error!("Failed to poll, reason = {}", err);
recovering_connection = true;
if let Err(e) = tx.send(AgentNotification::ConnectionError) {
error!("Failed to notify about connection error: {}", e);
}
match reconnect_interval {
Some(value) => {
tokio::time::sleep(std::time::Duration::from_secs(value)).await
}
None => break,
}
}
}
}
});
let agent = Agent::new(
self.connection.agent_id,
&self.api_version,
mqtt_tx,
#[cfg(feature = "queue-counter")]
queue_counter,
);
Ok((agent, rx))
}
}
fn mqtt_options(connection: &Connection, config: &AgentConfig) -> Result<MqttOptions, Error> {
let uri = config
.uri
.parse::<http::Uri>()
.map_err(|e| Error::new(&format!("error parsing MQTT connection URL, {}", e)))?;
let host = uri.host().ok_or_else(|| Error::new("missing MQTT host"))?;
let port = uri
.port_u16()
.ok_or_else(|| Error::new("missing MQTT port"))?;
let username = format!("{}::{}", connection.version, connection.mode);
let password = config
.password
.to_owned()
.unwrap_or_else(|| String::from(""));
let mut opts = MqttOptions::new(connection.agent_id.to_string(), host, port);
opts.set_credentials(username, password);
if let Some(value) = config.clean_session {
opts.set_clean_session(value);
}
if let Some(value) = config.keep_alive_interval {
opts.set_keep_alive(value as u16);
}
if let Some(value) = config.incoming_message_queue_size {
opts.set_request_channel_capacity(value);
}
if let Some(value) = config.outgoing_message_queue_size {
opts.set_inflight(value as u16);
}
if let Some(value) = config.max_message_size {
opts.set_max_packet_size(value, value);
};
Ok(opts)
}
}
#[derive(Clone, Debug)]
pub struct Address {
id: AgentId,
version: String,
}
impl Address {
pub fn new(id: AgentId, version: &str) -> Self {
Self {
id,
version: version.to_owned(),
}
}
pub fn id(&self) -> &AgentId {
&self.id
}
pub fn version(&self) -> &str {
&self.version
}
}
#[derive(Clone)]
pub struct Agent {
address: Address,
tx: Sender<Request>,
#[cfg(feature = "queue-counter")]
queue_counter: QueueCounterHandle,
}
impl Agent {
#[cfg(feature = "queue-counter")]
fn new(
id: AgentId,
api_version: &str,
tx: Sender<Request>,
queue_counter: QueueCounterHandle,
) -> Self {
Self {
address: Address::new(id, api_version),
tx,
queue_counter,
}
}
#[cfg(not(feature = "queue-counter"))]
fn new(id: AgentId, api_version: &str, tx: Sender<Request>) -> Self {
Self {
address: Address::new(id, api_version),
tx,
}
}
pub fn address(&self) -> &Address {
&self.address
}
pub fn id(&self) -> &AgentId {
self.address.id()
}
pub fn publish<T: serde::Serialize>(
&mut self,
message: OutgoingMessage<T>,
) -> Result<(), Error> {
let dump = Box::new(message).into_dump(&self.address)?;
self.publish_dump(dump)
}
pub fn publish_publishable(
&mut self,
message: Box<dyn IntoPublishableMessage>,
) -> Result<(), Error> {
let dump = message.into_dump(&self.address)?;
self.publish_dump(dump)
}
pub fn publish_dump(&mut self, dump: PublishableMessage) -> Result<(), Error> {
#[cfg(feature = "queue-counter")]
self.queue_counter.add_outgoing_message(&dump);
let dump = match dump {
PublishableMessage::Event(dump) => dump,
PublishableMessage::Request(dump) => dump,
PublishableMessage::Response(dump) => dump,
};
info!(
"Outgoing message = '{}' sending to the topic = '{}'",
dump.payload(),
dump.topic(),
);
let publish = Publish::new(dump.topic(), dump.qos(), dump.payload());
self.tx.try_send(Request::Publish(publish)).map_err(|e| {
if e.is_full() {
error!(
"Rumq Requests channel reached maximum capacity, no space to publish, {:?}",
&e
)
}
Error::new(&format!("error publishing MQTT message, {}", &e))
})
}
pub fn subscribe<S>(
&mut self,
subscription: &S,
qos: QoS,
maybe_group: Option<&SharedGroup>,
) -> Result<(), Error>
where
S: SubscriptionTopic,
{
let topic = self.get_topic(subscription, maybe_group)?;
self.tx
.try_send(Request::Subscribe(Subscribe::new(topic, qos)))
.map_err(|e| Error::new(&format!("error creating MQTT subscription, {}", e)))?;
Ok(())
}
pub fn unsubscribe<S>(
&mut self,
subscription: &S,
maybe_group: Option<&SharedGroup>,
) -> Result<(), Error>
where
S: SubscriptionTopic,
{
let topic = self.get_topic(subscription, maybe_group)?;
self.tx
.try_send(Request::Unsubscribe(Unsubscribe::new(topic)))
.map_err(|e| Error::new(&format!("error creating MQTT subscription, {}", e)))?;
Ok(())
}
fn get_topic<S>(
&self,
subscription: &S,
maybe_group: Option<&SharedGroup>,
) -> Result<String, Error>
where
S: SubscriptionTopic,
{
let mut topic = subscription.subscription_topic(self.id(), self.address.version())?;
if let Some(ref group) = maybe_group {
topic = format!("$share/{group}/{topic}", group = group, topic = topic);
};
Ok(topic)
}
#[cfg(feature = "queue-counter")]
pub fn get_queue_counter(&self) -> QueueCounterHandle {
self.queue_counter.clone()
}
}
#[derive(Debug, Clone)]
pub enum ConnectionMode {
Default,
Service,
Observer,
Bridge,
}
impl fmt::Display for ConnectionMode {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"{}",
match self {
ConnectionMode::Default => "default",
ConnectionMode::Service => "service",
ConnectionMode::Observer => "observer",
ConnectionMode::Bridge => "bridge",
}
)
}
}
impl FromStr for ConnectionMode {
type Err = Error;
fn from_str(val: &str) -> Result<Self, Self::Err> {
match val {
"default" => Ok(ConnectionMode::Default),
"service" => Ok(ConnectionMode::Service),
"observer" => Ok(ConnectionMode::Observer),
"bridge" => Ok(ConnectionMode::Bridge),
_ => Err(Error::new(&format!(
"invalid value for the connection mode: {}",
val
))),
}
}
}
#[derive(Debug, Clone)]
pub struct Connection {
agent_id: AgentId,
version: String,
mode: ConnectionMode,
}
impl Connection {
fn new(agent_id: AgentId) -> Self {
Self {
agent_id,
version: String::from("v2"),
mode: ConnectionMode::Default,
}
}
fn set_version(&mut self, value: &str) -> &mut Self {
self.version = value.to_owned();
self
}
fn set_mode(&mut self, value: ConnectionMode) -> &mut Self {
self.mode = value;
self
}
pub fn agent_id(&self) -> &AgentId {
&self.agent_id
}
pub fn version(&self) -> &str {
&self.version
}
pub fn mode(&self) -> &ConnectionMode {
&self.mode
}
}
impl fmt::Display for Connection {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}/{}/{}", self.version, self.mode, self.agent_id,)
}
}
impl FromStr for Connection {
type Err = Error;
fn from_str(val: &str) -> Result<Self, Self::Err> {
match val.split('/').collect::<Vec<&str>>().as_slice() {
[version_str, mode_str, agent_id_str] => {
let version = (*version_str).to_string();
let mode = ConnectionMode::from_str(mode_str)?;
let agent_id = AgentId::from_str(agent_id_str)?;
Ok(Self {
version,
mode,
agent_id,
})
}
_ => Err(Error::new(&format!(
"invalid value for connection: {}",
val
))),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ConnectionProperties {
agent_id: AgentId,
#[serde(rename = "connection_version")]
version: String,
#[serde(rename = "connection_mode")]
mode: ConnectionMode,
}
impl ConnectionProperties {
pub(crate) fn to_connection(&self) -> Connection {
let mut connection = Connection::new(self.agent_id.clone());
connection.set_version(&self.version);
connection.set_mode(self.mode.clone());
connection
}
}
impl Authenticable for ConnectionProperties {
fn as_account_id(&self) -> &AccountId {
self.agent_id.as_account_id()
}
}
impl Addressable for ConnectionProperties {
fn as_agent_id(&self) -> &AgentId {
&self.agent_id
}
}
impl Authenticable for &ConnectionProperties {
fn as_account_id(&self) -> &AccountId {
self.agent_id.as_account_id()
}
}
impl Addressable for &ConnectionProperties {
fn as_agent_id(&self) -> &AgentId {
&self.agent_id
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum AgentNotification {
Message(Result<IncomingMessage<String>, String>, MessageData),
Reconnection,
ConnectionError,
Puback(PubAck),
Pubrec(PubRec),
Pubcomp(PubComp),
Suback(SubAck),
Unsuback(UnsubAck),
Connect(Connect),
Connack(ConnAck),
Pubrel(PubRel),
Subscribe(Subscribe),
Unsubscribe(Unsubscribe),
PingReq,
PingResp,
Disconnect,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MessageData {
pub dup: bool,
pub qos: QoS,
pub retain: bool,
pub topic: String,
pub pkid: u16,
}
impl From<Packet> for AgentNotification {
fn from(notification: Packet) -> Self {
match notification {
Packet::Publish(message) => {
let message_data = MessageData {
dup: message.dup,
qos: message.qos,
retain: message.retain,
topic: message.topic,
pkid: message.pkid,
};
let env_result =
serde_json::from_slice::<compat::IncomingEnvelope>(&message.payload)
.map_err(|err| format!("Failed to parse incoming envelope: {}", err))
.and_then(|env| match env.properties() {
compat::IncomingEnvelopeProperties::Request(_) => {
compat::into_request(env)
.map_err(|e| format!("Failed to convert into request: {}", e))
}
compat::IncomingEnvelopeProperties::Response(_) => {
compat::into_response(env)
.map_err(|e| format!("Failed to convert into response: {}", e))
}
compat::IncomingEnvelopeProperties::Event(_) => compat::into_event(env)
.map_err(|e| format!("Failed to convert into event: {}", e)),
});
Self::Message(env_result, message_data)
}
Packet::PubAck(p) => Self::Puback(p),
Packet::PubRec(p) => Self::Pubrec(p),
Packet::PubComp(p) => Self::Pubcomp(p),
Packet::SubAck(s) => Self::Suback(s),
Packet::UnsubAck(p) => Self::Unsuback(p),
Packet::Connect(connect) => Self::Connect(connect),
Packet::ConnAck(conn_ack) => Self::Connack(conn_ack),
Packet::PubRel(pub_rel) => Self::Pubrel(pub_rel),
Packet::Subscribe(sub) => Self::Subscribe(sub),
Packet::Unsubscribe(unsub) => Self::Unsubscribe(unsub),
Packet::PingReq => Self::PingReq,
Packet::PingResp => Self::PingResp,
Packet::Disconnect => Self::Disconnect,
}
}
}