use std::io::ErrorKind;
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::thread;
use std::time::{Duration, Instant};
use crate::crash::points;
use crate::crash_here;
use crate::drainable::Drainable;
use crate::error::{Error, Result};
use crate::frame::{read_message, write_message};
use crate::lock::DataDirLock;
use crate::metrics::events;
use crate::protocol::{
Capabilities, HandoffId, Message, PROTO_MAX, PROTO_MIN, ProtoVersion, Side, negotiate_version,
short_name,
};
use crate::util::now_unix_ms;
const RESUME_FLOCK_TIMEOUT: Duration = Duration::from_secs(2);
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(2);
const HELLO_READ_TIMEOUT: Duration = Duration::from_secs(5);
pub struct Incumbent {
listener: UnixListener,
lock: Option<DataDirLock>,
data_dir: PathBuf,
build_id: Vec<u8>,
}
fn run_with_heartbeats<F, T>(stream: &UnixStream, chosen: ProtoVersion, work: F) -> Result<T>
where
F: FnOnce() -> Result<T>,
{
let writer = match stream.try_clone() {
Ok(w) => w,
Err(e) => {
tracing::warn!(
error = %e,
"could not clone control stream for heartbeats; running without"
);
return work();
}
};
let (stop_tx, stop_rx) = mpsc::channel::<()>();
let hb_thread = thread::spawn(move || {
let mut writer = writer;
while stop_rx.recv_timeout(HEARTBEAT_INTERVAL).is_err() {
let msg = Message::Heartbeat {
ts_ms: now_unix_ms(),
};
if write_message(&mut writer, chosen, &msg).is_err() {
return;
}
}
});
struct StopGuard {
stop_tx: Option<mpsc::Sender<()>>,
thread: Option<thread::JoinHandle<()>>,
}
impl Drop for StopGuard {
fn drop(&mut self) {
if let Some(tx) = self.stop_tx.take() {
let _ = tx.send(());
}
if let Some(h) = self.thread.take() {
let _ = h.join();
}
}
}
let _guard = StopGuard {
stop_tx: Some(stop_tx),
thread: Some(hb_thread),
};
work()
}
fn bind_unlinking(socket_path: &Path, lock: DataDirLock) -> Result<Incumbent> {
if let Some(parent) = socket_path.parent() {
std::fs::create_dir_all(parent)?;
}
let _ = std::fs::remove_file(socket_path);
let listener = UnixListener::bind(socket_path)?;
let data_dir = lock.data_dir().to_path_buf();
Ok(Incumbent {
listener,
lock: Some(lock),
data_dir,
build_id: Vec::new(),
})
}
fn acquire_with_short_retry(data_dir: &Path, timeout: Duration) -> Result<DataDirLock> {
const RETRY_INTERVAL: Duration = Duration::from_millis(25);
let deadline = Instant::now() + timeout;
loop {
match DataDirLock::acquire(data_dir) {
Ok(lock) => return Ok(lock),
Err(Error::LockHeld { .. }) if Instant::now() < deadline => {
std::thread::sleep(RETRY_INTERVAL);
}
Err(e) => return Err(e),
}
}
}
enum SessionOutcome {
Committed,
Closed,
}
#[derive(Default)]
struct SessionState {
active: Option<HandoffId>,
sealed: bool,
drained: bool,
}
impl Incumbent {
pub fn bind_cold_start(socket_path: &Path, lock: DataDirLock) -> Result<Self> {
bind_unlinking(socket_path, lock)
}
pub(crate) fn bind_after_ready(socket_path: &Path, lock: DataDirLock) -> Result<Self> {
bind_unlinking(socket_path, lock)
}
pub fn with_build_id(mut self, build_id: Vec<u8>) -> Self {
self.build_id = build_id;
self
}
pub fn serve<D: Drainable + 'static>(mut self, drainable: D) -> Result<()> {
loop {
let (stream, _addr) = match self.listener.accept() {
Ok(x) => x,
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e.into()),
};
match self.handle_session(stream, &drainable) {
Ok(SessionOutcome::Committed) => {
tracing::info!("handoff committed; incumbent exiting serve loop");
return Ok(());
}
Ok(SessionOutcome::Closed) => continue,
Err(e) => {
tracing::error!(error = %e, "handoff session ended with error");
if self.lock.is_none() {
match acquire_with_short_retry(&self.data_dir, RESUME_FLOCK_TIMEOUT) {
Ok(lock) => {
self.lock = Some(lock);
if let Err(e2) = drainable.resume_after_abort() {
tracing::error!(
error = %e2,
"resume_after_abort failed after session error"
);
}
}
Err(e2) => {
tracing::error!(
error = %e2,
"failed to re-acquire flock after session error; \
incumbent cannot resume"
);
return Err(e2);
}
}
}
}
}
}
}
fn handle_session<D: Drainable>(
&mut self,
mut stream: UnixStream,
drainable: &D,
) -> Result<SessionOutcome> {
let our_hello = Message::Hello {
role: Side::Incumbent,
pid: std::process::id(),
build_id: self.build_id.clone(),
proto_min: PROTO_MIN,
proto_max: PROTO_MAX,
capabilities: Capabilities::default(),
};
write_message(&mut stream, PROTO_MAX, &our_hello)?;
stream.set_read_timeout(Some(HELLO_READ_TIMEOUT))?;
let read_result = read_message(&mut stream);
let _ = stream.set_read_timeout(None);
let (_v, ack) = match read_result {
Ok(x) => x,
Err(Error::Io(e))
if matches!(e.kind(), ErrorKind::WouldBlock | ErrorKind::TimedOut) =>
{
return Err(Error::Timeout("HelloAck"));
}
Err(e) => return Err(e),
};
let chosen = match ack {
Message::HelloAck {
proto_version_chosen,
..
} => negotiate_version(
PROTO_MIN,
PROTO_MAX,
proto_version_chosen,
proto_version_chosen,
)?,
other => return Err(Error::UnexpectedMessage(short_name(&other))),
};
let mut state = SessionState::default();
let outcome = self.run_session_loop(&mut stream, chosen, drainable, &mut state);
let committed = matches!(outcome, Ok(SessionOutcome::Committed));
if state.drained
&& !state.sealed
&& !committed
&& let Err(e) = drainable.resume_after_abort()
{
tracing::error!(
error = %e,
"resume_after_abort during drained-session cleanup failed"
);
}
outcome
}
fn run_session_loop<D: Drainable>(
&mut self,
stream: &mut UnixStream,
chosen: u16,
drainable: &D,
state: &mut SessionState,
) -> Result<SessionOutcome> {
loop {
let (_v, msg) = match read_message(stream) {
Ok(x) => x,
Err(Error::Io(e))
if matches!(
e.kind(),
ErrorKind::UnexpectedEof | ErrorKind::ConnectionReset
) =>
{
if state.sealed {
return Err(Error::Protocol(
"supervisor disconnected after seal; resuming".into(),
));
}
return Ok(SessionOutcome::Closed);
}
Err(e) => return Err(e),
};
match msg {
Message::PrepareHandoff {
handoff_id,
deadline_ms,
drain_grace_ms,
..
} => {
if let Some(existing) = state.active
&& existing != handoff_id
{
return Err(Error::HandoffInProgress);
}
state.active = Some(handoff_id);
let now = Instant::now();
let deadline = now + Duration::from_millis(drain_grace_ms.min(deadline_ms));
tracing::info!(
target: events::PREPARE,
%handoff_id, "drain start"
);
let report = run_with_heartbeats(stream, chosen, || drainable.drain(deadline))?;
state.drained = true;
tracing::info!(
target: events::DRAINED,
%handoff_id, open_conns_remaining = report.open_conns_remaining,
"drain done"
);
write_message(
stream,
chosen,
&Message::Drained {
open_conns_remaining: report.open_conns_remaining,
accept_closed: report.accept_closed,
},
)?;
crash_here!(points::O_AFTER_DRAINED_SENT);
}
Message::SealRequest { handoff_id } => {
if state.active != Some(handoff_id) {
return Err(Error::UnexpectedMessage("SealRequest for unknown id"));
}
tracing::info!(
target: events::SEAL,
%handoff_id, "seal start"
);
let seal_outcome = run_with_heartbeats(stream, chosen, || drainable.seal());
match seal_outcome {
Ok(report) => {
self.lock.take();
state.sealed = true;
crash_here!(points::O_AFTER_SEAL_FLOCK_RELEASED);
tracing::info!(
target: events::SEAL_COMPLETE,
%handoff_id, "seal complete; flock released"
);
write_message(
stream,
chosen,
&Message::SealComplete {
handoff_id,
last_revision_per_shard: report.last_revision_per_shard,
data_dir_fingerprint: report.data_dir_fingerprint,
},
)?;
crash_here!(points::O_AFTER_SEAL_COMPLETE_SENT);
}
Err(e) => {
tracing::error!(
%handoff_id, error = %e, "seal failed; remaining as incumbent"
);
write_message(
stream,
chosen,
&Message::SealFailed {
handoff_id,
error: format!("{e}"),
partial_state: String::new(),
},
)?;
drainable.resume_after_abort()?;
state.drained = false;
state.active = None;
}
}
}
Message::Commit { handoff_id } => {
if state.active != Some(handoff_id) {
return Err(Error::UnexpectedMessage("Commit for unknown id"));
}
if !state.sealed {
return Err(Error::Protocol("Commit before SealComplete".into()));
}
tracing::info!(
target: events::COMMIT,
%handoff_id, "handoff committed"
);
crash_here!(points::O_AFTER_COMMIT_RECV);
return Ok(SessionOutcome::Committed);
}
Message::ResumeAfterAbort { handoff_id } => {
if state.active != Some(handoff_id) {
return Err(Error::UnexpectedMessage("Resume for unknown id"));
}
if state.sealed {
let lock = DataDirLock::acquire(&self.data_dir)?;
self.lock = Some(lock);
drainable.resume_after_abort()?;
state.sealed = false;
state.drained = false;
tracing::info!(
target: events::RESUME,
%handoff_id, "resumed after abort; flock re-acquired"
);
} else if state.drained {
drainable.resume_after_abort()?;
state.drained = false;
}
state.active = None;
}
Message::Abort { handoff_id, reason } => {
if state.active != Some(handoff_id) {
return Err(Error::UnexpectedMessage("Abort for unknown id"));
}
tracing::warn!(
target: events::ABORT,
%handoff_id, reason, "handoff aborted"
);
if state.sealed {
let lock = DataDirLock::acquire(&self.data_dir)?;
self.lock = Some(lock);
drainable.resume_after_abort()?;
state.sealed = false;
state.drained = false;
} else if state.drained {
drainable.resume_after_abort()?;
state.drained = false;
}
state.active = None;
}
Message::Heartbeat { .. } => {
write_message(
stream,
chosen,
&Message::Heartbeat {
ts_ms: now_unix_ms(),
},
)?;
}
other => return Err(Error::UnexpectedMessage(short_name(&other))),
}
}
}
}