use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::{Mutex, RwLock};
use trusty_common::embedder_client::EmbedderClient;
pub use trusty_common::embedder_client::EmbedderSupervisor;
#[derive(Debug, Clone)]
pub struct SupervisorConfig {
pub startup_timeout_secs: u64,
pub backoff_max_secs: u64,
pub max_restarts: u32,
pub idle_shutdown_secs: u64,
}
impl SupervisorConfig {
pub fn from_env() -> Self {
Self {
startup_timeout_secs: parse_env_u64("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS", 30),
backoff_max_secs: parse_env_u64("TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS", 60),
max_restarts: parse_env_u32("TRUSTY_EMBEDDERD_MAX_RESTARTS", 5),
idle_shutdown_secs: parse_env_u64("TRUSTY_EMBEDDERD_IDLE_SHUTDOWN_SECS", 0),
}
}
pub fn into_common(self) -> trusty_common::embedder_client::SupervisorConfig {
trusty_common::embedder_client::SupervisorConfig {
startup_timeout_secs: self.startup_timeout_secs,
backoff_max_secs: self.backoff_max_secs,
max_restarts: self.max_restarts,
}
}
}
impl Default for SupervisorConfig {
fn default() -> Self {
Self {
startup_timeout_secs: 30,
backoff_max_secs: 60,
max_restarts: 5,
idle_shutdown_secs: 0,
}
}
}
pub fn locate_embedderd_binary() -> anyhow::Result<PathBuf> {
trusty_common::embedder_client::locate_embedderd_binary()
}
pub fn default_socket_path() -> PathBuf {
let pid = std::process::id();
let filename = format!("trusty-embedderd-{pid}.sock");
let dir = std::env::var("TMPDIR")
.ok()
.filter(|s| !s.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/tmp"));
dir.join(filename)
}
struct SpawnedState {
client_slot: Arc<RwLock<Arc<dyn EmbedderClient>>>,
#[allow(dead_code)]
shutdown_tx: tokio::sync::oneshot::Sender<()>,
#[allow(dead_code)]
pid_slot: Arc<AtomicU32>,
}
pub struct LazyEmbedderHandle {
binary_path: PathBuf,
config: SupervisorConfig,
state: Arc<Mutex<Option<SpawnedState>>>,
app_pid_slot: Arc<AtomicU32>,
last_use: Arc<Mutex<Option<Instant>>>,
}
impl LazyEmbedderHandle {
pub fn new(binary_path: PathBuf, config: SupervisorConfig) -> Self {
tracing::info!(
"embedderd supervisor armed, deferred spawn enabled \
(idle_shutdown_secs={})",
config.idle_shutdown_secs,
);
Self {
binary_path,
config,
state: Arc::new(Mutex::new(None)),
app_pid_slot: Arc::new(AtomicU32::new(0)),
last_use: Arc::new(Mutex::new(None)),
}
}
pub fn app_pid_slot(&self) -> Arc<AtomicU32> {
Arc::clone(&self.app_pid_slot)
}
pub async fn embed_via<F, Fut, T>(
&self,
op: F,
) -> Result<T, trusty_common::embedder_client::EmbedderError>
where
F: FnOnce(Arc<dyn EmbedderClient>) -> Fut,
Fut: std::future::Future<Output = Result<T, trusty_common::embedder_client::EmbedderError>>,
{
let client_slot = {
let mut guard = self.state.lock().await;
if guard.is_none() {
let spawned = do_spawn(
&self.binary_path,
&self.config,
Arc::clone(&self.app_pid_slot),
Arc::clone(&self.state),
Arc::clone(&self.last_use),
)
.await
.map_err(|e| {
trusty_common::embedder_client::EmbedderError::ModelError(format!(
"lazy embedderd spawn failed: {e:#}"
))
})?;
*guard = Some(spawned);
}
let spawned = guard.as_ref().expect("state is Some after spawn");
Arc::clone(&spawned.client_slot)
};
let client = client_slot.read().await.clone();
let result = op(client).await;
if result.is_ok() {
let mut last_use = self.last_use.lock().await;
*last_use = Some(Instant::now());
}
result
}
}
async fn do_spawn(
binary_path: &Path,
config: &SupervisorConfig,
app_pid_slot: Arc<AtomicU32>,
state_cell: Arc<Mutex<Option<SpawnedState>>>,
last_use: Arc<Mutex<Option<Instant>>>,
) -> Result<SpawnedState> {
tracing::info!(
binary = %binary_path.display(),
"LazyEmbedderHandle: first embed request — spawning trusty-embedderd",
);
let common_config = trusty_common::embedder_client::SupervisorConfig {
startup_timeout_secs: config.startup_timeout_secs,
backoff_max_secs: config.backoff_max_secs,
max_restarts: config.max_restarts,
};
let (supervisor, client_slot, child_pid_slot) =
EmbedderSupervisor::spawn_stdio(binary_path.to_path_buf(), common_config).await?;
let initial_pid = child_pid_slot.load(AtomicOrdering::Acquire);
app_pid_slot.store(initial_pid, AtomicOrdering::Release);
{
let src = Arc::clone(&child_pid_slot);
let dst = Arc::clone(&app_pid_slot);
tokio::spawn(async move {
loop {
let pid = src.load(AtomicOrdering::Acquire);
dst.store(pid, AtomicOrdering::Release);
if pid == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
});
}
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
supervisor.start_supervisor_task();
let idle_secs = config.idle_shutdown_secs;
if idle_secs > 0 {
let state_cell_clone = Arc::clone(&state_cell);
let app_pid_slot_clone = Arc::clone(&app_pid_slot);
let last_use_clone = Arc::clone(&last_use);
tokio::spawn(idle_watchdog(
idle_secs,
state_cell_clone,
app_pid_slot_clone,
last_use_clone,
shutdown_rx,
));
}
Ok(SpawnedState {
client_slot,
shutdown_tx,
pid_slot: child_pid_slot,
})
}
async fn idle_watchdog(
idle_secs: u64,
state_cell: Arc<Mutex<Option<SpawnedState>>>,
app_pid_slot: Arc<AtomicU32>,
last_use: Arc<Mutex<Option<Instant>>>,
mut shutdown_rx: tokio::sync::oneshot::Receiver<()>,
) {
let poll_interval = Duration::from_secs(10).min(Duration::from_secs(idle_secs));
let idle_threshold = Duration::from_secs(idle_secs);
loop {
tokio::select! {
_ = tokio::time::sleep(poll_interval) => {}
_ = &mut shutdown_rx => {
tracing::debug!("idle_watchdog: shutdown signal received, exiting");
return;
}
}
let idle_duration = {
let guard = last_use.lock().await;
match *guard {
Some(t) => t.elapsed(),
None => Duration::ZERO,
}
};
if idle_duration < idle_threshold {
continue;
}
tracing::info!(
idle_secs = idle_secs,
"LazyEmbedderHandle: idle threshold exceeded — shutting down embedderd"
);
let mut guard = state_cell.lock().await;
if guard.is_some() {
let pid = app_pid_slot.load(AtomicOrdering::Acquire);
if pid != 0 {
#[cfg(unix)]
{
use nix::sys::signal::{kill, Signal};
use nix::unistd::Pid;
let _ = kill(Pid::from_raw(pid as i32), Signal::SIGTERM);
tokio::time::sleep(Duration::from_millis(500)).await;
let _ = kill(Pid::from_raw(pid as i32), Signal::SIGKILL);
}
#[cfg(not(unix))]
{
tracing::warn!(
"idle_watchdog: idle kill not supported on this platform; \
clearing state only"
);
}
}
*guard = None;
app_pid_slot.store(0, AtomicOrdering::Release);
tracing::info!(
"LazyEmbedderHandle: embedderd idle-shutdown complete; spawn gate reset"
);
}
return;
}
}
fn parse_env_u64(var: &str, default: u64) -> u64 {
std::env::var(var)
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(default)
}
fn parse_env_u32(var: &str, default: u32) -> u32 {
std::env::var(var)
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(default)
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::sync::atomic::Ordering;
#[test]
#[serial]
fn config_from_env_defaults() {
let _g1 = EnvGuard::remove("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS");
let _g2 = EnvGuard::remove("TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS");
let _g3 = EnvGuard::remove("TRUSTY_EMBEDDERD_MAX_RESTARTS");
let _g4 = EnvGuard::remove("TRUSTY_EMBEDDERD_IDLE_SHUTDOWN_SECS");
let cfg = SupervisorConfig::from_env();
assert_eq!(cfg.startup_timeout_secs, 30);
assert_eq!(cfg.backoff_max_secs, 60);
assert_eq!(cfg.max_restarts, 5);
assert_eq!(
cfg.idle_shutdown_secs, 0,
"idle-shutdown must default to disabled"
);
}
#[test]
#[serial]
fn config_from_env_overrides() {
let _g1 = EnvGuard::set("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS", "15");
let _g2 = EnvGuard::set("TRUSTY_EMBEDDERD_RESTART_BACKOFF_MAX_SECS", "120");
let _g3 = EnvGuard::set("TRUSTY_EMBEDDERD_MAX_RESTARTS", "10");
let _g4 = EnvGuard::set("TRUSTY_EMBEDDERD_IDLE_SHUTDOWN_SECS", "300");
let cfg = SupervisorConfig::from_env();
assert_eq!(cfg.startup_timeout_secs, 15);
assert_eq!(cfg.backoff_max_secs, 120);
assert_eq!(cfg.max_restarts, 10);
assert_eq!(cfg.idle_shutdown_secs, 300);
}
#[test]
#[serial]
fn config_from_env_ignores_malformed() {
let _g1 = EnvGuard::set("TRUSTY_EMBEDDERD_STARTUP_TIMEOUT_SECS", "not_a_number");
let _g2 = EnvGuard::set("TRUSTY_EMBEDDERD_MAX_RESTARTS", "bad");
let _g3 = EnvGuard::set("TRUSTY_EMBEDDERD_IDLE_SHUTDOWN_SECS", "nope");
let cfg = SupervisorConfig::from_env();
assert_eq!(cfg.startup_timeout_secs, 30);
assert_eq!(cfg.max_restarts, 5);
assert_eq!(cfg.idle_shutdown_secs, 0);
}
#[test]
fn into_common_maps_fields() {
let cfg = SupervisorConfig {
startup_timeout_secs: 99,
backoff_max_secs: 77,
max_restarts: 3,
idle_shutdown_secs: 600,
};
let common = cfg.into_common();
assert_eq!(common.startup_timeout_secs, 99);
assert_eq!(common.backoff_max_secs, 77);
assert_eq!(common.max_restarts, 3);
}
#[test]
fn default_socket_path_is_pid_specific() {
let p = default_socket_path();
let pid = std::process::id().to_string();
assert!(
p.to_string_lossy().contains(&pid),
"socket path {p:?} must contain PID {pid}"
);
assert_eq!(
p,
default_socket_path(),
"must be deterministic for same PID"
);
}
#[test]
fn default_socket_path_has_parent() {
let p = default_socket_path();
assert!(
p.parent().is_some_and(|pp| !pp.as_os_str().is_empty()),
"socket path {p:?} must have a non-empty parent"
);
}
#[test]
#[serial]
fn locate_binary_bad_explicit_path_errors() {
let _g = EnvGuard::set("TRUSTY_EMBEDDERD_BIN", "/nonexistent/path/trusty-embedderd");
let result = locate_embedderd_binary();
assert!(result.is_err(), "expected Err, got {result:?}");
}
#[test]
#[serial]
fn locate_binary_via_explicit_env() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_path_buf();
let _g = EnvGuard::set("TRUSTY_EMBEDDERD_BIN", path.to_str().unwrap());
let result = locate_embedderd_binary();
assert!(result.is_ok(), "expected Ok, got {result:?}");
assert_eq!(result.unwrap(), path);
}
#[test]
fn lazy_handle_defers_spawn_pid_is_zero_at_construction() {
let handle = LazyEmbedderHandle::new(
PathBuf::from("/nonexistent/trusty-embedderd"),
SupervisorConfig::default(),
);
let pid = handle.app_pid_slot().load(Ordering::Relaxed);
assert_eq!(
pid, 0,
"PID slot must be 0 before any embed request; got {pid}"
);
}
#[tokio::test]
async fn lazy_handle_state_is_none_before_first_use() {
let handle = LazyEmbedderHandle::new(
PathBuf::from("/nonexistent/trusty-embedderd"),
SupervisorConfig::default(),
);
let guard = handle.state.lock().await;
assert!(guard.is_none(), "state must be None before first embed");
}
#[tokio::test]
async fn lazy_handle_spawn_failure_propagates_as_error() {
let handle = LazyEmbedderHandle::new(
PathBuf::from("/nonexistent/trusty-embedderd"),
SupervisorConfig {
startup_timeout_secs: 1,
..SupervisorConfig::default()
},
);
let result = handle
.embed_via(|_client| async {
Ok::<Vec<Vec<f32>>, trusty_common::embedder_client::EmbedderError>(vec![])
})
.await;
assert!(result.is_err(), "expected Err when binary is absent");
let err = result.unwrap_err().to_string();
assert!(
err.contains("spawn") || err.contains("embedderd") || err.contains("nonexistent"),
"error must describe the spawn failure; got: {err}"
);
}
#[tokio::test]
async fn lazy_handle_single_flight_concurrent_spawn_attempts() {
use std::sync::Arc as StdArc;
let handle = StdArc::new(LazyEmbedderHandle::new(
PathBuf::from("/nonexistent/trusty-embedderd"),
SupervisorConfig {
startup_timeout_secs: 1,
..SupervisorConfig::default()
},
));
let h1 = StdArc::clone(&handle);
let h2 = StdArc::clone(&handle);
let (r1, r2) = tokio::join!(
tokio::spawn(async move {
h1.embed_via(|_c| async {
Ok::<Vec<Vec<f32>>, trusty_common::embedder_client::EmbedderError>(vec![])
})
.await
}),
tokio::spawn(async move {
h2.embed_via(|_c| async {
Ok::<Vec<Vec<f32>>, trusty_common::embedder_client::EmbedderError>(vec![])
})
.await
}),
);
assert!(r1.is_ok(), "task 1 panicked: {:?}", r1);
assert!(r2.is_ok(), "task 2 panicked: {:?}", r2);
assert!(
r1.unwrap().is_err(),
"task 1 should return Err for missing binary"
);
assert!(
r2.unwrap().is_err(),
"task 2 should return Err for missing binary"
);
}
#[tokio::test]
async fn lazy_handle_no_watchdog_when_idle_secs_is_zero() {
let handle = LazyEmbedderHandle::new(
PathBuf::from("/nonexistent/trusty-embedderd"),
SupervisorConfig {
idle_shutdown_secs: 0,
..SupervisorConfig::default()
},
);
let guard = handle.state.lock().await;
assert!(
guard.is_none(),
"state must be None; watchdog must not trigger spawn"
);
}
struct EnvGuard {
key: String,
old: Option<String>,
}
impl EnvGuard {
fn set(key: &str, value: &str) -> Self {
let old = std::env::var(key).ok();
unsafe { std::env::set_var(key, value) }
Self {
key: key.to_owned(),
old,
}
}
fn remove(key: &str) -> Self {
let old = std::env::var(key).ok();
unsafe { std::env::remove_var(key) }
Self {
key: key.to_owned(),
old,
}
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
unsafe {
match &self.old {
Some(v) => std::env::set_var(&self.key, v),
None => std::env::remove_var(&self.key),
}
}
}
}
}