#![cfg_attr(docsrs, feature(doc_cfg))]
use std::convert::Infallible;
use std::io;
use std::net::SocketAddr;
use std::str::FromStr;
use std::time::Duration;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use socket2::Domain;
use socket2::Protocol;
use socket2::Socket;
use socket2::Type;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::router::Router;
use tokio::net::TcpListener;
use tokio::runtime::Builder;
use tokio::task::LocalSet;
#[derive(Debug, Clone)]
pub struct PerThreadConfig {
pub workers: usize,
pub pin_to_core: bool,
pub backlog: i32,
pub drain_timeout: Duration,
}
impl Default for PerThreadConfig {
fn default() -> Self {
Self {
workers: num_cpus(),
pin_to_core: cfg!(feature = "affinity"),
backlog: 1024,
drain_timeout: Duration::from_secs(30),
}
}
}
#[derive(Default)]
struct BindStatus {
succeeded: std::sync::atomic::AtomicUsize,
failed: std::sync::atomic::AtomicUsize,
first_err: std::sync::Mutex<Option<io::Error>>,
notify: tokio::sync::Notify,
}
#[derive(Clone, Default)]
pub struct PerThreadShutdown {
inner: tokio_util::sync::CancellationToken,
bind_status: std::sync::Arc<BindStatus>,
}
impl PerThreadShutdown {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn trigger(&self) {
self.inner.cancel();
}
pub async fn notified(&self) {
self.inner.cancelled().await;
}
pub(crate) fn report_bind_success(&self) {
self
.bind_status
.succeeded
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.bind_status.notify.notify_waiters();
}
pub(crate) fn report_bind_failure(&self, err: io::Error) {
{
let mut guard = self.bind_status.first_err.lock().unwrap();
if guard.is_none() {
*guard = Some(err);
}
}
self
.bind_status
.failed
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.bind_status.notify.notify_waiters();
}
pub async fn wait_for_bind_outcome(&self, total: usize) -> io::Result<()> {
use std::sync::atomic::Ordering;
loop {
let notified = self.bind_status.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
let succ = self.bind_status.succeeded.load(Ordering::SeqCst);
let fail = self.bind_status.failed.load(Ordering::SeqCst);
if succ > 0 {
return Ok(());
}
if succ + fail >= total {
let err = self
.bind_status
.first_err
.lock()
.unwrap()
.take()
.unwrap_or_else(|| {
io::Error::other(format!("all {total} per-thread workers failed to bind"))
});
return Err(err);
}
notified.await;
}
}
}
fn num_cpus() -> usize {
std::thread::available_parallelism().map_or(1, std::num::NonZero::get)
}
#[cfg(feature = "compio")]
fn compio_accept_backoff() -> Duration {
Duration::from_millis(5)
}
fn warn_reuseport_platform_once() {
static WARNED: std::sync::Once = std::sync::Once::new();
WARNED.call_once(|| {
#[cfg(target_os = "linux")]
{
}
#[cfg(all(unix, not(target_os = "linux")))]
{
tracing::warn!(
"tako-server-pt: SO_REUSEPORT is being used on a non-Linux Unix \
platform. The kernel typically sends incoming connections only to \
the most recent binder, so multi-worker thread-per-core mode will \
not load-balance correctly. Use a single worker or run on Linux."
);
}
#[cfg(windows)]
{
tracing::warn!(
"tako-server-pt: SO_REUSEPORT does not exist on Windows. Only the \
first worker will accept connections; subsequent worker binds will \
fail with EADDRINUSE. Use a single worker on Windows."
);
}
});
}
fn bind_reuseport_std(addr: SocketAddr, backlog: i32) -> io::Result<std::net::TcpListener> {
warn_reuseport_platform_once();
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
socket.set_reuse_port(true)?;
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(backlog)?;
Ok(socket.into())
}
fn bind_reuseport(addr: SocketAddr, backlog: i32) -> io::Result<TcpListener> {
TcpListener::from_std(bind_reuseport_std(addr, backlog)?)
}
#[cfg(feature = "compio")]
fn bind_reuseport_compio(addr: SocketAddr, backlog: i32) -> io::Result<compio::net::TcpListener> {
compio::net::TcpListener::from_std(bind_reuseport_std(addr, backlog)?)
}
pub fn serve_per_thread(addr: &str, router: Router, cfg: PerThreadConfig) -> io::Result<()> {
let workers = cfg.workers;
let (handle, shutdown) = spawn_per_thread(addr, router, cfg)?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| io::Error::other(format!("ctrl-c runtime: {e}")))?;
let result: io::Result<()> = rt.block_on(async {
shutdown.wait_for_bind_outcome(workers).await?;
let _ = tokio::signal::ctrl_c().await;
Ok(())
});
shutdown.trigger();
for h in handle {
let _ = h.join();
}
result
}
pub fn spawn_per_thread(
addr: &str,
router: Router,
cfg: PerThreadConfig,
) -> io::Result<(Vec<std::thread::JoinHandle<()>>, PerThreadShutdown)> {
let socket_addr =
SocketAddr::from_str(addr).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let router: &'static Router = Box::leak(Box::new(router));
let shutdown = PerThreadShutdown::new();
let mut handles = Vec::with_capacity(cfg.workers);
for worker_id in 0..cfg.workers {
let cfg = cfg.clone();
let shutdown = shutdown.clone();
let h = std::thread::Builder::new()
.name(format!("tako-pt-{worker_id}"))
.spawn(move || worker_main(worker_id, socket_addr, router, cfg, shutdown))
.expect("spawn tako-pt worker");
handles.push(h);
}
Ok((handles, shutdown))
}
#[cfg_attr(not(feature = "affinity"), allow(unused_variables))]
fn worker_main(
worker_id: usize,
addr: SocketAddr,
router: &'static Router,
cfg: PerThreadConfig,
shutdown: PerThreadShutdown,
) {
#[cfg(feature = "affinity")]
if cfg.pin_to_core {
if let Some(ids) = core_affinity::get_core_ids() {
if let Some(id) = ids.get(worker_id) {
if !core_affinity::set_for_current(*id) {
tracing::warn!(
worker_id,
"pin_to_core: core_affinity::set_for_current returned false; running without affinity"
);
}
} else {
tracing::warn!(
worker_id,
available_cores = ids.len(),
"pin_to_core: worker_id exceeds available cores; running without affinity"
);
}
} else {
tracing::warn!(
worker_id,
"pin_to_core: core_affinity::get_core_ids() returned None; running without affinity"
);
}
}
let rt = match Builder::new_current_thread().enable_all().build() {
Ok(rt) => rt,
Err(e) => {
tracing::error!("worker {worker_id}: failed to build runtime: {e}");
shutdown.report_bind_failure(io::Error::other(format!(
"worker {worker_id}: failed to build runtime: {e}"
)));
return;
}
};
let local = LocalSet::new();
local.block_on(&rt, async move {
let listener = match bind_reuseport(addr, cfg.backlog) {
Ok(l) => {
shutdown.report_bind_success();
l
}
Err(e) => {
tracing::error!("worker {worker_id}: bind failed: {e}");
shutdown.report_bind_failure(e);
return;
}
};
tracing::debug!("tako-pt worker {worker_id} listening on {addr}");
let shutdown_fut = shutdown.notified();
tokio::pin!(shutdown_fut);
let mut connection_handles: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
loop {
tokio::select! {
accept = listener.accept() => {
let (stream, peer) = match accept {
Ok(v) => v,
Err(e) => {
tracing::warn!("worker {worker_id}: accept failed: {e}");
continue;
}
};
if let Err(e) = stream.set_nodelay(true) {
tracing::debug!("worker {worker_id}: set_nodelay failed for {peer}: {e}");
}
let io = hyper_util::rt::TokioIo::new(stream);
connection_handles.spawn_local(async move {
let svc = service_fn(move |mut req| async move {
req.extensions_mut().insert(peer);
req.extensions_mut().insert(ConnInfo::tcp(peer));
let resp = router.dispatch(req.map(TakoBody::incoming)).await;
Ok::<_, Infallible>(resp)
});
let mut http = http1::Builder::new();
http.keep_alive(true);
http.pipeline_flush(true);
if let Err(err) = http.serve_connection(io, svc).with_upgrades().await {
if err.is_incomplete_message() {
tracing::debug!("worker {worker_id}: client disconnected mid-message: {err}");
} else {
tracing::error!("worker {worker_id}: connection error: {err}");
}
}
});
while connection_handles.try_join_next().is_some() {}
}
() = &mut shutdown_fut => {
tracing::info!("worker {worker_id}: shutdown signalled, draining");
break;
}
}
}
let drain = tokio::time::timeout(cfg.drain_timeout, async {
while connection_handles.join_next().await.is_some() {}
});
let _ = drain.await;
});
}
#[cfg(feature = "compio")]
#[cfg_attr(docsrs, doc(cfg(feature = "compio")))]
pub fn serve_per_thread_compio(addr: &str, router: Router, cfg: PerThreadConfig) -> io::Result<()> {
let socket_addr =
SocketAddr::from_str(addr).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let router: &'static Router = Box::leak(Box::new(router));
let workers = cfg.workers;
let shutdown = PerThreadShutdown::new();
let mut handles = Vec::with_capacity(cfg.workers);
for worker_id in 0..cfg.workers {
let cfg = cfg.clone();
let shutdown = shutdown.clone();
let h = std::thread::Builder::new()
.name(format!("tako-pt-compio-{worker_id}"))
.spawn(move || worker_main_compio(worker_id, socket_addr, router, cfg, shutdown))
.expect("spawn tako-pt-compio worker");
handles.push(h);
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| io::Error::other(format!("ctrl-c runtime: {e}")))?;
let result: io::Result<()> = rt.block_on(async {
shutdown.wait_for_bind_outcome(workers).await?;
let _ = tokio::signal::ctrl_c().await;
Ok(())
});
shutdown.trigger();
for h in handles {
let _ = h.join();
}
result
}
#[cfg(feature = "compio")]
struct PtConnGuard {
inflight: std::sync::Arc<std::sync::atomic::AtomicUsize>,
drain_notify: std::sync::Arc<tokio::sync::Notify>,
}
#[cfg(feature = "compio")]
impl PtConnGuard {
fn new(
inflight: std::sync::Arc<std::sync::atomic::AtomicUsize>,
drain_notify: std::sync::Arc<tokio::sync::Notify>,
) -> Self {
inflight.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
inflight,
drain_notify,
}
}
}
#[cfg(feature = "compio")]
impl Drop for PtConnGuard {
fn drop(&mut self) {
self
.inflight
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
self.drain_notify.notify_waiters();
}
}
#[cfg(feature = "compio")]
#[cfg_attr(not(feature = "affinity"), allow(unused_variables))]
fn worker_main_compio(
worker_id: usize,
addr: SocketAddr,
router: &'static Router,
cfg: PerThreadConfig,
shutdown: PerThreadShutdown,
) {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use cyper_core::HyperStream;
use tokio::sync::Notify;
#[cfg(feature = "affinity")]
if cfg.pin_to_core {
if let Some(ids) = core_affinity::get_core_ids() {
if let Some(id) = ids.get(worker_id) {
if !core_affinity::set_for_current(*id) {
tracing::warn!(
worker_id,
"pin_to_core: core_affinity::set_for_current returned false; running without affinity"
);
}
} else {
tracing::warn!(
worker_id,
available_cores = ids.len(),
"pin_to_core: worker_id exceeds available cores; running without affinity"
);
}
} else {
tracing::warn!(
worker_id,
"pin_to_core: core_affinity::get_core_ids() returned None; running without affinity"
);
}
}
let rt = match compio::runtime::RuntimeBuilder::new().build() {
Ok(rt) => rt,
Err(e) => {
tracing::error!("worker {worker_id}: failed to build compio runtime: {e}");
shutdown.report_bind_failure(io::Error::other(format!(
"worker {worker_id}: failed to build compio runtime: {e}"
)));
return;
}
};
rt.block_on(async move {
let listener = match bind_reuseport_compio(addr, cfg.backlog) {
Ok(l) => {
shutdown.report_bind_success();
l
}
Err(e) => {
tracing::error!("worker {worker_id}: bind failed: {e}");
shutdown.report_bind_failure(e);
return;
}
};
tracing::debug!("tako-pt-compio worker {worker_id} listening on {addr}");
let cancel = shutdown.inner.clone();
let mut backoff = compio_accept_backoff();
let inflight = Arc::new(AtomicUsize::new(0));
let drain_notify = Arc::new(Notify::new());
loop {
let accept_fut = listener.accept();
let cancel_fut = cancel.cancelled();
tokio::pin!(accept_fut, cancel_fut);
let accept = futures_util::future::select(accept_fut, cancel_fut).await;
let (stream, peer) = match accept {
futures_util::future::Either::Left((Ok(v), _)) => {
backoff = compio_accept_backoff();
v
}
futures_util::future::Either::Left((Err(e), _)) => {
let delay = backoff;
tracing::warn!("worker {worker_id}: accept failed: {e}; backing off {delay:?}");
compio::time::sleep(delay).await;
backoff = std::cmp::min(backoff * 2, Duration::from_secs(1));
continue;
}
futures_util::future::Either::Right(_) => {
tracing::info!("worker {worker_id}: shutdown signalled, draining");
break;
}
};
if let Err(e) = stream.set_nodelay(true) {
tracing::debug!("worker {worker_id}: set_nodelay failed for {peer}: {e}");
}
let io = HyperStream::new(stream);
let guard = PtConnGuard::new(inflight.clone(), drain_notify.clone());
compio::runtime::spawn(async move {
let _guard = guard;
let svc = service_fn(move |mut req| async move {
req.extensions_mut().insert(peer);
req.extensions_mut().insert(ConnInfo::tcp(peer));
let resp = router
.dispatch(req.map(tako_rs_core::body::TakoBody::new))
.await;
Ok::<_, Infallible>(resp)
});
let mut http = http1::Builder::new();
http.keep_alive(true);
if let Err(err) = http.serve_connection(io, svc).with_upgrades().await {
if err.is_incomplete_message() {
tracing::debug!("worker {worker_id}: client disconnected mid-message: {err}");
} else {
tracing::error!("worker {worker_id}: connection error: {err}");
}
}
})
.detach();
}
let drain_deadline = std::time::Instant::now() + cfg.drain_timeout;
while inflight.load(Ordering::SeqCst) > 0 {
let now = std::time::Instant::now();
if now >= drain_deadline {
tracing::warn!(
worker_id,
drain_timeout = ?cfg.drain_timeout,
still_inflight = inflight.load(Ordering::SeqCst),
"drain timeout exceeded; remaining connections will be aborted"
);
break;
}
let remaining = drain_deadline - now;
let wait = drain_notify.notified();
let sleep = compio::time::sleep(remaining);
let wait = std::pin::pin!(wait);
let sleep = std::pin::pin!(sleep);
if let futures_util::future::Either::Right(_) =
futures_util::future::select(wait, sleep).await
{
tracing::warn!(
worker_id,
drain_timeout = ?cfg.drain_timeout,
still_inflight = inflight.load(Ordering::SeqCst),
"drain timeout exceeded; remaining connections will be aborted"
);
break;
}
}
});
}