use std::future::{pending, Future};
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use futures::stream::StreamExt;
use http::{Request, Response};
use http_body::Body;
use hyper::body::Incoming;
use hyper::rt::{Read, Write};
use hyper::service::Service;
use hyper_util::rt::TokioIo;
use hyper_util::server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::sleep;
use tokio_rustls::TlsAcceptor;
use tokio_stream::Stream;
use tracing::{debug, trace};
use crate::fuse::Fuse;
use crate::io::Transport;
#[inline]
async fn sleep_or_pending(wait_for: Option<Duration>) {
match wait_for {
Some(wait) => sleep(wait).await,
None => pending().await,
};
}
async fn accept_tls_connection<IO>(
io: IO,
tls_acceptor: Arc<TlsAcceptor>,
) -> Result<tokio_rustls::server::TlsStream<IO>, crate::Error>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
match tokio::task::spawn_blocking(move || {
tokio::runtime::Handle::current().block_on(tls_acceptor.accept(io))
})
.await
{
Ok(Ok(stream)) => Ok(stream),
Ok(Err(e)) => Err(e.into()),
Err(e) => Err(e.into()),
}
}
#[inline]
pub async fn serve_http_connection<B, IO, S, E>(
hyper_io: IO,
hyper_service: S,
builder: HttpConnectionBuilder<E>,
watcher: Option<tokio::sync::watch::Receiver<()>>,
max_connection_age: Option<Duration>,
) where
B: Body + 'static,
B::Error: Into<crate::Error>,
IO: Read + Write + Unpin + Send + 'static,
S: Service<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<crate::Error>,
E: HttpServerConnExec<S::Future, B>,
{
let mut watcher = watcher;
let mut sig = pin!(Fuse {
inner: watcher.as_mut().map(|w| w.changed()),
});
let sleep = sleep_or_pending(max_connection_age);
tokio::pin!(sleep);
let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service));
loop {
tokio::select! {
result = &mut conn => {
if let Err(err) = result {
debug!("failed serving HTTP connection: {:#}", err);
}
break;
},
_ = &mut sleep => {
conn.as_mut().graceful_shutdown();
sleep.set(sleep_or_pending(None));
},
_ = &mut sig => {
conn.as_mut().graceful_shutdown();
}
}
}
trace!("HTTP connection closed");
}
pub async fn serve_http_with_shutdown<E, F, I, IO, IE, ResBody, S>(
service: S,
incoming: I,
builder: HttpConnectionBuilder<E>,
tls_config: Option<Arc<rustls::ServerConfig>>,
signal: Option<F>,
) -> Result<(), super::Error>
where
F: Future<Output = ()> + Send + 'static,
I: Stream<Item = Result<IO, IE>> + Send + 'static,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::Error> + Send + 'static,
S: Service<Request<Incoming>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<crate::Error>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::Error> + Send,
E: HttpServerConnExec<S::Future, ResBody> + Send + Sync + 'static,
{
let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
let signal_tx = Arc::new(signal_tx);
let graceful = signal.is_some();
let mut sig = pin!(Fuse { inner: signal });
let tls_acceptor = tls_config.map(|config| Arc::new(TlsAcceptor::from(config)));
let incoming = crate::tcp::serve_tcp_incoming(incoming);
let mut incoming = pin!(incoming);
loop {
tokio::select! {
_ = &mut sig => {
trace!("signal received, shutting down");
break;
},
Some(io_result) = incoming.next() => {
let connection_service = service.clone();
let connection_builder = builder.clone();
let connection_signal_rx = graceful.then_some(signal_rx.clone());
let connection_tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let io = match io_result {
Ok(io) => io,
Err(e) => {
trace!("error accepting connection: {:#}", e);
return;
}
};
trace!("TCP streaming connection accepted");
let transport = if let Some(connection_tls_acceptor) = &connection_tls_acceptor {
match accept_tls_connection(io, Arc::clone(connection_tls_acceptor)).await {
Ok(tls_stream) => Transport::new_tls(tls_stream),
Err(e) => {
debug!("TLS handshake failed: {:#}", e);
return;
}
}
} else {
Transport::new_plain(io)
};
let hyper_io = TokioIo::new(transport);
serve_http_connection(
hyper_io,
connection_service,
connection_builder,
connection_signal_rx,
None
).await;
}
);
},
}
}
if graceful {
let _ = signal_tx.send(());
drop(signal_rx);
trace!(
"waiting for {} connections to close",
signal_tx.receiver_count()
);
signal_tx.closed().await;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::StatusCode;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response};
use hyper_util::rt::TokioExecutor;
use rustls::ServerConfig;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio_stream::wrappers::TcpListenerStream;
async fn test_handler(req: Request<Incoming>) -> Result<Response<Full<Bytes>>, hyper::Error> {
match (req.method(), req.uri().path()) {
(&hyper::Method::GET, "/") => {
Ok(Response::new(Full::new(Bytes::from("Hello, World!"))))
}
(&hyper::Method::POST, "/echo") => {
let body = req.collect().await?.to_bytes();
Ok(Response::new(Full::new(body)))
}
(&hyper::Method::GET, "/delay") => {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(Response::new(Full::new(Bytes::from("Delayed response"))))
}
(&hyper::Method::GET, "/large") => {
let large_data = vec![b'x'; 1024 * 1024]; Ok(Response::new(Full::new(Bytes::from(large_data))))
}
_ => {
let mut res = Response::new(Full::new(Bytes::from("Not Found")));
*res.status_mut() = StatusCode::NOT_FOUND;
Ok(res)
}
}
}
async fn setup_test_server(
_max_conn_age: Option<Duration>,
) -> (SocketAddr, oneshot::Sender<()>) {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let incoming = TcpListenerStream::new(listener);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let builder = HttpConnectionBuilder::new(TokioExecutor::new());
let service = hyper::service::service_fn(test_handler);
tokio::spawn(serve_http_with_shutdown(
service,
incoming,
builder,
None,
Some(async {
shutdown_rx.await.ok();
}),
));
(server_addr, shutdown_tx)
}
mod payload_tests {
use super::*;
#[tokio::test]
async fn test_large_payload() {
let (addr, shutdown_tx) = setup_test_server(None).await;
let stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
if let Err(err) = conn.await {
eprintln!("Connection failed: {:?}", err);
}
});
let req = Request::builder()
.uri("/large")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(body.len(), 1024 * 1024);
let large_data = vec![b'x'; 1024 * 1024];
let req = Request::builder()
.method(hyper::Method::POST)
.uri("/echo")
.body(Full::new(Bytes::from(large_data.clone())))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(body.len(), large_data.len());
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_concurrent_large_payloads() {
let (addr, shutdown_tx) = setup_test_server(None).await;
let mut handles = Vec::new();
for _ in 0..3 {
let socket_addr = addr;
let handle = tokio::spawn(async move {
let stream = TcpStream::connect(socket_addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) =
hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
if let Err(err) = conn.await {
eprintln!("Connection failed: {:?}", err);
}
});
let req = Request::builder()
.uri("/large")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap();
assert_eq!(body.to_bytes().len(), 1024 * 1024);
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
shutdown_tx.send(()).unwrap();
}
}
mod shutdown_tests {
use super::*;
#[tokio::test]
async fn test_graceful_shutdown_with_active_requests() {
let (addr, shutdown_tx) = setup_test_server(None).await;
let slow_req = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
if let Err(err) = conn.await {
eprintln!("Connection failed: {:?}", err);
}
});
let req = Request::builder()
.uri("/delay")
.body(Full::new(Bytes::new()))
.unwrap();
sender.send_request(req).await
});
tokio::time::sleep(Duration::from_millis(50)).await;
shutdown_tx.send(()).unwrap();
let res = slow_req.await.unwrap().unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Delayed response");
}
#[tokio::test]
async fn test_shutdown_rejects_new_connections() {
let (addr, shutdown_tx) = setup_test_server(None).await;
shutdown_tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let result = TcpStream::connect(addr).await;
assert!(result.is_err());
}
}
mod https_tests {
use super::*;
use crate::test::helper::RUSTLS;
use crate::{load_certs, load_private_key};
use once_cell::sync::Lazy;
async fn setup_test_tls_config() -> Arc<ServerConfig> {
let certs = load_certs("examples/sample.pem").unwrap();
let key = load_private_key("examples/sample.rsa").unwrap();
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
Arc::new(config)
}
async fn setup_test_client() -> (
tokio_rustls::TlsConnector,
rustls::pki_types::ServerName<'static>,
) {
let mut root_store = rustls::RootCertStore::empty();
root_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap());
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap();
(connector, domain)
}
async fn setup_tls_test_server() -> (SocketAddr, oneshot::Sender<()>, Arc<ServerConfig>) {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let incoming = TcpListenerStream::new(listener);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let tls_config = setup_test_tls_config().await;
let builder = HttpConnectionBuilder::new(TokioExecutor::new());
let service = hyper::service::service_fn(test_handler);
tokio::spawn(serve_http_with_shutdown(
service,
incoming,
builder,
Some(tls_config.clone()),
Some(async {
shutdown_rx.await.ok();
}),
));
(server_addr, shutdown_tx, tls_config)
}
async fn connect_tls_client(
addr: SocketAddr,
connector: tokio_rustls::TlsConnector,
domain: rustls::pki_types::ServerName<'static>,
) -> hyper::client::conn::http1::SendRequest<Full<Bytes>> {
let tcp = TcpStream::connect(addr).await.unwrap();
let tls_stream = connector.connect(domain, tcp).await.unwrap();
let io = TokioIo::new(tls_stream);
let (sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
if let Err(err) = conn.await {
eprintln!("Connection failed: {:?}", err);
}
});
sender
}
mod tls_connection_tests {
use super::*;
#[tokio::test]
async fn test_tls_basic_request() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let (connector, domain) = setup_test_client().await;
let mut sender = connect_tls_client(addr, connector, domain).await;
let req = Request::builder()
.uri("/")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Hello, World!");
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_tls_multiple_requests_same_connection() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let (connector, domain) = setup_test_client().await;
let mut sender = connect_tls_client(addr, connector, domain).await;
for _ in 0..3 {
let req = Request::builder()
.uri("/")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Hello, World!");
}
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_tls_concurrent_connections() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let mut handles = Vec::new();
for _ in 0..5 {
let socket_addr = addr;
let handle = tokio::spawn(async move {
let (connector, domain) = setup_test_client().await;
let mut sender = connect_tls_client(socket_addr, connector, domain).await;
let req = Request::builder()
.uri("/")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Hello, World!");
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
shutdown_tx.send(()).unwrap();
}
}
mod tls_error_tests {
use super::*;
#[tokio::test]
async fn test_invalid_client_cert() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let tcp = TcpStream::connect(addr).await.unwrap();
let result = connector.connect(domain, tcp).await;
assert!(result.is_err());
shutdown_tx.send(()).unwrap();
}
#[tokio::test]
async fn test_wrong_hostname() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let (connector, _) = setup_test_client().await;
let wrong_domain = rustls::pki_types::ServerName::try_from("wronghost").unwrap();
let tcp = TcpStream::connect(addr).await.unwrap();
let result = connector.connect(wrong_domain, tcp).await;
assert!(result.is_err());
shutdown_tx.send(()).unwrap();
}
}
mod tls_payload_tests {
use super::*;
#[tokio::test]
async fn test_tls_large_payload() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let (connector, domain) = setup_test_client().await;
let mut sender = connect_tls_client(addr, connector, domain).await;
let req = Request::builder()
.uri("/large")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(body.len(), 1024 * 1024);
let large_data = vec![b'x'; 1024 * 1024];
let req = Request::builder()
.method(hyper::Method::POST)
.uri("/echo")
.body(Full::new(Bytes::from(large_data.clone())))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(body.len(), large_data.len());
shutdown_tx.send(()).unwrap();
}
}
mod tls_shutdown_tests {
use super::*;
#[tokio::test]
async fn test_tls_graceful_shutdown() {
Lazy::force(&RUSTLS);
let (addr, shutdown_tx, _) = setup_tls_test_server().await;
let (connector, domain) = setup_test_client().await;
let mut sender = connect_tls_client(addr, connector, domain).await;
let req = Request::builder()
.uri("/")
.body(Full::new(Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
shutdown_tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let req = Request::builder()
.uri("/")
.body(Full::new(Bytes::new()))
.unwrap();
let result = sender.send_request(req).await;
assert!(result.is_err());
}
}
}
}