#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use futures_util::stream::{FuturesUnordered, Stream, StreamExt};
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::Unpin;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{io, mem};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::{timeout, Timeout};
pub const DEFAULT_MAX_HANDSHAKES: usize = 64;
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(200);
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
type Stream;
type Error: std::error::Error;
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Unpin;
fn accept(&self, stream: C) -> Self::AcceptFuture;
}
pub trait AsyncAccept {
type Connection: AsyncRead + AsyncWrite;
type Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Connection, Self::Error>>;
}
pin_project! {
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
#[pin]
listener: A,
tls: T,
waiting: FuturesUnordered<Timeout<T::AcceptFuture>>,
max_handshakes: usize,
timeout: Duration,
}
}
#[derive(Clone)]
pub struct Builder<T> {
tls: T,
max_handshakes: usize,
handshake_timeout: Duration,
}
#[derive(Debug, Error)]
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
#[error("{0}")]
ListenerError(#[source] LE),
#[error("{0}")]
TlsAcceptError(#[source] TE),
}
impl<A: AsyncAccept, T> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
pub fn new(tls: T, listener: A) -> Self {
builder(tls).listen(listener)
}
}
impl<A, T> TlsListener<A, T>
where
A: AsyncAccept,
A::Connection: AsyncRead + AsyncWrite + Unpin + 'static,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
pub fn accept(&mut self) -> impl Future<Output = Option<<Self as Stream>::Item>> + '_
where
Self: Unpin,
{
self.next()
}
pub fn replace_acceptor(&mut self, acceptor: T) {
self.tls = acceptor;
}
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
*self.project().tls = acceptor;
}
}
impl<A, T> Stream for TlsListener<A, T>
where
A: AsyncAccept,
A::Connection: AsyncRead + AsyncWrite + Unpin + 'static,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
while this.waiting.len() < *this.max_handshakes {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Ok(conn)) => {
this.waiting
.push(timeout(*this.timeout, this.tls.accept(conn)));
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
}
}
}
loop {
return match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(conn))) => {
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
}
Poll::Ready(Some(Err(_))) => continue,
_ => Poll::Pending,
};
}
}
}
#[cfg(feature = "rustls")]
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncTls<C> for tokio_rustls::TlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<C>;
type Error = io::Error;
type AcceptFuture = tokio_rustls::Accept<C>;
fn accept(&self, conn: C) -> Self::AcceptFuture {
tokio_rustls::TlsAcceptor::accept(self, conn)
}
}
#[cfg(feature = "native-tls")]
impl<C> AsyncTls<C> for tokio_native_tls::TlsAcceptor
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = tokio_native_tls::TlsStream<C>;
type Error = tokio_native_tls::native_tls::Error;
type AcceptFuture = Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Send>>;
fn accept(&self, conn: C) -> Self::AcceptFuture {
let tls = self.clone();
Box::pin(async move { tokio_native_tls::TlsAcceptor::accept(&tls, conn).await })
}
}
impl<T> Builder<T> {
pub fn max_handshakes(&mut self, max: usize) -> &mut Self {
self.max_handshakes = max;
self
}
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
self.handshake_timeout = timeout;
self
}
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
TlsListener {
listener,
tls: self.tls.clone(),
waiting: FuturesUnordered::new(),
max_handshakes: self.max_handshakes,
timeout: self.handshake_timeout,
}
}
}
pub fn builder<T>(tls: T) -> Builder<T> {
Builder {
tls,
max_handshakes: DEFAULT_MAX_HANDSHAKES,
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
}
}
#[cfg(feature = "tokio-net")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
impl AsyncAccept for tokio::net::TcpListener {
type Connection = tokio::net::TcpStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Connection, Self::Error>> {
match (*self).poll_accept(cx) {
Poll::Ready(Ok((stream, _))) => Poll::Ready(Ok(stream)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(all(unix, feature = "tokio-net"))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
impl AsyncAccept for tokio::net::UnixListener {
type Connection = tokio::net::UnixStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Connection, Self::Error>> {
match (*self).poll_accept(cx) {
Poll::Ready(Ok((stream, _))) => Poll::Ready(Ok(stream)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(any(feature = "hyper-h1", feature = "hyper-h2"))]
mod hyper_impl {
use super::*;
use hyper::server::accept::Accept as HyperAccept;
use hyper::server::conn::{AddrIncoming, AddrStream};
#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
impl AsyncAccept for AddrIncoming {
type Connection = AddrStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Connection, Self::Error>> {
match <AddrIncoming as HyperAccept>::poll_accept(self, cx) {
Poll::Ready(Some(res)) => Poll::Ready(res),
Poll::Pending => Poll::Pending,
Poll::Ready(None) => unreachable!("None returned from AddrIncoming"),
}
}
}
#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
impl<A, T> HyperAccept for TlsListener<A, T>
where
A: AsyncAccept,
A::Connection: AsyncRead + AsyncWrite + Unpin + 'static,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Conn = T::Stream;
type Error = Error<A::Error, T::Error>;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.poll_next(cx)
}
}
}