use std::{future::Ready, path::Path, sync::Arc, task::Poll};
use pin_project_lite::pin_project;
use rustls::{
OtherError, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
server::ClientHello,
sign::CertifiedKey,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use crate::{Accept, IntoAccept};
macro_rules! r#try {
($($tt:tt)*) => {
(|| { $($tt)* })()
};
}
#[derive(Debug, Clone)]
pub struct Certificate(Arc<CertifiedKey>);
impl Certificate {
#[inline]
pub fn from_pem(pem: &[u8]) -> Result<Self, rustls::Error> {
Self::from_der(
CertificateDer::pem_slice_iter(pem)
.collect::<Result<_, rustls::pki_types::pem::Error>>()
.map_err(|err| rustls::Error::Other(OtherError(Arc::new(err))))?,
PrivateKeyDer::from_pem_slice(pem)
.map_err(|err| rustls::Error::Other(OtherError(Arc::new(err))))?,
)
}
#[inline]
pub fn from_der(
cert_chain: Box<[CertificateDer<'static>]>,
private_key: PrivateKeyDer<'static>,
) -> Result<Self, rustls::Error> {
Ok(Self(Arc::new(CertifiedKey::from_der(
cert_chain.into_vec(),
private_key,
&rustls::crypto::aws_lc_rs::default_provider(),
)?)))
}
}
impl From<CertifiedKey> for Certificate {
#[inline]
fn from(value: CertifiedKey) -> Self {
Self(Arc::new(value))
}
}
impl From<Arc<CertifiedKey>> for Certificate {
#[inline]
fn from(value: Arc<CertifiedKey>) -> Self {
Self(value)
}
}
pub trait Resolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Certificate>;
}
impl Resolver for () {
#[inline]
fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Certificate> {
None
}
}
impl Resolver for Certificate {
#[inline]
fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Certificate> {
Some(self.clone())
}
}
pub trait IntoResolver {
type Resolver: Resolver;
fn into_resolver(self) -> std::io::Result<Self::Resolver>;
}
impl<T> IntoResolver for T
where
T: Resolver,
{
type Resolver = T;
#[inline]
fn into_resolver(self) -> std::io::Result<Self::Resolver> {
Ok(self)
}
}
pub struct Pem<P>(P);
impl<P> IntoResolver for Pem<P>
where
P: AsRef<Path>,
{
type Resolver = Certificate;
#[inline]
fn into_resolver(self) -> std::io::Result<Self::Resolver> {
Certificate::from_pem(&std::fs::read(self.0)?).map_err(std::io::Error::other)
}
}
#[derive(Debug, Default)]
pub struct Tls<T = ()>(T);
impl Tls {
#[inline]
pub const fn new() -> Self {
Self(())
}
}
impl<T> Tls<T> {
#[inline]
pub fn with_resolver<R>(self, resolver: R) -> Tls<R>
where
R: IntoResolver,
{
Tls(resolver)
}
#[inline]
pub fn with_certificate<P>(self, path: P) -> Tls<Pem<P>>
where
P: AsRef<Path>,
{
self.with_resolver(Pem(path))
}
}
impl<I, S, T> IntoAccept<I, S> for Tls<T>
where
I: AsyncRead + AsyncWrite + Unpin,
T: IntoResolver,
T::Resolver: Send + Sync + 'static,
{
type Accept = TlsAcceptor;
type Future = Ready<std::io::Result<Self::Accept>>;
#[inline]
fn into_accept(self) -> Self::Future {
::core::future::ready(r#try! {
let resolver = Arc::new(ResolvesServerCert(self.0.into_resolver()?));
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
let mut config = ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(std::io::Error::other)?
.with_no_client_auth()
.with_cert_resolver(resolver);
config.alpn_protocols = vec![b"h2".into(), b"http/1.1".into()];
let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
Ok(TlsAcceptor(acceptor))
})
}
}
pub struct TlsAcceptor(tokio_rustls::TlsAcceptor);
impl<I, S> Accept<I, S> for TlsAcceptor
where
I: AsyncRead + AsyncWrite + Unpin,
{
type Stream = TlsStream<I>;
type Service = S;
type Future = TlsAcceptorFuture<I, S>;
#[inline]
fn accept(&self, stream: I, service: S) -> Self::Future {
TlsAcceptorFuture {
service: Some(service),
accept: self.0.accept(stream),
}
}
}
pin_project! {
#[doc(hidden)]
pub struct TlsAcceptorFuture<I, S> {
service: Option<S>,
#[pin] accept: tokio_rustls::Accept<I>,
}
}
impl<I, S> Future for TlsAcceptorFuture<I, S>
where
I: AsyncRead + AsyncWrite + Unpin,
{
type Output = std::io::Result<(TlsStream<I>, S)>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.accept.poll(cx) {
Poll::Ready(Ok(stream)) => Poll::Ready(Ok((stream, this.service.take().unwrap()))),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,
}
}
}
struct ResolvesServerCert<T>(T);
impl<T> std::fmt::Debug for ResolvesServerCert<T> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("ResolvesServerCert")
.field(&::core::any::type_name::<T>())
.finish()
}
}
impl<T> rustls::server::ResolvesServerCert for ResolvesServerCert<T>
where
T: Resolver + Send + Sync,
{
#[inline]
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.0.resolve(client_hello).map(|x| x.0)
}
}