use std::{
io::Cursor,
{fmt, sync::Arc},
};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{
rustls::{server::WebPkiClientVerifier, ClientConfig, RootCertStore, ServerConfig},
TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
};
use super::io::BoxedIo;
use crate::transport::{
server::{Connected, TlsStream},
Certificate, Identity,
};
const ALPN_H2: &[u8] = b"h2";
#[derive(Debug)]
enum TlsError {
H2NotNegotiated,
CertificateParseError,
PrivateKeyParseError,
}
#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
}
impl TlsConnector {
pub(crate) fn new(
ca_cert: Option<Certificate>,
identity: Option<Identity>,
domain: &str,
) -> Result<Self, crate::Error> {
let builder = ClientConfig::builder();
let mut roots = RootCertStore::empty();
#[cfg(feature = "tls-roots")]
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
#[cfg(feature = "tls-webpki-roots")]
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(cert) = ca_cert {
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
}
let builder = builder.with_root_certificates(roots);
let mut config = match identity {
Some(identity) => {
let (client_cert, client_key) = load_identity(identity)?;
builder.with_client_auth_cert(client_cert, client_key)?
}
None => builder.with_no_client_auth(),
};
config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
})
}
pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;
let (_, session) = io.get_ref();
if session.alpn_protocol() != Some(ALPN_H2) {
return Err(TlsError::H2NotNegotiated)?;
}
Ok(BoxedIo::new(io))
}
}
impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
}
}
#[derive(Clone)]
pub(crate) struct TlsAcceptor {
inner: Arc<ServerConfig>,
}
impl TlsAcceptor {
pub(crate) fn new(
identity: Identity,
client_ca_root: Option<Certificate>,
client_auth_optional: bool,
) -> Result<Self, crate::Error> {
let builder = ServerConfig::builder();
let builder = match client_ca_root {
None => builder.with_no_client_auth(),
Some(cert) => {
let mut roots = RootCertStore::empty();
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
let verifier = if client_auth_optional {
WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
} else {
WebPkiClientVerifier::builder(roots.into())
}
.build()?;
builder.with_client_cert_verifier(verifier)
}
};
let (cert, key) = load_identity(identity)?;
let mut config = builder.with_single_cert(cert, key)?;
config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
inner: Arc::new(config),
})
}
pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
acceptor.accept(io).await.map_err(Into::into)
}
}
impl fmt::Debug for TlsAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsAcceptor").finish()
}
}
impl fmt::Display for TlsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
TlsError::PrivateKeyParseError => write!(
f,
"Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
),
}
}
}
impl std::error::Error for TlsError {}
fn load_identity(
identity: Identity,
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError> {
let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert))
.collect::<Result<Vec<_>, _>>()
.map_err(|_| TlsError::CertificateParseError)?;
let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else {
return Err(TlsError::PrivateKeyParseError);
};
Ok((cert, key))
}
fn add_certs_from_pem(
mut certs: &mut dyn std::io::BufRead,
roots: &mut RootCertStore,
) -> Result<(), crate::Error> {
for cert in rustls_pemfile::certs(&mut certs).collect::<Result<Vec<_>, _>>()? {
roots
.add(cert)
.map_err(|_| TlsError::CertificateParseError)?;
}
Ok(())
}