use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub struct TlsConfig {
cert: std::path::PathBuf,
pk: std::path::PathBuf,
config: Arc<Mutex<Arc<rustls::server::ServerConfig>>>,
}
impl TlsConfig {
pub async fn new(
cert: &std::path::Path,
pk: &std::path::Path,
) -> std::io::Result<Self> {
let cert = cert.to_owned();
let pk = pk.to_owned();
let config = Self::load(&cert, &pk).await?;
Ok(Self {
cert,
pk,
config: Arc::new(Mutex::new(config)),
})
}
pub fn config(&self) -> Arc<rustls::server::ServerConfig> {
self.config.lock().unwrap().clone()
}
#[allow(dead_code)] pub async fn reload(&self) -> std::io::Result<()> {
let new_config = Self::load(&self.cert, &self.pk).await?;
*self.config.lock().unwrap() = new_config;
Ok(())
}
async fn load(
cert: &std::path::Path,
pk: &std::path::Path,
) -> std::io::Result<Arc<rustls::server::ServerConfig>> {
let cert = tokio::fs::read(cert).await?;
let pk = tokio::fs::read(pk).await?;
let mut certs = Vec::new();
for cert in rustls_pemfile::certs(&mut std::io::Cursor::new(&cert)) {
certs.push(cert?);
}
let pk = rustls_pemfile::private_key(&mut std::io::Cursor::new(&pk))?
.ok_or_else(|| {
std::io::Error::other("error reading priv key")
})?;
Ok(Arc::new(
rustls::server::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, pk)
.map_err(std::io::Error::other)?,
))
}
}
#[non_exhaustive]
#[derive(Debug)]
pub enum MaybeTlsStream {
Tcp(tokio::net::TcpStream),
Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
}
impl MaybeTlsStream {
pub async fn tls(
tls_config: &TlsConfig,
tcp: tokio::net::TcpStream,
) -> std::io::Result<Self> {
let config = tls_config.config();
let tls = tokio_rustls::TlsAcceptor::from(config).accept(tcp).await?;
Ok(Self::Tls(tls))
}
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Tcp(ref mut s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx),
}
}
}