use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Notify;
use tracing::{error, info, warn};
pub const GRACE_PERIOD: std::time::Duration = std::time::Duration::from_secs(5 * 60);
pub async fn run_daemon(sock_path: &Path, pid_path: &Path) -> i32 {
remove_stale_socket(sock_path);
let Some(listener) = bind_listener(sock_path) else {
cleanup(pid_path, sock_path);
return 1;
};
info!(path = %sock_path.display(), "listening on Unix domain socket");
let client_count = Arc::new(AtomicUsize::new(0));
let client_changed = Arc::new(Notify::new());
let accept_clients = client_count.clone();
let accept_notify = client_changed.clone();
let accept_handle = tokio::spawn(async move {
accept_loop(listener, accept_clients, accept_notify).await;
});
let reason = shutdown_monitor(client_count, client_changed).await;
info!(reason = %reason, "initiating shutdown");
accept_handle.abort();
let _ = accept_handle.await;
0
}
#[cfg(unix)]
async fn accept_loop(
listener: tokio::net::UnixListener,
client_count: Arc<AtomicUsize>,
client_changed: Arc<Notify>,
) {
loop {
let (stream, _addr) = match listener.accept().await {
Ok(pair) => pair,
Err(e) => {
warn!("Failed to accept connection: {e}");
continue;
}
};
let count = client_count.clone();
let notify = client_changed.clone();
let prev = count.fetch_add(1, Ordering::SeqCst);
info!(clients = prev + 1, "client connected");
notify.notify_waiters();
let _handle: tokio::task::JoinHandle<()> = tokio::spawn(async move {
handle_client(stream).await;
let prev = count.fetch_sub(1, Ordering::SeqCst);
info!(clients = prev - 1, "client disconnected");
notify.notify_waiters();
});
}
}
#[cfg(unix)]
async fn handle_client(stream: tokio::net::UnixStream) {
let _ = stream.readable().await;
let mut buf = [0u8; 1024];
loop {
match stream.try_read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = stream.readable().await;
}
}
}
}
#[cfg(not(unix))]
async fn accept_loop(_listener: (), _client_count: Arc<AtomicUsize>, _client_changed: Arc<Notify>) {
std::future::pending::<()>().await;
}
async fn shutdown_monitor(
client_count: Arc<AtomicUsize>,
client_changed: Arc<Notify>,
) -> &'static str {
loop {
let clients = client_count.load(Ordering::SeqCst);
if clients > 0 {
tokio::select! {
() = shutdown_signal() => return "received shutdown signal",
() = client_changed.notified() => {}
}
} else {
info!(
grace_secs = GRACE_PERIOD.as_secs(),
"no active clients, grace period started"
);
tokio::select! {
() = shutdown_signal() => return "received shutdown signal",
() = tokio::time::sleep(GRACE_PERIOD) => {
if client_count.load(Ordering::SeqCst) == 0 {
return "grace period expired with no clients";
}
}
() = client_changed.notified() => {}
}
}
}
}
async fn shutdown_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = signal(SignalKind::terminate()).unwrap_or_else(|e| {
error!("Failed to register SIGTERM handler: {e}");
std::process::exit(1);
});
let mut sigint = signal(SignalKind::interrupt()).unwrap_or_else(|e| {
error!("Failed to register SIGINT handler: {e}");
std::process::exit(1);
});
tokio::select! {
() = async { let _ = sigterm.recv().await; } => {}
() = async { let _ = sigint.recv().await; } => {}
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
}
pub fn check_existing_daemon(pid_path: &Path) -> Result<(), String> {
let contents = match std::fs::read_to_string(pid_path) {
Ok(c) => c,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
Err(e) => {
return Err(format!(
"Unable to read PID file {}: {e}",
pid_path.display()
));
}
};
let Ok(pid) = contents.trim().parse::<u32>() else {
warn!(
path = %pid_path.display(),
contents = %contents.trim(),
"Removing PID file with invalid contents"
);
let _ = std::fs::remove_file(pid_path);
return Ok(());
};
if is_process_alive(pid) {
return Err(format!(
"Another synwire-daemon is already running (PID {pid})"
));
}
warn!(pid, "Removing stale PID file for dead process");
let _ = std::fs::remove_file(pid_path);
Ok(())
}
pub fn write_pid_file(pid_path: &Path) -> std::io::Result<()> {
std::fs::write(pid_path, format!("{}\n", std::process::id()))
}
#[cfg(unix)]
fn is_process_alive(pid: u32) -> bool {
nix::sys::signal::kill(
nix::unistd::Pid::from_raw(i32::try_from(pid).unwrap_or(0)),
None,
)
.is_ok()
}
#[cfg(not(unix))]
fn is_process_alive(_pid: u32) -> bool {
true
}
fn remove_stale_socket(sock_path: &Path) {
if sock_path.exists() {
warn!(path = %sock_path.display(), "Removing stale daemon socket");
let _ = std::fs::remove_file(sock_path);
}
}
#[cfg(unix)]
fn bind_listener(sock_path: &Path) -> Option<tokio::net::UnixListener> {
match tokio::net::UnixListener::bind(sock_path) {
Ok(l) => {
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
if let Err(e) = std::fs::set_permissions(sock_path, perms) {
warn!(
path = %sock_path.display(),
"Failed to restrict socket permissions: {e}"
);
}
Some(l)
}
Err(e) => {
error!(path = %sock_path.display(), "Failed to bind Unix socket: {e}");
None
}
}
}
#[cfg(not(unix))]
fn bind_listener(sock_path: &Path) -> Option<()> {
warn!(
path = %sock_path.display(),
"Unix domain sockets are not supported on this platform; running without IPC listener"
);
Some(())
}
pub fn cleanup(pid_path: &Path, sock_path: &Path) {
if let Err(e) = std::fs::remove_file(pid_path)
&& e.kind() != std::io::ErrorKind::NotFound
{
warn!(path = %pid_path.display(), "Failed to remove PID file: {e}");
}
if let Err(e) = std::fs::remove_file(sock_path)
&& e.kind() != std::io::ErrorKind::NotFound
{
warn!(path = %sock_path.display(), "Failed to remove socket: {e}");
}
}