use serde::{Deserialize, Serialize};
use uuid::Uuid;
pub type ProtoVersion = u16;
pub const PROTO_MIN: ProtoVersion = 1;
pub const PROTO_MAX: ProtoVersion = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Side {
Incumbent,
Successor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct HandoffId(pub Uuid);
impl HandoffId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl Default for HandoffId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for HandoffId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Capabilities {
pub reserved: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Message {
Hello {
role: Side,
pid: u32,
build_id: Vec<u8>,
proto_min: ProtoVersion,
proto_max: ProtoVersion,
capabilities: Capabilities,
},
HelloAck {
proto_version_chosen: ProtoVersion,
handoff_id: HandoffId,
},
PrepareHandoff {
handoff_id: HandoffId,
successor_pid: u32,
deadline_ms: u64,
drain_grace_ms: u64,
},
Drained {
open_conns_remaining: u32,
accept_closed: bool,
},
SealRequest { handoff_id: HandoffId },
SealProgress {
shards_sealed: u32,
shards_total: u32,
last_revision: u64,
},
SealComplete {
handoff_id: HandoffId,
last_revision_per_shard: Vec<u64>,
data_dir_fingerprint: [u8; 32],
},
SealFailed {
handoff_id: HandoffId,
error: String,
partial_state: String,
},
Begin { handoff_id: HandoffId },
Ready {
handoff_id: HandoffId,
listening_on: Vec<String>,
healthz_ok: bool,
advertised_revision_per_shard: Vec<u64>,
},
Commit { handoff_id: HandoffId },
Abort {
handoff_id: HandoffId,
reason: String,
},
ResumeAfterAbort { handoff_id: HandoffId },
Heartbeat { ts_ms: u64 },
}
pub fn short_name(msg: &Message) -> &'static str {
match msg {
Message::Hello { .. } => "Hello",
Message::HelloAck { .. } => "HelloAck",
Message::PrepareHandoff { .. } => "PrepareHandoff",
Message::Drained { .. } => "Drained",
Message::SealRequest { .. } => "SealRequest",
Message::SealProgress { .. } => "SealProgress",
Message::SealComplete { .. } => "SealComplete",
Message::SealFailed { .. } => "SealFailed",
Message::Begin { .. } => "Begin",
Message::Ready { .. } => "Ready",
Message::Commit { .. } => "Commit",
Message::Abort { .. } => "Abort",
Message::ResumeAfterAbort { .. } => "ResumeAfterAbort",
Message::Heartbeat { .. } => "Heartbeat",
}
}
pub fn negotiate_version(
our_min: ProtoVersion,
our_max: ProtoVersion,
their_min: ProtoVersion,
their_max: ProtoVersion,
) -> crate::error::Result<ProtoVersion> {
let lo = our_min.max(their_min);
let hi = our_max.min(their_max);
if lo > hi {
Err(crate::error::Error::VersionMismatch {
our_min,
our_max,
their_min,
their_max,
})
} else {
Ok(hi)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn negotiate_picks_highest_overlap() {
assert_eq!(negotiate_version(1, 3, 2, 5).unwrap(), 3);
assert_eq!(negotiate_version(1, 1, 1, 1).unwrap(), 1);
assert_eq!(negotiate_version(1, 5, 3, 4).unwrap(), 4);
}
#[test]
fn negotiate_rejects_disjoint() {
assert!(matches!(
negotiate_version(1, 1, 2, 2),
Err(crate::error::Error::VersionMismatch { .. })
));
}
#[test]
fn handoff_id_unique() {
let a = HandoffId::new();
let b = HandoffId::new();
assert_ne!(a, b);
}
}