1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
use futures_util::future::BoxFuture;
use futures_util::stream::FuturesUnordered;
use futures_util::{FutureExt, StreamExt};
use hyper::server::accept::Accept;
use std::pin::Pin;
use std::task::{Context, Poll};
use thiserror::Error;
use crate::conn::{ConnKind, HttpOrHttpsConnection};
/// Choose to accept either a HTTP or HTTPS connection
pub struct HyperHttpOrHttpsAcceptor {
listener: tokio::net::TcpListener,
kind: AcceptorKind,
}
enum AcceptorKind {
Http,
Https {
tls_acceptor: tokio_rustls::TlsAcceptor,
timeout: std::time::Duration,
// Future has to be boxed because Rust doesn't allow writing out the full type
// Side benefit of allow us to use Timeout without needing pin projection
encryption_futures: FuturesUnordered<
tokio::time::Timeout<BoxFuture<'static, Result<HttpOrHttpsConnection, AcceptorError>>>,
>,
},
}
impl HyperHttpOrHttpsAcceptor {
/// Create an acceptor that will accept HTTP connections
pub const fn new_http(listener: tokio::net::TcpListener) -> Self {
Self {
listener,
kind: AcceptorKind::Http,
}
}
/// Create an acceptor that will accept HTTPS connections using the provided `TlsAcceptor`
///
/// `handshake_timeout` is the length of time that should be allowed to finish a TLS handshake before we drop the connection.
/// Setting it to 0 will not disable the timeout, but will instead instantly drop every connection (you probably don't want this).
pub fn new_https(
listener: tokio::net::TcpListener,
tls_acceptor: tokio_rustls::TlsAcceptor,
handshake_timeout: std::time::Duration,
) -> Self {
Self {
listener,
kind: AcceptorKind::Https {
tls_acceptor,
timeout: handshake_timeout,
encryption_futures: FuturesUnordered::new(),
},
}
}
}
/// Error when accepting connections
#[derive(Error, Debug)]
pub enum AcceptorError {
/// Failed to connect to client over TCP
#[error("TCP connection to client failed")]
TcpConnect(#[source] std::io::Error),
/// Failed to make TLS handshake with client
#[error("TLS handshake with client failed")]
TlsHandshake(#[source] std::io::Error),
}
impl Accept for HyperHttpOrHttpsAcceptor {
type Conn = HttpOrHttpsConnection;
type Error = AcceptorError;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
// Necessary to allow partial borrows
let this = self.get_mut();
match &mut this.kind {
// If just a normal HTTP connection, just poll to accept the new TCP connection
AcceptorKind::Http => match this.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => Poll::Ready(Some(Ok(HttpOrHttpsConnection {
remote_addr: stream.1,
kind: ConnKind::Http(stream.0),
}))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(AcceptorError::TcpConnect(err)))),
Poll::Pending => Poll::Pending,
},
// Otherwise, if it's an HTTPS connection, check if we're ready to encrypt the connection
AcceptorKind::Https {
tls_acceptor,
timeout,
encryption_futures,
} => {
// Accept all pending TCP connections at once (this future won't be woken up for TCP unless we get a pending here)
loop {
match this.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => {
let tls_future = tls_acceptor
.accept(stream.0)
.map(move |f| {
// Map so that we can pass along the remote address
f.map(|conn| HttpOrHttpsConnection {
remote_addr: stream.1,
kind: ConnKind::Https(conn),
})
.map_err(AcceptorError::TlsHandshake)
})
.boxed();
let timed_tls_future = tokio::time::timeout(*timeout, tls_future);
encryption_futures.push(timed_tls_future);
}
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(Err(AcceptorError::TcpConnect(err))))
}
// Break on pending here so we can check on the TLS queue
Poll::Pending => break,
}
}
// Check queue to see if any handshakes are done/timeouts hit
loop {
match encryption_futures.poll_next_unpin(cx) {
// Already `map`ed to a Result<HttpOrHttpsConnection>, so no need to differentiate
// between Some(Err) and Some(Ok)
Poll::Ready(Some(Ok(res))) => return Poll::Ready(Some(res)),
// An error here means that the timeout ran out, so just skip to the next one in the queue
Poll::Ready(Some(Err(_))) => continue,
_ => return Poll::Pending,
}
}
}
}
}
}