pub mod lifecycle;
mod pipes;
pub mod restart_coordination_socket;
pub mod shutdown;
pub use shutdown::{ShutdownCoordinator, ShutdownHandle, ShutdownSignal};
use crate::lifecycle::LifecycleHandler;
use crate::pipes::{
completion_pipes, create_paired_pipes, CompletionReceiver, CompletionSender, FdStringExt,
PipeMode,
};
use crate::restart_coordination_socket::{
RestartCoordinationSocket, RestartMessage, RestartRequest, RestartResponse,
};
use anyhow::anyhow;
use futures::stream::{Stream, StreamExt};
use std::env;
use std::ffi::OsString;
use std::fs::{remove_file, File as StdFile};
use std::future::Future;
use std::io;
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::os::unix::net::UnixListener as StdUnixListener;
use std::os::unix::process::CommandExt;
use std::path::{Path, PathBuf};
use std::process;
use std::thread;
use thiserror::Error;
use tokio::fs::File;
use tokio::net::{UnixListener, UnixStream};
use tokio::select;
use tokio::signal::unix::{signal, Signal, SignalKind};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_stream::wrappers::UnixListenerStream;
pub type RestartResult<T> = anyhow::Result<T>;
const ENV_NOTIFY_SOCKET: &str = "OXY_NOTIFY_SOCKET";
const ENV_RESTART_SOCKET: &str = "OXY_RESTART_SOCKET";
const ENV_HANDOVER_PIPE: &str = "OXY_HANDOVER_PIPE";
const ENV_SYSTEMD_PID: &str = "LISTEN_PID";
const REBIND_SYSTEMD_PID: &str = "auto";
pub struct RestartConfig {
pub enabled: bool,
pub coordination_socket_path: PathBuf,
pub environment: Vec<(OsString, OsString)>,
pub lifecycle_handler: Box<dyn LifecycleHandler>,
pub exit_on_error: bool,
}
impl RestartConfig {
pub fn try_into_restart_task(
self,
) -> io::Result<(impl Future<Output = RestartResult<process::Child>> + Send)> {
fixup_systemd_env();
spawn_restart_task(self)
}
pub async fn request_restart(self) -> RestartResult<u32> {
if !self.enabled {
return Err(anyhow!(
"no restart coordination socket socket defined in config"
));
}
let socket = UnixStream::connect(self.coordination_socket_path).await?;
restart_coordination_socket::RestartCoordinationSocket::new(socket)
.send_restart_command()
.await
}
pub fn request_restart_sync(self) -> RestartResult<u32> {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(self.request_restart())
}
}
impl Default for RestartConfig {
fn default() -> Self {
RestartConfig {
enabled: false,
coordination_socket_path: Default::default(),
environment: vec![],
lifecycle_handler: Box::new(lifecycle::NullLifecycleHandler),
exit_on_error: true,
}
}
}
pub fn fixup_systemd_env() {
#[cfg(target_os = "linux")]
if let Ok(true) = env::var(ENV_SYSTEMD_PID).map(|p| p == REBIND_SYSTEMD_PID) {
env::set_var(ENV_SYSTEMD_PID, process::id().to_string());
}
}
pub fn startup_complete() -> io::Result<()> {
if let Ok(notify_fd) = env::var(ENV_NOTIFY_SOCKET) {
pipes::CompletionSender(unsafe { std::fs::File::from_fd_string(¬ify_fd)? }).send()?;
}
env::remove_var(ENV_NOTIFY_SOCKET);
let _ = sd_notify::notify(false, &[sd_notify::NotifyState::Ready]);
Ok(())
}
struct RestartResponder {
rpc: Option<RestartCoordinationSocket>,
}
impl RestartResponder {
async fn respond(self, result: Result<u32, String>) {
let response = match result {
Ok(pid) => RestartResponse::RestartComplete(pid),
Err(e) => RestartResponse::RestartFailed(e),
};
if let Some(mut rpc) = self.rpc {
if let Err(e) = rpc.send_message(RestartMessage::Response(response)).await {
log::warn!("Failed to respond to restart coordinator: {}", e);
}
}
}
}
pub fn spawn_restart_task(
settings: RestartConfig,
) -> io::Result<impl Future<Output = RestartResult<process::Child>> + Send> {
let socket = match settings.enabled {
true => Some(settings.coordination_socket_path.as_ref()),
false => None,
};
let mut signal_stream = signal(SignalKind::user_defined1())?;
let (restart_fd, mut socket_stream) = new_restart_coordination_socket_stream(socket)?;
let mut child_spawner =
ChildSpawner::new(restart_fd, settings.environment, settings.lifecycle_handler);
Ok(async move {
startup_complete()?;
loop {
let responder = next_restart_request(&mut signal_stream, &mut socket_stream).await?;
log::debug!("Spawning new process");
let res = child_spawner.spawn_new_process().await;
responder
.respond(res.as_ref().map(|p| p.id()).map_err(|e| e.to_string()))
.await;
match res {
Ok(child) => {
log::debug!("New process spawned with pid {}", child.id());
if let Err(e) =
sd_notify::notify(true, &[sd_notify::NotifyState::MainPid(child.id())])
{
log::error!("Failed to notify systemd: {}", e);
}
return Ok(child);
}
Err(ChildSpawnError::ChildError(e)) => {
if settings.exit_on_error {
return Err(anyhow!("Restart failed: {}", e));
} else {
log::error!("Restart failed: {}", e);
}
}
Err(ChildSpawnError::RestartThreadGone) => {
res?;
}
}
}
})
}
struct ChildSpawner {
signal_sender: Sender<()>,
pid_receiver: Receiver<io::Result<process::Child>>,
}
impl ChildSpawner {
fn new(
restart_fd: Option<OwnedFd>,
environment: Vec<(OsString, OsString)>,
mut lifecycle_handler: Box<dyn LifecycleHandler>,
) -> Self {
let (signal_sender, mut signal_receiver) = channel(1);
let (pid_sender, pid_receiver) = channel(1);
thread::spawn(move || {
let restart_fd = restart_fd.as_ref().map(OwnedFd::as_fd);
while let Some(()) = signal_receiver.blocking_recv() {
let child = tokio::runtime::Runtime::new()
.unwrap()
.block_on(spawn_child(
restart_fd,
&environment,
&mut *lifecycle_handler,
));
pid_sender
.blocking_send(child)
.expect("parent needs to receive the child");
}
});
ChildSpawner {
signal_sender,
pid_receiver,
}
}
async fn spawn_new_process(&mut self) -> Result<process::Child, ChildSpawnError> {
self.signal_sender
.send(())
.await
.map_err(|_| ChildSpawnError::RestartThreadGone)?;
match self.pid_receiver.recv().await {
Some(Ok(child)) => Ok(child),
Some(Err(e)) => Err(ChildSpawnError::ChildError(e)),
None => Err(ChildSpawnError::RestartThreadGone),
}
}
}
#[derive(Error, Debug)]
pub enum ChildSpawnError {
#[error("Restart thread exited")]
RestartThreadGone,
#[error("Child failed to start: {0}")]
ChildError(io::Error),
}
async fn next_restart_request(
signal_stream: &mut Signal,
mut socket_stream: impl Stream<Item = RestartResponder> + Unpin,
) -> RestartResult<RestartResponder> {
select! {
_ = signal_stream.recv() => Ok(RestartResponder{ rpc: None }),
r = socket_stream.next() => match r {
Some(r) => Ok(r),
None => {
Err(anyhow!("Restart coordinator socket acceptor terminated"))
}
}
}
}
fn new_restart_coordination_socket_stream(
restart_coordination_socket: Option<&Path>,
) -> io::Result<(Option<OwnedFd>, impl Stream<Item = RestartResponder>)> {
if let Some(path) = restart_coordination_socket {
let listener = bind_restart_coordination_socket(path)?;
listener.set_nonblocking(true)?;
let inherit_socket = OwnedFd::from(listener.try_clone()?);
let listener = UnixListener::from_std(listener)?;
let st = listen_for_restart_events(listener);
Ok((Some(inherit_socket), st.boxed()))
} else {
Ok((None, futures::stream::pending().boxed()))
}
}
fn bind_restart_coordination_socket(path: &Path) -> io::Result<StdUnixListener> {
match env::var(ENV_RESTART_SOCKET) {
Err(_) => {
let _ = remove_file(path);
StdUnixListener::bind(path)
}
Ok(maybe_sock_fd) => unsafe { StdUnixListener::from_fd_string(&maybe_sock_fd) },
}
}
fn listen_for_restart_events(
restart_coordination_socket: UnixListener,
) -> impl Stream<Item = RestartResponder> {
UnixListenerStream::new(restart_coordination_socket).filter_map(move |r| async move {
let sock = match r {
Ok(sock) => sock,
Err(e) => {
log::error!("Restart coordination socket accept error: {}", e);
return None;
}
};
let mut rpc = RestartCoordinationSocket::new(sock);
match rpc.receive_message().await {
Ok(RestartMessage::Request(RestartRequest::TryRestart)) => {
Some(RestartResponder { rpc: Some(rpc) })
}
Ok(m) => {
log::warn!(
"Restart coordination socket received unexpected message: {:?}",
m
);
None
}
Err(e) => {
log::warn!("Restart coordination socket connection error: {}", e);
None
}
}
})
}
fn clear_cloexec(fd: RawFd) -> nix::Result<()> {
use nix::fcntl::*;
let mut current_flags = FdFlag::from_bits_truncate(fcntl(fd, FcntlArg::F_GETFD)?);
current_flags.remove(FdFlag::FD_CLOEXEC);
fcntl(fd, FcntlArg::F_SETFD(current_flags))?;
Ok(())
}
async fn spawn_child(
restart_fd: Option<BorrowedFd<'_>>,
user_envs: &[(OsString, OsString)],
lifecycle_handler: &mut dyn LifecycleHandler,
) -> io::Result<process::Child> {
lifecycle_handler.pre_new_process().await;
let mut args = env::args();
let process_name = args.next().unwrap();
let (notif_r, notif_w) = completion_pipes()?;
let (handover_r, handover_w) = create_paired_pipes(PipeMode::ParentWrites)?;
let mut cmd = process::Command::new(process_name);
cmd.args(args)
.envs(user_envs.iter().map(|(k, v)| (k, v)))
.env(ENV_SYSTEMD_PID, REBIND_SYSTEMD_PID)
.env(ENV_HANDOVER_PIPE, handover_r.fd_string())
.env(ENV_NOTIFY_SOCKET, notif_w.0.fd_string());
if let Some(fd) = restart_fd {
let fd = fd.as_raw_fd();
unsafe {
cmd.env(ENV_RESTART_SOCKET, fd.to_string())
.pre_exec(move || {
clear_cloexec(fd)?;
Ok(())
});
}
}
let mut child = cmd.spawn()?;
if let Err(e) = send_parent_state(lifecycle_handler, notif_r, notif_w, handover_w).await {
if child.kill().is_err() {
log::error!("Child process has already exited. Failed to send parent state: {e:?}");
} else {
log::error!("Killed child process because failed to send parent state: {e:?}");
}
return Err(e);
}
Ok(child)
}
async fn send_parent_state(
lifecycle_handler: &mut dyn LifecycleHandler,
mut notif_r: CompletionReceiver,
notif_w: CompletionSender,
handover_w: StdFile,
) -> io::Result<()> {
lifecycle_handler
.send_to_new_process(Box::pin(File::from(handover_w)))
.await?;
drop(notif_w);
match notif_r.recv() {
Ok(_) => Ok(()),
Err(e) => {
lifecycle_handler.new_process_failed().await;
Err(e)
}
}
}