use std::process::Stdio;
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use tokio::process::{Child, ChildStderr, ChildStdout, Command};
use crate::sdk::{HealthDef, ServiceConfig, signal as sig_util};
use super::error::SupervisorError;
pub struct SpawnResult {
pub child: Child,
pub stdout: ChildStdout,
pub stderr: ChildStderr,
pub pid: u32,
}
pub fn spawn_process(config: &ServiceConfig) -> Result<SpawnResult, SupervisorError> {
let exec = &config.service.exec;
let working_dir = config.service.dir.as_deref();
let mut cmd = Command::new("/bin/sh");
cmd.arg("-c").arg(exec);
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
for (key, value) in &config.service.env {
cmd.env(key, value);
}
cmd.stdin(Stdio::null());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
unsafe {
cmd.pre_exec(|| {
for fd in 3..1024 {
libc::close(fd);
}
libc::setpgid(0, 0);
Ok(())
});
}
let mut child = cmd.spawn().map_err(|e| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: e.to_string(),
})?;
let pid = child.id().ok_or_else(|| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: "failed to get process id".to_string(),
})?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: "failed to capture stdout".to_string(),
})?;
let stderr = child
.stderr
.take()
.ok_or_else(|| SupervisorError::SpawnError {
service: config.service.name.clone(),
message: "failed to capture stderr".to_string(),
})?;
tracing::info!(
service = %config.service.name,
pid = pid,
"spawned process"
);
Ok(SpawnResult {
child,
stdout,
stderr,
pid,
})
}
pub fn parse_signal(name: &str) -> Result<Signal, SupervisorError> {
let sig_num = sig_util::parse(name).ok_or_else(|| SupervisorError::SignalError {
service: String::new(),
signal: name.to_string(),
message: "unknown signal".to_string(),
})?;
Signal::try_from(sig_num).map_err(|_| SupervisorError::SignalError {
service: String::new(),
signal: name.to_string(),
message: "invalid signal number".to_string(),
})
}
pub fn send_signal(pid: u32, sig: Signal) -> Result<(), SupervisorError> {
signal::kill(Pid::from_raw(pid as i32), sig).map_err(|e| SupervisorError::SignalError {
service: String::new(),
signal: format!("{:?}", sig),
message: e.to_string(),
})
}
pub fn send_signal_to_group(pid: u32, sig: Signal) -> Result<(), SupervisorError> {
signal::kill(Pid::from_raw(-(pid as i32)), sig).map_err(|e| SupervisorError::SignalError {
service: String::new(),
signal: format!("{:?}", sig),
message: e.to_string(),
})
}
pub fn process_exists(pid: u32) -> bool {
signal::kill(Pid::from_raw(pid as i32), None).is_ok()
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ChildProcessInfo {
pub pid: u32,
pub name: String,
pub memory_bytes: u64,
}
pub fn get_child_processes(parent_pid: u32) -> Vec<ChildProcessInfo> {
use sysinfo::{Pid, ProcessesToUpdate, System};
let mut system = System::new();
system.refresh_processes(ProcessesToUpdate::All, true);
let parent_pid_sys = Pid::from_u32(parent_pid);
let mut children = Vec::new();
let mut to_visit = vec![parent_pid_sys];
let mut visited = std::collections::HashSet::new();
while let Some(current_pid) = to_visit.pop() {
if !visited.insert(current_pid) {
continue;
}
for (pid, process) in system.processes() {
if process.parent() == Some(current_pid) && *pid != parent_pid_sys {
to_visit.push(*pid);
children.push(ChildProcessInfo {
pid: pid.as_u32(),
name: process.name().to_string_lossy().to_string(),
memory_bytes: process.memory(),
});
}
}
}
children.sort_by(|a, b| b.pid.cmp(&a.pid));
children
}
#[derive(Debug)]
pub struct KillTreeResult {
pub killed: Vec<u32>,
pub failed: Vec<u32>,
pub success: bool,
}
pub fn kill_process_tree(pid: u32) -> KillTreeResult {
let children = get_child_processes(pid);
let mut pids_to_kill: Vec<u32> = children.iter().map(|c| c.pid).collect();
pids_to_kill.push(pid);
tracing::info!(
parent_pid = pid,
children_count = children.len(),
pids = ?pids_to_kill,
"killing process tree"
);
for child in &children {
if let Err(e) = send_signal(child.pid, Signal::SIGKILL) {
tracing::warn!(pid = child.pid, error = %e, "failed to send SIGKILL to child process");
} else {
tracing::debug!(pid = child.pid, name = %child.name, "sent SIGKILL to child process");
}
}
if let Err(e) = send_signal(pid, Signal::SIGKILL) {
tracing::warn!(pid = pid, error = %e, "failed to send SIGKILL to parent process");
} else {
tracing::debug!(pid = pid, "sent SIGKILL to parent process");
}
std::thread::sleep(std::time::Duration::from_millis(100));
let mut killed = Vec::new();
let mut failed = Vec::new();
for &target_pid in &pids_to_kill {
if process_exists(target_pid) {
tracing::error!(
pid = target_pid,
"CRITICAL: process still alive after SIGKILL"
);
failed.push(target_pid);
} else {
killed.push(target_pid);
}
}
let success = failed.is_empty();
if success {
tracing::info!(
parent_pid = pid,
killed_count = killed.len(),
"process tree killed successfully"
);
} else {
tracing::error!(
parent_pid = pid,
killed_count = killed.len(),
failed_count = failed.len(),
failed_pids = ?failed,
"CRITICAL: some processes survived SIGKILL"
);
}
KillTreeResult {
killed,
failed,
success,
}
}
#[derive(Debug, Clone)]
pub struct ProcessOnPort {
pub pid: u32,
pub name: String,
}
pub fn find_processes_on_ports(ports: &[u16]) -> Vec<ProcessOnPort> {
use std::collections::HashMap;
use sysinfo::{ProcessesToUpdate, System};
if ports.is_empty() {
return Vec::new();
}
let mut result = Vec::new();
if let Ok(tcp_content) = std::fs::read_to_string("/proc/net/tcp") {
let port_set: std::collections::HashSet<u16> = ports.iter().copied().collect();
let mut pids_on_ports: HashMap<u32, String> = HashMap::new();
for line in tcp_content.lines().skip(1) {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 4 {
continue;
}
if let Some(local_addr) = parts.get(1) {
if let Some((_, port_hex)) = local_addr.split_once(':') {
if let Ok(port_num) = u16::from_str_radix(port_hex, 16) {
if port_set.contains(&port_num) {
if let Some(inode_str) = parts.get(9) {
if let Ok(inode) = inode_str.parse::<u64>() {
pids_on_ports.insert(inode as u32, port_num.to_string());
}
}
}
}
}
}
}
let mut system = System::new();
system.refresh_processes(ProcessesToUpdate::All, true);
for (pid, process) in system.processes() {
for (_, _process_info) in &pids_on_ports {
result.push(ProcessOnPort {
pid: pid.as_u32(),
name: process.name().to_string_lossy().to_string(),
});
}
}
} else {
let mut system = System::new();
system.refresh_processes(ProcessesToUpdate::All, true);
for (pid, process) in system.processes() {
let name = process.name().to_string_lossy().to_lowercase();
if name.contains("python")
|| name.contains("node")
|| name.contains("ruby")
|| name.contains("java")
|| name.contains("nginx")
|| name.contains("apache")
{
result.push(ProcessOnPort {
pid: pid.as_u32(),
name: process.name().to_string_lossy().to_string(),
});
}
}
}
result
}
pub fn kill_processes_on_ports(ports: &[u16]) -> KillTreeResult {
let processes = find_processes_on_ports(ports);
if processes.is_empty() {
tracing::debug!(ports = ?ports, "no processes found on specified ports");
return KillTreeResult {
killed: Vec::new(),
failed: Vec::new(),
success: true,
};
}
tracing::info!(
ports = ?ports,
process_count = processes.len(),
processes = ?processes,
"killing processes on ports"
);
let mut all_killed = Vec::new();
let mut all_failed = Vec::new();
for process in processes {
let result = kill_process_tree(process.pid);
all_killed.extend(result.killed);
all_failed.extend(result.failed);
}
let success = all_failed.is_empty();
KillTreeResult {
killed: all_killed,
failed: all_failed,
success,
}
}
#[derive(Debug, thiserror::Error)]
pub enum HealthError {
#[error("timeout")]
Timeout,
#[error("connection error: {0}")]
Connect(#[from] std::io::Error),
#[error("exec error: {0}")]
Exec(String),
#[error("non-zero exit: {0:?}{1}")]
NonZeroExit(Option<i32>, String),
#[error("unexpected status: {expected}, got {actual}")]
UnexpectedStatus { expected: u16, actual: u16 },
}
pub async fn check_health(health: &HealthDef) -> Result<(), HealthError> {
match health {
HealthDef::Tcp { target, common } => {
let timeout = std::time::Duration::from_millis(common.timeout_ms);
let result =
tokio::time::timeout(timeout, tokio::net::TcpStream::connect(target)).await;
match result {
Ok(Ok(_stream)) => Ok(()),
Ok(Err(e)) => Err(HealthError::Connect(e)),
Err(_) => Err(HealthError::Timeout),
}
}
HealthDef::Http {
target,
expect_status,
common,
} => {
let timeout = std::time::Duration::from_millis(common.timeout_ms);
let url = target
.trim_start_matches("http://")
.trim_start_matches("https://");
let (host_port, _path) = url.split_once('/').unwrap_or((url, ""));
let result =
tokio::time::timeout(timeout, tokio::net::TcpStream::connect(host_port)).await;
match result {
Ok(Ok(_stream)) => {
if *expect_status == 200 {
Ok(())
} else {
Ok(())
}
}
Ok(Err(e)) => Err(HealthError::Connect(e)),
Err(_) => Err(HealthError::Timeout),
}
}
HealthDef::Exec { target, common } => {
let timeout = std::time::Duration::from_millis(common.timeout_ms);
let mut cmd = Command::new("/bin/sh");
cmd.arg("-c").arg(target);
cmd.env(
"PATH",
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
);
let result = tokio::time::timeout(timeout, cmd.output()).await;
match result {
Ok(Ok(output)) => {
if output.status.success() {
Ok(())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
let stderr_msg = if stderr.is_empty() {
String::new()
} else {
format!(": {}", stderr.trim())
};
Err(HealthError::NonZeroExit(output.status.code(), stderr_msg))
}
}
Ok(Err(e)) => Err(HealthError::Exec(e.to_string())),
Err(_) => Err(HealthError::Timeout),
}
}
}
}
pub async fn wait_for_exit(mut child: Child) -> (Option<i32>, Option<i32>) {
match child.wait().await {
Ok(status) => {
let exit_code = status.code();
#[cfg(unix)]
let signal = {
use std::os::unix::process::ExitStatusExt;
status.signal()
};
#[cfg(not(unix))]
let signal = None;
(exit_code, signal)
}
Err(_) => (None, None),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_signal() {
let sig = parse_signal("SIGTERM").unwrap();
assert_eq!(sig, Signal::SIGTERM);
let sig = parse_signal("TERM").unwrap();
assert_eq!(sig, Signal::SIGTERM);
let sig = parse_signal("9").unwrap();
assert_eq!(sig, Signal::SIGKILL);
}
#[test]
fn test_parse_signal_invalid() {
let result = parse_signal("INVALID");
assert!(result.is_err());
}
#[tokio::test]
async fn test_health_check_tcp_nonexistent() {
use crate::sdk::HealthCommon;
let health = HealthDef::Tcp {
target: "127.0.0.1:59999".to_string(),
common: HealthCommon {
interval_ms: 1000,
timeout_ms: 100,
retries: 1,
start_period_ms: 0,
},
};
let result = check_health(&health).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_health_check_exec_success() {
use crate::sdk::HealthCommon;
let health = HealthDef::Exec {
target: "true".to_string(),
common: HealthCommon {
interval_ms: 1000,
timeout_ms: 5000,
retries: 1,
start_period_ms: 0,
},
};
let result = check_health(&health).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_health_check_exec_failure() {
use crate::sdk::HealthCommon;
let health = HealthDef::Exec {
target: "false".to_string(),
common: HealthCommon {
interval_ms: 1000,
timeout_ms: 5000,
retries: 1,
start_period_ms: 0,
},
};
let result = check_health(&health).await;
assert!(matches!(result, Err(HealthError::NonZeroExit(..))));
}
}