zinit 0.3.6

Process supervisor with dependency management
Documentation
//! Process spawning, signaling, and health checking.

use std::process::Stdio;

use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use tokio::process::{Child, ChildStderr, ChildStdout, Command};

use crate::sdk::{HealthDef, ServiceConfig, signal as sig_util};

use super::error::SupervisorError;

/// Result of spawning a process.
pub struct SpawnResult {
    pub child: Child,
    pub stdout: ChildStdout,
    pub stderr: ChildStderr,
    pub pid: u32,
}

/// Spawn a service process.
pub fn spawn_process(config: &ServiceConfig) -> Result<SpawnResult, SupervisorError> {
    let exec = &config.service.exec;
    let working_dir = config.service.dir.as_deref();

    // Use sh -c to execute the command
    let mut cmd = Command::new("/bin/sh");
    cmd.arg("-c").arg(exec);

    // Set working directory if specified
    if let Some(dir) = working_dir {
        cmd.current_dir(dir);
    }

    // Set environment variables
    for (key, value) in &config.service.env {
        cmd.env(key, value);
    }

    // Configure I/O
    cmd.stdin(Stdio::null());
    cmd.stdout(Stdio::piped());
    cmd.stderr(Stdio::piped());

    // Create new process group and close inherited fds (unsafe block for pre_exec)
    unsafe {
        cmd.pre_exec(|| {
            // Close all inherited file descriptors except stdin/stdout/stderr (0, 1, 2)
            // This prevents leaking zinit's internal sockets (IPC, etc.) to child processes
            // We close up to 1024 which covers most cases; close() on invalid fds is harmless
            for fd in 3..1024 {
                libc::close(fd);
            }

            // Create new process group with this process as leader
            libc::setpgid(0, 0);
            Ok(())
        });
    }

    // Spawn the process
    let mut child = cmd.spawn().map_err(|e| SupervisorError::SpawnError {
        service: config.service.name.clone(),
        message: e.to_string(),
    })?;

    let pid = child.id().ok_or_else(|| SupervisorError::SpawnError {
        service: config.service.name.clone(),
        message: "failed to get process id".to_string(),
    })?;

    let stdout = child
        .stdout
        .take()
        .ok_or_else(|| SupervisorError::SpawnError {
            service: config.service.name.clone(),
            message: "failed to capture stdout".to_string(),
        })?;

    let stderr = child
        .stderr
        .take()
        .ok_or_else(|| SupervisorError::SpawnError {
            service: config.service.name.clone(),
            message: "failed to capture stderr".to_string(),
        })?;

    tracing::info!(
        service = %config.service.name,
        pid = pid,
        "spawned process"
    );

    Ok(SpawnResult {
        child,
        stdout,
        stderr,
        pid,
    })
}

/// Parse a signal name to a nix Signal.
pub fn parse_signal(name: &str) -> Result<Signal, SupervisorError> {
    let sig_num = sig_util::parse(name).ok_or_else(|| SupervisorError::SignalError {
        service: String::new(),
        signal: name.to_string(),
        message: "unknown signal".to_string(),
    })?;

    Signal::try_from(sig_num).map_err(|_| SupervisorError::SignalError {
        service: String::new(),
        signal: name.to_string(),
        message: "invalid signal number".to_string(),
    })
}

/// Send a signal to a process.
pub fn send_signal(pid: u32, sig: Signal) -> Result<(), SupervisorError> {
    signal::kill(Pid::from_raw(pid as i32), sig).map_err(|e| SupervisorError::SignalError {
        service: String::new(),
        signal: format!("{:?}", sig),
        message: e.to_string(),
    })
}

/// Send a signal to a process group.
pub fn send_signal_to_group(pid: u32, sig: Signal) -> Result<(), SupervisorError> {
    // Negative PID sends to process group
    signal::kill(Pid::from_raw(-(pid as i32)), sig).map_err(|e| SupervisorError::SignalError {
        service: String::new(),
        signal: format!("{:?}", sig),
        message: e.to_string(),
    })
}

/// Check if a process exists (without sending a signal).
pub fn process_exists(pid: u32) -> bool {
    // kill with signal 0 checks if process exists without sending a signal
    signal::kill(Pid::from_raw(pid as i32), None).is_ok()
}

