use hyper::server::conn::*;
use hyper_util::rt::TokioTimer;
use std::mem;
use std::process::ExitCode;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio::task::{JoinSet, coop};
use tokio::time::timeout;
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
#[cfg(any(feature = "native-tls", feature = "rustls-23"))]
use hyper_util::rt::TokioExecutor;
use super::{IoWithPermit, tls};
use crate::app::ServiceAdapter;
use crate::error::ServerError;
#[derive(Clone)]
struct InitializationToken(CancellationToken);
macro_rules! log {
($($arg:tt)*) => {
if cfg!(debug_assertions) {
eprintln!($($arg)*)
}
};
}
pub(super) async fn accept<App, TlsAcceptor>(
acceptor: TlsAcceptor,
listener: TcpListener,
service: ServiceAdapter<App>,
) -> ExitCode
where
App: Send + Sync + 'static,
ServerError: From<TlsAcceptor::Error>,
TlsAcceptor: tls::Acceptor,
TlsAcceptor::Io: Send + Unpin + 'static,
{
#[cfg(not(any(feature = "native-tls", feature = "rustls-23")))]
drop(acceptor);
let semaphore = Arc::new(Semaphore::new(service.config().max_connections()));
let mut connections = JoinSet::new();
let shutdown = wait_for_ctrl_c();
let exit_code = loop {
let (io, _) = tokio::select! {
biased;
result = listener.accept() => match result {
Ok(stream) => stream,
Err(error) => {
log!("error(accept): {}", error);
#[cfg(unix)]
let Some(12 | 23 | 24) = error.raw_os_error() else {
continue;
};
#[cfg(windows)]
let Some(10024 | 10055) = error.raw_os_error() else {
continue;
};
break ExitCode::FAILURE;
}
},
_ = shutdown.requested() => {
break ExitCode::SUCCESS;
}
};
let Ok(permit) = semaphore.clone().try_acquire_owned() else {
continue;
};
let service = service.clone();
let shutdown = shutdown.clone();
#[cfg(any(feature = "native-tls", feature = "rustls-23"))]
connections.spawn({
let handshake = acceptor.accept(io);
async move {
let (io, alpn) =
timeout(service.config().tls_handshake_timeout(), handshake).await??;
let io = IoWithPermit::new(io, permit);
if alpn == tls::Alpn::HTTP_2 {
serve_http2_connection(io, service, shutdown).await
} else {
serve_http1_connection(io, service, shutdown).await
}
}
});
#[cfg(not(any(feature = "native-tls", feature = "rustls-23")))]
connections.spawn(async move {
serve_http1_connection(IoWithPermit::new(io, permit), service, shutdown).await
});
if connections.len() >= 1024 {
let batch = mem::take(&mut connections);
tokio::spawn(drain_connections(false, batch));
}
};
match timeout(
service.config().shutdown_timeout(),
drain_connections(true, connections),
)
.await
{
Ok(_) => exit_code,
Err(_) => ExitCode::FAILURE,
}
}
async fn drain_connections(immediate: bool, mut connections: JoinSet<Result<(), ServerError>>) {
log!("joining {} inflight connections...", connections.len());
while let Some(result) = connections.join_next().await {
match result {
Ok(Ok(_)) => {}
Err(error) => log!("error(connection): {}", error),
Ok(Err(error)) => log!("error(service): {}", error),
}
if !immediate {
coop::consume_budget().await;
}
}
}
async fn serve_http1_connection<App, Io>(
io: IoWithPermit<Io>,
service: ServiceAdapter<App>,
shutdown: InitializationToken,
) -> Result<(), ServerError>
where
App: Send + Sync + 'static,
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let connection = http1::Builder::new()
.allow_multiple_spaces_in_request_line_delimiters(false)
.auto_date_header(true)
.half_close(false)
.ignore_invalid_headers(false)
.keep_alive(service.config().keep_alive())
.max_buf_size(service.config().max_buf_size())
.pipeline_flush(false)
.preserve_header_case(false)
.timer(TokioTimer::new())
.title_case_headers(false)
.serve_connection(io, service)
.with_upgrades();
tokio::pin!(connection);
tokio::select! {
biased;
result = connection.as_mut() => result?,
_ = shutdown.requested() => {
connection.as_mut().graceful_shutdown();
connection.await?;
}
}
Ok(())
}
#[cfg(any(feature = "native-tls", feature = "rustls-23"))]
async fn serve_http2_connection<App, Io>(
io: IoWithPermit<Io>,
service: ServiceAdapter<App>,
shutdown: InitializationToken,
) -> Result<(), ServerError>
where
App: Send + Sync + 'static,
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let connection = http2::Builder::new(TokioExecutor::new())
.adaptive_window(false)
.auto_date_header(true)
.max_header_list_size(16384) .initial_connection_window_size(Some(1048576)) .initial_stream_window_size(Some(65536)) .max_frame_size(Some(16384)) .max_concurrent_streams(service.config().http2_max_concurrent_streams())
.max_send_buf_size(service.config().http2_max_send_buf_size())
.timer(TokioTimer::new())
.serve_connection(io, service);
tokio::pin!(connection);
tokio::select! {
biased;
result = connection.as_mut() => result?,
_ = shutdown.requested() => {
connection.as_mut().graceful_shutdown();
connection.await?;
}
}
Ok(())
}
fn wait_for_ctrl_c() -> InitializationToken {
let token = InitializationToken::new();
let shutdown = token.clone();
tokio::spawn(async move {
if tokio::signal::ctrl_c().await.is_err() {
eprintln!("unable to register the 'ctrl-c' signal.");
}
shutdown.start();
});
token
}
impl InitializationToken {
fn new() -> Self {
Self(CancellationToken::new())
}
fn requested(&self) -> WaitForCancellationFuture<'_> {
self.0.cancelled()
}
fn start(&self) {
self.0.cancel();
}
}