use std::{
collections::HashMap,
io,
path::PathBuf,
process::{ExitStatus, Stdio},
};
use nix::{
errno::Errno,
sys::signal::{kill, SIGTERM},
unistd::Pid,
};
use serde::Deserialize;
use tokio::process::Command;
use tracing::{error, info};
use crate::{
common::{stdio, Output},
signals::Shutdown,
target::TargetPidReceiver,
};
#[derive(Debug)]
pub enum Error {
Errno(Errno),
Io(io::Error),
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct Config {
pub command: PathBuf,
pub arguments: Vec<String>,
pub environment_variables: HashMap<String, String>,
pub output: Output,
}
#[derive(Debug)]
pub struct Server {
config: Config,
shutdown: Shutdown,
}
impl Server {
pub fn new(config: Config, shutdown: Shutdown) -> Result<Self, Error> {
Ok(Self { config, shutdown })
}
pub async fn run(mut self, mut pid_snd: TargetPidReceiver) -> Result<ExitStatus, Error> {
let target_pid = pid_snd
.recv()
.await
.expect("target failed to transmit PID, catastrophic failure");
drop(pid_snd);
let config = self.config;
let mut target_cmd = Command::new(config.command);
let mut environment_variables = config.environment_variables.clone();
if let Some(pid) = target_pid {
environment_variables.insert(String::from("TARGET_PID"), pid.to_string());
} else {
environment_variables.insert(String::from("NO_TARGET"), String::from("1"));
}
target_cmd
.stdin(Stdio::null())
.stdout(stdio(&config.output.stdout))
.stderr(stdio(&config.output.stderr))
.env_clear()
.kill_on_drop(true)
.args(config.arguments)
.envs(environment_variables.iter());
let mut target_child = target_cmd.spawn().map_err(Error::Io)?;
let target_wait = target_child.wait();
tokio::select! {
res = target_wait => {
match res {
Ok(status) => {
error!("child exited with status: {}", status);
Ok(status)
}
Err(err) => {
error!("child exited with error: {}", err);
Err(Error::Io(err))
}
}
},
_ = self.shutdown.recv() => {
info!("shutdown signal received");
let pid: Pid = Pid::from_raw(target_child.id().unwrap().try_into().unwrap());
kill(pid, SIGTERM).map_err(Error::Errno)?;
let res = target_child.wait().await.map_err(Error::Io)?;
Ok(res)
}
}
}
}