use super::ReloadState;
use std::sync::Arc;
#[cfg(unix)]
pub const UPGRADE_READY_FD_ENV: &str = "HYPERSHUNT_UPGRADE_READY_FD";
#[cfg(unix)]
pub struct UpgradeState {
pub stop_accept_txs: Arc<
std::sync::Mutex<
std::collections::HashMap<String, tokio::sync::watch::Sender<bool>>,
>,
>,
pub startup_timeout_secs: u32,
pub drain_signal: tokio::sync::watch::Sender<bool>,
}
#[cfg(unix)]
pub fn spawn_sigusr2_listener(
upgrade_state: Arc<UpgradeState>,
) -> tokio::task::JoinHandle<()> {
crate::task::spawn_supervised("signal.sigusr2", async move {
use tokio::signal::unix::{SignalKind, signal};
let mut sig = match signal(SignalKind::user_defined2()) {
Ok(s) => s,
Err(e) => {
tracing::error!(target: crate::reload::TARGET, "installing SIGUSR2 handler failed: {e}");
return;
}
};
tracing::info!(target: crate::reload::TARGET, "SIGUSR2 listener installed");
while sig.recv().await.is_some() {
tracing::info!(target: crate::reload::TARGET, "SIGUSR2 received; starting binary upgrade");
if let Err(e) = perform_upgrade(&upgrade_state).await {
tracing::warn!(target: crate::reload::TARGET, "SIGUSR2: upgrade failed: {e:#}");
}
}
})
}
#[cfg(unix)]
async fn perform_upgrade(state: &UpgradeState) -> anyhow::Result<()> {
use anyhow::{Context, anyhow};
use nix::sys::signal::{Signal, kill};
use nix::unistd::{ForkResult, close, execv, fork, pipe};
use std::ffi::CString;
use std::os::fd::{AsRawFd, IntoRawFd};
let argv_path = std::env::current_exe()
.context("resolving current_exe for upgrade")?;
let argv_path_c = CString::new(argv_path.as_os_str().as_encoded_bytes())
.context("argv[0] contains a null byte")?;
let args: Vec<CString> = std::env::args()
.skip(1)
.map(|a| CString::new(a).expect("argv contains null byte"))
.collect();
let mut argv: Vec<CString> = Vec::with_capacity(args.len() + 1);
argv.push(argv_path_c.clone());
argv.extend(args);
let (read_end, write_end) = pipe().context("pipe() for upgrade ready")?;
match unsafe { fork() }.context("fork() for upgrade")? {
ForkResult::Child => {
drop(read_end);
unsafe {
std::env::set_var(
UPGRADE_READY_FD_ENV,
write_end.as_raw_fd().to_string(),
);
}
let _leaked = write_end.into_raw_fd();
let err = execv(&argv_path_c, &argv);
eprintln!("hypershunt upgrade: execv failed: {err:?}");
unsafe { libc::_exit(127) };
}
ForkResult::Parent { child } => {
drop(write_end);
let read_fd = read_end.into_raw_fd();
let ready = tokio::time::timeout(
std::time::Duration::from_secs(
state.startup_timeout_secs as u64,
),
read_one_byte(read_fd),
)
.await;
match ready {
Ok(Ok(())) => {
tracing::info!(target: crate::reload::TARGET,
pid = child.as_raw(),
"SIGUSR2: child reported ready; \
beginning parent drain"
);
let txs = state.stop_accept_txs.lock().expect("reload stop-accept mutex");
for (bind, tx) in txs.iter() {
tracing::debug!(target: crate::reload::TARGET, %bind, "stop-accept fired");
let _ = tx.send(true);
}
drop(txs);
let _ = state.drain_signal.send(true);
Ok(())
}
Ok(Err(e)) => {
let _ = kill(child, Signal::SIGTERM);
Err(anyhow!(
"child closed ready pipe before signalling: {e}"
))
}
Err(_) => {
let _ = kill(child, Signal::SIGTERM);
let _ = close(read_fd);
Err(anyhow!(
"child did not signal ready within {}s",
state.startup_timeout_secs
))
}
}
}
}
}
#[cfg(unix)]
pub(crate) async fn read_one_byte(
fd: std::os::fd::RawFd,
) -> std::io::Result<()> {
use tokio::io::unix::AsyncFd;
let flags = nix::fcntl::fcntl(
unsafe { std::os::fd::BorrowedFd::borrow_raw(fd) },
nix::fcntl::FcntlArg::F_GETFL,
)
.map_err(std::io::Error::from)?;
nix::fcntl::fcntl(
unsafe { std::os::fd::BorrowedFd::borrow_raw(fd) },
nix::fcntl::FcntlArg::F_SETFL(
nix::fcntl::OFlag::from_bits_truncate(flags)
| nix::fcntl::OFlag::O_NONBLOCK,
),
)
.map_err(std::io::Error::from)?;
use std::os::fd::FromRawFd;
let owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(fd) };
let async_fd = AsyncFd::new(owned)?;
loop {
let mut guard = async_fd.readable().await?;
let mut buf = [0u8; 1];
match guard.try_io(|inner| {
use std::os::fd::AsRawFd;
let n = unsafe {
libc::read(
inner.get_ref().as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
1,
)
};
if n < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(n as usize)
}
}) {
Ok(Ok(0)) => {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"child closed pipe without writing ready byte",
));
}
Ok(Ok(_)) => return Ok(()),
Ok(Err(e)) => return Err(e),
Err(_) => continue, }
}
}
#[cfg(unix)]
pub fn signal_upgrade_ready() {
use std::os::fd::FromRawFd;
let Some(fd) = std::env::var(UPGRADE_READY_FD_ENV)
.ok()
.and_then(|s| s.parse::<std::os::fd::RawFd>().ok())
else {
return;
};
let mut f = unsafe { std::fs::File::from_raw_fd(fd) };
use std::io::Write;
if let Err(e) = f.write_all(b".") {
tracing::warn!(target: crate::reload::TARGET, "upgrade ready signal write failed: {e}");
}
drop(f); unsafe { std::env::remove_var(UPGRADE_READY_FD_ENV) };
tracing::info!(target: crate::reload::TARGET, "upgrade: signalled parent that child is ready");
}
#[cfg(unix)]
pub fn spawn_sighup_listener(
reload_state: Arc<ReloadState>,
) -> tokio::task::JoinHandle<()> {
crate::task::spawn_supervised("signal.sighup", async move {
use tokio::signal::unix::{SignalKind, signal};
let mut sig = match signal(SignalKind::hangup()) {
Ok(s) => s,
Err(e) => {
tracing::error!(target: crate::reload::TARGET, "installing SIGHUP handler failed: {e}");
return;
}
};
tracing::info!(target: crate::reload::TARGET, "SIGHUP listener installed");
while sig.recv().await.is_some() {
tracing::info!(target: crate::reload::TARGET, "SIGHUP received; reloading config");
let _ = super::reload(&reload_state).await;
}
})
}
#[cfg(test)]
#[cfg(unix)]
mod tests {
use super::{read_one_byte, signal_upgrade_ready, UPGRADE_READY_FD_ENV};
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[tokio::test]
async fn read_one_byte_success() {
use std::os::fd::IntoRawFd;
let (read_fd, write_fd) = nix::unistd::pipe().unwrap();
let read_raw = read_fd.into_raw_fd();
let write_raw = write_fd.into_raw_fd();
tokio::task::spawn_blocking(move || {
use std::io::Write;
use std::os::fd::FromRawFd;
let mut f = unsafe { std::fs::File::from_raw_fd(write_raw) };
f.write_all(b".").unwrap();
});
read_one_byte(read_raw).await.unwrap();
}
#[tokio::test]
async fn read_one_byte_eof() {
use std::os::fd::IntoRawFd;
let (read_fd, write_fd) = nix::unistd::pipe().unwrap();
let read_raw = read_fd.into_raw_fd();
drop(write_fd);
let err = read_one_byte(read_raw).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn signal_upgrade_ready_writes_and_clears_env() {
use std::io::Read;
use std::os::fd::{FromRawFd, IntoRawFd};
let _guard = ENV_LOCK.lock().unwrap();
let (read_fd, write_fd) = nix::unistd::pipe().unwrap();
let read_raw = read_fd.into_raw_fd();
let write_raw = write_fd.into_raw_fd();
unsafe {
std::env::set_var(
UPGRADE_READY_FD_ENV,
write_raw.to_string(),
);
}
signal_upgrade_ready();
assert!(std::env::var(UPGRADE_READY_FD_ENV).is_err(),
"env var must be removed after signalling");
let mut buf = [0u8; 1];
let mut r = unsafe { std::fs::File::from_raw_fd(read_raw) };
assert_eq!(r.read(&mut buf).unwrap(), 1);
}
#[test]
fn signal_upgrade_ready_no_env_is_noop() {
let _guard = ENV_LOCK.lock().unwrap();
unsafe { std::env::remove_var(UPGRADE_READY_FD_ENV) };
signal_upgrade_ready();
assert!(std::env::var(UPGRADE_READY_FD_ENV).is_err());
}
}