use std::{
net::SocketAddr,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use http::{Request, Response};
use hyper::body::{Body, Incoming};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::{conn::auto::Builder, graceful::GracefulShutdown},
service::TowerToHyperService,
};
use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::Notify,
};
use tower::{MakeService, Service};
use crate::{Accept, IntoAccept};
#[cfg(feature = "systemd")]
use crate::accept::systemd;
#[cfg(feature = "https-upgrade")]
use crate::accept::https_upgrade;
macro_rules! ok_or_continue {
($expr:expr) => {
match $expr {
::core::result::Result::Ok(x) => x,
::core::result::Result::Err(_) => continue,
}
};
}
#[derive(Debug, Default, Clone)]
pub struct Shutdown(Arc<Notify>);
impl Shutdown {
#[inline]
pub fn new() -> Self {
Self(Arc::new(Notify::new()))
}
#[inline]
pub fn notify(&self) {
self.0.notify_waiters();
}
}
#[derive(Debug, Clone)]
pub struct Server<A, U = ()> {
socket_addr: A,
shutdown: Option<Shutdown>,
acceptor: U,
}
impl<A> Server<A> {
#[inline]
pub const fn new(socket_addr: A) -> Self {
Self {
socket_addr,
shutdown: None,
acceptor: (),
}
}
}
impl<A, U> Server<A, U> {
#[inline]
pub fn with_acceptor<S>(self, acceptor: S) -> Server<A, S> {
self.map_acceptor(|_| acceptor)
}
#[inline]
pub fn map_acceptor<F, S>(self, f: F) -> Server<A, S>
where
F: FnOnce(U) -> S,
{
Server {
socket_addr: self.socket_addr,
shutdown: self.shutdown,
acceptor: f(self.acceptor),
}
}
#[inline]
pub fn with_shutdown(self, shutdown: Shutdown) -> Self {
Self {
socket_addr: self.socket_addr,
shutdown: Some(shutdown),
acceptor: self.acceptor,
}
}
#[cfg(feature = "systemd")]
#[inline]
pub fn with_systemd_notify(self) -> Server<A, systemd::Notify<U>> {
self.map_acceptor(|x| systemd::Notify::new(x))
}
#[cfg(feature = "https-upgrade")]
#[inline]
pub fn with_https_upgrade(self) -> Server<A, https_upgrade::HttpsUpgrade<U>> {
self.map_acceptor(|x| https_upgrade::HttpsUpgrade::new(x))
}
}
impl<A, U> Server<A, U>
where
A: ToSocketAddrs,
{
pub async fn serve<M, K, S, B>(self, mut make_service: M) -> std::io::Result<()>
where
M: MakeService<SocketAddr, Request<Incoming>>,
M::MakeError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
M::Service: Send + 'static,
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
U: IntoAccept<TcpStream, M::Service, Accept = K>,
K: Accept<TcpStream, M::Service, Service = S> + Send + Sync + 'static,
K::Future: Send + 'static,
K::Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
S: Service<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
let socket_addr = tokio::net::lookup_host(self.socket_addr)
.await?
.next()
.ok_or_else(|| std::io::Error::from(std::io::ErrorKind::InvalidInput))?;
let socket = Socket::new(
Domain::for_address(socket_addr),
Type::STREAM,
Some(Protocol::TCP),
)?;
socket.set_reuse_address(true)?;
socket.set_reuse_port(true)?;
socket.set_tcp_nodelay(true)?;
socket.set_nonblocking(true)?;
socket.set_tcp_keepalive(
&TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10))
.with_retries(5),
)?;
socket.bind(&socket_addr.into())?;
socket.listen(1024)?;
let listener = TcpListener::from_std(socket.into())?;
let graceful = GracefulShutdown::new();
let state = Arc::new(State {
inflight: AtomicUsize::new(0),
inflight_notify: Notify::new(),
acceptor: self.acceptor.into_accept().await?,
builder: Builder::new(TokioExecutor::new()),
});
loop {
let (stream, socket_addr) = tokio::select! {
biased;
_ = async {
match &self.shutdown {
Some(shutdown) => shutdown.0.notified().await,
None => ::core::future::pending().await,
}
} => break,
result = listener.accept() => ok_or_continue!(result),
};
::core::future::poll_fn(|cx| make_service.poll_ready(cx))
.await
.map_err(std::io::Error::other)?;
let service = ok_or_continue!(make_service.make_service(socket_addr).await);
let state = Arc::clone(&state);
state.inflight.fetch_add(1, Ordering::Relaxed);
let watcher = graceful.watcher();
tokio::task::spawn(async move {
if let Ok((stream, service)) = state.acceptor.accept(stream, service).await {
let io = TokioIo::new(stream);
let service = TowerToHyperService::new(service);
let _ = watcher
.watch(state.builder.serve_connection_with_upgrades(io, service))
.await;
}
if state.inflight.fetch_sub(1, Ordering::Release) == 1 {
state.inflight_notify.notify_one();
}
});
}
graceful.shutdown().await;
if state.inflight.load(Ordering::Acquire) != 0 {
state.inflight_notify.notified().await;
}
Ok(())
}
}
struct State<T> {
inflight: AtomicUsize,
inflight_notify: Notify,
acceptor: T,
builder: Builder<TokioExecutor>,
}