#![cfg(all(target_os = "linux", feature = "vsock"))]
#![cfg_attr(docsrs, doc(cfg(all(target_os = "linux", feature = "vsock"))))]
use std::convert::Infallible;
use std::future::Future;
use std::sync::Arc;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::conn_info::PeerAddr;
use tako_rs_core::conn_info::Transport;
use tako_rs_core::router::Router;
use tako_rs_core::types::BoxError;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use tokio_vsock::VsockAddr;
use tokio_vsock::VsockListener;
use crate::ServerConfig;
pub async fn serve_vsock_http(cid: u32, port: u32, router: Router) {
if let Err(e) = run(
cid,
port,
router,
None::<std::future::Pending<()>>,
ServerConfig::default(),
)
.await
{
tracing::error!("vsock HTTP server error: {e}");
}
}
pub async fn serve_vsock_http_with_shutdown(
cid: u32,
port: u32,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
) {
if let Err(e) = run(cid, port, router, Some(signal), ServerConfig::default()).await {
tracing::error!("vsock HTTP server error: {e}");
}
}
pub async fn serve_vsock_http_with_config(
cid: u32,
port: u32,
router: Router,
config: ServerConfig,
) {
if let Err(e) = run(cid, port, router, None::<std::future::Pending<()>>, config).await {
tracing::error!("vsock HTTP server error: {e}");
}
}
pub async fn serve_vsock_http_with_shutdown_and_config(
cid: u32,
port: u32,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
config: ServerConfig,
) {
if let Err(e) = run(cid, port, router, Some(signal), config).await {
tracing::error!("vsock HTTP server error: {e}");
}
}
async fn run(
cid: u32,
port: u32,
router: Router,
signal: Option<impl Future<Output = ()> + Send + 'static>,
config: ServerConfig,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
tako_rs_core::tracing::init_tracing();
let listener = VsockListener::bind(VsockAddr::new(cid, port))?;
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
tracing::info!("Tako vsock HTTP listening on cid={cid} port={port}");
let mut join_set = JoinSet::new();
let mut accept_backoff = config.accept_backoff;
let max_conn_semaphore = config.max_connections.map(|n| Arc::new(Semaphore::new(n)));
let drain_timeout = config.drain_timeout;
let header_read_timeout = config.header_read_timeout;
let keep_alive = config.keep_alive;
let cancel = tokio_util::sync::CancellationToken::new();
if let Some(s) = signal {
let cancel_for_signal = cancel.clone();
tokio::spawn(async move {
s.await;
cancel_for_signal.cancel();
});
}
loop {
tokio::select! {
result = listener.accept() => {
let (stream, peer) = match result {
Ok(v) => { accept_backoff.reset(); v }
Err(err) => {
tracing::warn!("vsock accept failed: {err}; backing off");
accept_backoff.sleep_and_grow().await;
continue;
}
};
let permit = if let Some(sem) = &max_conn_semaphore {
tokio::select! {
biased;
_ = cancel.cancelled() => break,
permit = sem.clone().acquire_owned() => match permit {
Ok(p) => Some(p),
Err(_) => continue,
},
}
} else {
None
};
let io = hyper_util::rt::TokioIo::new(stream);
let router = router.clone();
join_set.spawn(async move {
let peer_label = format!("vsock:{}:{}", peer.cid(), peer.port());
let svc = service_fn(move |mut req| {
let router = router.clone();
let peer_label = peer_label.clone();
async move {
let conn_info = ConnInfo {
peer: PeerAddr::Other(peer_label.clone()),
local: None,
transport: Transport::Http1,
tls: None,
};
req.extensions_mut().insert(conn_info);
let response = router.dispatch(req.map(TakoBody::incoming)).await;
Ok::<_, Infallible>(response)
}
});
let mut http = http1::Builder::new();
http.keep_alive(keep_alive);
http.timer(hyper_util::rt::TokioTimer::new());
if let Some(t) = header_read_timeout {
http.header_read_timeout(t);
}
if let Err(err) = http.serve_connection(io, svc).with_upgrades().await {
if err.is_incomplete_message() {
tracing::debug!("vsock client disconnected mid-message: {err}");
} else {
tracing::error!("vsock HTTP error: {err}");
}
}
drop(permit);
});
}
() = cancel.cancelled() => {
tracing::info!("vsock HTTP server shutting down...");
break;
}
}
}
let drain = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout ({:?}) exceeded, aborting {} remaining vsock connections",
drain_timeout,
join_set.len()
);
join_set.abort_all();
}
tracing::info!("vsock HTTP server shut down gracefully");
Ok(())
}