use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::time::Duration;
use tokio::time::sleep;
use crate::daemon::{SharedState, build_shell_command, with_process, with_process_mut};
use crate::model::{HealthProbe, HttpCheck, NameHandle, ProcessStatus};
#[derive(Debug, Clone, Copy)]
pub(crate) enum ProbeKind {
Readiness,
Liveness,
}
pub(crate) fn spawn_probe_if_present(
probe: Option<&HealthProbe>,
kind: ProbeKind,
name_handle: &NameHandle,
working_dir: &Path,
environment: &BTreeMap<String, String>,
state: &SharedState,
) {
if let Some(probe) = probe {
tokio::spawn(run_probe(
kind,
name_handle.clone(),
probe.clone(),
state.clone(),
working_dir.to_path_buf(),
environment.clone(),
));
}
}
async fn run_probe(
kind: ProbeKind,
name_handle: NameHandle,
probe: HealthProbe,
state: SharedState,
working_dir: PathBuf,
environment: BTreeMap<String, String>,
) {
if probe.initial_delay_seconds > 0 {
sleep(Duration::from_secs(probe.initial_delay_seconds)).await;
}
let mut consecutive_successes: u32 = 0;
let mut consecutive_failures: u32 = 0;
loop {
let keep_going = with_process(&state, &name_handle, |r| !r.status.is_terminal())
.await
.unwrap_or(false);
if !keep_going {
break;
}
let success = run_single_check(&probe, &working_dir, &environment).await;
if success {
consecutive_successes += 1;
consecutive_failures = 0;
if consecutive_successes >= probe.success_threshold {
with_process_mut(&state, &name_handle, |runtime| match kind {
ProbeKind::Readiness => runtime.ready = true,
ProbeKind::Liveness => runtime.alive = true,
})
.await;
}
} else {
consecutive_failures += 1;
consecutive_successes = 0;
if consecutive_failures >= probe.failure_threshold {
match kind {
ProbeKind::Readiness => {
with_process_mut(&state, &name_handle, |runtime| {
runtime.ready = false;
})
.await;
}
ProbeKind::Liveness => {
with_process_mut(&state, &name_handle, |runtime| {
runtime.alive = false;
if let ProcessStatus::Running { pid } = runtime.status {
#[cfg(unix)]
{
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
let _ =
signal::kill(Pid::from_raw(pid as i32), Signal::SIGKILL);
}
#[cfg(not(unix))]
{
let _ = pid; }
}
})
.await;
consecutive_failures = 0;
sleep(Duration::from_secs(probe.period_seconds)).await;
continue;
}
}
}
}
sleep(Duration::from_secs(probe.period_seconds)).await;
}
}
async fn run_single_check(
probe: &HealthProbe,
working_dir: &Path,
environment: &BTreeMap<String, String>,
) -> bool {
if let Some(ref exec) = probe.exec {
let timeout = Duration::from_secs(probe.timeout_seconds);
let mut cmd = match build_shell_command(&exec.command) {
Ok(c) => c,
Err(_) => return false,
};
cmd.current_dir(working_dir)
.stdout(Stdio::null())
.stderr(Stdio::null())
.envs(environment);
match tokio::time::timeout(timeout, cmd.output()).await {
Ok(Ok(output)) => return output.status.success(),
_ => return false,
}
}
if let Some(ref http) = probe.http_get {
let timeout = Duration::from_secs(probe.timeout_seconds);
return tokio::time::timeout(timeout, http_get_check(http))
.await
.unwrap_or(false);
}
false
}
async fn http_get_check(http: &HttpCheck) -> bool {
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
use tokio::net::TcpStream;
let addr = format!("{}:{}", http.host, http.port);
let mut stream = match TcpStream::connect(&addr).await {
Ok(s) => s,
Err(_) => return false,
};
let request = format!(
"GET {} HTTP/1.1\r\nHost: {}:{}\r\nConnection: close\r\n\r\n",
http.path, http.host, http.port
);
if stream.write_all(request.as_bytes()).await.is_err() {
return false;
}
let mut buf = vec![0u8; 1024];
let n = match stream.read(&mut buf).await {
Ok(n) if n > 0 => n,
_ => return false,
};
let response = String::from_utf8_lossy(&buf[..n]);
let status: u16 = response
.split_whitespace()
.nth(1)
.and_then(|s| s.parse().ok())
.unwrap_or(0);
(200..400).contains(&status)
}