use std::convert::Infallible;
use std::future::Future;
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::router::Router;
use tako_rs_core::types::BoxError;
use tokio::task::JoinSet;
use crate::ServerConfig;
#[inline]
fn is_abstract_path(path: &Path) -> bool {
path.to_str().is_some_and(|s| s.starts_with('@'))
}
async fn bind_unix_listener(path: &Path) -> io::Result<tokio::net::UnixListener> {
if is_abstract_path(path) {
#[cfg(target_os = "linux")]
{
use std::os::linux::net::SocketAddrExt;
let name = &path.to_str().unwrap().as_bytes()[1..];
let addr = std::os::unix::net::SocketAddr::from_abstract_name(name)?;
let std_listener = std::os::unix::net::UnixListener::bind_addr(&addr)?;
std_listener.set_nonblocking(true)?;
return tokio::net::UnixListener::from_std(std_listener);
}
#[cfg(not(target_os = "linux"))]
{
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"abstract Unix socket paths (`@`-prefixed) are Linux-only",
));
}
}
cleanup_stale_socket(path).await?;
tokio::net::UnixListener::bind(path)
}
#[derive(Debug, Clone)]
pub struct UnixPeerAddr {
pub path: Option<std::path::PathBuf>,
}
pub async fn serve_unix<F>(path: impl AsRef<Path>, handler: F) -> std::io::Result<()>
where
F: Fn(
tokio::net::UnixStream,
tokio::net::unix::SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let path = path.as_ref();
let listener = bind_unix_listener(path).await?;
tracing::info!("Unix socket server listening on {}", path.display());
let handler = Arc::new(handler);
loop {
let (stream, addr) = listener.accept().await?;
let handler = Arc::clone(&handler);
tokio::spawn(async move {
if let Err(e) = handler(stream, addr).await {
tracing::error!("Unix socket connection error: {e}");
}
});
}
}
pub async fn serve_unix_with_shutdown<F, S>(
path: impl AsRef<Path>,
handler: F,
signal: S,
) -> std::io::Result<()>
where
F: Fn(
tokio::net::UnixStream,
tokio::net::unix::SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + Send + 'static,
{
serve_unix_with_shutdown_and_drain(path, handler, signal, Duration::from_secs(30)).await
}
pub async fn serve_unix_with_shutdown_and_drain<F, S>(
path: impl AsRef<Path>,
handler: F,
signal: S,
drain_timeout: Duration,
) -> std::io::Result<()>
where
F: Fn(
tokio::net::UnixStream,
tokio::net::unix::SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + Send + 'static,
{
let path = path.as_ref();
let listener = bind_unix_listener(path).await?;
tracing::info!("Unix socket server listening on {}", path.display());
let handler = Arc::new(handler);
let mut join_set = JoinSet::new();
tokio::pin!(signal);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = result?;
let handler = Arc::clone(&handler);
join_set.spawn(async move {
if let Err(e) = handler(stream, addr).await {
tracing::error!("Unix socket connection error: {e}");
}
});
}
() = &mut signal => {
tracing::info!("Unix socket server shutting down, draining {} connections", join_set.len());
break;
}
}
}
let _ = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
})
.await;
Ok(())
}
pub async fn serve_unix_http(path: impl AsRef<Path>, router: Router) {
if let Err(e) = run_http(
path.as_ref(),
router,
None::<std::future::Pending<()>>,
ServerConfig::default(),
)
.await
{
tracing::error!("Unix HTTP server error: {e}");
}
}
pub async fn serve_unix_http_with_shutdown(
path: impl AsRef<Path>,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
) {
if let Err(e) = run_http(path.as_ref(), router, Some(signal), ServerConfig::default()).await {
tracing::error!("Unix HTTP server error: {e}");
}
}
pub async fn serve_unix_http_with_config(
path: impl AsRef<Path>,
router: Router,
config: ServerConfig,
) {
if let Err(e) = run_http(
path.as_ref(),
router,
None::<std::future::Pending<()>>,
config,
)
.await
{
tracing::error!("Unix HTTP server error: {e}");
}
}
pub async fn serve_unix_http_with_shutdown_and_config(
path: impl AsRef<Path>,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
config: ServerConfig,
) {
if let Err(e) = run_http(path.as_ref(), router, Some(signal), config).await {
tracing::error!("Unix HTTP server error: {e}");
}
}
async fn run_http(
path: &Path,
router: Router,
signal: Option<impl Future<Output = ()> + Send + 'static>,
config: ServerConfig,
) -> Result<(), BoxError> {
let listener = bind_unix_listener(path).await?;
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
tracing::debug!("Tako Unix HTTP listening on {}", path.display());
let mut join_set = JoinSet::new();
let mut accept_backoff = config.accept_backoff;
let max_conn_semaphore = config
.max_connections
.map(|n| Arc::new(tokio::sync::Semaphore::new(n)));
let drain_timeout = config.drain_timeout;
let header_read_timeout = config.header_read_timeout;
let keep_alive = config.keep_alive;
let cancel = tokio_util::sync::CancellationToken::new();
if let Some(s) = signal {
let cancel_for_signal = cancel.clone();
tokio::spawn(async move {
s.await;
cancel_for_signal.cancel();
});
}
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = match result {
Ok(v) => { accept_backoff.reset(); v }
Err(err) => {
tracing::warn!("Unix accept failed: {err}; backing off");
accept_backoff.sleep_and_grow().await;
continue;
}
};
let permit = if let Some(sem) = &max_conn_semaphore {
tokio::select! {
biased;
() = cancel.cancelled() => break,
permit = sem.clone().acquire_owned() => match permit {
Ok(p) => Some(p),
Err(_) => continue,
},
}
} else {
None
};
let io = hyper_util::rt::TokioIo::new(stream);
let router = router.clone();
let peer_addr = UnixPeerAddr {
path: addr.as_pathname().map(std::path::Path::to_path_buf),
};
join_set.spawn(async move {
let svc = service_fn(move |mut req| {
let router = router.clone();
let peer_addr = peer_addr.clone();
async move {
let conn_info = ConnInfo::unix(peer_addr.path.clone());
req.extensions_mut().insert(peer_addr);
req.extensions_mut().insert(conn_info);
let response = router.dispatch(req.map(TakoBody::incoming)).await;
Ok::<_, Infallible>(response)
}
});
let mut http = http1::Builder::new();
http.keep_alive(keep_alive);
http.timer(hyper_util::rt::TokioTimer::new());
if let Some(t) = header_read_timeout {
http.header_read_timeout(t);
}
let conn = http.serve_connection(io, svc).with_upgrades();
if let Err(err) = conn.await {
if err.is_incomplete_message() {
tracing::debug!("client disconnected mid-message on Unix socket: {err}");
} else {
tracing::error!("Error serving Unix HTTP connection: {err}");
}
}
drop(permit);
});
}
() = cancel.cancelled() => {
tracing::info!("Unix HTTP server shutting down...");
break;
}
}
}
let drain = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout exceeded, aborting {} remaining connections",
join_set.len()
);
join_set.abort_all();
}
if !is_abstract_path(path) {
let _ = std::fs::remove_file(path);
}
tracing::info!("Unix HTTP server shut down gracefully");
Ok(())
}
async fn cleanup_stale_socket(path: &Path) -> std::io::Result<()> {
use std::os::unix::fs::FileTypeExt;
use std::time::Duration;
let meta = match std::fs::symlink_metadata(path) {
Ok(m) => m,
Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(()),
Err(e) => return Err(e),
};
if !meta.file_type().is_socket() {
return Err(std::io::Error::new(
std::io::ErrorKind::AlreadyExists,
format!(
"{} exists but is not a unix socket; refusing to remove",
path.display()
),
));
}
let connect = tokio::net::UnixStream::connect(path);
match tokio::time::timeout(Duration::from_millis(50), connect).await {
Ok(Ok(_)) => Err(std::io::Error::new(
std::io::ErrorKind::AddrInUse,
format!("Unix socket {} is already in use", path.display()),
)),
Ok(Err(_)) | Err(_) => match std::fs::remove_file(path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
},
}
}