use tokio_rustls::{client::TlsStream, TlsConnector};
use tokio_rustls::rustls::{ClientConfig, TLSError};
use tokio_rustls::webpki::{self, DNSNameRef, InvalidDNSNameError};
use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio::net::TcpStream;
use crate::MqttOptions;
use std::net::AddrParseError;
use std::io;
use std::sync::Arc;
use std::io::{Cursor, BufReader};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Addr")]
Addr(#[from] AddrParseError),
#[error("I/O")]
Io(#[from] io::Error),
#[error("Web Pki")]
WebPki(#[from] webpki::Error),
#[error("DNS name")]
DNSName(#[from] InvalidDNSNameError),
#[error("TLS error")]
TLS(#[from] TLSError),
#[error("No valid cert in chain")]
NoValidCertInChain
}
impl From<()> for Error {
fn from(_: ()) -> Self {
Error::NoValidCertInChain
}
}
pub async fn tls_connect(options: &MqttOptions) -> Result<TlsStream<TcpStream>, Error> {
let addr = format!("{}:{}", options.broker_addr, options.port);
let tcp = TcpStream::connect(addr).await?;
let mut config = ClientConfig::new();
let ca = options.ca.as_ref().unwrap();
if config.root_store.add_pem_file(&mut BufReader::new(Cursor::new(ca)))?.0 == 0 {
return Err(Error::NoValidCertInChain)
}
if let Some(client) = options.client_auth.as_ref() {
let certs = certs(&mut BufReader::new(Cursor::new(client.0.clone())))?;
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(client.1.clone())))?;
config.set_single_client_cert(certs, keys.remove(0))?;
}
if let Some(alpn) = options.alpn.as_ref() {
config.set_protocols(&alpn);
}
let connector = TlsConnector::from(Arc::new(config));
let domain = DNSNameRef::try_from_ascii_str(&options.broker_addr)?;
let tls = connector.connect(domain, tcp).await?;
Ok(tls)
}
pub async fn tcp_connect(options: &MqttOptions) -> Result<TcpStream, Error> {
let addr = format!("{}:{}", options.broker_addr, options.port);
let tcp = TcpStream::connect(addr).await?;
Ok(tcp)
}