use std::future::{poll_fn, Future};
use std::net::SocketAddr;
use std::task::Poll;
use std::thread;
use anyhow::Context;
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TrySendError;
use crate::connection::ConnectionInfo;
use crate::server::ShutdownMode;
pub(super) struct ConnectionMessage {
pub(super) connection: TcpStream,
pub(super) peer_addr: SocketAddr,
}
pub(super) struct WorkerHandle {
connection_outbox: tokio::sync::mpsc::Sender<ConnectionMessage>,
shutdown_outbox: tokio::sync::mpsc::UnboundedSender<ShutdownWorkerCommand>,
id: usize,
}
thread_local! {
static LIVE_CONNECTION_COUNTER: std::cell::RefCell<usize> = std::cell::RefCell::new(0);
}
struct ConnectionCounterGuard;
impl ConnectionCounterGuard {
fn new() -> Self {
LIVE_CONNECTION_COUNTER.with(|counter| {
let mut counter = counter.borrow_mut();
*counter += 1;
});
Self
}
}
impl Drop for ConnectionCounterGuard {
fn drop(&mut self) {
LIVE_CONNECTION_COUNTER.with(|counter| {
let mut counter = counter.borrow_mut();
*counter -= 1;
});
}
}
impl WorkerHandle {
pub(super) fn dispatch(
&self,
connection: ConnectionMessage,
) -> Result<(), TrySendError<ConnectionMessage>> {
self.connection_outbox.try_send(connection)
}
pub(super) fn id(&self) -> usize {
self.id
}
pub(super) fn shutdown(self, mode: ShutdownMode) -> impl Future<Output = ()> {
let (completion_notifier, completion) = tokio::sync::oneshot::channel();
let sent = self
.shutdown_outbox
.send(ShutdownWorkerCommand {
completion_notifier,
mode,
})
.is_ok();
async move {
if sent {
let _ = completion.await;
}
}
}
}
pub(super) struct ShutdownWorkerCommand {
completion_notifier: tokio::sync::oneshot::Sender<()>,
mode: ShutdownMode,
}
#[must_use]
pub(super) struct Worker<HandlerFuture, ApplicationState> {
connection_inbox: tokio::sync::mpsc::Receiver<ConnectionMessage>,
shutdown_inbox: tokio::sync::mpsc::UnboundedReceiver<ShutdownWorkerCommand>,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
id: usize,
}
impl<HandlerFuture, ApplicationState> Worker<HandlerFuture, ApplicationState>
where
HandlerFuture: Future<Output = crate::response::Response> + 'static,
ApplicationState: Clone + Send + Sync + 'static,
{
pub(super) fn new(
id: usize,
max_queue_length: usize,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
) -> (Self, WorkerHandle) {
let (connection_outbox, connection_inbox) = tokio::sync::mpsc::channel(max_queue_length);
let (shutdown_outbox, shutdown_inbox) = tokio::sync::mpsc::unbounded_channel();
let self_ = Self {
connection_inbox,
shutdown_inbox,
handler,
application_state,
id,
};
let handle = WorkerHandle {
connection_outbox,
shutdown_outbox,
id,
};
(self_, handle)
}
pub(super) fn spawn(self) -> Result<thread::JoinHandle<()>, anyhow::Error> {
thread::Builder::new()
.name(format!("pavex-worker-{}", self.id))
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build single-threaded Tokio runtime for worker thread");
let local = tokio::task::LocalSet::new();
local.block_on(&runtime, self.run());
})
.context("Failed to spawn worker thread")
}
async fn run(self) {
let Self {
mut connection_inbox,
mut shutdown_inbox,
handler,
application_state,
id,
} = self;
'event_loop: loop {
let message =
poll_fn(|cx| Self::poll_inboxes(cx, &mut shutdown_inbox, &mut connection_inbox))
.await;
match message {
WorkerInboxMessage::Connection(connection) => {
Self::handle_connection(connection, handler, application_state.clone());
}
WorkerInboxMessage::Shutdown(shutdown) => {
let ShutdownWorkerCommand {
completion_notifier,
mode,
} = shutdown;
match mode {
ShutdownMode::Graceful { timeout } => {
connection_inbox.close();
while let Some(connection) = connection_inbox.recv().await {
Self::handle_connection(
connection,
handler,
application_state.clone(),
);
}
let connections_closed = async move {
let mut ticker =
tokio::time::interval(std::time::Duration::from_millis(500));
loop {
ticker.tick().await;
let ready_to_shutdown =
LIVE_CONNECTION_COUNTER.with(|counter| {
let counter = counter.borrow();
*counter == 0
});
if ready_to_shutdown {
break;
}
}
};
let _ = tokio::time::timeout(timeout, connections_closed).await;
}
ShutdownMode::Forced => {}
}
let _ = completion_notifier.send(());
break 'event_loop;
}
}
}
tracing::info!(worker_id = id, "Worker shut down");
}
fn handle_connection(
connection_message: ConnectionMessage,
handler: fn(
http::Request<hyper::body::Incoming>,
Option<ConnectionInfo>,
ApplicationState,
) -> HandlerFuture,
application_state: ApplicationState,
) {
let ConnectionMessage {
connection,
peer_addr,
} = connection_message;
let handler = hyper::service::service_fn(move |request| {
let state = application_state.clone();
async move {
let handler = (handler)(request, Some(ConnectionInfo { peer_addr }), state);
let response = handler.await;
let response = hyper::Response::from(response);
Ok::<_, hyper::Error>(response)
}
});
let connection_counter_guard = ConnectionCounterGuard::new();
tokio::task::spawn_local(async move {
let _guard = connection_counter_guard;
let builder = hyper_util::server::conn::auto::Builder::new(LocalExec);
let connection = TokioIo::new(connection);
builder
.serve_connection(connection, handler)
.await
.expect("Failed to handle a connection");
});
}
fn poll_inboxes(
cx: &mut std::task::Context<'_>,
shutdown_inbox: &mut tokio::sync::mpsc::UnboundedReceiver<ShutdownWorkerCommand>,
connection_inbox: &mut tokio::sync::mpsc::Receiver<ConnectionMessage>,
) -> Poll<WorkerInboxMessage> {
if let Poll::Ready(Some(message)) = shutdown_inbox.poll_recv(cx) {
return Poll::Ready(message.into());
}
if let Poll::Ready(Some(message)) = connection_inbox.poll_recv(cx) {
return Poll::Ready(message.into());
}
Poll::Pending
}
}
enum WorkerInboxMessage {
Connection(ConnectionMessage),
Shutdown(ShutdownWorkerCommand),
}
impl From<ConnectionMessage> for WorkerInboxMessage {
fn from(connection: ConnectionMessage) -> Self {
Self::Connection(connection)
}
}
impl From<ShutdownWorkerCommand> for WorkerInboxMessage {
fn from(command: ShutdownWorkerCommand) -> Self {
Self::Shutdown(command)
}
}
#[derive(Clone, Copy, Debug)]
struct LocalExec;
impl<F> hyper::rt::Executor<F> for LocalExec
where
F: Future + 'static, {
fn execute(&self, fut: F) {
tokio::task::spawn_local(fut);
}
}