use std::{
convert::TryFrom,
fs::File,
io::{BufReader, Seek},
sync::Arc,
};
#[cfg(feature = "rustls-tls-aws-lc")]
use rustls::crypto::aws_lc_rs as provider;
#[cfg(all(feature = "rustls-tls", not(feature = "rustls-tls-aws-lc")))]
use rustls::crypto::ring as provider;
use rustls::{
client::ClientConfig,
pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer, ServerName},
Error as TlsError,
RootCertStore,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
use webpki_roots::TLS_SERVER_ROOTS;
use crate::{
client::options::TlsOptions,
error::{ErrorKind, Result},
};
pub(super) type TlsStream<T> = tokio_rustls::client::TlsStream<T>;
#[derive(Clone)]
pub(crate) struct TlsConfig {
connector: TlsConnector,
}
impl TlsConfig {
pub(crate) fn new(options: TlsOptions) -> Result<TlsConfig> {
let mut tls_config = make_rustls_config(options)?;
tls_config.enable_sni = true;
let connector: TlsConnector = Arc::new(tls_config).into();
Ok(TlsConfig { connector })
}
}
pub(super) async fn tls_connect<T: AsyncRead + AsyncWrite + Unpin>(
host: &str,
tcp_stream: T,
cfg: &TlsConfig,
) -> Result<TlsStream<T>> {
let name = ServerName::try_from(host)
.map_err(|e| ErrorKind::DnsResolve {
message: format!("could not resolve {host:?}: {e}"),
})?
.to_owned();
let conn = cfg
.connector
.connect_with(name, tcp_stream, |c| {
c.set_buffer_limit(None);
})
.await?;
Ok(conn)
}
fn make_rustls_config(cfg: TlsOptions) -> Result<rustls::ClientConfig> {
let mut store = RootCertStore::empty();
if let Some(path) = cfg.ca_file_path {
let ders = CertificateDer::pem_file_iter(&path)
.map_err(|err| ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded root certificate from {}: {err}",
path.display()
),
})?
.flatten();
store.add_parsable_certificates(ders);
} else {
store.extend(TLS_SERVER_ROOTS.iter().cloned());
}
let config_builder = ClientConfig::builder_with_provider(provider::default_provider().into())
.with_safe_default_protocol_versions()
.map_err(|e| ErrorKind::InvalidTlsConfig {
message: format!("built-in provider should support default protocol versions: {e}"),
})?
.with_root_certificates(store);
let mut config = if let Some(path) = cfg.cert_key_file_path {
let mut file = BufReader::new(File::open(&path)?);
let mut certs = vec![];
for cert in CertificateDer::pem_reader_iter(&mut file) {
let cert = cert.map_err(|error| ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded client certificate from {}: {error}",
path.display(),
),
})?;
certs.push(cert);
}
file.rewind()?;
let key = 'key: {
#[cfg(feature = "cert-key-password")]
if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() {
use rustls::pki_types::PrivatePkcs8KeyDer;
use std::io::Read;
let mut contents = vec![];
file.read_to_end(&mut contents)?;
break 'key PrivatePkcs8KeyDer::from(super::pem::decrypt_private_key(
&contents, key_pw,
)?)
.into();
}
match PrivateKeyDer::from_pem_reader(&mut file) {
Ok(key) => break 'key key,
Err(err) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded item from {}: {err}",
path.display(),
),
}
.into())
}
}
};
config_builder
.with_client_auth_cert(certs, key)
.map_err(|error| ErrorKind::InvalidTlsConfig {
message: error.to_string(),
})?
} else {
config_builder.with_no_client_auth()
};
if let Some(true) = cfg.allow_invalid_certificates {
config .dangerous()
.set_certificate_verifier(Arc::new(danger::NoCertVerifier(
provider::default_provider(),
)));
}
Ok(config)
}
mod danger {
use super::*;
use rustls::{
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider},
pki_types::UnixTime,
DigitallySignedStruct,
SignatureScheme,
};
#[derive(Debug)]
pub(super) struct NoCertVerifier(pub(super) CryptoProvider);
impl ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
}