use self::future::RustlsAcceptorFuture;
use crate::{
accept::{Accept, DefaultAcceptor},
server::{io_other, Server},
Address,
};
use arc_swap::ArcSwap;
use rustls::ServerConfig;
use rustls_pki_types::pem::PemObject;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::time::Duration;
use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncWrite},
task::spawn_blocking,
};
use tokio_rustls::server::TlsStream;
pub(crate) mod export {
#[allow(clippy::wildcard_imports)]
use super::*;
#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
pub fn bind_rustls<A: Address>(addr: A, config: RustlsConfig) -> Server<A, RustlsAcceptor> {
super::bind_rustls(addr, config)
}
#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
pub fn from_tcp_rustls(
listener: std::net::TcpListener,
config: RustlsConfig,
) -> io::Result<Server<SocketAddr, RustlsAcceptor>> {
let acceptor = RustlsAcceptor::new(config);
Ok(crate::from_tcp(listener)?.acceptor(acceptor))
}
#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
#[cfg(unix)]
pub fn from_unix_rustls(
listener: std::os::unix::net::UnixListener,
config: RustlsConfig,
) -> io::Result<Server<std::os::unix::net::SocketAddr, RustlsAcceptor>> {
let acceptor = RustlsAcceptor::new(config);
Ok(crate::from_unix(listener)?.acceptor(acceptor))
}
}
pub mod future;
pub fn bind_rustls<A: Address>(addr: A, config: RustlsConfig) -> Server<A, RustlsAcceptor> {
let acceptor = RustlsAcceptor::new(config);
Server::bind(addr).acceptor(acceptor)
}
pub fn from_tcp_rustls(
listener: std::net::TcpListener,
config: RustlsConfig,
) -> io::Result<Server<SocketAddr, RustlsAcceptor>> {
let acceptor = RustlsAcceptor::new(config);
Ok(crate::from_tcp(listener)?.acceptor(acceptor))
}
#[cfg(unix)]
pub fn from_unix_rustls(
listener: std::os::unix::net::UnixListener,
config: RustlsConfig,
) -> io::Result<Server<std::os::unix::net::SocketAddr, RustlsAcceptor>> {
let acceptor = RustlsAcceptor::new(config);
Ok(crate::from_unix(listener)?.acceptor(acceptor))
}
#[derive(Clone)]
pub struct RustlsAcceptor<A = DefaultAcceptor> {
inner: A,
config: RustlsConfig,
handshake_timeout: Duration,
}
impl RustlsAcceptor {
pub fn new(config: RustlsConfig) -> Self {
let inner = DefaultAcceptor::new();
#[cfg(not(test))]
let handshake_timeout = Duration::from_secs(10);
#[cfg(test)]
let handshake_timeout = Duration::from_secs(1);
Self {
inner,
config,
handshake_timeout,
}
}
pub fn handshake_timeout(mut self, val: Duration) -> Self {
self.handshake_timeout = val;
self
}
}
impl<A> RustlsAcceptor<A> {
pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
RustlsAcceptor {
inner: acceptor,
config: self.config,
handshake_timeout: self.handshake_timeout,
}
}
}
impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
where
A: Accept<I, S>,
A::Stream: AsyncRead + AsyncWrite + Unpin,
{
type Stream = TlsStream<A::Stream>;
type Service = A::Service;
type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
fn accept(&self, stream: I, service: S) -> Self::Future {
let inner_future = self.inner.accept(stream, service);
let config = self.config.clone();
RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
}
}
impl<A> fmt::Debug for RustlsAcceptor<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RustlsAcceptor").finish()
}
}
#[derive(Clone)]
pub struct RustlsConfig {
inner: Arc<ArcSwap<ServerConfig>>,
}
impl RustlsConfig {
pub fn from_config(config: Arc<ServerConfig>) -> Self {
let inner = Arc::new(ArcSwap::new(config));
Self { inner }
}
pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
let server_config = spawn_blocking(|| config_from_der(cert, key))
.await
.unwrap()?;
let inner = Arc::new(ArcSwap::from_pointee(server_config));
Ok(Self { inner })
}
pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
let server_config = spawn_blocking(|| config_from_pem(cert, key))
.await
.unwrap()?;
let inner = Arc::new(ArcSwap::from_pointee(server_config));
Ok(Self { inner })
}
pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
let server_config = config_from_pem_file(cert, key).await?;
let inner = Arc::new(ArcSwap::from_pointee(server_config));
Ok(Self { inner })
}
pub fn get_inner(&self) -> Arc<ServerConfig> {
self.inner.load_full()
}
pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
self.inner.store(config);
}
pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
let server_config = spawn_blocking(|| config_from_der(cert, key))
.await
.unwrap()?;
let inner = Arc::new(server_config);
self.inner.store(inner);
Ok(())
}
pub async fn from_pem_chain_file(
chain: impl AsRef<Path>,
key: impl AsRef<Path>,
) -> io::Result<Self> {
let server_config = config_from_pem_chain_file(chain, key).await?;
let inner = Arc::new(ArcSwap::from_pointee(server_config));
Ok(Self { inner })
}
pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
let server_config = spawn_blocking(|| config_from_pem(cert, key))
.await
.unwrap()?;
let inner = Arc::new(server_config);
self.inner.store(inner);
Ok(())
}
pub async fn reload_from_pem_file(
&self,
cert: impl AsRef<Path>,
key: impl AsRef<Path>,
) -> io::Result<()> {
let server_config = config_from_pem_file(cert, key).await?;
let inner = Arc::new(server_config);
self.inner.store(inner);
Ok(())
}
}
impl fmt::Debug for RustlsConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RustlsConfig").finish()
}
}
fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
let cert = cert.into_iter().map(CertificateDer::from).collect();
let key = PrivateKeyDer::try_from(key).map_err(io_other)?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(io_other)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(config)
}
fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
let cert: Vec<CertificateDer> = CertificateDer::pem_slice_iter(&cert)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| io_other("failed to parse certificate"))?;
let mut key_result: Result<PrivateKeyDer, io::Error> =
Err(io_other("The private key file contained no keys"));
for item in rustls_pki_types::pem::PemObject::pem_slice_iter(&key) {
let key: Result<PrivateKeyDer, io::Error> =
item.map_err(|_| io_other("failed to parse PEM"));
match key_result {
Ok(_) => {
if key.is_ok() {
return Err(io_other(
"The private key file containsed multiple keys (it must only contain one)",
));
}
}
Err(_) => key_result = key,
}
}
let key = key_result?;
let cert_der: Vec<Vec<u8>> = cert.into_iter().map(|c| c.to_vec()).collect();
let key_der = key.secret_der().to_vec();
config_from_der(cert_der, key_der)
}
async fn config_from_pem_file(
cert: impl AsRef<Path>,
key: impl AsRef<Path>,
) -> io::Result<ServerConfig> {
let cert = fs_err::tokio::read(cert.as_ref()).await?;
let key = fs_err::tokio::read(key.as_ref()).await?;
config_from_pem(cert, key)
}
async fn config_from_pem_chain_file(
cert: impl AsRef<Path>,
chain: impl AsRef<Path>,
) -> io::Result<ServerConfig> {
let cert = fs_err::tokio::read(cert.as_ref()).await?;
let cert = CertificateDer::pem_slice_iter(&cert)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| io_other("failed to parse certificate"))?;
let key = fs_err::tokio::read(chain.as_ref()).await?;
let key_cert: PrivateKeyDer =
PrivateKeyDer::from_pem_slice(&key).map_err(|_| io_other("could not parse pem file"))?;
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key_cert)
.map_err(|_| io_other("invalid certificate"))
}
#[cfg(test)]
mod tests {
use crate::handle::Handle;
use crate::tls_rustls::{self, RustlsConfig};
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use bytes::Bytes;
use http::{response, Request};
use http_body_util::BodyExt;
use hyper::client::conn::http1::{handshake, SendRequest};
use hyper_util::rt::TokioIo;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
use std::fmt::Debug;
use std::{convert::TryFrom, io, net::SocketAddr, sync::Arc, time::Duration};
use tokio::time::sleep;
use tokio::{net::TcpStream, task::JoinHandle};
use tokio_rustls::TlsConnector;
#[tokio::test]
async fn start_and_request() {
let (_handle, _server_task, addr) = start_server().await;
let (mut client, _conn) = connect(addr).await;
let (_parts, body) = send_empty_request(&mut client).await;
assert_eq!(body.as_ref(), b"Hello, world!");
}
#[ignore]
#[tokio::test]
async fn tls_timeout() {
let (handle, _server_task, addr) = start_server().await;
assert_eq!(handle.connection_count(), 0);
let _stream = TcpStream::connect(addr).await.unwrap();
sleep(Duration::from_millis(500)).await;
assert_eq!(handle.connection_count(), 1);
tokio::time::sleep(Duration::from_millis(1000)).await;
assert_eq!(handle.connection_count(), 0);
}
#[tokio::test]
async fn test_reload() {
let handle = Handle::new();
let config = RustlsConfig::from_pem_file(
"examples/self-signed-certs/cert.pem",
"examples/self-signed-certs/key.pem",
)
.await
.unwrap();
let server_handle = handle.clone();
let rustls_config = config.clone();
tokio::spawn(async move {
let app = Router::new().route("/", get(|| async { "Hello, world!" }));
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
tls_rustls::bind_rustls(addr, rustls_config)
.handle(server_handle)
.serve(app.into_make_service())
.await
});
let addr = handle.listening().await.unwrap();
let cert_a = get_first_cert(addr).await;
let mut cert_b = get_first_cert(addr).await;
assert_eq!(cert_a, cert_b);
config
.reload_from_pem_file(
"examples/self-signed-certs/reload/cert.pem",
"examples/self-signed-certs/reload/key.pem",
)
.await
.unwrap();
cert_b = get_first_cert(addr).await;
assert_ne!(cert_a, cert_b);
config
.reload_from_pem_file(
"examples/self-signed-certs/cert.pem",
"examples/self-signed-certs/key.pem",
)
.await
.unwrap();
cert_b = get_first_cert(addr).await;
assert_eq!(cert_a, cert_b);
}
async fn start_server() -> (Handle<SocketAddr>, JoinHandle<io::Result<()>>, SocketAddr) {
let handle = Handle::new();
let server_handle = handle.clone();
let server_task = tokio::spawn(async move {
let app = Router::new().route("/", get(|| async { "Hello, world!" }));
let config = RustlsConfig::from_pem_file(
"examples/self-signed-certs/cert.pem",
"examples/self-signed-certs/key.pem",
)
.await?;
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
tls_rustls::bind_rustls(addr, config)
.handle(server_handle)
.serve(app.into_make_service())
.await
});
let addr = handle.listening().await.unwrap();
(handle, server_task, addr)
}
async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
let stream = TcpStream::connect(addr).await.unwrap();
let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
let (_io, client_connection) = tls_stream.into_inner();
client_connection.peer_certificates().unwrap()[0].clone()
}
async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
let stream = TcpStream::connect(addr).await.unwrap();
let tls_stream = TokioIo::new(tls_connector().connect(dns_name(), stream).await.unwrap());
let (send_request, connection) = handshake(tls_stream).await.unwrap();
let task = tokio::spawn(async move {
let _ = connection.await;
});
(send_request, task)
}
async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
let (parts, body) = client
.send_request(Request::new(Body::empty()))
.await
.unwrap()
.into_parts();
let body = body.collect().await.unwrap().to_bytes();
(parts, body)
}
fn tls_connector() -> TlsConnector {
#[derive(Debug)]
struct NoVerify;
impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer,
_intermediates: &[CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
]
}
}
let mut client_config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth();
client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
TlsConnector::from(Arc::new(client_config))
}
fn dns_name() -> ServerName<'static> {
ServerName::try_from("localhost").unwrap()
}
#[tokio::test]
async fn from_pem_file_not_found() {
let err = RustlsConfig::from_pem_file(
"examples/self-signed-certs/missing.pem",
"examples/self-signed-certs/key.pem",
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
assert_eq!(
err.to_string(),
"failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
);
let err = RustlsConfig::from_pem_file(
"examples/self-signed-certs/cert.pem",
"examples/self-signed-certs/missing.pem",
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
assert_eq!(
err.to_string(),
"failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
);
}
#[tokio::test]
async fn from_pem_file_chain_file_not_found() {
let err = RustlsConfig::from_pem_chain_file(
"examples/self-signed-certs/missing.pem",
"examples/self-signed-certs/key.pem",
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
assert_eq!(
err.to_string(),
"failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
);
let err = RustlsConfig::from_pem_chain_file(
"examples/self-signed-certs/cert.pem",
"examples/self-signed-certs/missing.pem",
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
assert_eq!(
err.to_string(),
"failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
);
}
}