watchctl 0.4.0

Process supervisor with wait, watch, and retry phases
use crate::check::{Check, FileCheck, HttpCheck, TcpCheck, build_http_client};
use crate::config::WatchConfig;
use crate::error::Result;
use crate::process::Process;
use crate::signal::{Termination, TerminationListener};
use std::process::ExitStatus;
use std::sync::Arc;
use std::time::Duration;
use tokio::select;
use tokio::task::JoinSet;
use tokio::time::{Instant, interval, sleep};
use tracing::{debug, info, warn};

pub enum WatchResult {
    ProcessExited(ExitStatus),
    #[allow(dead_code)]
    HealthCheckFailed(String),
    Timeout,
    Terminated(Termination),
}

async fn terminate_child(process: &mut Process, term: Termination) -> WatchResult {
    warn!("received {}, terminating child process", term.name);
    if let Err(e) = process.kill_and_wait().await {
        warn!("failed to kill process: {e}");
    }
    WatchResult::Terminated(term)
}

pub async fn run_watch_phase(
    config: &WatchConfig,
    mut process: Process,
    term: &mut TerminationListener,
) -> Result<WatchResult> {
    let start = Instant::now();

    let has_health_checks =
        !config.http.is_empty() || !config.tcp.is_empty() || !config.files.is_empty();

    if !has_health_checks && config.timeout.is_none() {
        debug!("no watch conditions, waiting for process to exit");
        select! {
            biased;
            term = term.recv() => return Ok(terminate_child(&mut process, term).await),
            status = process.wait() => return Ok(WatchResult::ProcessExited(status?)),
        }
    }

    let watch_future = async {
        if has_health_checks {
            run_health_checks(config).await
        } else {
            std::future::pending().await
        }
    };

    let timeout_future = async {
        if let Some(t) = config.timeout {
            sleep(t).await;
            Some(())
        } else {
            std::future::pending().await
        }
    };

    select! {
        biased;

        signal = term.recv() => {
            Ok(terminate_child(&mut process, signal).await)
        }

        status = process.wait() => {
            let status = status?;
            info!("process exited with {:?} after {:?}", status.code(), start.elapsed());
            Ok(WatchResult::ProcessExited(status))
        }

        result = watch_future => {
            match result {
                Ok(()) => std::future::pending().await,
                Err(msg) => {
                    warn!("health check failed: {msg}");
                    if let Err(e) = process.kill_and_wait().await {
                        warn!("failed to kill process: {e}");
                    }
                    Ok(WatchResult::HealthCheckFailed(msg))
                }
            }
        }

        _ = timeout_future => {
            warn!("watch timeout reached after {:?}", start.elapsed());
            if let Err(e) = process.kill_and_wait().await {
                warn!("failed to kill process: {e}");
            }
            Ok(WatchResult::Timeout)
        }
    }
}

async fn run_health_checks(config: &WatchConfig) -> std::result::Result<(), String> {
    let mut join_set = JoinSet::new();

    let http_client = if !config.http.is_empty() {
        Some(Arc::new(
            build_http_client(config.http_timeout).map_err(|e| e.to_string())?,
        ))
    } else {
        None
    };

    for url in &config.http {
        let check = HttpCheck::new(url.clone(), Arc::clone(http_client.as_ref().unwrap()));
        let initial_delay = config.delay;
        let interval_duration = config.http_interval;
        join_set.spawn(async move {
            run_periodic_check(Box::new(check), initial_delay, interval_duration).await
        });
    }

    for addr in &config.tcp {
        let check = TcpCheck::new(addr.clone(), config.tcp_timeout);
        let initial_delay = config.delay;
        let interval_duration = config.tcp_interval;
        join_set.spawn(async move {
            run_periodic_check(Box::new(check), initial_delay, interval_duration).await
        });
    }

    for path in &config.files {
        let check = FileCheck::new(path);
        let initial_delay = config.delay;
        let interval_duration = config.file_interval;
        join_set.spawn(async move {
            run_periodic_check(Box::new(check), initial_delay, interval_duration).await
        });
    }

    while let Some(result) = join_set.join_next().await {
        match result {
            Ok(Ok(())) => continue,
            Ok(Err(msg)) => return Err(msg),
            Err(join_err) if join_err.is_panic() => {
                return Err("health check task panicked".to_string());
            }
            Err(_) => continue,
        }
    }

    Ok(())
}

async fn run_periodic_check(
    check: Box<dyn Check>,
    initial_delay: Duration,
    interval_duration: Duration,
) -> std::result::Result<(), String> {
    if !initial_delay.is_zero() {
        sleep(initial_delay).await;
    }

    let mut ticker = interval(interval_duration);
    ticker.tick().await;

    loop {
        let desc = check.description();
        match check.check().await {
            Ok(()) => debug!("{desc} healthy"),
            Err(msg) => return Err(msg),
        }
        ticker.tick().await;
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use std::sync::atomic::{AtomicUsize, Ordering};

    struct CountingCheck {
        calls: Arc<AtomicUsize>,
    }

    impl Check for CountingCheck {
        fn check(&self) -> crate::check::CheckFuture<'_> {
            let calls = Arc::clone(&self.calls);
            Box::pin(async move {
                calls.fetch_add(1, Ordering::SeqCst);
                Ok(())
            })
        }

        fn description(&self) -> &str {
            "counting"
        }
    }

    #[tokio::test]
    async fn watch_delay_defers_first_probe() {
        let calls = Arc::new(AtomicUsize::new(0));
        let task = tokio::spawn(run_periodic_check(
            Box::new(CountingCheck {
                calls: Arc::clone(&calls),
            }),
            Duration::from_millis(200),
            Duration::from_secs(60),
        ));

        sleep(Duration::from_millis(50)).await;
        assert_eq!(calls.load(Ordering::SeqCst), 0);

        sleep(Duration::from_millis(200)).await;
        assert_eq!(calls.load(Ordering::SeqCst), 1);

        task.abort();
    }

    #[tokio::test]
    async fn zero_watch_delay_keeps_immediate_first_probe() {
        let calls = Arc::new(AtomicUsize::new(0));
        let task = tokio::spawn(run_periodic_check(
            Box::new(CountingCheck {
                calls: Arc::clone(&calls),
            }),
            Duration::ZERO,
            Duration::from_secs(60),
        ));

        tokio::time::timeout(Duration::from_millis(100), async {
            loop {
                if calls.load(Ordering::SeqCst) > 0 {
                    break;
                }
                tokio::task::yield_now().await;
            }
        })
        .await
        .expect("first probe should run without an explicit watch delay");

        assert_eq!(calls.load(Ordering::SeqCst), 1);

        task.abort();
    }
}