#![cfg_attr(feature = "rustls-tls", allow(dead_code))]
use super::*;
use once_cell::sync::Lazy;
use openssl::{
error::ErrorStack,
pkey::{PKey, Private},
ssl::{self, Ssl},
x509::X509,
};
use std::pin::Pin;
use tokio_openssl::SslStream;
pub(in crate::server) type TlsAcceptor = ssl::SslAcceptor;
#[derive(Debug, thiserror::Error)]
pub(in crate::server) enum AcceptError {
#[error("failed to construct SSL session from acceptor context: {0}")]
Ssl(#[source] ErrorStack),
#[error("failed to construct SslStream from SSL session: {0}")]
Stream(#[source] ErrorStack),
#[error("failed to accept TLS connection: {0}")]
Accept(#[from] ssl::Error),
}
pub(in crate::server) async fn accept(
acceptor: &TlsAcceptor,
sock: TcpStream,
) -> Result<SslStream<TcpStream>, AcceptError> {
let ssl = Ssl::new(acceptor.context()).map_err(AcceptError::Ssl)?;
let mut stream = SslStream::new(ssl, sock).map_err(AcceptError::Stream)?;
Pin::new(&mut stream).accept().await?;
Ok(stream)
}
pub(in crate::server) async fn load_tls(
pk: &TlsKeyPath,
crts: &TlsCertPath,
) -> Result<TlsAcceptor, Error> {
let key = load_private_key(pk).await.map_err(Error::TlsKeyReadError)?;
let certs = load_certs(crts).await.map_err(Error::TlsCertsReadError)?;
configure(key, certs).map_err(|error| Error::InvalidTlsCredentials(Box::new(error)))
}
fn configure(key: PKey<Private>, certs: Vec<X509>) -> Result<TlsAcceptor, ErrorStack> {
let mut conn = {
let method = ssl::SslMethod::tls_server();
ssl::SslAcceptor::mozilla_intermediate_v5(method)?
};
conn.set_verify(ssl::SslVerifyMode::NONE);
conn.set_private_key(&key)?;
conn.set_certificate(&certs[0])?;
for c in certs.iter().skip(1) {
conn.add_extra_chain_cert(c.to_owned())?;
}
conn.set_alpn_protos(&ALPN_PROTOCOLS)?;
Ok(conn.build())
}
static ALPN_PROTOCOLS: Lazy<Vec<u8>> = Lazy::new(|| {
let protocols: &[&[u8]] = &[b"h2", b"http/1.1"];
let mut bytes = {
let cap = protocols.len() + protocols.iter().map(|p| p.len()).sum::<usize>();
Vec::with_capacity(cap)
};
for p in protocols {
if p.is_empty() {
continue;
}
debug_assert!(p.len() <= 255, "ALPN protocols must be less than 256 bytes");
bytes.push(p.len() as u8);
bytes.extend_from_slice(p);
}
bytes
});
async fn load_certs(TlsCertPath(cp): &TlsCertPath) -> std::io::Result<Vec<X509>> {
let pem = tokio::fs::read(cp).await?;
let certs = X509::stack_from_pem(&pem)?;
Ok(certs)
}
async fn load_private_key(TlsKeyPath(kp): &TlsKeyPath) -> std::io::Result<PKey<Private>> {
let pem = tokio::fs::read(kp).await?;
Ok(PKey::private_key_from_pem(&pem)?)
}