use std::{io, net::SocketAddr, sync::Arc};
pub fn tls_config(key: &String, cert: &String) -> Result<Arc<ServerConfig>, std::io::Error> {
use rustls_pki_types::pem::PemObject;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let certs = CertificateDer::pem_file_iter(cert)
.map_err(|_| io::Error::other("open cert failed"))?
.collect::<Result<Vec<_>, _>>()
.map_err(|_| io::Error::other("invalid cert pem"))?;
let key = PrivateKeyDer::from_pem_file(key).map_err(|_| io::Error::other("failed to read private key"))?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(std::io::Error::other)?;
config.alpn_protocols = vec![
b"h2".to_vec(), b"http/1.1".to_vec(), ];
Ok(Arc::new(config))
}
#[allow(dead_code)]
pub fn rust_tls_acceptor(key: &String, cert: &String) -> Result<tokio_rustls::TlsAcceptor, std::io::Error> {
Ok(tls_config(key, cert)?.into())
}
use core::task::{Context, Poll};
use std::future::Future;
use std::pin::Pin;
use futures_util::ready;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{TcpListener, TcpStream},
};
use tokio_rustls::rustls::{ServerConfig, ServerConnection};
pub struct TlsAcceptor<L = TcpListener> {
config: Arc<ServerConfig>,
listener: L,
}
impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, listener: TcpListener) -> Self {
Self { config, listener }
}
pub fn replace_config(&mut self, new_config: Arc<ServerConfig>) {
self.config = new_config;
}
pub async fn accept(&mut self) -> Result<(TlsStream, SocketAddr), io::Error> {
let (sock, addr) = self.listener.accept().await?;
Ok((TlsStream::new(sock, self.config.clone()), addr))
}
}
impl<C, L> From<(C, L)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
L: Into<TcpListener>,
{
fn from((config, listener): (C, L)) -> Self {
Self::new(config.into(), listener.into())
}
}
pub struct TlsStream<C = TcpStream> {
state: State<C>,
}
impl<C: AsyncRead + AsyncWrite + Unpin> TlsStream<C> {
fn new(stream: C, config: Arc<ServerConfig>) -> Self {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
Self {
state: State::Handshaking(accept),
}
}
pub fn _io(&self) -> Option<&C> {
match &self.state {
State::Handshaking(accept) => accept.get_ref(),
State::Streaming(stream) => Some(stream.get_ref().0),
}
}
pub fn _connection(&self) -> Option<&ServerConnection> {
match &self.state {
State::Handshaking(_) => None,
State::Streaming(stream) => Some(stream.get_ref().1),
}
}
}
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<C> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<io::Result<()>> {
let pin = self.get_mut();
let accept = match &mut pin.state {
State::Handshaking(accept) => accept,
State::Streaming(stream) => return Pin::new(stream).poll_read(cx, buf),
};
let mut stream = match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => stream,
Err(err) => return Poll::Ready(Err(err)),
};
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
}
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<C> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
let accept = match &mut pin.state {
State::Handshaking(accept) => accept,
State::Streaming(stream) => return Pin::new(stream).poll_write(cx, buf),
};
let mut stream = match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => stream,
Err(err) => return Poll::Ready(Err(err)),
};
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
enum State<C> {
Handshaking(tokio_rustls::Accept<C>),
Streaming(tokio_rustls::server::TlsStream<C>),
}