use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
use std::time::Duration;
use anyhow::{Context, Result};
use tokio::process::{Child, Command};
use tokio::sync::RwLock;
use super::{EmbedderClient, StdioEmbedderClient};
#[derive(Debug, Clone)]
pub struct SupervisorConfig {
pub max_restarts: u32,
pub backoff_max_secs: u64,
pub startup_timeout_secs: u64,
}
impl Default for SupervisorConfig {
fn default() -> Self {
Self {
max_restarts: 5,
backoff_max_secs: 60,
startup_timeout_secs: 5,
}
}
}
impl SupervisorConfig {
pub fn from_env() -> Self {
let def = Self::default();
Self {
max_restarts: parse_env("TRUSTY_EMBEDDERD_MAX_RESTARTS", def.max_restarts),
backoff_max_secs: parse_env(
"TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS",
def.backoff_max_secs,
),
startup_timeout_secs: parse_env(
"TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS",
def.startup_timeout_secs,
),
}
}
}
fn parse_env<T: std::str::FromStr + Copy>(name: &str, default: T) -> T {
std::env::var(name)
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(default)
}
pub struct EmbedderSupervisor {
binary_path: PathBuf,
child: Arc<tokio::sync::Mutex<Option<Child>>>,
client_slot: Arc<RwLock<Arc<dyn EmbedderClient>>>,
child_pid_slot: Arc<AtomicU32>,
config: SupervisorConfig,
}
impl EmbedderSupervisor {
pub async fn spawn_stdio(
binary_path: impl Into<PathBuf>,
config: SupervisorConfig,
) -> Result<(Self, Arc<RwLock<Arc<dyn EmbedderClient>>>, Arc<AtomicU32>)> {
let binary_path = binary_path.into();
let (child, client) = spawn_child(&binary_path, &config).await?;
let initial_pid: u32 = child.id().unwrap_or(0);
let child_pid_slot = Arc::new(AtomicU32::new(initial_pid));
let client_slot: Arc<RwLock<Arc<dyn EmbedderClient>>> =
Arc::new(RwLock::new(Arc::new(client)));
let client_slot_clone = Arc::clone(&client_slot);
let child_pid_slot_clone = Arc::clone(&child_pid_slot);
let supervisor = Self {
binary_path,
child: Arc::new(tokio::sync::Mutex::new(Some(child))),
client_slot,
child_pid_slot,
config,
};
Ok((supervisor, client_slot_clone, child_pid_slot_clone))
}
pub fn start_supervisor_task(self) {
tokio::spawn(supervision_loop(
self.binary_path,
self.child,
self.client_slot,
self.child_pid_slot,
self.config,
));
}
pub async fn shutdown(self) {
let mut guard = self.child.lock().await;
if let Some(mut child) = guard.take() {
let _ = child.kill().await;
let _ = child.wait().await;
tracing::info!("EmbedderSupervisor: sidecar terminated on shutdown");
}
}
}
async fn spawn_child(
binary_path: &Path,
config: &SupervisorConfig,
) -> Result<(Child, StdioEmbedderClient)> {
use std::process::Stdio;
let mut child = Command::new(binary_path)
.arg("--stdio")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.kill_on_drop(true)
.spawn()
.with_context(|| {
format!(
"spawn trusty-embedderd --stdio from {}",
binary_path.display()
)
})?;
let stdin = child
.stdin
.take()
.context("child stdin handle missing (expected Stdio::piped)")?;
let stdout = child
.stdout
.take()
.context("child stdout handle missing (expected Stdio::piped)")?;
let client = StdioEmbedderClient::new(stdin, stdout);
let probe_result = tokio::time::timeout(
Duration::from_secs(config.startup_timeout_secs),
client.embed_batch(vec!["trusty-embedderd startup probe".to_string()]),
)
.await;
match probe_result {
Ok(Ok(_)) => {
tracing::info!(
binary = %binary_path.display(),
"EmbedderSupervisor: sidecar started and responding"
);
}
Ok(Err(e)) => {
anyhow::bail!("sidecar startup probe failed: {e}");
}
Err(_elapsed) => {
anyhow::bail!(
"sidecar did not respond within {}s (TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS={})",
config.startup_timeout_secs,
config.startup_timeout_secs
);
}
}
Ok((child, client))
}
async fn supervision_loop(
binary_path: PathBuf,
child_slot: Arc<tokio::sync::Mutex<Option<Child>>>,
client_slot: Arc<RwLock<Arc<dyn EmbedderClient>>>,
child_pid_slot: Arc<AtomicU32>,
config: SupervisorConfig,
) {
let mut consecutive_failures: u32 = 0;
loop {
let exit_status = {
let mut guard = child_slot.lock().await;
match guard.as_mut() {
Some(child) => match child.wait().await {
Ok(status) => status,
Err(e) => {
tracing::error!("EmbedderSupervisor: wait() failed: {e}");
child_pid_slot.store(0, AtomicOrdering::Release);
return;
}
},
None => {
child_pid_slot.store(0, AtomicOrdering::Release);
return;
}
}
};
child_pid_slot.store(0, AtomicOrdering::Release);
if exit_status.success() {
tracing::info!("EmbedderSupervisor: sidecar exited cleanly — stopping supervision");
return;
}
consecutive_failures += 1;
tracing::warn!(
"EmbedderSupervisor: sidecar exited with {:?} (failure #{}/{})",
exit_status.code(),
consecutive_failures,
config.max_restarts,
);
if consecutive_failures > config.max_restarts {
tracing::error!(
"EmbedderSupervisor: exceeded max_restarts={} — giving up. \
Set TRUSTY_EMBEDDERD_MAX_RESTARTS to increase the limit.",
config.max_restarts
);
return;
}
let delay_secs = (1u64 << consecutive_failures.min(16)).min(config.backoff_max_secs);
tracing::info!(
"EmbedderSupervisor: restarting sidecar in {delay_secs}s (attempt {consecutive_failures})"
);
tokio::time::sleep(Duration::from_secs(delay_secs)).await;
match spawn_child(&binary_path, &config).await {
Ok((new_child, new_client)) => {
let new_pid = new_child.id().unwrap_or(0);
{
let mut client_guard = client_slot.write().await;
*client_guard = Arc::new(new_client);
}
{
let mut child_guard = child_slot.lock().await;
*child_guard = Some(new_child);
}
child_pid_slot.store(new_pid, AtomicOrdering::Release);
consecutive_failures = 0;
tracing::info!(
"EmbedderSupervisor: sidecar restarted successfully (pid={new_pid})"
);
}
Err(e) => {
tracing::error!("EmbedderSupervisor: respawn failed: {e:#}");
}
}
}
}
pub fn locate_embedderd_binary() -> Result<PathBuf> {
if let Ok(explicit) = std::env::var("TRUSTY_EMBEDDERD_BIN") {
let p = PathBuf::from(&explicit);
if p.is_file() {
return Ok(p);
}
anyhow::bail!("TRUSTY_EMBEDDERD_BIN={explicit:?} does not point to an existing file");
}
if let Ok(exe) = std::env::current_exe()
&& let Some(dir) = exe.parent()
{
let sibling = dir.join("trusty-embedderd");
if sibling.is_file() {
return Ok(sibling);
}
let sibling_exe = dir.join("trusty-embedderd.exe");
if sibling_exe.is_file() {
return Ok(sibling_exe);
}
}
if let Ok(path) = which_embedderd() {
return Ok(path);
}
anyhow::bail!(
"could not locate trusty-embedderd binary. \
Set TRUSTY_EMBEDDERD_BIN=/path/to/trusty-embedderd or ensure it is on PATH."
)
}
fn which_embedderd() -> Result<PathBuf> {
let path_var = std::env::var("PATH").unwrap_or_default();
let sep = if cfg!(windows) { ';' } else { ':' };
for dir in path_var.split(sep) {
let candidate = PathBuf::from(dir).join("trusty-embedderd");
if candidate.is_file() {
return Ok(candidate);
}
#[cfg(windows)]
{
let candidate_exe = PathBuf::from(dir).join("trusty-embedderd.exe");
if candidate_exe.is_file() {
return Ok(candidate_exe);
}
}
}
anyhow::bail!("trusty-embedderd not found on PATH")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_env_uses_defaults_when_no_vars_set() {
let saved_max = std::env::var("TRUSTY_EMBEDDERD_MAX_RESTARTS").ok();
let saved_backoff = std::env::var("TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS").ok();
let saved_timeout = std::env::var("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS").ok();
unsafe {
std::env::remove_var("TRUSTY_EMBEDDERD_MAX_RESTARTS");
std::env::remove_var("TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS");
std::env::remove_var("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS");
}
let cfg = SupervisorConfig::from_env();
assert_eq!(cfg.max_restarts, 5);
assert_eq!(cfg.backoff_max_secs, 60);
assert_eq!(cfg.startup_timeout_secs, 5);
unsafe {
if let Some(v) = saved_max {
std::env::set_var("TRUSTY_EMBEDDERD_MAX_RESTARTS", v);
}
if let Some(v) = saved_backoff {
std::env::set_var("TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS", v);
}
if let Some(v) = saved_timeout {
std::env::set_var("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS", v);
}
}
}
#[test]
fn parse_env_uses_override() {
let saved = std::env::var("TRUSTY_EMBEDDERD_MAX_RESTARTS").ok();
unsafe {
std::env::set_var("TRUSTY_EMBEDDERD_MAX_RESTARTS", "99");
}
let cfg = SupervisorConfig::from_env();
assert_eq!(cfg.max_restarts, 99);
unsafe {
if let Some(v) = saved {
std::env::set_var("TRUSTY_EMBEDDERD_MAX_RESTARTS", v);
} else {
std::env::remove_var("TRUSTY_EMBEDDERD_MAX_RESTARTS");
}
}
}
#[test]
fn locate_binary_respects_explicit_override() {
let saved = std::env::var("TRUSTY_EMBEDDERD_BIN").ok();
unsafe {
std::env::set_var("TRUSTY_EMBEDDERD_BIN", "/no/such/binary");
}
let result = locate_embedderd_binary();
assert!(result.is_err(), "must fail on non-existent override path");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("TRUSTY_EMBEDDERD_BIN"),
"error must mention the env var"
);
unsafe {
if let Some(v) = saved {
std::env::set_var("TRUSTY_EMBEDDERD_BIN", v);
} else {
std::env::remove_var("TRUSTY_EMBEDDERD_BIN");
}
}
}
}