use async_rustls::{server::TlsStream, TlsAcceptor};
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::certs;
use std::{
fmt::{Debug, Formatter},
io::{BufReader, Error, Result},
sync::Arc,
};
use trillium_tls_common::{async_trait, Acceptor, AsyncRead, AsyncWrite};
#[derive(Clone)]
pub struct RustlsAcceptor(TlsAcceptor);
impl Debug for RustlsAcceptor {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("RustTls").field(&"<<TlsAcceptor>>").finish()
}
}
impl RustlsAcceptor {
pub fn new(t: impl Into<Self>) -> Self {
t.into()
}
pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
let mut br = BufReader::new(cert);
let certs = certs(&mut br)
.expect("could not read cert pemfile")
.into_iter()
.map(Certificate)
.collect();
let mut br = BufReader::new(key);
let key = rustls_pemfile::pkcs8_private_keys(&mut br)
.expect("could not read key pemfile")
.first()
.expect("no pkcs8 private key found in `key`")
.to_owned();
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, PrivateKey(key))
.expect("could not create a rustls ServerConfig from the supplied cert and key")
.into()
}
}
impl From<ServerConfig> for RustlsAcceptor {
fn from(sc: ServerConfig) -> Self {
Self(Arc::new(sc).into())
}
}
impl From<TlsAcceptor> for RustlsAcceptor {
fn from(ta: TlsAcceptor) -> Self {
Self(ta)
}
}
#[async_trait]
impl<Input> Acceptor<Input> for RustlsAcceptor
where
Input: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
{
type Output = TlsStream<Input>;
type Error = Error;
async fn accept(&self, input: Input) -> Result<Self::Output> {
self.0.accept(input).await
}
}