use std::fs::{self};
use std::io::BufReader;
use std::sync::Arc;
use tokio::{
sync::broadcast::{self, Receiver, Sender},
time::Duration,
};
use tokio_rustls::rustls::{self, RootCertStore};
use rumqttc::{
self, ConnectionError, Event, Incoming, LastWill, MqttOptions, QoS, TlsConfiguration, Transport,
};
#[cfg(feature = "async")]
use rumqttc::{AsyncClient as RumqttcAsyncClient, EventLoop};
use rumqttc::{ClientError, Request, Sender as MqttSender};
use crate::error::IoTError;
const KEEP_ALIVE_SECONDS: u64 = 10;
const MESSAGE_QUEUE_SIZE: usize = 10;
#[derive(Debug)]
pub enum ConnectionType {
TcpCredentials,
TlsCertificates,
}
#[derive(Debug)]
pub struct ConnectionSettings {
connection_type: ConnectionType,
client_id: String,
endpoint: String,
port: u16,
ca_path: String,
client_cert_path: String,
client_key_path: String,
username: String,
password: String,
last_will: Option<LastWill>,
}
impl ConnectionSettings {
pub fn new_tcp(
client_id: String,
endpoint: String,
port: u16,
username: String,
password: String,
last_will: Option<LastWill>,
) -> ConnectionSettings {
ConnectionSettings {
connection_type: ConnectionType::TcpCredentials,
client_id,
endpoint,
port,
ca_path: String::new(),
client_cert_path: String::new(),
client_key_path: String::new(),
username,
password,
last_will,
}
}
pub fn new_tls(
client_id: String,
endpoint: String,
port: u16,
ca_path: String,
client_cert_path: String,
client_key_path: String,
last_will: Option<LastWill>,
) -> ConnectionSettings {
ConnectionSettings {
connection_type: ConnectionType::TlsCertificates,
client_id,
endpoint,
port,
ca_path,
client_cert_path,
client_key_path,
username: String::new(),
password: String::new(),
last_will,
}
}
}
fn tcp_build_mqtt_options(
settings: ConnectionSettings,
) -> Result<MqttOptions, Box<dyn std::error::Error>> {
let mut mqtt_options = MqttOptions::new(settings.client_id, settings.endpoint, settings.port);
mqtt_options.set_keep_alive(Duration::from_secs(KEEP_ALIVE_SECONDS));
mqtt_options.set_credentials(settings.username, settings.password);
mqtt_options.set_transport(Transport::Tcp);
match settings.last_will {
Some(last_will) => {
mqtt_options.set_last_will(last_will);
}
None => (),
}
Ok(mqtt_options)
}
fn load_certs(filename: &str) -> Vec<rustls::Certificate> {
let certfile = fs::File::open(filename).expect("cannot open certificate file");
let mut reader = BufReader::new(certfile);
rustls_pemfile::certs(&mut reader)
.unwrap()
.iter()
.map(|v| rustls::Certificate(v.clone()))
.collect()
}
fn load_private_key(filename: &str) -> rustls::PrivateKey {
let keyfile = fs::File::open(filename).expect("cannot open private key file:");
let mut reader = BufReader::new(keyfile);
loop {
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
None => break,
_ => {}
}
}
panic!(
"no keys found in {:?} (encrypted keys not supported)",
filename
);
}
fn configure_aws_tls(settings: &ConnectionSettings) -> Arc<rustls::ClientConfig> {
let mut root_store = RootCertStore::empty();
let ca_certs = load_certs(&settings.ca_path);
for c in ca_certs.iter() {
root_store.add(c).expect("cannot add root certificate");
}
let certs = load_certs(&settings.client_cert_path);
let key = load_private_key(&settings.client_key_path);
let mut config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_single_cert(certs, key)
.unwrap();
if settings.port == 443u16 {
config.alpn_protocols.extend_from_slice(&["x-amzn-mqtt-ca".as_bytes().to_vec()])
}
Arc::new(config)
}
fn tls_build_mqtt_options(
settings: ConnectionSettings,
) -> Result<MqttOptions, Box<dyn std::error::Error>> {
let mut mqtt_options = MqttOptions::new(&settings.client_id, &settings.endpoint, settings.port);
let transport = Transport::Tls(TlsConfiguration::Rustls(configure_aws_tls(&settings)));
mqtt_options
.set_transport(transport)
.set_keep_alive(Duration::from_secs(10));
match settings.last_will {
Some(last_will) => {
mqtt_options.set_last_will(last_will);
}
None => (),
}
Ok(mqtt_options)
}
fn build_mqtt_options(
settings: ConnectionSettings,
) -> Result<MqttOptions, Box<dyn std::error::Error>> {
match settings.connection_type {
ConnectionType::TcpCredentials => return tcp_build_mqtt_options(settings),
ConnectionType::TlsCertificates => return tls_build_mqtt_options(settings),
}
}
pub async fn eventloop_monitor(
(mut eventloop, incoming_event_broadcaster): (EventLoop, Sender<Incoming>),
) -> Result<(), ConnectionError> {
loop {
match eventloop.poll().await? {
Event::Incoming(e) => {
incoming_event_broadcaster.send(e).unwrap();
}
_ => (),
}
}
}
pub struct AsyncClient {
client: RumqttcAsyncClient,
eventloop_handle: MqttSender<Request>,
incoming_event_broadcaster: Sender<Incoming>,
}
impl AsyncClient {
pub async fn new(
settings: ConnectionSettings,
) -> Result<(AsyncClient, (EventLoop, Sender<Incoming>)), IoTError> {
let mqtt_options = build_mqtt_options(settings).unwrap();
let (client, eventloop) = RumqttcAsyncClient::new(mqtt_options, MESSAGE_QUEUE_SIZE);
let (event_broadcaster, _) = broadcast::channel(16);
let eventloop_handle = eventloop.handle();
let async_client = AsyncClient {
client: client,
eventloop_handle: eventloop_handle,
incoming_event_broadcaster: event_broadcaster.clone(),
};
Ok((async_client, (eventloop, event_broadcaster)))
}
pub async fn publish<S, V>(&self, topic: S, qos: QoS, payload: V) -> Result<(), ClientError>
where
S: Into<String>,
V: Into<Vec<u8>>,
{
self.client.publish(topic, qos, false, payload).await?;
Ok(())
}
pub async fn subscribe<S: Into<String>>(&self, topic: S, qos: QoS) -> Result<(), ClientError> {
self.client.subscribe(topic, qos).await?;
Ok(())
}
pub async fn unsubscribe<S: Into<String>>(&self, topic: S) -> Result<(), ClientError> {
self.client.unsubscribe(topic).await?;
Ok(())
}
pub async fn get_receiver(&self) -> Receiver<Incoming> {
self.incoming_event_broadcaster.subscribe()
}
pub async fn get_client(self) -> RumqttcAsyncClient {
self.client
}
pub fn get_eventloop_handle(&self) -> MqttSender<Request> {
self.eventloop_handle.clone()
}
pub async fn cancel(&self) -> Result<(), ClientError> {
self.client.cancel().await?;
Ok(())
}
pub async fn disconnect(&self) -> Result<(), ClientError> {
self.client.disconnect().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::CONFIG_DIRNAME;
use find_folder::Search;
use std::env;
fn find_config_dir() -> String {
let mut exe_folder = env::current_exe().unwrap();
println!("EXE_FOLDER: {:#?}", exe_folder);
exe_folder.pop(); let pb: std::path::PathBuf = Search::ParentsThenKids(5, 5)
.of(exe_folder)
.for_folder(CONFIG_DIRNAME)
.expect("Config directory not found");
return pb.into_os_string().into_string().unwrap();
}
struct ExpectedMqttOptions {
broker_addr: (String, u16),
client_id: String,
credentials: Option<(String, String)>,
last_will: Option<LastWill>,
}
#[test]
fn tls_build_mqtt_options_test() {
let config_dir: String = find_config_dir();
let settings = ConnectionSettings {
connection_type: ConnectionType::TlsCertificates,
client_id: "16A8_99998".to_string(),
endpoint: "ENDPOINTID-ats.iot.eu-central-1.amazonaws.com".to_string(),
port: 8883,
ca_path: format!("{}{}", config_dir, "/certs/AmazonRootCA1.pem"),
client_cert_path: format!("{}{}", config_dir, "/certs/IotCertificate.pem"),
client_key_path: format!("{}{}", config_dir, "/certs/IotPrivateKey.pem"),
username: "".to_string(),
password: "".to_string(),
last_will: None,
};
let expected_mqtt_options = ExpectedMqttOptions {
broker_addr: (
"ENDPOINTID-ats.iot.eu-central-1.amazonaws.com".to_string(),
8883,
),
client_id: "16A8_99998".to_string(),
credentials: None,
last_will: None,
};
let returned_mqtt_options = tls_build_mqtt_options(settings).unwrap();
assert_eq!(
returned_mqtt_options.broker_address(),
expected_mqtt_options.broker_addr
);
assert_eq!(
returned_mqtt_options.client_id(),
expected_mqtt_options.client_id
);
assert!(returned_mqtt_options.last_will().is_none());
match returned_mqtt_options.credentials() {
Some((_, _)) => {
assert!(false);
}
None => {
assert!(true);
}
}
}
#[test]
fn tcp_build_mqtt_options_test() {
let settings = ConnectionSettings {
connection_type: ConnectionType::TcpCredentials,
client_id: "adapterClient".to_string(),
endpoint: "127.0.0.1".to_string(),
port: 1883,
ca_path: "".to_string(),
client_cert_path: "".to_string(),
client_key_path: "".to_string(),
username: "username".to_string(),
password: "password".to_string(),
last_will: None,
};
let expected_mqtt_options = ExpectedMqttOptions {
broker_addr: ("127.0.0.1".to_string(), 1883),
client_id: "adapterClient".to_string(),
credentials: Some(("username".to_string(), "password".to_string())),
last_will: None,
};
let returned_mqtt_options = tcp_build_mqtt_options(settings).unwrap();
assert_eq!(
returned_mqtt_options.broker_address(),
expected_mqtt_options.broker_addr
);
assert_eq!(
returned_mqtt_options.client_id(),
expected_mqtt_options.client_id
);
assert!(returned_mqtt_options.last_will().is_none());
assert_eq!(
returned_mqtt_options.credentials(),
expected_mqtt_options.credentials
);
assert_eq!(
returned_mqtt_options.last_will(),
expected_mqtt_options.last_will
);
}
}