/// Health check error.
#[derive(Debug, thiserror::Error)]
pub enum HealthError {
    #[error("timeout")]
    Timeout,
    #[error("connection error: {0}")]
    Connect(#[from] std::io::Error),
    #[error("exec error: {0}")]
    Exec(String),
    #[error("non-zero exit: {0:?}{1}")]
    NonZeroExit(Option<i32>, String),
    #[error("unexpected status: {expected}, got {actual}")]
    UnexpectedStatus { expected: u16, actual: u16 },
}

/// Run a health check.
pub async fn check_health(health: &HealthDef) -> Result<(), HealthError> {
    match health {
        HealthDef::Tcp { target, common } => {
            let timeout = std::time::Duration::from_millis(common.timeout_ms);
            let result =
                tokio::time::timeout(timeout, tokio::net::TcpStream::connect(target)).await;

            match result {
                Ok(Ok(_stream)) => Ok(()),
                Ok(Err(e)) => Err(HealthError::Connect(e)),
                Err(_) => Err(HealthError::Timeout),
            }
        }
        HealthDef::Http {
            target,
            expect_status,
            common,
        } => {
            // Simple HTTP check using TCP + manual HTTP request
            // For a real implementation, you'd want to use a proper HTTP client
            let timeout = std::time::Duration::from_millis(common.timeout_ms);

            // Parse URL to get host:port
            let url = target
                .trim_start_matches("http://")
                .trim_start_matches("https://");
            let (host_port, _path) = url.split_once('/').unwrap_or((url, ""));

            let result =
                tokio::time::timeout(timeout, tokio::net::TcpStream::connect(host_port)).await;

            match result {
                Ok(Ok(_stream)) => {
                    // For simplicity, we just check if we can connect
                    // A real implementation would send HTTP request and check status
                    if *expect_status == 200 {
                        Ok(())
                    } else {
                        // Would need actual HTTP check here
                        Ok(())
                    }
                }
                Ok(Err(e)) => Err(HealthError::Connect(e)),
                Err(_) => Err(HealthError::Timeout),
            }
        }
        HealthDef::Exec { target, common } => {
            let timeout = std::time::Duration::from_millis(common.timeout_ms);

            let mut cmd = Command::new("/bin/sh");
            cmd.arg("-c").arg(target);
            // Set PATH so health checks can find system binaries (ip, systemctl, etc.)
            cmd.env(
                "PATH",
                "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
            );

            let result = tokio::time::timeout(timeout, cmd.output()).await;

            match result {
                Ok(Ok(output)) => {
                    if output.status.success() {
                        Ok(())
                    } else {
                        // Include stderr in error for debugging
                        let stderr = String::from_utf8_lossy(&output.stderr);
                        let stderr_msg = if stderr.is_empty() {
                            String::new()
                        } else {
                            format!(": {}", stderr.trim())
                        };
                        Err(HealthError::NonZeroExit(output.status.code(), stderr_msg))
                    }
                }
                Ok(Err(e)) => Err(HealthError::Exec(e.to_string())),
                Err(_) => Err(HealthError::Timeout),
            }
        }
    }
}

/// Wait for a process to exit and return exit information.
pub async fn wait_for_exit(mut child: Child) -> (Option<i32>, Option<i32>) {
    match child.wait().await {
        Ok(status) => {
            let exit_code = status.code();
            #[cfg(unix)]
            let signal = {
                use std::os::unix::process::ExitStatusExt;
                status.signal()
            };
            #[cfg(not(unix))]
            let signal = None;

            (exit_code, signal)
        }
        Err(_) => (None, None),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_signal() {
        let sig = parse_signal("SIGTERM").unwrap();
        assert_eq!(sig, Signal::SIGTERM);

        let sig = parse_signal("TERM").unwrap();
        assert_eq!(sig, Signal::SIGTERM);

        let sig = parse_signal("9").unwrap();
        assert_eq!(sig, Signal::SIGKILL);
    }

    #[test]
    fn test_parse_signal_invalid() {
        let result = parse_signal("INVALID");
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_health_check_tcp_nonexistent() {
        use crate::sdk::HealthCommon;

        let health = HealthDef::Tcp {
            target: "127.0.0.1:59999".to_string(),
            common: HealthCommon {
                interval_ms: 1000,
                timeout_ms: 100,
                retries: 1,
                start_period_ms: 0,
            },
        };

        let result = check_health(&health).await;
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_health_check_exec_success() {
        use crate::sdk::HealthCommon;

        let health = HealthDef::Exec {
            target: "true".to_string(),
            common: HealthCommon {
                interval_ms: 1000,
                timeout_ms: 5000,
                retries: 1,
                start_period_ms: 0,
            },
        };

        let result = check_health(&health).await;
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn test_health_check_exec_failure() {
        use crate::sdk::HealthCommon;

        let health = HealthDef::Exec {
            target: "false".to_string(),
            common: HealthCommon {
                interval_ms: 1000,
                timeout_ms: 5000,
                retries: 1,
                start_period_ms: 0,
            },
        };

        let result = check_health(&health).await;
        assert!(matches!(result, Err(HealthError::NonZeroExit(..))));
    }
}