use std::convert::Infallible;
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::Request;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;
use crate::admission::{Admission, IngressLimits};
use crate::handler::IngressHandler;
use crate::http_io::{serve_request, ConnInfo};
pub const DRAIN_DEADLINE: Duration = Duration::from_secs(30);
pub async fn serve<H: IngressHandler>(
listener: TcpListener,
handler: Arc<H>,
) -> std::io::Result<()> {
serve_with_limits(listener, handler, IngressLimits::default()).await
}
pub async fn serve_with_limits<H: IngressHandler>(
listener: TcpListener,
handler: Arc<H>,
limits: IngressLimits,
) -> std::io::Result<()> {
run(listener, handler, limits, Mode::Plain, never()).await
}
pub async fn serve_with_shutdown<H: IngressHandler>(
listener: TcpListener,
handler: Arc<H>,
shutdown: impl Future<Output = ()>,
) -> std::io::Result<()> {
run(
listener,
handler,
IngressLimits::default(),
Mode::Plain,
shutdown,
)
.await
}
pub async fn serve_tls<H, P>(
listener: TcpListener,
provider: Arc<P>,
handler: Arc<H>,
) -> std::io::Result<()>
where
H: IngressHandler,
P: crate::tls::CryptoProvider,
{
serve_tls_with_limits(listener, provider, handler, IngressLimits::default()).await
}
pub async fn serve_tls_with_limits<H, P>(
listener: TcpListener,
provider: Arc<P>,
handler: Arc<H>,
limits: IngressLimits,
) -> std::io::Result<()>
where
H: IngressHandler,
P: crate::tls::CryptoProvider,
{
let acceptor = tokio_rustls::TlsAcceptor::from(provider.server_config());
run(listener, handler, limits, Mode::Tls(acceptor), never()).await
}
pub async fn serve_tls_with_shutdown<H, P>(
listener: TcpListener,
provider: Arc<P>,
handler: Arc<H>,
shutdown: impl Future<Output = ()>,
) -> std::io::Result<()>
where
H: IngressHandler,
P: crate::tls::CryptoProvider,
{
let acceptor = tokio_rustls::TlsAcceptor::from(provider.server_config());
run(
listener,
handler,
IngressLimits::default(),
Mode::Tls(acceptor),
shutdown,
)
.await
}
enum Mode {
Plain,
Tls(tokio_rustls::TlsAcceptor),
}
fn never() -> impl Future<Output = ()> {
std::future::pending()
}
async fn run<H: IngressHandler>(
listener: TcpListener,
handler: Arc<H>,
limits: IngressLimits,
mode: Mode,
shutdown: impl Future<Output = ()>,
) -> std::io::Result<()> {
let admission = Arc::new(Admission::new(limits.inflight_ceiling));
let (drain_tx, drain_rx) = watch::channel(false);
let active = Arc::new(AtomicUsize::new(0));
tokio::pin!(shutdown);
loop {
tokio::select! {
accepted = listener.accept() => {
let (stream, _peer) = accepted?;
if active.load(Ordering::Acquire) >= limits.max_connections {
drop(stream);
continue;
}
spawn_conn(stream, &mode, &handler, &admission, limits, &active, &drain_rx);
}
() = &mut shutdown => break,
}
}
let _ = drain_tx.send(true);
await_drain(&active, DRAIN_DEADLINE).await;
Ok(())
}
fn spawn_conn<H: IngressHandler>(
stream: TcpStream,
mode: &Mode,
handler: &Arc<H>,
admission: &Arc<Admission>,
limits: IngressLimits,
active: &Arc<AtomicUsize>,
drain_rx: &watch::Receiver<bool>,
) {
let _ = stream.set_nodelay(true);
active.fetch_add(1, Ordering::Relaxed);
let guard = ActiveGuard(Arc::clone(active));
let handler = Arc::clone(handler);
let admission = Arc::clone(admission);
let drain_rx = drain_rx.clone();
match mode {
Mode::Plain => {
tokio::spawn(async move {
let _guard = guard;
serve_connection(
TokioIo::new(stream),
handler,
ConnInfo::default(),
limits,
admission,
drain_rx,
)
.await;
});
}
Mode::Tls(acceptor) => {
let acceptor = acceptor.clone();
tokio::spawn(async move {
let _guard = guard;
if let Ok(tls) = acceptor.accept(stream).await {
let conn_info = conn_info_from_tls(&tls);
serve_connection(
TokioIo::new(tls),
handler,
conn_info,
limits,
admission,
drain_rx,
)
.await;
}
});
}
}
}
struct ActiveGuard(Arc<AtomicUsize>);
impl Drop for ActiveGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Release);
}
}
async fn await_drain(active: &AtomicUsize, deadline: Duration) {
let drained = async {
while active.load(Ordering::Acquire) > 0 {
tokio::time::sleep(Duration::from_millis(20)).await;
}
};
let _ = tokio::time::timeout(deadline, drained).await;
}
fn conn_info_from_tls(tls: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>) -> ConnInfo {
ConnInfo {
client_cert_subject: crate::tls::client_subject_from_tls(tls),
secure: true,
}
}
async fn serve_connection<H, IO>(
io: IO,
handler: Arc<H>,
conn_info: ConnInfo,
limits: IngressLimits,
admission: Arc<Admission>,
mut drain: watch::Receiver<bool>,
) where
H: IngressHandler,
IO: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
{
let service = service_fn(move |req: Request<Incoming>| {
let handler = Arc::clone(&handler);
let conn_info = conn_info.clone();
let admission = Arc::clone(&admission);
async move {
Ok::<_, Infallible>(serve_request(&*handler, req, &conn_info, limits, &admission).await)
}
});
let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
let conn = builder.serve_connection(io, service);
tokio::pin!(conn);
tokio::select! {
_ = conn.as_mut() => {}
_ = drain.changed() => {
conn.as_mut().graceful_shutdown();
let _ = conn.await;
}
}
}