use std::convert::Infallible;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use compio::net::TcpListener;
use cyper_core::HyperStream;
use futures_util::future::Either;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::router::Router;
#[cfg(feature = "signals")]
use tako_rs_core::signals::transport as signal_tx;
use tako_rs_core::types::BoxError;
use tokio::sync::Notify;
use crate::ServerConfig;
pub(crate) struct ConnectionGuard {
inflight: Arc<AtomicUsize>,
drain_notify: Arc<Notify>,
}
impl ConnectionGuard {
pub(crate) fn new(inflight: Arc<AtomicUsize>, drain_notify: Arc<Notify>) -> Self {
inflight.fetch_add(1, Ordering::SeqCst);
Self {
inflight,
drain_notify,
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.inflight.fetch_sub(1, Ordering::SeqCst);
self.drain_notify.notify_waiters();
}
}
pub async fn serve(listener: TcpListener, router: Router) {
if let Err(e) = run(
listener,
router,
None::<std::future::Pending<()>>,
ServerConfig::default(),
)
.await
{
tracing::error!("Server error: {e}");
}
}
pub async fn serve_with_shutdown(
listener: TcpListener,
router: Router,
signal: impl Future<Output = ()>,
) {
if let Err(e) = run(listener, router, Some(signal), ServerConfig::default()).await {
tracing::error!("Server error: {e}");
}
}
pub async fn serve_with_config(listener: TcpListener, router: Router, config: ServerConfig) {
if let Err(e) = run(listener, router, None::<std::future::Pending<()>>, config).await {
tracing::error!("Server error: {e}");
}
}
pub async fn serve_with_shutdown_and_config(
listener: TcpListener,
router: Router,
signal: impl Future<Output = ()>,
config: ServerConfig,
) {
if let Err(e) = run(listener, router, Some(signal), config).await {
tracing::error!("Server error: {e}");
}
}
async fn run(
listener: TcpListener,
router: Router,
signal: Option<impl Future<Output = ()>>,
config: ServerConfig,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
tako_rs_core::tracing::init_tracing();
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
let addr_str = listener.local_addr()?.to_string();
#[cfg(feature = "signals")]
signal_tx::emit_server_started(&addr_str, "tcp", false).await;
tracing::debug!("Tako listening on {}", addr_str);
let inflight = Arc::new(AtomicUsize::new(0));
let drain_notify = Arc::new(Notify::new());
let drain_timeout = config.drain_timeout;
let keep_alive = config.keep_alive;
let max_conn_semaphore = config
.max_connections
.map(|n| Arc::new(tokio::sync::Semaphore::new(n)));
let mut accept_backoff = config.accept_backoff;
let signal = signal.map(|s| Box::pin(s));
let mut signal_fused = std::pin::pin!(async {
if let Some(s) = signal {
s.await;
} else {
std::future::pending::<()>().await;
}
});
loop {
let accept = std::pin::pin!(listener.accept());
match futures_util::future::select(accept, signal_fused.as_mut()).await {
Either::Left((result, _)) => {
let (stream, addr) = match result {
Ok(v) => {
accept_backoff.reset();
v
}
Err(err) => {
tracing::warn!("compio accept failed: {err}; backing off");
let d = accept_backoff.current_and_grow();
let sleep = std::pin::pin!(compio::time::sleep(d));
match futures_util::future::select(sleep, signal_fused.as_mut()).await {
Either::Left(((), _)) => continue,
Either::Right(_) => break,
}
}
};
let permit = if let Some(sem) = max_conn_semaphore.as_ref() {
let acquire = std::pin::pin!(sem.clone().acquire_owned());
match futures_util::future::select(acquire, signal_fused.as_mut()).await {
Either::Left((Ok(p), _)) => Some(p),
Either::Left((Err(_), _)) => continue,
Either::Right(_) => break,
}
} else {
None
};
let io = HyperStream::new(stream);
let router = router.clone();
let guard = ConnectionGuard::new(inflight.clone(), drain_notify.clone());
compio::runtime::spawn(async move {
let _permit = permit;
let _guard = guard;
#[cfg(feature = "signals")]
signal_tx::emit_connection_opened(&addr.to_string(), false, None).await;
let svc = service_fn(move |mut req| {
let router = router.clone();
async move {
req.extensions_mut().insert(addr);
req.extensions_mut().insert(ConnInfo::tcp(addr));
let response = router.dispatch(req.map(TakoBody::new)).await;
Ok::<_, Infallible>(response)
}
});
let mut http = http1::Builder::new();
http.keep_alive(keep_alive);
let conn = http.serve_connection(io, svc).with_upgrades();
if let Err(err) = conn.await {
if err.is_incomplete_message() {
tracing::debug!("client disconnected mid-message: {err}");
} else {
tracing::error!("Error serving connection: {err}");
}
}
#[cfg(feature = "signals")]
signal_tx::emit_connection_closed(&addr.to_string(), false, None).await;
})
.detach();
}
Either::Right(_) => {
tracing::info!("Shutdown signal received, draining connections...");
break;
}
}
}
let drain_deadline = std::time::Instant::now() + drain_timeout;
while inflight.load(Ordering::SeqCst) > 0 {
let now = std::time::Instant::now();
if now >= drain_deadline {
tracing::warn!(
"Drain timeout ({:?}) exceeded, {} connections still active",
drain_timeout,
inflight.load(Ordering::SeqCst)
);
break;
}
let remaining = drain_deadline - now;
let drain_wait = drain_notify.notified();
let sleep = compio::time::sleep(remaining);
let drain_wait = std::pin::pin!(drain_wait);
let sleep = std::pin::pin!(sleep);
if let Either::Right(_) = futures_util::future::select(drain_wait, sleep).await {
tracing::warn!(
"Drain timeout ({:?}) exceeded, {} connections still active",
drain_timeout,
inflight.load(Ordering::SeqCst)
);
break;
}
}
tracing::info!("Server shut down gracefully");
Ok(())
}