use std::convert::Infallible;
use std::future::Future;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use tokio::task::JoinSet;
use crate::body::TakoBody;
use crate::router::Router;
use crate::types::BoxError;
const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone)]
pub struct UnixPeerAddr {
pub path: Option<std::path::PathBuf>,
}
pub async fn serve_unix<F>(path: impl AsRef<Path>, handler: F) -> std::io::Result<()>
where
F: Fn(
tokio::net::UnixStream,
tokio::net::unix::SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let path = path.as_ref();
cleanup_stale_socket(path)?;
let listener = tokio::net::UnixListener::bind(path)?;
tracing::info!("Unix socket server listening on {}", path.display());
let handler = Arc::new(handler);
loop {
let (stream, addr) = listener.accept().await?;
let handler = Arc::clone(&handler);
tokio::spawn(async move {
if let Err(e) = handler(stream, addr).await {
tracing::error!("Unix socket connection error: {e}");
}
});
}
}
pub async fn serve_unix_with_shutdown<F, S>(
path: impl AsRef<Path>,
handler: F,
signal: S,
) -> std::io::Result<()>
where
F: Fn(
tokio::net::UnixStream,
tokio::net::unix::SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + Send + 'static,
{
let path = path.as_ref();
cleanup_stale_socket(path)?;
let listener = tokio::net::UnixListener::bind(path)?;
tracing::info!("Unix socket server listening on {}", path.display());
let handler = Arc::new(handler);
let mut join_set = JoinSet::new();
tokio::pin!(signal);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = result?;
let handler = Arc::clone(&handler);
join_set.spawn(async move {
if let Err(e) = handler(stream, addr).await {
tracing::error!("Unix socket connection error: {e}");
}
});
}
() = &mut signal => {
tracing::info!("Unix socket server shutting down, draining {} connections", join_set.len());
break;
}
}
}
let drain_timeout = Duration::from_secs(30);
let _ = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
})
.await;
Ok(())
}
pub async fn serve_unix_http(path: impl AsRef<Path>, router: Router) {
if let Err(e) = run_http(path.as_ref(), router, None::<std::future::Pending<()>>).await {
tracing::error!("Unix HTTP server error: {e}");
}
}
pub async fn serve_unix_http_with_shutdown(
path: impl AsRef<Path>,
router: Router,
signal: impl Future<Output = ()>,
) {
if let Err(e) = run_http(path.as_ref(), router, Some(signal)).await {
tracing::error!("Unix HTTP server error: {e}");
}
}
async fn run_http(
path: &Path,
router: Router,
signal: Option<impl Future<Output = ()>>,
) -> Result<(), BoxError> {
cleanup_stale_socket(path)?;
let listener = tokio::net::UnixListener::bind(path)?;
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
tracing::debug!("Tako Unix HTTP listening on {}", path.display());
let mut join_set = JoinSet::new();
let signal = signal.map(|s| Box::pin(s));
let signal_fused = async {
if let Some(s) = signal {
s.await;
} else {
std::future::pending::<()>().await;
}
};
tokio::pin!(signal_fused);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = result?;
let io = hyper_util::rt::TokioIo::new(stream);
let router = router.clone();
let peer_addr = UnixPeerAddr {
path: addr.as_pathname().map(|p| p.to_path_buf()),
};
join_set.spawn(async move {
let svc = service_fn(move |mut req| {
let router = router.clone();
let peer_addr = peer_addr.clone();
async move {
req.extensions_mut().insert(peer_addr);
let response = router.dispatch(req.map(TakoBody::incoming)).await;
Ok::<_, Infallible>(response)
}
});
let mut http = http1::Builder::new();
http.keep_alive(true);
let conn = http.serve_connection(io, svc).with_upgrades();
if let Err(err) = conn.await {
tracing::error!("Error serving Unix HTTP connection: {err}");
}
});
}
() = &mut signal_fused => {
tracing::info!("Unix HTTP server shutting down...");
break;
}
}
}
let drain = tokio::time::timeout(DEFAULT_DRAIN_TIMEOUT, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout exceeded, aborting {} remaining connections",
join_set.len()
);
join_set.abort_all();
}
let _ = std::fs::remove_file(path);
tracing::info!("Unix HTTP server shut down gracefully");
Ok(())
}
fn cleanup_stale_socket(path: &Path) -> std::io::Result<()> {
if path.exists() {
match std::os::unix::net::UnixStream::connect(path) {
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::AddrInUse,
format!("Unix socket {} is already in use", path.display()),
));
}
Err(_) => {
std::fs::remove_file(path)?;
}
}
}
Ok(())
}