use async_std::net::TcpStream;
use rustls::internal::pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use rustls::{ClientConfig, TLSError};
use webpki::{self, InvalidDNSNameError};
use async_tls::{client::TlsStream, TlsConnector};
use crate::{Key, MqttOptions, TlsConfiguration};
use std::io;
use std::io::{BufReader, Cursor};
use std::net::AddrParseError;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Addr")]
Addr(#[from] AddrParseError),
#[error("I/O")]
Io(#[from] io::Error),
#[error("Web Pki")]
WebPki(#[from] webpki::Error),
#[error("DNS name")]
DNSName(#[from] InvalidDNSNameError),
#[error("TLS error")]
TLS(#[from] TLSError),
#[error("No valid cert in chain")]
NoValidCertInChain,
}
impl From<()> for Error {
fn from(_: ()) -> Self {
Error::NoValidCertInChain
}
}
pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result<TlsConnector, Error> {
let config = match tls_config {
TlsConfiguration::Simple {
ca,
alpn,
client_auth,
} => {
let mut config = ClientConfig::new();
if config
.root_store
.add_pem_file(&mut BufReader::new(Cursor::new(ca)))?
.0
== 0
{
return Err(Error::NoValidCertInChain);
}
if let Some(client) = client_auth.as_ref() {
let certs = certs(&mut BufReader::new(Cursor::new(client.0.clone())))?;
let read_keys = match &client.1 {
Key::RSA(k) => rsa_private_keys(&mut BufReader::new(Cursor::new(k.clone()))),
Key::ECC(k) => pkcs8_private_keys(&mut BufReader::new(Cursor::new(k.clone()))),
};
let keys = match read_keys {
Ok(v) => v,
Err(_e) => return Err(Error::NoValidCertInChain),
};
let key = match keys.first() {
Some(k) => k.clone(),
None => return Err(Error::NoValidCertInChain),
};
config.set_single_client_cert(certs, key)?;
}
if let Some(alpn) = alpn.as_ref() {
config.set_protocols(&alpn);
}
Arc::new(config)
}
TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(),
};
Ok(TlsConnector::from(config))
}
pub async fn tls_connect(
options: &MqttOptions,
tls_config: &TlsConfiguration,
) -> Result<TlsStream<TcpStream>, Error> {
let addr = options.broker_addr.as_str();
let port = options.port;
let connector = tls_connector(tls_config).await?;
let domain = &options.broker_addr;
let tcp = TcpStream::connect((addr, port)).await?;
let tls = connector.connect(domain, tcp).await?;
Ok(tls)
}