use std::convert::Infallible;
use std::future::Future;
use std::sync::Arc;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::router::Router;
#[cfg(feature = "signals")]
use tako_rs_core::signals::transport as signal_tx;
use tako_rs_core::types::BoxError;
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::ServerConfig;
pub async fn serve(listener: TcpListener, router: Router) {
if let Err(e) = run(
listener,
router,
None::<std::future::Pending<()>>,
ServerConfig::default(),
)
.await
{
tracing::error!("Server error: {e}");
}
}
pub async fn serve_with_shutdown(
listener: TcpListener,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
) {
if let Err(e) = run(listener, router, Some(signal), ServerConfig::default()).await {
tracing::error!("Server error: {e}");
}
}
pub async fn serve_with_config(listener: TcpListener, router: Router, config: ServerConfig) {
if let Err(e) = run(listener, router, None::<std::future::Pending<()>>, config).await {
tracing::error!("Server error: {e}");
}
}
pub async fn serve_with_shutdown_and_config(
listener: TcpListener,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
config: ServerConfig,
) {
if let Err(e) = run(listener, router, Some(signal), config).await {
tracing::error!("Server error: {e}");
}
}
async fn run(
listener: TcpListener,
router: Router,
signal: Option<impl Future<Output = ()> + Send + 'static>,
config: ServerConfig,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
tako_rs_core::tracing::init_tracing();
let router: &'static Router = Box::leak(Box::new(router));
#[cfg(feature = "plugins")]
router.setup_plugins_once();
let addr_str = listener.local_addr()?.to_string();
#[cfg(feature = "signals")]
signal_tx::emit_server_started(&addr_str, "tcp", false).await;
tracing::debug!("Tako listening on {}", addr_str);
let mut join_set = JoinSet::new();
let mut accept_backoff = config.accept_backoff;
let max_conn_semaphore = config.max_connections.map(|n| Arc::new(Semaphore::new(n)));
let keep_alive = config.keep_alive;
let header_read_timeout = config.header_read_timeout;
let keep_alive_timeout = config.keep_alive_timeout;
let drain_timeout = config.drain_timeout;
if let Some(t) = keep_alive_timeout {
tracing::warn!(
"ServerConfig::keep_alive_timeout ({:?}) is not currently plumbed to hyper's http1 builder (upstream gap); the value will be ignored.",
t
);
}
let cancel = CancellationToken::new();
if let Some(s) = signal {
let cancel_for_signal = cancel.clone();
tokio::spawn(async move {
s.await;
cancel_for_signal.cancel();
});
}
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = match result {
Ok(v) => { accept_backoff.reset(); v }
Err(err) => {
tracing::warn!("accept failed: {err}; backing off");
accept_backoff.sleep_and_grow().await;
continue;
}
};
let permit = if let Some(sem) = &max_conn_semaphore {
tokio::select! {
biased;
() = cancel.cancelled() => break,
permit = sem.clone().acquire_owned() => match permit {
Ok(p) => Some(p),
Err(_) => continue,
},
}
} else {
None
};
let _ = stream.set_nodelay(true);
let io = hyper_util::rt::TokioIo::new(stream);
join_set.spawn(async move {
#[cfg(feature = "signals")]
signal_tx::emit_connection_opened(&addr.to_string(), false, None).await;
let svc = service_fn(move |mut req| async move {
req.extensions_mut().insert(addr);
req.extensions_mut().insert(ConnInfo::tcp(addr));
let response = router.dispatch(req.map(TakoBody::incoming)).await;
Ok::<_, Infallible>(response)
});
let mut http = http1::Builder::new();
http.keep_alive(keep_alive);
http.pipeline_flush(true);
http.timer(hyper_util::rt::TokioTimer::new());
if let Some(t) = header_read_timeout {
http.header_read_timeout(t);
}
let _ = keep_alive_timeout;
let conn = http.serve_connection(io, svc).with_upgrades();
if let Err(err) = conn.await {
if err.is_incomplete_message() {
tracing::debug!("client disconnected mid-message: {err}");
} else {
tracing::error!("Error serving connection: {err}");
}
}
#[cfg(feature = "signals")]
signal_tx::emit_connection_closed(&addr.to_string(), false, None).await;
drop(permit);
});
}
() = cancel.cancelled() => {
tracing::info!("Shutdown signal received, draining connections...");
break;
}
}
}
let drain = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout ({:?}) exceeded, aborting {} remaining connections",
drain_timeout,
join_set.len()
);
join_set.abort_all();
}
tracing::info!("Server shut down gracefully");
Ok(())
}