use std::future::{poll_fn, Future, IntoFuture};
use std::io::Error;
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::Poll;
use std::thread;
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TrySendError;
use tokio::task::{JoinError, JoinSet, LocalSet};
use crate::connection::ConnectionInfo;
use crate::server::configuration::ServerConfiguration;
use crate::server::worker::{ConnectionMessage, Worker, WorkerHandle};
use super::{IncomingStream, ShutdownMode};
#[derive(Clone)]
pub struct ServerHandle {
command_outbox: tokio::sync::mpsc::Sender<ServerCommand>,
}
impl ServerHandle {
pub(super) fn new<HandlerFuture, ApplicationState>(
config: ServerConfiguration,
incoming: Vec<IncomingStream>,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
) -> Self
where
HandlerFuture: Future<Output = crate::response::Response> + 'static,
ApplicationState: Clone + Send + Sync + 'static,
{
let (command_outbox, command_inbox) = tokio::sync::mpsc::channel(32);
let acceptor = Acceptor::new(config, incoming, handler, application_state, command_inbox);
let _ = acceptor.spawn();
Self { command_outbox }
}
#[doc(alias("stop"))]
pub async fn shutdown(self, mode: ShutdownMode) {
let (completion_notifier, completion) = tokio::sync::oneshot::channel();
if self
.command_outbox
.send(ServerCommand::Shutdown {
completion_notifier,
mode,
})
.await
.is_ok()
{
let _ = completion.await;
}
}
}
impl IntoFuture for ServerHandle {
type Output = ();
type IntoFuture = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.command_outbox.closed().await })
}
}
enum ServerCommand {
Shutdown {
completion_notifier: tokio::sync::oneshot::Sender<()>,
mode: ShutdownMode,
},
}
#[must_use]
struct Acceptor<HandlerFuture, ApplicationState> {
command_inbox: tokio::sync::mpsc::Receiver<ServerCommand>,
incoming: Vec<IncomingStream>,
worker_handles: Vec<WorkerHandle>,
#[allow(dead_code)]
config: ServerConfiguration,
next_worker: usize,
max_queue_length: usize,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
handler_output_future: PhantomData<fn() -> HandlerFuture>,
}
enum AcceptorInboxMessage {
ServerCommand(ServerCommand),
Connection(Option<Result<(IncomingStream, TcpStream, SocketAddr), JoinError>>),
}
impl<HandlerFuture, ApplicationState> Acceptor<HandlerFuture, ApplicationState>
where
HandlerFuture: Future<Output = crate::response::Response> + 'static,
ApplicationState: Clone + Send + Sync + 'static,
{
fn new(
config: ServerConfiguration,
incoming: Vec<IncomingStream>,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
command_inbox: tokio::sync::mpsc::Receiver<ServerCommand>,
) -> Self {
let max_queue_length = 15;
let n_workers = config.n_workers.get();
let mut worker_handles = Vec::with_capacity(n_workers);
for i in 0..n_workers {
let (worker, handle) =
Worker::new(i, max_queue_length, handler, application_state.clone());
worker_handles.push(handle);
worker.spawn().expect("Failed to spawn worker thread");
}
Self {
command_inbox,
incoming,
worker_handles,
config,
max_queue_length,
handler,
handler_output_future: Default::default(),
next_worker: 0,
application_state,
}
}
async fn run(self) {
async fn accept_connection(
incoming: IncomingStream,
) -> (IncomingStream, TcpStream, SocketAddr) {
#[allow(deprecated)]
fn is_rt_shutdown_err(err: &Error) -> bool {
const RT_SHUTDOWN_ERR: &str =
"A Tokio 1.x context was found, but it is being shutdown.";
if err.kind() != std::io::ErrorKind::Other {
return false;
}
let Some(inner) = err.get_ref() else {
return false;
};
inner.source().is_none() && inner.description() == RT_SHUTDOWN_ERR
}
loop {
match incoming.accept().await {
Ok((connection, remote_peer)) => return (incoming, connection, remote_peer),
Err(e) => {
if is_rt_shutdown_err(&e) {
tracing::debug!(error.msg = %e, error.details = ?e, "Failed to accept connection");
} else {
tracing::error!(error.msg = %e, error.details = ?e, "Failed to accept connection");
}
continue;
}
}
}
}
let Self {
mut command_inbox,
mut next_worker,
mut worker_handles,
incoming,
config: _,
max_queue_length,
handler,
application_state,
handler_output_future: _,
} = self;
let n_workers = worker_handles.len();
let mut incoming_join_set = JoinSet::new();
for incoming in incoming.into_iter() {
incoming_join_set.spawn(accept_connection(incoming));
}
let error = 'event_loop: loop {
let message =
poll_fn(|cx| Self::poll_inboxes(cx, &mut command_inbox, &mut incoming_join_set))
.await;
match message {
AcceptorInboxMessage::ServerCommand(command) => match command {
ServerCommand::Shutdown {
completion_notifier,
mode,
} => {
Self::shutdown(
completion_notifier,
mode,
incoming_join_set,
worker_handles,
)
.await;
return;
}
},
AcceptorInboxMessage::Connection(msg) => {
let (incoming, connection, remote_peer) = match msg {
Some(Ok((incoming, connection, remote_peer))) => {
(incoming, connection, remote_peer)
}
Some(Err(e)) => {
break 'event_loop e;
}
None => {
unreachable!(
"The JoinSet for incoming connections cannot ever be empty"
)
}
};
incoming_join_set.spawn(accept_connection(incoming));
let mut has_been_handled = false;
let mut connection_message = ConnectionMessage {
connection,
peer_addr: remote_peer,
};
for _ in 0..n_workers {
let mut has_crashed: Option<usize> = None;
let worker_handle = &worker_handles[next_worker];
if let Err(e) = worker_handle.dispatch(connection_message) {
connection_message = match e {
TrySendError::Full(message) => message,
TrySendError::Closed(conn) => {
has_crashed = Some(worker_handle.id());
conn
}
};
next_worker = (next_worker + 1) % n_workers;
} else {
has_been_handled = true;
break;
}
if let Some(worker_id) = has_crashed {
tracing::warn!(worker_id = worker_id, "Worker crashed, restarting it");
let (worker, worker_handle) = Worker::new(
worker_id,
max_queue_length,
handler,
application_state.clone(),
);
worker.spawn().expect("Failed to spawn worker thread");
worker_handles[worker_id] = worker_handle;
}
}
if !has_been_handled {
tracing::error!(
remote_peer = %remote_peer,
"All workers are busy, dropping connection",
);
}
}
}
};
tracing::error!(
error.msg = %error,
error.details = ?error,
"Failed to accept new connections. The acceptor thread will exit now."
);
}
fn poll_inboxes(
cx: &mut std::task::Context<'_>,
server_command_inbox: &mut tokio::sync::mpsc::Receiver<ServerCommand>,
incoming_join_set: &mut JoinSet<(IncomingStream, TcpStream, SocketAddr)>,
) -> Poll<AcceptorInboxMessage> {
if let Poll::Ready(Some(message)) = server_command_inbox.poll_recv(cx) {
return Poll::Ready(AcceptorInboxMessage::ServerCommand(message));
}
if let Poll::Ready(message) = incoming_join_set.poll_join_next(cx) {
return Poll::Ready(AcceptorInboxMessage::Connection(message));
}
Poll::Pending
}
fn spawn(self) -> thread::JoinHandle<()> {
thread::Builder::new()
.name("pavex-acceptor".to_string())
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build single-threaded Tokio runtime for acceptor thread");
LocalSet::new().block_on(&rt, self.run());
})
.expect("Failed to spawn acceptor thread")
}
async fn shutdown(
completion_notifier: tokio::sync::oneshot::Sender<()>,
mode: ShutdownMode,
incoming_join_set: JoinSet<(IncomingStream, TcpStream, SocketAddr)>,
worker_handles: Vec<WorkerHandle>,
) {
drop(incoming_join_set);
let mut shutdown_join_set = JoinSet::new();
for worker_handle in worker_handles {
let mode2 = mode.clone();
let future = worker_handle.shutdown(mode2);
if mode.is_graceful() {
shutdown_join_set.spawn_local(future);
}
}
if let ShutdownMode::Graceful { timeout } = mode {
let _ = tokio::time::timeout(timeout, async move {
while shutdown_join_set.join_next().await.is_some() {}
})
.await;
}
let _ = completion_notifier.send(());
}
}