#![allow(unsafe_code)]
use std::io::ErrorKind;
use std::os::fd::{AsRawFd, IntoRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::{Path, PathBuf};
use std::process::{Child, Command, Stdio};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use nix::sys::socket::{AddressFamily, SockFlag, SockType, socketpair};
use crate::crash::points;
use crate::crash_here;
use crate::error::{Error, Result};
use crate::fd::pass_listener_fds_on_spawn;
use crate::frame::{read_message, write_message};
use crate::metrics::events;
use crate::protocol::{
HandoffId, Message, PROTO_MAX, PROTO_MIN, ProtoVersion, Side, negotiate_version, short_name,
};
use crate::state::{Phase, StateJournal};
use crate::util::now_unix_ms;
const MIN_READ_TIMEOUT: Duration = Duration::from_millis(100);
const LIVENESS_TIMEOUT: Duration = Duration::from_secs(10);
const HELLO_READ_TIMEOUT: Duration = Duration::from_secs(5);
const WIRE_SLACK: Duration = Duration::from_secs(1);
pub struct Supervisor {
socket_path: PathBuf,
listener_fds: Vec<(String, RawFd)>,
journal_path: Option<PathBuf>,
build_id: Vec<u8>,
in_flight: Mutex<()>,
}
#[derive(Debug, Clone)]
pub struct SpawnSpec {
pub binary: PathBuf,
pub args: Vec<String>,
pub env: Vec<(String, String)>,
pub deadline: Duration,
pub drain_grace: Duration,
}
impl SpawnSpec {
pub fn new(binary: impl Into<PathBuf>) -> Self {
Self {
binary: binary.into(),
args: Vec::new(),
env: Vec::new(),
deadline: Duration::from_secs(300),
drain_grace: Duration::from_secs(60),
}
}
}
#[derive(Debug)]
pub struct HandoffOutcome {
pub handoff_id: HandoffId,
pub committed: bool,
pub abort_reason: Option<String>,
pub child: Option<Child>,
}
struct ChildGuard {
child: Option<Child>,
pid: u32,
}
impl ChildGuard {
fn new(child: Child) -> Self {
let pid = child.id();
Self {
child: Some(child),
pid,
}
}
fn id(&self) -> u32 {
self.pid
}
fn disarm(mut self) -> Child {
self.child
.take()
.expect("BUG: ChildGuard inner Child missing — constructor invariant violated")
}
fn kill_and_reap(mut self) {
if let Some(mut c) = self.child.take() {
let _ = c.kill();
let _ = c.wait();
}
}
}
impl Drop for ChildGuard {
fn drop(&mut self) {
if let Some(mut c) = self.child.take() {
tracing::warn!(
pid = self.pid,
"killing leaked successor child on guard drop"
);
let _ = c.kill();
let _ = c.wait();
}
}
}
impl Supervisor {
pub fn new(socket_path: &Path) -> Result<Self> {
Ok(Self {
socket_path: socket_path.to_path_buf(),
listener_fds: Vec::new(),
journal_path: None,
build_id: Vec::new(),
in_flight: Mutex::new(()),
})
}
pub fn with_listener(mut self, name: impl Into<String>, fd: RawFd) -> Self {
self.listener_fds.push((name.into(), fd));
self
}
pub fn with_journal(mut self, path: PathBuf) -> Self {
self.journal_path = Some(path);
self
}
pub fn with_build_id(mut self, build_id: Vec<u8>) -> Self {
self.build_id = build_id;
self
}
pub fn perform_handoff(&self, spec: SpawnSpec) -> Result<HandoffOutcome> {
let _in_flight = self
.in_flight
.try_lock()
.map_err(|_| Error::HandoffInProgress)?;
let handoff_id = HandoffId::new();
let started_instant = Instant::now();
let started_unix_ms = now_unix_ms();
let total_deadline_at = started_instant + spec.deadline;
let mut o_stream = UnixStream::connect(&self.socket_path)?;
let chosen_o =
self.exchange_hello_as_supervisor(&mut o_stream, handoff_id, Side::Incumbent, None)?;
crash_here!(points::S_AFTER_O_HELLO);
let (s_end, n_end) = make_socketpair()?;
let n_end_raw = n_end.as_raw_fd();
let child = self.spawn_successor(&spec, n_end_raw)?;
let child_guard = ChildGuard::new(child);
let successor_pid = child_guard.id();
drop(n_end);
crash_here!(points::S_AFTER_SPAWN_SUCCESSOR);
let mut n_stream = s_end;
let chosen_n = self.exchange_hello_as_supervisor(
&mut n_stream,
handoff_id,
Side::Successor,
Some(successor_pid),
)?;
crash_here!(points::S_AFTER_N_HELLO);
self.journal_set(
handoff_id,
Phase::Negotiating,
successor_pid,
started_unix_ms,
)?;
let prepare_at = Instant::now();
let deadline_ms = remaining_until(total_deadline_at).as_millis() as u64;
let drain_grace_ms = spec.drain_grace.as_millis() as u64;
tracing::info!(
target: events::PREPARE,
%handoff_id, successor_pid,
drain_grace_ms,
deadline_ms,
"prepare handoff"
);
write_message(
&mut o_stream,
chosen_o,
&Message::PrepareHandoff {
handoff_id,
successor_pid,
deadline_ms,
drain_grace_ms,
},
)?;
crash_here!(points::S_AFTER_PREPARE_SENT);
let drained_msg = read_until(
&mut o_stream,
spec.drain_grace + WIRE_SLACK,
"Drained",
|m| matches!(m, Message::Drained { .. }),
)?;
let (drained_open_conns, drained_accept_closed) = match &drained_msg {
Message::Drained {
open_conns_remaining,
accept_closed,
} => (*open_conns_remaining, *accept_closed),
_ => unreachable!("read_until predicate restricts variant"),
};
tracing::info!(
target: events::DRAINED,
%handoff_id,
open_conns_remaining = drained_open_conns,
accept_closed = drained_accept_closed,
drain_seconds = prepare_at.elapsed().as_secs_f64(),
"drain complete"
);
crash_here!(points::S_AFTER_DRAINED_RECV);
self.journal_set(handoff_id, Phase::Draining, successor_pid, started_unix_ms)?;
let seal_at = Instant::now();
tracing::info!(target: events::SEAL, %handoff_id, "seal request");
write_message(
&mut o_stream,
chosen_o,
&Message::SealRequest { handoff_id },
)?;
crash_here!(points::S_AFTER_SEAL_REQUEST_SENT);
let seal_read_deadline = total_deadline_at + WIRE_SLACK;
let seal_outcome: std::result::Result<(), String> = loop {
let now = Instant::now();
if now >= seal_read_deadline {
let _ = o_stream.set_read_timeout(None);
send_best_effort_abort(
&mut n_stream,
chosen_n,
handoff_id,
"seal phase timed out".into(),
);
child_guard.kill_and_reap();
self.journal_clear();
return Err(Error::Timeout("SealComplete"));
}
let remaining = seal_read_deadline - now;
let recv_timeout = LIVENESS_TIMEOUT.min(remaining).max(MIN_READ_TIMEOUT);
o_stream.set_read_timeout(Some(recv_timeout))?;
match read_message(&mut o_stream) {
Ok((_, Message::SealProgress { .. })) => continue,
Ok((_, Message::Heartbeat { .. })) => continue,
Ok((_, Message::SealComplete { handoff_id: id, .. })) if id == handoff_id => {
break Ok(());
}
Ok((
_,
Message::SealFailed {
handoff_id: id,
error,
..
},
)) if id == handoff_id => break Err(error),
Ok((_, other)) => {
let _ = o_stream.set_read_timeout(None);
return Err(Error::UnexpectedMessage(short_name(&other)));
}
Err(Error::Io(e)) if is_timeout(&e) => {
let _ = o_stream.set_read_timeout(None);
send_best_effort_abort(
&mut n_stream,
chosen_n,
handoff_id,
"incumbent unresponsive during seal".into(),
);
child_guard.kill_and_reap();
self.journal_clear();
return Err(Error::Timeout("SealComplete"));
}
Err(e) => {
let _ = o_stream.set_read_timeout(None);
return Err(e);
}
}
};
let _ = o_stream.set_read_timeout(None);
if let Err(error) = seal_outcome {
tracing::warn!(
target: events::ABORT,
%handoff_id, error = %error, "seal failed; aborting handoff"
);
send_best_effort_abort(
&mut n_stream,
chosen_n,
handoff_id,
format!("seal failed: {error}"),
);
child_guard.kill_and_reap();
self.journal_clear();
return Ok(HandoffOutcome {
handoff_id,
committed: false,
abort_reason: Some(format!("seal failed: {error}")),
child: None,
});
}
tracing::info!(
target: events::SEAL_COMPLETE,
%handoff_id,
seal_seconds = seal_at.elapsed().as_secs_f64(),
"seal complete; flock released by O"
);
crash_here!(points::S_AFTER_SEAL_COMPLETE_RECV);
self.journal_set(handoff_id, Phase::Sealing, successor_pid, started_unix_ms)?;
let begin_at = Instant::now();
let ready_result =
match write_message(&mut n_stream, chosen_n, &Message::Begin { handoff_id }) {
Ok(()) => {
crash_here!(points::S_AFTER_BEGIN_SENT);
self.journal_set(
handoff_id,
Phase::AwaitingReady,
successor_pid,
started_unix_ms,
)?;
let ready_timeout = remaining_until(total_deadline_at) + WIRE_SLACK;
read_until(&mut n_stream, ready_timeout, "Ready", |m| {
matches!(m, Message::Ready { .. })
})
}
Err(e) => Err(e),
};
let ready_result = match ready_result {
Ok(Message::Ready { handoff_id: id, .. }) if id != handoff_id => Err(Error::Protocol(
format!("Ready carries wrong handoff_id: got {id}, expected {handoff_id}"),
)),
other => other,
};
match ready_result {
Ok(Message::Ready {
handoff_id: id,
listening_on,
healthz_ok,
advertised_revision_per_shard,
}) if id == handoff_id => {
tracing::info!(
target: events::READY,
%handoff_id,
healthz_ok,
listeners = ?listening_on,
advertised_revisions = ?advertised_revision_per_shard,
begin_to_ready_seconds = begin_at.elapsed().as_secs_f64(),
"successor ready"
);
crash_here!(points::S_AFTER_READY_RECV);
let child = child_guard.disarm();
tracing::info!(
target: events::COMMIT,
%handoff_id,
total_seconds = started_instant.elapsed().as_secs_f64(),
"commit"
);
if let Err(e) =
write_message(&mut o_stream, chosen_o, &Message::Commit { handoff_id })
{
tracing::warn!(
%handoff_id, error = %e,
"failed to send Commit to incumbent; O may have crashed — \
N is the new incumbent regardless, handoff is committed"
);
}
crash_here!(points::S_AFTER_COMMIT_SENT);
if let Err(e) =
self.journal_set(handoff_id, Phase::Committed, successor_pid, started_unix_ms)
{
tracing::warn!(%handoff_id, error = %e, "journal Committed failed");
}
self.journal_clear();
crash_here!(points::S_AFTER_JOURNAL_CLEAR);
Ok(HandoffOutcome {
handoff_id,
committed: true,
abort_reason: None,
child: Some(child),
})
}
other => {
let reason = match &other {
Ok(m) => format!("expected Ready, got {}", short_name(m)),
Err(Error::Timeout(s)) => format!("ready phase timed out waiting for {s}"),
Err(e) => format!("ready read failed: {e}"),
};
tracing::warn!(
target: events::ABORT,
%handoff_id, reason, "aborting handoff before commit"
);
send_best_effort_abort(&mut n_stream, chosen_n, handoff_id, reason.clone());
child_guard.kill_and_reap();
write_message(
&mut o_stream,
chosen_o,
&Message::ResumeAfterAbort { handoff_id },
)?;
tracing::info!(
target: events::RESUME,
%handoff_id, "sent ResumeAfterAbort to O"
);
self.journal_set(
handoff_id,
Phase::ResumingAfterAbort,
successor_pid,
started_unix_ms,
)?;
self.journal_clear();
Ok(HandoffOutcome {
handoff_id,
committed: false,
abort_reason: Some(reason),
child: None,
})
}
}
}
pub fn resume_from_journal(&self) -> Result<Option<StateJournal>> {
let Some(path) = self.journal_path.as_deref() else {
return Ok(None);
};
let Some(journal) = StateJournal::read(path)? else {
return Ok(None);
};
tracing::warn!(
handoff_id = %journal.handoff_id,
phase = ?journal.phase,
"found prior handoff state on disk; verifying incumbent then clearing"
);
match UnixStream::connect(&self.socket_path) {
Ok(mut stream) => {
let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
if let Err(e) = read_message(&mut stream) {
tracing::debug!(
error = %e,
"incumbent Hello read failed during journal resume probe; \
continuing — EOF on drop will reset incumbent session"
);
}
}
Err(e) => {
tracing::warn!(error = %e, "incumbent unreachable during journal resume");
}
}
StateJournal::delete(path)?;
Ok(Some(journal))
}
fn exchange_hello_as_supervisor(
&self,
stream: &mut UnixStream,
handoff_id: HandoffId,
expected_role: Side,
expected_pid: Option<u32>,
) -> Result<ProtoVersion> {
stream.set_read_timeout(Some(HELLO_READ_TIMEOUT))?;
let read_result = read_message(stream);
let _ = stream.set_read_timeout(None);
let (_v, peer_hello) = match read_result {
Ok(x) => x,
Err(Error::Io(e)) if is_timeout(&e) => return Err(Error::Timeout("peer Hello")),
Err(e) => return Err(e),
};
let (their_role, their_pid, their_min, their_max) = match peer_hello {
Message::Hello {
role,
pid,
proto_min,
proto_max,
..
} => (role, pid, proto_min, proto_max),
other => return Err(Error::UnexpectedMessage(short_name(&other))),
};
if their_role != expected_role {
return Err(Error::Protocol(format!(
"peer announced role {:?}, expected {:?}",
their_role, expected_role
)));
}
if let Some(expected) = expected_pid
&& their_pid != expected
{
return Err(Error::PidMismatch {
expected,
announced: their_pid,
});
}
let chosen = negotiate_version(PROTO_MIN, PROTO_MAX, their_min, their_max)?;
write_message(
stream,
chosen,
&Message::HelloAck {
proto_version_chosen: chosen,
handoff_id,
},
)?;
Ok(chosen)
}
fn spawn_successor(&self, spec: &SpawnSpec, n_sock_fd: RawFd) -> Result<Child> {
let sock_target_fd = 3 + self.listener_fds.len() as RawFd;
let mut cmd = Command::new(&spec.binary);
cmd.args(&spec.args);
for (k, v) in &spec.env {
cmd.env(k, v);
}
cmd.env("HANDOFF_ROLE", "successor");
cmd.env("HANDOFF_SOCK_FD", sock_target_fd.to_string());
cmd.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit());
pass_listener_fds_on_spawn(&mut cmd, &self.listener_fds, Some(n_sock_fd));
let child = cmd.spawn()?;
Ok(child)
}
fn journal_set(
&self,
handoff_id: HandoffId,
phase: Phase,
successor_pid: u32,
started_at_unix_ms: u64,
) -> Result<()> {
if let Some(path) = &self.journal_path {
StateJournal {
handoff_id,
phase,
incumbent_pid: std::process::id(),
successor_pid: Some(successor_pid),
started_at_unix_ms,
}
.write_atomic(path)?;
}
Ok(())
}
fn journal_clear(&self) {
if let Some(path) = &self.journal_path
&& let Err(e) = StateJournal::delete(path)
{
tracing::warn!(
error = %e,
path = %path.display(),
"failed to clear handoff journal; next supervisor start will see stale state"
);
}
}
}
fn send_best_effort_abort(
stream: &mut UnixStream,
version: ProtoVersion,
handoff_id: HandoffId,
reason: String,
) {
let reason_for_log = reason.clone();
if let Err(e) = write_message(stream, version, &Message::Abort { handoff_id, reason }) {
tracing::warn!(
%handoff_id,
reason = %reason_for_log,
error = %e,
"best-effort Abort to successor failed; child will still be killed and reaped"
);
}
}
fn make_socketpair() -> Result<(UnixStream, UnixStream)> {
let (a, b) = socketpair(
AddressFamily::Unix,
SockType::Stream,
None,
SockFlag::SOCK_CLOEXEC,
)?;
let s_a = unsafe {
use std::os::fd::FromRawFd;
UnixStream::from_raw_fd(a.into_raw_fd())
};
let s_b = unsafe {
use std::os::fd::FromRawFd;
UnixStream::from_raw_fd(b.into_raw_fd())
};
Ok((s_a, s_b))
}
fn read_until<F>(
stream: &mut UnixStream,
timeout: Duration,
awaiting: &'static str,
pred: F,
) -> Result<Message>
where
F: Fn(&Message) -> bool,
{
let deadline = Instant::now() + timeout.max(MIN_READ_TIMEOUT);
loop {
let now = Instant::now();
if now >= deadline {
let _ = stream.set_read_timeout(None);
return Err(Error::Timeout(awaiting));
}
let remaining = deadline - now;
let recv_timeout = LIVENESS_TIMEOUT.min(remaining).max(MIN_READ_TIMEOUT);
stream.set_read_timeout(Some(recv_timeout))?;
match read_message(stream) {
Ok((_, Message::Heartbeat { .. })) => continue,
Ok((_, Message::SealProgress { .. })) => continue,
Ok((_, msg)) if pred(&msg) => {
let _ = stream.set_read_timeout(None);
return Ok(msg);
}
Ok((_, other)) => {
let _ = stream.set_read_timeout(None);
return Err(Error::UnexpectedMessage(short_name(&other)));
}
Err(Error::Io(e)) if is_timeout(&e) => {
let _ = stream.set_read_timeout(None);
return Err(Error::Timeout(awaiting));
}
Err(e) => {
let _ = stream.set_read_timeout(None);
return Err(e);
}
}
}
}
fn remaining_until(deadline: Instant) -> Duration {
deadline
.checked_duration_since(Instant::now())
.unwrap_or(MIN_READ_TIMEOUT)
.max(MIN_READ_TIMEOUT)
}
fn is_timeout(e: &std::io::Error) -> bool {
matches!(e.kind(), ErrorKind::WouldBlock | ErrorKind::TimedOut)
}