use futures_util::future;
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::task::{Context, Poll};
use rustls::{
client::{
danger::{ServerCertVerified, ServerCertVerifier},
WebPkiServerVerifier,
},
crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider},
pki_types::{
pem::{self, PemObject},
CertificateDer, PrivateKeyDer, ServerName, UnixTime,
},
CertificateError, ClientConfig, ClientConnection, Error as TlsError, RootCertStore,
};
use crate::error::Error;
use crate::io::ReadBuf;
use crate::net::tls::util::StdSocket;
use crate::net::tls::TlsConfig;
use crate::net::Socket;
pub struct RustlsSocket<S: Socket> {
inner: StdSocket<S>,
state: ClientConnection,
close_notify_sent: bool,
}
impl<S: Socket> RustlsSocket<S> {
fn poll_complete_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match self.state.complete_io(&mut self.inner) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
futures_util::ready!(self.inner.poll_ready(cx))?;
}
ready => return Poll::Ready(ready.map(|_| ())),
}
}
}
async fn complete_io(&mut self) -> io::Result<()> {
future::poll_fn(|cx| self.poll_complete_io(cx)).await
}
}
impl<S: Socket> Socket for RustlsSocket<S> {
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
self.state.reader().read(buf.init_mut())
}
fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.state.writer().write(buf) {
Ok(0) => Err(io::ErrorKind::WouldBlock.into()),
other => other,
}
}
fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_complete_io(cx)
}
fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_complete_io(cx)
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_complete_io(cx)
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if !self.close_notify_sent {
self.state.send_close_notify();
self.close_notify_sent = true;
}
futures_util::ready!(self.poll_complete_io(cx))?;
let _ = futures_util::ready!(self.inner.socket.poll_shutdown(cx));
Poll::Ready(Ok(()))
}
}
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
where
S: Socket,
{
#[cfg(all(
feature = "_tls-rustls-aws-lc-rs",
not(feature = "_tls-rustls-ring-webpki"),
not(feature = "_tls-rustls-ring-native-roots")
))]
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
#[cfg(any(
feature = "_tls-rustls-ring-webpki",
feature = "_tls-rustls-ring-native-roots"
))]
let provider = Arc::new(rustls::crypto::ring::default_provider());
let config = ClientConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions()
.unwrap();
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert_chain = certs_from_pem(cert_path.data().await?)?;
let key_der = private_key_from_pem(key_path.data().await?)?;
Some((cert_chain, key_der))
}
(None, None) => None,
(_, _) => {
return Err(Error::Configuration(
"user auth key and certs must be given together".into(),
))
}
};
let config = if tls_config.accept_invalid_certs {
if let Some(user_auth) = user_auth {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
.with_no_client_auth()
}
} else {
#[cfg(any(feature = "_tls-rustls-aws-lc-rs", feature = "_tls-rustls-ring-webpki"))]
let mut cert_store = certs_from_webpki();
#[cfg(feature = "_tls-rustls-ring-native-roots")]
let mut cert_store = certs_from_native_store();
if let Some(ca) = tls_config.root_cert_path {
let data = ca.data().await?;
for result in CertificateDer::pem_slice_iter(&data) {
let Ok(cert) = result else {
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
};
cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
}
}
if tls_config.accept_invalid_hostnames {
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
.build()
.map_err(|err| Error::Tls(err.into()))?;
if let Some(user_auth) = user_auth {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_no_client_auth()
}
} else if let Some(user_auth) = user_auth {
config
.with_root_certificates(cert_store)
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.with_root_certificates(cert_store)
.with_no_client_auth()
}
};
let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?;
let mut socket = RustlsSocket {
inner: StdSocket::new(socket),
state: ClientConnection::new(Arc::new(config), host).map_err(Error::tls)?,
close_notify_sent: false,
};
socket.complete_io().await?;
Ok(socket)
}
fn certs_from_pem(pem: Vec<u8>) -> Result<Vec<CertificateDer<'static>>, Error> {
CertificateDer::pem_slice_iter(&pem)
.map(|result| result.map_err(|err| Error::Tls(err.into())))
.collect()
}
fn private_key_from_pem(pem: Vec<u8>) -> Result<PrivateKeyDer<'static>, Error> {
match PrivateKeyDer::from_pem_slice(&pem) {
Ok(key) => Ok(key),
Err(pem::Error::NoItemsFound) => Err(Error::Configuration("no keys found pem file".into())),
Err(e) => Err(Error::Configuration(e.to_string().into())),
}
}
#[cfg(any(feature = "_tls-rustls-aws-lc-rs", feature = "_tls-rustls-ring-webpki"))]
fn certs_from_webpki() -> RootCertStore {
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned())
}
#[cfg(feature = "_tls-rustls-ring-native-roots")]
fn certs_from_native_store() -> RootCertStore {
let mut root_cert_store = RootCertStore::empty();
let load_results = rustls_native_certs::load_native_certs();
for e in load_results.errors {
log::warn!("Error loading native certificates: {e:?}");
}
for cert in load_results.certs {
if let Err(e) = root_cert_store.add(cert.into()) {
log::warn!("rustls failed to parse native certificate: {e:?}");
}
}
root_cert_store
}
#[derive(Debug)]
struct DummyTlsVerifier {
provider: Arc<CryptoProvider>,
}
impl ServerCertVerifier for DummyTlsVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, TlsError> {
verify_tls12_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, TlsError> {
verify_tls13_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.provider
.signature_verification_algorithms
.supported_schemes()
}
}
#[derive(Debug)]
pub struct NoHostnameTlsVerifier {
verifier: Arc<WebPkiServerVerifier>,
}
impl ServerCertVerifier for NoHostnameTlsVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, TlsError> {
match self.verifier.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Err(TlsError::InvalidCertificate(CertificateError::NotValidForName)) => {
Ok(ServerCertVerified::assertion())
}
res => res,
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, TlsError> {
self.verifier.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, TlsError> {
self.verifier.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.verifier.supported_verify_schemes()
}
}