use std::process::Stdio;
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use tokio::process::{Child, ChildStderr, ChildStdout, Command};
use crate::sdk::{HealthDef, ServiceConfig, signal as sig_util};
use super::error::SupervisorError;
pub struct SpawnResult {
pub child: Child,
pub stdout: ChildStdout,
pub stderr: ChildStderr,
pub pid: u32,
}
pub fn spawn_process(config: &ServiceConfig) -> Result<SpawnResult, SupervisorError> {
let exec = &config.service.exec;
let working_dir = config.service.dir.as_deref();
let mut cmd = Command::new("/bin/sh");
cmd.arg("-c").arg(exec);
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
for (key, value) in &config.service.env {
cmd.env(key, value);
}
cmd.stdin(Stdio::null());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
unsafe {
cmd.pre_exec(|| {
for fd in 3..1024 {
libc::close(fd);
}
libc::setpgid(0, 0);
Ok(())
});
}
let mut child = cmd.spawn().map_err(|e| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: e.to_string(),
})?;
let pid = child.id().ok_or_else(|| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: "failed to get process id".to_string(),
})?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: "failed to capture stdout".to_string(),
})?;
let stderr = child
.stderr
.take()
.ok_or_else(|| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: "failed to capture stderr".to_string(),
})?;
tracing::info!(
service = %config.service.name,
pid = pid,
"spawned process"
);
Ok(SpawnResult {
child,
stdout,
stderr,
pid,
})
}
pub fn parse_signal(name: &str) -> Result<Signal, SupervisorError> {
let sig_num = sig_util::parse(name).ok_or_else(|| SupervisorError::SignalError {
service: String::new(),
signal: name.to_string(),
message: "unknown signal".to_string(),
})?;
Signal::try_from(sig_num).map_err(|_| SupervisorError::SignalError {
service: String::new(),
signal: name.to_string(),
message: "invalid signal number".to_string(),
})
}
pub fn send_signal(pid: u32, sig: Signal) -> Result<(), SupervisorError> {
signal::kill(Pid::from_raw(pid as i32), sig).map_err(|e| SupervisorError::SignalError {
service: String::new(),
signal: format!("{:?}", sig),
message: e.to_string(),
})
}
pub fn send_signal_to_group(pid: u32, sig: Signal) -> Result<(), SupervisorError> {
signal::kill(Pid::from_raw(-(pid as i32)), sig).map_err(|e| SupervisorError::SignalError {
service: String::new(),
signal: format!("{:?}", sig),
message: e.to_string(),
})
}
pub fn process_exists(pid: u32) -> bool {
signal::kill(Pid::from_raw(pid as i32), None).is_ok()
}
#[derive(Debug, thiserror::Error)]
pub enum HealthError {
#[error("timeout")]
Timeout,
#[error("connection error: {0}")]
Connect(#[from] std::io::Error),
#[error("exec error: {0}")]
Exec(String),
#[error("non-zero exit: {0:?}{1}")]
NonZeroExit(Option<i32>, String),
#[error("unexpected status: {expected}, got {actual}")]
UnexpectedStatus { expected: u16, actual: u16 },
}
pub async fn check_health(health: &HealthDef) -> Result<(), HealthError> {
match health {
HealthDef::Tcp { target, common } => {
let timeout = std::time::Duration::from_millis(common.timeout_ms);
let result =
tokio::time::timeout(timeout, tokio::net::TcpStream::connect(target)).await;
match result {
Ok(Ok(_stream)) => Ok(()),
Ok(Err(e)) => Err(HealthError::Connect(e)),
Err(_) => Err(HealthError::Timeout),
}
}
HealthDef::Http {
target,
expect_status,
common,
} => {
let timeout = std::time::Duration::from_millis(common.timeout_ms);
let url = target
.trim_start_matches("http://")
.trim_start_matches("https://");
let (host_port, _path) = url.split_once('/').unwrap_or((url, ""));
let result =
tokio::time::timeout(timeout, tokio::net::TcpStream::connect(host_port)).await;
match result {
Ok(Ok(_stream)) => {
if *expect_status == 200 {
Ok(())
} else {
Ok(())
}
}
Ok(Err(e)) => Err(HealthError::Connect(e)),
Err(_) => Err(HealthError::Timeout),
}
}
HealthDef::Exec { target, common } => {
let timeout = std::time::Duration::from_millis(common.timeout_ms);
let mut cmd = Command::new("/bin/sh");
cmd.arg("-c").arg(target);
cmd.env(
"PATH",
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
);
let result = tokio::time::timeout(timeout, cmd.output()).await;
match result {
Ok(Ok(output)) => {
if output.status.success() {
Ok(())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
let stderr_msg = if stderr.is_empty() {
String::new()
} else {
format!(": {}", stderr.trim())
};
Err(HealthError::NonZeroExit(output.status.code(), stderr_msg))
}
}
Ok(Err(e)) => Err(HealthError::Exec(e.to_string())),
Err(_) => Err(HealthError::Timeout),
}
}
}
}
pub async fn wait_for_exit(mut child: Child) -> (Option<i32>, Option<i32>) {
match child.wait().await {
Ok(status) => {
let exit_code = status.code();
#[cfg(unix)]
let signal = {
use std::os::unix::process::ExitStatusExt;
status.signal()
};
#[cfg(not(unix))]
let signal = None;
(exit_code, signal)
}
Err(_) => (None, None),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_signal() {
let sig = parse_signal("SIGTERM").unwrap();
assert_eq!(sig, Signal::SIGTERM);
let sig = parse_signal("TERM").unwrap();
assert_eq!(sig, Signal::SIGTERM);
let sig = parse_signal("9").unwrap();
assert_eq!(sig, Signal::SIGKILL);
}
#[test]
fn test_parse_signal_invalid() {
let result = parse_signal("INVALID");
assert!(result.is_err());
}
#[tokio::test]
async fn test_health_check_tcp_nonexistent() {
use crate::sdk::HealthCommon;
let health = HealthDef::Tcp {
target: "127.0.0.1:59999".to_string(),
common: HealthCommon {
interval_ms: 1000,
timeout_ms: 100,
retries: 1,
start_period_ms: 0,
},
};
let result = check_health(&health).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_health_check_exec_success() {
use crate::sdk::HealthCommon;
let health = HealthDef::Exec {
target: "true".to_string(),
common: HealthCommon {
interval_ms: 1000,
timeout_ms: 5000,
retries: 1,
start_period_ms: 0,
},
};
let result = check_health(&health).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_health_check_exec_failure() {
use crate::sdk::HealthCommon;
let health = HealthDef::Exec {
target: "false".to_string(),
common: HealthCommon {
interval_ms: 1000,
timeout_ms: 5000,
retries: 1,
start_period_ms: 0,
},
};
let result = check_health(&health).await;
assert!(matches!(result, Err(HealthError::NonZeroExit(..))));
}
}