#[cfg(feature = "tls-server")]
pub(crate) mod rustls_server;
use crate::StreamOps;
use crate::tls::TlsAcceptorSettings;
use crate::traits::{CertifiedConn, TlsConnector, TlsProvider};
use async_trait::async_trait;
use futures::{AsyncRead, AsyncWrite};
use futures_rustls::rustls::{self, crypto::CryptoProvider};
use rustls::client::danger;
use rustls::crypto::{WebPkiSupportedAlgorithms, verify_tls12_signature, verify_tls13_signature};
use rustls::{CertificateError, Error as TLSError};
use rustls_pki_types::{CertificateDer, ServerName};
use tracing::instrument;
use webpki::EndEntityCert;
use std::borrow::Cow;
use std::{
io::{self, Error as IoError, Result as IoResult},
sync::Arc,
};
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "rustls",
any(feature = "tokio", feature = "async-std", feature = "smol")
)))
)]
#[derive(Clone)]
#[non_exhaustive]
pub struct RustlsProvider {
config: Arc<futures_rustls::rustls::ClientConfig>,
}
impl<S> CertifiedConn for futures_rustls::client::TlsStream<S> {
fn peer_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
let (_, session) = self.get_ref();
Ok(session
.peer_certificates()
.and_then(|certs| certs.first().map(|c| Cow::from(c.as_ref()))))
}
fn export_keying_material(
&self,
len: usize,
label: &[u8],
context: Option<&[u8]>,
) -> IoResult<Vec<u8>> {
let (_, session) = self.get_ref();
session
.export_keying_material(vec![0_u8; len], label, context)
.map_err(|e| IoError::new(io::ErrorKind::InvalidData, e))
}
fn own_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
Ok(None)
}
}
impl<S: StreamOps> StreamOps for futures_rustls::client::TlsStream<S> {
fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
self.get_ref().0.set_tcp_notsent_lowat(notsent_lowat)
}
fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
self.get_ref().0.new_handle()
}
}
pub struct RustlsConnector<S> {
connector: futures_rustls::TlsConnector,
_phantom: std::marker::PhantomData<fn(S) -> S>,
}
#[async_trait]
impl<S> TlsConnector<S> for RustlsConnector<S>
where
S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
{
type Conn = futures_rustls::client::TlsStream<S>;
#[instrument(skip_all, level = "trace")]
async fn negotiate_unvalidated(&self, stream: S, sni_hostname: &str) -> IoResult<Self::Conn> {
let name: ServerName<'_> = sni_hostname
.try_into()
.map_err(|e| IoError::new(io::ErrorKind::InvalidInput, e))?;
self.connector.connect(name.to_owned(), stream).await
}
}
impl<S> TlsProvider<S> for RustlsProvider
where
S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
{
type Connector = RustlsConnector<S>;
type TlsStream = futures_rustls::client::TlsStream<S>;
fn tls_connector(&self) -> Self::Connector {
let connector = futures_rustls::TlsConnector::from(Arc::clone(&self.config));
RustlsConnector {
connector,
_phantom: std::marker::PhantomData,
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "tls-server")] {
type Acceptor = rustls_server::RustlsAcceptor<S>;
type TlsServerStream = rustls_server::RustlsServerStream<S>;
fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
rustls_server::RustlsAcceptor::new(&settings)
}
} else {
type Acceptor = crate::tls::UnimplementedTls;
type TlsServerStream = crate::tls::UnimplementedTls;
fn tls_acceptor(&self, _settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
Err(crate::tls::TlsServerUnsupported{}.into())
}
}
}
fn supports_keying_material_export(&self) -> bool {
true
}
}
fn ensure_provider_installed() {
if CryptoProvider::get_default().is_none() {
tracing::warn!(
"Creating a RustlsRuntime, but no CryptoProvider is installed. The application \
should call CryptoProvider::install_default()"
);
let _idempotent_ignore = CryptoProvider::install_default(
futures_rustls::rustls::crypto::ring::default_provider(),
);
}
}
impl RustlsProvider {
pub(crate) fn new() -> Self {
ensure_provider_installed();
let mut config = futures_rustls::rustls::client::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(Verifier(
CryptoProvider::get_default()
.expect("CryptoProvider not installed")
.signature_verification_algorithms,
)))
.with_no_client_auth();
config.resumption = futures_rustls::rustls::client::Resumption::disabled();
RustlsProvider {
config: Arc::new(config),
}
}
}
impl Default for RustlsProvider {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
struct Verifier(pub(crate) WebPkiSupportedAlgorithms);
impl danger::ServerCertVerifier for Verifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer,
_roots: &[CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<danger::ServerCertVerified, TLSError> {
let _cert: EndEntityCert<'_> = end_entity
.try_into()
.map_err(|_| TLSError::InvalidCertificate(CertificateError::BadEncoding))?;
Ok(danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, TLSError> {
verify_tls12_signature(message, cert, dss, &self.0)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, TLSError> {
verify_tls13_signature(message, cert, dss, &self.0)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.supported_schemes()
}
fn root_hint_subjects(&self) -> Option<&[rustls::DistinguishedName]> {
None
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
const TOR_CERTIFICATE: &[u8] = include_bytes!("./tor-generated.der");
#[test]
fn basic_tor_cert() {
ensure_provider_installed();
let der = CertificateDer::from_slice(TOR_CERTIFICATE);
let _cert = EndEntityCert::try_from(&der).unwrap();
}
}