use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::time::Duration;
use anyhow::{Context, Result};
use tokio::net::UnixStream;
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use trusty_common::bm25_client::{locate_bm25_daemon_binary, socket_path_for_palace};
pub const ENV_EXTERNAL_BM25: &str = "TRUSTY_BM25_EXTERNAL";
const SPAWN_PROBE_TIMEOUT: Duration = Duration::from_millis(3000);
const INITIAL_PROBE_INTERVAL: Duration = Duration::from_millis(20);
const MAX_PROBE_INTERVAL: Duration = Duration::from_millis(250);
struct ChildHandle {
child: Child,
socket_path: PathBuf,
}
pub struct Bm25Supervisor {
children: Mutex<HashMap<String, ChildHandle>>,
}
impl Bm25Supervisor {
pub fn new() -> Self {
Self {
children: Mutex::new(HashMap::new()),
}
}
pub async fn ensure_running(&self, palace: &str, data_dir: &Path) -> Result<PathBuf> {
let socket_path = socket_path_for_palace(palace);
if external_mode_enabled() {
tracing::debug!(
palace = %palace,
socket = %socket_path.display(),
"{ENV_EXTERNAL_BM25}=1 — skipping spawn supervision"
);
return Ok(socket_path);
}
{
let mut guard = self.children.lock().await;
if let Some(entry) = guard.get_mut(palace) {
match entry.child.try_wait() {
Ok(None) => {
tracing::trace!(
palace = %palace,
socket = %entry.socket_path.display(),
"bm25 supervisor: child already running"
);
return Ok(entry.socket_path.clone());
}
Ok(Some(status)) => {
tracing::warn!(
palace = %palace,
?status,
"bm25 daemon exited unexpectedly — attempting one restart"
);
guard.remove(palace);
}
Err(e) => {
tracing::warn!(
palace = %palace,
"bm25 supervisor: try_wait failed: {e:#} — evicting and retrying"
);
guard.remove(palace);
}
}
}
}
if probe_socket(&socket_path).await {
tracing::info!(
palace = %palace,
socket = %socket_path.display(),
"bm25 daemon socket already responding — not spawning a new child"
);
return Ok(socket_path);
}
let binary =
locate_bm25_daemon_binary().context("locate trusty-bm25-daemon binary for spawn")?;
let child = spawn_child(&binary, palace, data_dir)
.await
.with_context(|| {
format!(
"spawn trusty-bm25-daemon {} for palace {palace}",
binary.display()
)
})?;
if let Err(e) = wait_for_socket(&socket_path).await {
drop(child);
return Err(e.context(format!(
"bm25 daemon for palace {palace} did not bind {} within {:?}",
socket_path.display(),
SPAWN_PROBE_TIMEOUT
)));
}
tracing::info!(
palace = %palace,
socket = %socket_path.display(),
binary = %binary.display(),
"spawned trusty-bm25-daemon"
);
let mut guard = self.children.lock().await;
guard.insert(
palace.to_string(),
ChildHandle {
child,
socket_path: socket_path.clone(),
},
);
Ok(socket_path)
}
pub async fn shutdown(&self) {
let mut guard = self.children.lock().await;
let handles: Vec<(String, ChildHandle)> = guard.drain().collect();
drop(guard);
for (palace, mut entry) in handles {
tracing::info!(
palace = %palace,
pid = ?entry.child.id(),
"shutting down bm25 daemon"
);
if let Err(e) = terminate_child(&mut entry.child).await {
tracing::warn!(
palace = %palace,
"bm25 daemon shutdown encountered an error: {e:#}"
);
}
if let Err(e) = tokio::fs::remove_file(&entry.socket_path).await {
if e.kind() != std::io::ErrorKind::NotFound {
tracing::debug!(
palace = %palace,
socket = %entry.socket_path.display(),
"could not remove bm25 daemon socket (likely already cleaned up): {e}"
);
}
}
}
}
pub async fn supervised_count(&self) -> usize {
self.children.lock().await.len()
}
}
impl Default for Bm25Supervisor {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for Bm25Supervisor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bm25Supervisor")
.field("children", &"<locked>")
.finish()
}
}
fn external_mode_enabled() -> bool {
std::env::var(ENV_EXTERNAL_BM25).as_deref() == Ok("1")
}
async fn probe_socket(path: &Path) -> bool {
let connect = UnixStream::connect(path);
matches!(
tokio::time::timeout(Duration::from_millis(200), connect).await,
Ok(Ok(_))
)
}
async fn wait_for_socket(path: &Path) -> Result<()> {
let deadline = tokio::time::Instant::now() + SPAWN_PROBE_TIMEOUT;
let mut interval = INITIAL_PROBE_INTERVAL;
loop {
if probe_socket(path).await {
return Ok(());
}
let now = tokio::time::Instant::now();
if now >= deadline {
anyhow::bail!(
"socket {} did not become connectable within {:?}",
path.display(),
SPAWN_PROBE_TIMEOUT
);
}
let remaining = deadline.saturating_duration_since(now);
let sleep_for = interval.min(remaining);
tokio::time::sleep(sleep_for).await;
interval = (interval * 2).min(MAX_PROBE_INTERVAL);
}
}
async fn spawn_child(binary: &Path, palace: &str, data_dir: &Path) -> Result<Child> {
if !data_dir.exists() {
tokio::fs::create_dir_all(data_dir)
.await
.with_context(|| format!("create bm25 data dir {}", data_dir.display()))?;
}
let child = Command::new(binary)
.arg("--palace")
.arg(palace)
.arg("--data-dir")
.arg(data_dir)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::inherit())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("spawn {}", binary.display()))?;
Ok(child)
}
async fn terminate_child(child: &mut Child) -> Result<()> {
#[cfg(unix)]
if let Some(pid) = child.id() {
unsafe {
let _ = libc::kill(pid as libc::pid_t, libc::SIGTERM);
}
}
let wait_result = tokio::time::timeout(Duration::from_millis(2000), child.wait()).await;
match wait_result {
Ok(Ok(status)) => {
tracing::debug!(?status, "bm25 daemon exited after SIGTERM");
Ok(())
}
Ok(Err(e)) => Err(e).context("wait on bm25 daemon child after SIGTERM"),
Err(_elapsed) => {
tracing::warn!("bm25 daemon ignored SIGTERM after 2s — sending SIGKILL");
child
.kill()
.await
.context("SIGKILL bm25 daemon after SIGTERM timeout")?;
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::Mutex as TokioMutex;
fn env_lock() -> std::sync::Arc<TokioMutex<()>> {
static LOCK: std::sync::OnceLock<std::sync::Arc<TokioMutex<()>>> =
std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Arc::new(TokioMutex::new(())))
.clone()
}
#[tokio::test]
async fn supervisor_starts_empty() {
let sup = Bm25Supervisor::new();
assert_eq!(sup.supervised_count().await, 0);
}
#[tokio::test]
async fn supervisor_default_matches_new() {
let sup: Bm25Supervisor = Default::default();
assert_eq!(sup.supervised_count().await, 0);
}
#[tokio::test]
async fn external_mode_skips_spawn() {
let lock = env_lock();
let _env = lock.lock().await;
let _guard = EnvGuard::set(ENV_EXTERNAL_BM25, "1");
let tmp = tempfile::tempdir().expect("tempdir");
let sup = Bm25Supervisor::new();
let palace = "ext-skip";
let path = sup
.ensure_running(palace, tmp.path())
.await
.expect("external mode must return socket path without spawning");
assert_eq!(path, socket_path_for_palace(palace));
assert_eq!(
sup.supervised_count().await,
0,
"external mode must not register a child"
);
}
#[tokio::test]
async fn already_running_skips_spawn() {
let lock = env_lock();
let _env = lock.lock().await;
let _g = EnvGuard::remove(ENV_EXTERNAL_BM25);
let palace = format!("a{:x}", std::process::id() & 0xffff);
let socket = socket_path_for_palace(&palace);
let _ = std::fs::remove_file(&socket);
let listener =
tokio::net::UnixListener::bind(&socket).expect("bind dummy listener at canonical path");
let tmp = tempfile::tempdir().expect("tempdir");
let sup = Bm25Supervisor::new();
let path = sup
.ensure_running(&palace, tmp.path())
.await
.expect("ensure_running must adopt existing socket");
assert_eq!(path, socket);
assert_eq!(
sup.supervised_count().await,
0,
"adoption path must not register a child"
);
drop(listener);
let _ = std::fs::remove_file(&socket);
}
#[tokio::test]
async fn shutdown_with_no_children_is_noop() {
let sup = Bm25Supervisor::new();
sup.shutdown().await;
assert_eq!(sup.supervised_count().await, 0);
}
#[test]
fn supervisor_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Bm25Supervisor>();
}
#[tokio::test]
async fn probe_returns_false_for_missing_socket() {
let tmp = tempfile::tempdir().expect("tempdir");
let missing = tmp.path().join("nonexistent.sock");
assert!(!probe_socket(&missing).await);
}
#[tokio::test]
async fn probe_returns_true_for_bound_socket() {
let tmp = tempfile::tempdir().expect("tempdir");
let sock = tmp.path().join("listen.sock");
let _listener =
tokio::net::UnixListener::bind(&sock).expect("bind listener for probe test");
assert!(probe_socket(&sock).await);
}
struct EnvGuard {
key: String,
prev: Option<String>,
}
impl EnvGuard {
fn set(key: &str, value: &str) -> Self {
let prev = std::env::var(key).ok();
unsafe { std::env::set_var(key, value) }
Self {
key: key.to_string(),
prev,
}
}
fn remove(key: &str) -> Self {
let prev = std::env::var(key).ok();
unsafe { std::env::remove_var(key) }
Self {
key: key.to_string(),
prev,
}
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
unsafe {
match &self.prev {
Some(v) => std::env::set_var(&self.key, v),
None => std::env::remove_var(&self.key),
}
}
}
}
}