use crate::{
env::{EnvError, EnvManager},
events::{EventHandler, FsEvent, HandlerError, wait_for_signal},
path::AbsolutePath,
secrets::SecretKey,
};
use async_trait::async_trait;
use futures::future::BoxFuture;
use nix::sys::{
signal::{self, Signal},
termios::{SetArg, Termios, tcgetattr, tcsetattr},
};
use nix::unistd::Pid;
use secrecy::ExposeSecret;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, hash_map::DefaultHasher};
use std::hash::{Hash, Hasher};
use std::process::ExitStatus;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use thiserror::Error;
use tokio::process::Command;
use tokio::signal::unix::{SignalKind, signal};
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
#[derive(Debug, Error)]
pub enum ProcessError {
#[error(transparent)]
Env(#[from] EnvError),
#[error("process I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("child process exited with status {0}")]
Exited(ExitStatus),
#[error("child process terminated by signal")]
Signaled,
#[error("invalid command: {0}")]
InvalidCommand(String),
}
impl ProcessError {
pub fn from_status(status: ExitStatus) -> Result<(), Self> {
if status.success() {
Ok(())
} else {
Err(Self::Exited(status))
}
}
}
#[derive(Debug, Clone)]
pub struct ShellCommand {
program: String,
args: Vec<String>,
}
impl ShellCommand {
pub fn new(program: String, args: Vec<String>) -> Self {
Self { program, args }
}
pub fn try_from_vec(mut raw: Vec<String>) -> Result<Self, ProcessError> {
if raw.is_empty() {
return Err(ProcessError::InvalidCommand(
"No program specified".to_string(),
));
}
let program = raw.remove(0);
Ok(Self { program, args: raw })
}
}
impl TryFrom<Vec<String>> for ShellCommand {
type Error = ProcessError;
fn try_from(value: Vec<String>) -> Result<Self, Self::Error> {
Self::try_from_vec(value)
}
}
impl From<ShellCommand> for Vec<String> {
fn from(cmd: ShellCommand) -> Self {
let mut v = vec![cmd.program];
v.extend(cmd.args);
v
}
}
impl std::fmt::Display for ShellCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} {}", self.program, self.args.join(" "))
}
}
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(try_from = "String")]
pub struct ProcessTimeout(pub Duration);
impl Serialize for ProcessTimeout {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.collect_str(self)
}
}
impl TryFrom<String> for ProcessTimeout {
type Error = String;
fn try_from(s: String) -> Result<Self, Self::Error> {
s.parse()
.map_err(|e: humantime::DurationError| e.to_string())
}
}
impl FromStr for ProcessTimeout {
type Err = humantime::DurationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(s) = s.parse::<u64>() {
return Ok(ProcessTimeout(Duration::from_secs(s)));
}
let duration = humantime::parse_duration(s)?;
Ok(ProcessTimeout(duration))
}
}
impl std::fmt::Display for ProcessTimeout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", humantime::format_duration(self.0))
}
}
impl From<ProcessTimeout> for Duration {
fn from(val: ProcessTimeout) -> Self {
val.0
}
}
impl Default for ProcessTimeout {
fn default() -> Self {
ProcessTimeout(Duration::from_secs(30))
}
}
enum ProcessState {
Idle,
Running {
pid: Pid,
alive: Arc<AtomicBool>,
forwarder: JoinHandle<()>,
monitor: JoinHandle<()>,
exit_rx: watch::Receiver<Option<ExitStatus>>,
},
}
pub struct ProcessManager {
env: EnvManager,
cmd: ShellCommand,
env_hash: u64,
state: ProcessState,
termios: Option<Mutex<Termios>>,
interactive: bool,
timeout: Duration,
}
impl ProcessManager {
pub fn new(
env: EnvManager,
cmd: ShellCommand,
interactive: bool,
timeout: impl Into<Duration>,
) -> Self {
let termios = if interactive {
tcgetattr(std::io::stdin()).ok().map(Mutex::new)
} else {
None
};
ProcessManager {
env,
cmd,
env_hash: 0,
state: ProcessState::Idle,
termios,
interactive,
timeout: timeout.into(),
}
}
fn hash_env(map: &HashMap<SecretKey, secrecy::SecretString>) -> u64 {
let mut hasher = DefaultHasher::new();
let mut keys: Vec<_> = map.keys().collect();
keys.sort();
for k in keys {
k.hash(&mut hasher);
map.get(k).unwrap().expose_secret().hash(&mut hasher);
}
hasher.finish()
}
fn spawn_forwarder(target: Pid, interactive: bool) -> JoinHandle<()> {
tokio::spawn(async move {
let mut signals = vec![
(SignalKind::interrupt(), Signal::SIGINT, "SIGINT"),
(SignalKind::terminate(), Signal::SIGTERM, "SIGTERM"),
(SignalKind::hangup(), Signal::SIGHUP, "SIGHUP"),
(SignalKind::quit(), Signal::SIGQUIT, "SIGQUIT"),
(SignalKind::user_defined1(), Signal::SIGUSR1, "SIGUSR1"),
(SignalKind::user_defined2(), Signal::SIGUSR2, "SIGUSR2"),
(SignalKind::window_change(), Signal::SIGWINCH, "SIGWINCH"),
];
if interactive {
signals.retain(|(_, sig, _)| *sig != Signal::SIGINT && *sig != Signal::SIGQUIT);
}
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
for (kind, sig, name) in signals {
match signal(kind) {
Ok(mut stream) => {
let tx = tx.clone();
tokio::spawn(async move {
while stream.recv().await.is_some() {
if tx.send((sig, name)).await.is_err() {
break;
}
}
});
}
Err(e) => tracing::warn!("failed to register listener for {}: {}", name, e),
}
}
while let Some((sig, name)) = rx.recv().await {
debug!("forwarding {} to process {}", name, target);
if signal::kill(target, sig).is_err() {
break;
}
}
})
}
fn reset_tty(&self) {
if let Some(mutex) = &self.termios
&& let Ok(guard) = mutex.lock()
{
let _ = tcsetattr(std::io::stdin(), SetArg::TCSANOW, &guard);
}
}
async fn restart(
&mut self,
env_map: &HashMap<SecretKey, secrecy::SecretString>,
) -> Result<(), ProcessError> {
self.stop().await;
let mut command = Command::new(&self.cmd.program);
command.args(&self.cmd.args);
command.envs(env_map.iter().map(|(k, v)| (k.as_ref(), v.expose_secret())));
if self.interactive {
command.stdin(std::process::Stdio::inherit());
command.stdout(std::process::Stdio::inherit());
command.stderr(std::process::Stdio::inherit());
} else {
command.process_group(0);
command.stdin(std::process::Stdio::null());
}
command.kill_on_drop(false);
info!(cmd = ?self.cmd, "Spawning child process");
let mut child = command.spawn()?;
let child_id = child.id();
let (tx, rx) = watch::channel(None);
let alive = Arc::new(AtomicBool::new(true));
let alive_clone = alive.clone();
let monitor = tokio::spawn(async move {
match child.wait().await {
Ok(status) => {
alive_clone.store(false, Ordering::SeqCst);
info!("Child process exited: {}", status);
let _ = tx.send(Some(status));
}
Err(e) => {
alive_clone.store(false, Ordering::SeqCst);
error!("Monitor failed to wait on child: {}", e);
}
}
});
if let Some(id) = child_id {
let pid = if self.interactive {
Pid::from_raw(id as i32)
} else {
Pid::from_raw(-(id as i32))
};
let forwarder = Self::spawn_forwarder(pid, self.interactive);
self.state = ProcessState::Running {
pid,
alive,
forwarder,
monitor,
exit_rx: rx,
};
} else {
return Err(ProcessError::InvalidCommand(
"Failed to get child pid (process exited immediately?)".into(),
));
}
Ok(())
}
pub async fn start(&mut self) -> Result<(), ProcessError> {
let env = self.env.resolve().await?;
self.env_hash = Self::hash_env(&env);
self.restart(&env).await?;
Ok(())
}
pub async fn stop(&mut self) {
let old_state = std::mem::replace(&mut self.state, ProcessState::Idle);
if let ProcessState::Running {
pid,
alive,
forwarder,
mut monitor,
..
} = old_state
{
forwarder.abort();
if alive.load(Ordering::SeqCst) {
debug!("Stopping process {:?}", pid);
if let Err(e) = signal::kill(pid, Signal::SIGTERM) {
debug!("Failed to send SIGTERM: {}", e);
}
let sleep = tokio::time::sleep(self.timeout);
tokio::pin!(sleep);
let mut interrupt =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
.expect("failed to install interrupt handler");
let mut finished = false;
tokio::select! {
res = &mut monitor => {
match res {
Ok(_) => debug!("Child exited gracefully"),
Err(e) => error!("Monitor task failed: {}", e),
}
finished = true;
}
_ = &mut sleep => {
warn!("Child timed out after {:?}, sending SIGKILL", self.timeout);
let _ = signal::kill(pid, Signal::SIGKILL);
}
_ = interrupt.recv() => {
warn!("Received Ctrl+C during shutdown, sending SIGKILL");
let _ = signal::kill(pid, Signal::SIGKILL);
}
}
if !finished && let Err(e) = monitor.await {
error!("Failed to join monitor task after kill: {}", e);
}
} else {
debug!("Process {:?} already exited, don't need to terminate.", pid);
if let Err(e) = monitor.await {
debug!("Monitor task join error: {:?}", e);
}
}
self.reset_tty();
}
}
}
impl Drop for ProcessManager {
fn drop(&mut self) {
if let ProcessState::Running {
pid,
alive,
forwarder,
monitor,
..
} = &mut self.state
{
forwarder.abort();
monitor.abort();
if alive.load(Ordering::SeqCst) {
debug!("ProcessManager dropped, force killing PID {:?}", pid);
let _ = signal::kill(*pid, Signal::SIGKILL);
}
self.reset_tty();
}
}
}
#[async_trait]
impl EventHandler for ProcessManager {
fn paths(&self) -> Vec<AbsolutePath> {
self.env.files()
}
async fn handle(&mut self, events: Vec<FsEvent>) -> Result<(), HandlerError> {
if events.is_empty() {
return Ok(());
}
match self.env.resolve().await {
Ok(resolved) => {
let new_hash = Self::hash_env(&resolved);
if new_hash != self.env_hash {
self.env_hash = new_hash;
tracing::info!(
"Environment changed ({} events), restarting process...",
events.len()
);
if let Err(e) = self.restart(&resolved).await {
error!("Failed to restart process: {}", e);
}
} else {
debug!("Files changed but resolved environment is identical; skipping restart");
}
}
Err(e) => {
error!("Failed to reload environment: {}", e);
}
}
Ok(())
}
fn wait(&self) -> BoxFuture<'static, Result<(), HandlerError>> {
let os_signal = wait_for_signal(self.interactive);
match &self.state {
ProcessState::Running { exit_rx, .. } => {
let mut rx = exit_rx.clone();
let child_exit = async move {
let _ = rx.wait_for(|val| val.is_some()).await;
*rx.borrow()
};
Box::pin(async move {
tokio::select! {
Some(status) = child_exit => {
HandlerError::from_status(status)
}
_ = os_signal => {
Ok(())
}
}
})
}
ProcessState::Idle => Box::pin(async move {
os_signal.await;
Ok(())
}),
}
}
async fn cleanup(&mut self) {
self.stop().await;
}
}