#![allow(unsafe_code)]
use std::collections::HashMap;
use std::env;
use std::marker::PhantomData;
use std::net::TcpListener;
use std::os::fd::{FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use crate::drainable::ReadinessSnapshot;
use crate::error::{Error, Result};
use crate::frame::{read_message, write_message};
use crate::protocol::{
Capabilities, HandoffId, Message, PROTO_MAX, PROTO_MIN, ProtoVersion, Side, short_name,
};
use crate::util::now_unix_ms;
const SUCCESSOR_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(2);
pub const ENV_HANDOFF_ROLE: &str = "HANDOFF_ROLE";
pub const ENV_HANDOFF_SOCK_FD: &str = "HANDOFF_SOCK_FD";
pub const ENV_LISTEN_FDS: &str = "LISTEN_FDS";
pub const ENV_LISTEN_FDNAMES: &str = "LISTEN_FDNAMES";
pub const SD_LISTEN_FDS_START: RawFd = 3;
pub enum Role {
ColdStart { inherited: InheritedListeners },
Successor(Successor),
}
#[derive(Default)]
pub struct InheritedListeners {
listeners: HashMap<String, RawFd>,
}
impl InheritedListeners {
pub fn take(&mut self, name: &str) -> Option<TcpListener> {
let fd = self.listeners.remove(name)?;
Some(unsafe { TcpListener::from_raw_fd(fd) })
}
pub fn names(&self) -> Vec<String> {
self.listeners.keys().cloned().collect()
}
pub fn is_empty(&self) -> bool {
self.listeners.is_empty()
}
}
pub struct Successor {
control: UnixStream,
inherited: InheritedListeners,
}
pub struct HandshookSuccessor {
control: UnixStream,
inherited: InheritedListeners,
handoff_id: HandoffId,
proto_version: ProtoVersion,
}
pub struct BegunSuccessor {
control: UnixStream,
inherited: InheritedListeners,
handoff_id: HandoffId,
proto_version: ProtoVersion,
}
pub fn detect_role() -> Result<Role> {
let inherited = read_inherited_listeners();
unsafe {
env::remove_var(ENV_LISTEN_FDS);
env::remove_var(ENV_LISTEN_FDNAMES);
}
match env::var(ENV_HANDOFF_ROLE) {
Ok(s) if s == "successor" => {}
_ => return Ok(Role::ColdStart { inherited }),
}
let sock_raw =
env::var(ENV_HANDOFF_SOCK_FD).map_err(|_| Error::MissingEnv(ENV_HANDOFF_SOCK_FD))?;
let sock_fd: RawFd = sock_raw.parse().map_err(|_| Error::BadEnv {
var: ENV_HANDOFF_SOCK_FD,
value: sock_raw,
})?;
unsafe {
env::remove_var(ENV_HANDOFF_ROLE);
env::remove_var(ENV_HANDOFF_SOCK_FD);
}
let control = unsafe { UnixStream::from_raw_fd(sock_fd) };
Ok(Role::Successor(Successor { control, inherited }))
}
fn read_inherited_listeners() -> InheritedListeners {
let count: usize = env::var(ENV_LISTEN_FDS)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
if count == 0 {
return InheritedListeners::default();
}
let names: Vec<String> = env::var(ENV_LISTEN_FDNAMES)
.ok()
.map(|s| s.split(':').map(|s| s.to_string()).collect())
.unwrap_or_default();
let mut map = HashMap::with_capacity(count);
for i in 0..count {
let fd = SD_LISTEN_FDS_START + i as RawFd;
let name = names.get(i).cloned().unwrap_or_else(|| i.to_string());
map.insert(name, fd);
}
InheritedListeners { listeners: map }
}
impl Successor {
pub fn handshake(mut self, build_id: Vec<u8>) -> Result<HandshookSuccessor> {
let hello = Message::Hello {
role: Side::Successor,
pid: std::process::id(),
build_id,
proto_min: PROTO_MIN,
proto_max: PROTO_MAX,
capabilities: Capabilities::default(),
};
write_message(&mut self.control, PROTO_MAX, &hello)?;
let (_ver, ack) = read_message(&mut self.control)?;
match ack {
Message::HelloAck {
proto_version_chosen,
handoff_id,
} => Ok(HandshookSuccessor {
control: self.control,
inherited: self.inherited,
handoff_id,
proto_version: proto_version_chosen,
}),
other => Err(Error::UnexpectedMessage(short_name(&other))),
}
}
pub fn listener_names(&self) -> Vec<String> {
self.inherited.names()
}
}
impl HandshookSuccessor {
pub fn wait_for_begin(mut self) -> Result<BegunSuccessor> {
let expected = self.handoff_id;
loop {
let (_ver, msg) = read_message(&mut self.control)?;
match msg {
Message::Begin { handoff_id } if handoff_id == expected => {
return Ok(BegunSuccessor {
control: self.control,
inherited: self.inherited,
handoff_id,
proto_version: self.proto_version,
});
}
Message::Begin { handoff_id } => {
return Err(Error::Protocol(format!(
"Begin handoff_id {handoff_id} does not match \
handshake id {expected}"
)));
}
Message::Abort { reason, .. } => return Err(Error::Aborted(reason)),
Message::Heartbeat { .. } => continue,
other => return Err(Error::UnexpectedMessage(short_name(&other))),
}
}
}
pub fn listener_names(&self) -> Vec<String> {
self.inherited.names()
}
pub fn handoff_id(&self) -> HandoffId {
self.handoff_id
}
}
impl BegunSuccessor {
pub fn take_listener(&mut self, name: &str) -> Option<TcpListener> {
self.inherited.take(name)
}
pub fn listener_names(&self) -> Vec<String> {
self.inherited.names()
}
pub fn handoff_id(&self) -> HandoffId {
self.handoff_id
}
pub fn announce_ready(mut self, snapshot: ReadinessSnapshot) -> Result<()> {
let ready = Message::Ready {
handoff_id: self.handoff_id,
listening_on: snapshot.listening_on,
healthz_ok: snapshot.healthz_ok,
advertised_revision_per_shard: snapshot.advertised_revision_per_shard,
};
write_message(&mut self.control, self.proto_version, &ready)?;
Ok(())
}
pub fn announce_and_bind(
self,
snapshot: ReadinessSnapshot,
socket_path: &std::path::Path,
lock: crate::DataDirLock,
) -> Result<crate::Incumbent> {
self.announce_ready(snapshot)?;
crate::Incumbent::bind_after_ready(socket_path, lock)
}
pub fn start_heartbeats(&self) -> HeartbeatGuard<'_> {
HeartbeatGuard::start(&self.control, self.proto_version)
}
}
pub struct HeartbeatGuard<'a> {
stop_tx: Option<mpsc::Sender<()>>,
thread: Option<thread::JoinHandle<()>>,
_borrow: PhantomData<&'a UnixStream>,
}
impl<'a> HeartbeatGuard<'a> {
fn start(stream: &'a UnixStream, chosen: ProtoVersion) -> Self {
let writer = match stream.try_clone() {
Ok(w) => w,
Err(e) => {
tracing::warn!(
error = %e,
"could not clone control stream for successor heartbeats; running without"
);
return Self {
stop_tx: None,
thread: None,
_borrow: PhantomData,
};
}
};
let (stop_tx, stop_rx) = mpsc::channel::<()>();
let thread = thread::spawn(move || {
let mut writer = writer;
while stop_rx.recv_timeout(SUCCESSOR_HEARTBEAT_INTERVAL).is_err() {
let msg = Message::Heartbeat {
ts_ms: now_unix_ms(),
};
if write_message(&mut writer, chosen, &msg).is_err() {
return;
}
}
});
Self {
stop_tx: Some(stop_tx),
thread: Some(thread),
_borrow: PhantomData,
}
}
}
impl Drop for HeartbeatGuard<'_> {
fn drop(&mut self) {
if let Some(tx) = self.stop_tx.take() {
let _ = tx.send(());
}
if let Some(h) = self.thread.take() {
let _ = h.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_env_branches() {
unsafe {
env::remove_var(ENV_HANDOFF_ROLE);
env::remove_var(ENV_HANDOFF_SOCK_FD);
env::remove_var(ENV_LISTEN_FDS);
env::remove_var(ENV_LISTEN_FDNAMES);
}
assert!(matches!(detect_role().unwrap(), Role::ColdStart { .. }));
unsafe {
env::set_var(ENV_HANDOFF_ROLE, "other");
}
assert!(matches!(detect_role().unwrap(), Role::ColdStart { .. }));
unsafe {
env::remove_var(ENV_HANDOFF_ROLE);
}
}
}