use std::pin::Pin;
use openssl::{
error::ErrorStack,
ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_openssl::SslStream;
use crate::{
client::options::TlsOptions,
error::{Error, ErrorKind, Result},
};
pub(super) type TlsStream<T> = SslStream<T>;
#[derive(Clone)]
pub(crate) struct TlsConfig {
connector: SslConnector,
verify_hostname: bool,
}
impl TlsConfig {
pub(crate) fn new(options: TlsOptions) -> Result<TlsConfig> {
let verify_hostname = match options.allow_invalid_hostnames {
Some(b) => !b,
None => true,
};
let connector = make_openssl_connector(options)?;
Ok(TlsConfig {
connector,
verify_hostname,
})
}
}
pub(super) async fn tls_connect<T: AsyncRead + AsyncWrite + Unpin>(
host: &str,
tcp_stream: T,
cfg: &TlsConfig,
) -> Result<TlsStream<T>> {
let mut stream = make_ssl_stream(host, tcp_stream, cfg).map_err(|err| {
Error::from(ErrorKind::InvalidTlsConfig {
message: err.to_string(),
})
})?;
Pin::new(&mut stream).connect().await.map_err(|err| {
use std::io;
match err.into_io_error() {
Ok(err) => err,
Err(err) => io::Error::other(err),
}
})?;
Ok(stream)
}
fn make_openssl_connector(cfg: TlsOptions) -> Result<SslConnector> {
let openssl_err = |e: ErrorStack| {
Error::from(ErrorKind::InvalidTlsConfig {
message: e.to_string(),
})
};
let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(openssl_err)?;
let probe = openssl_probe::probe();
builder
.load_verify_locations(probe.cert_file.as_deref(), probe.cert_dir.as_deref())
.map_err(openssl_err)?;
let TlsOptions {
allow_invalid_certificates,
ca_file_path,
cert_key_file_path,
allow_invalid_hostnames: _,
#[cfg(feature = "cert-key-password")]
tls_certificate_key_file_password,
} = cfg;
if let Some(true) = allow_invalid_certificates {
builder.set_verify(SslVerifyMode::NONE);
}
if let Some(path) = ca_file_path {
builder.set_ca_file(path).map_err(openssl_err)?;
}
if let Some(path) = cert_key_file_path {
builder
.set_certificate_file(path.clone(), SslFiletype::PEM)
.map_err(openssl_err)?;
let handle_private_key = || -> Result<()> {
#[cfg(feature = "cert-key-password")]
if let Some(key_pw) = tls_certificate_key_file_password {
let contents = std::fs::read(&path)?;
let key = openssl::pkey::PKey::private_key_from_pem_passphrase(&contents, &key_pw)
.map_err(openssl_err)?;
builder.set_private_key(&key).map_err(openssl_err)?;
return Ok(());
}
builder
.set_private_key_file(path, SslFiletype::PEM)
.map_err(openssl_err)
};
handle_private_key()?;
}
Ok(builder.build())
}
fn make_ssl_stream<T: AsyncRead + AsyncWrite>(
host: &str,
tcp_stream: T,
cfg: &TlsConfig,
) -> std::result::Result<SslStream<T>, ErrorStack> {
let ssl = cfg
.connector
.configure()?
.use_server_name_indication(true)
.verify_hostname(cfg.verify_hostname)
.into_ssl(host)?;
SslStream::new(ssl, tcp_stream)
}