use hyper::server::conn;
use hyper_util::rt::TokioTimer;
use std::error::Error;
use std::process::ExitCode;
use std::sync::Arc;
use std::{io, mem};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio::task::{JoinSet, coop};
use tokio::{signal, time};
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
use super::io::IoWithPermit;
use super::server::ServerConfig;
use super::tls::Acceptor;
use crate::app::AppService;
use crate::error::ServerError;
#[derive(Clone)]
struct InitializationToken(CancellationToken);
macro_rules! joined {
($result:expr) => {
match $result {
Ok(Err(error)) => handle_error(&error),
Err(error) => log!("error(join): {}", &error),
_ => {}
}
};
}
macro_rules! log {
($($arg:tt)*) => {
if cfg!(debug_assertions) {
eprintln!($($arg)*)
}
};
}
#[inline(never)]
pub async fn accept<App, TlsAcceptor>(
config: ServerConfig,
acceptor: TlsAcceptor,
service: AppService<App>,
listener: TcpListener,
) -> ExitCode
where
App: Send + Sync + 'static,
ServerError: From<TlsAcceptor::Error>,
TlsAcceptor: Acceptor,
TlsAcceptor::Io: Send + Unpin + 'static,
{
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
drop(acceptor);
let semaphore = Arc::new(Semaphore::new(config.max_connections));
let mut connections = JoinSet::new();
let shutdown = wait_for_ctrl_c();
let exit_code = loop {
let (io, _) = tokio::select! {
result = listener.accept() => match result {
Err(error) if is_fatal(&error) => return ExitCode::FAILURE,
Ok(accepted) => accepted,
Err(error) => {
log!("error(accept): {}", error);
continue;
}
},
_ = shutdown.requested() => {
break ExitCode::FAILURE;
}
};
let Ok(permit) = semaphore.clone().try_acquire_owned() else {
continue;
};
#[cfg(any(feature = "native-tls", feature = "rustls"))]
let handshake = acceptor.accept(io);
let service = service.clone();
let shutdown = shutdown.clone();
connections.spawn(async move {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
let io = match config.tls_handshake_timeout {
Some(duration) => time::timeout(duration, handshake).await??,
None => handshake.await?,
};
serve_connection(IoWithPermit::new(io, permit), service, shutdown).await
});
if connections.len() >= 1024 {
let batch = mem::take(&mut connections);
tokio::spawn(drain_connections(false, batch));
}
};
match time::timeout(
config.shutdown_timeout,
drain_connections(true, connections),
)
.await
{
Ok(_) => exit_code,
Err(_) => ExitCode::FAILURE,
}
}
async fn drain_connections(immediate: bool, mut connections: JoinSet<Result<(), ServerError>>) {
if cfg!(debug_assertions) {
println!("joining {} inflight connections...", connections.len());
}
while let Some(result) = connections.join_next().await {
joined!(result);
if !immediate {
coop::consume_budget().await;
}
}
}
fn handle_error(error: &ServerError) {
if let ServerError::Http(error) = error {
if error.is_canceled()
|| error.is_incomplete_message()
|| error.source().is_some_and(|source| {
source
.downcast_ref::<std::io::Error>()
.is_some_and(|e| e.kind() == std::io::ErrorKind::NotConnected)
})
{
log!("warn(disconnect): {}", error);
} else {
log!("error(http): {}", error);
}
} else {
log!("error(task): {}", &error);
}
}
#[cfg(unix)]
fn is_fatal(error: &io::Error) -> bool {
if let io::ErrorKind::Other = error.kind() {
matches!(error.raw_os_error(), Some(12 | 23 | 24))
} else {
false
}
}
#[cfg(windows)]
fn is_fatal(error: &io::Error) -> bool {
if let io::ErrorKind::Other = error.kind() {
matches!(error.raw_os_error(), Some(10024 | 10055))
} else {
false
}
}
#[cfg(feature = "http2")]
async fn serve_connection<App, Io>(
io: IoWithPermit<Io>,
service: AppService<App>,
shutdown: InitializationToken,
) -> Result<(), ServerError>
where
App: Send + Sync + 'static,
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let connection = conn::http2::Builder::new(hyper_util::rt::TokioExecutor::new())
.timer(TokioTimer::new())
.serve_connection(io, service);
tokio::pin!(connection);
tokio::select! {
result = &mut connection => Ok(result?),
_ = shutdown.requested() => {
connection.as_mut().graceful_shutdown();
Ok((&mut connection).await?)
}
}
}
#[cfg(all(feature = "http1", not(feature = "http2")))]
async fn serve_connection<App, Io>(
io: IoWithPermit<Io>,
service: AppService<App>,
shutdown: InitializationToken,
) -> Result<(), ServerError>
where
App: Send + Sync + 'static,
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let connection = conn::http1::Builder::new()
.timer(TokioTimer::new())
.serve_connection(io, service)
.with_upgrades();
tokio::pin!(connection);
tokio::select! {
result = &mut connection => Ok(result?),
_ = shutdown.requested() => {
connection.as_mut().graceful_shutdown();
Ok((&mut connection).await?)
}
}
}
fn wait_for_ctrl_c() -> InitializationToken {
let token = InitializationToken::new();
let shutdown = token.clone();
tokio::spawn(async move {
if signal::ctrl_c().await.is_err() {
eprintln!("unable to register the 'ctrl-c' signal.");
}
shutdown.start();
});
token
}
impl InitializationToken {
pub fn new() -> Self {
Self(CancellationToken::new())
}
pub fn requested(&self) -> WaitForCancellationFuture<'_> {
self.0.cancelled()
}
pub fn start(&self) {
self.0.cancel();
}
}