use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
use crate::config::TlsConfig;
pub enum ClientStream {
Plain(TcpStream),
Tls(Box<TlsStream<TcpStream>>),
}
impl ClientStream {
pub fn peer_cert_present(&self) -> bool {
match self {
ClientStream::Plain(_) => false,
ClientStream::Tls(s) => s
.get_ref()
.1
.peer_certificates()
.map(|c| !c.is_empty())
.unwrap_or(false),
}
}
}
impl AsyncRead for ClientStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
ClientStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
}
}
}
impl AsyncWrite for ClientStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
ClientStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
ClientStream::Plain(s) => Pin::new(s).poll_flush(cx),
ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
ClientStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
}
}
}
pub fn build_tls_acceptor(tls: &TlsConfig) -> Result<TlsAcceptor, String> {
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
let cert_chain: Vec<CertificateDer<'static>> = {
let data = std::fs::read(&tls.cert_path)
.map_err(|e| format!("reading cert {}: {}", tls.cert_path, e))?;
rustls_pemfile::certs(&mut &data[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("parsing cert {}: {}", tls.cert_path, e))?
};
if cert_chain.is_empty() {
return Err(format!("no certificates found in {}", tls.cert_path));
}
let key: PrivateKeyDer<'static> = {
let data = std::fs::read(&tls.key_path)
.map_err(|e| format!("reading key {}: {}", tls.key_path, e))?;
rustls_pemfile::private_key(&mut &data[..])
.map_err(|e| format!("parsing key {}: {}", tls.key_path, e))?
.ok_or_else(|| format!("no private key found in {}", tls.key_path))?
};
let builder = rustls::ServerConfig::builder();
let config = if tls.require_client_cert {
let ca_path = tls
.ca_path
.as_ref()
.ok_or_else(|| "require_client_cert is set but ca_path is missing".to_string())?;
let ca_data =
std::fs::read(ca_path).map_err(|e| format!("reading ca {}: {}", ca_path, e))?;
let mut roots = rustls::RootCertStore::empty();
for ca in rustls_pemfile::certs(&mut &ca_data[..]) {
let ca = ca.map_err(|e| format!("parsing ca {}: {}", ca_path, e))?;
roots
.add(ca)
.map_err(|e| format!("adding ca cert: {}", e))?;
}
let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| format!("building client verifier: {}", e))?;
builder
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)
.map_err(|e| format!("server config (mTLS): {}", e))?
} else {
builder
.with_no_client_auth()
.with_single_cert(cert_chain, key)
.map_err(|e| format!("server config: {}", e))?
};
Ok(TlsAcceptor::from(Arc::new(config)))
}