mod common;
use std::thread;
use std::time::Duration;
use handoff::frame::{read_message, write_message};
use handoff::protocol::{HandoffId, Message, PROTO_MAX, Side};
use handoff::{DataDirLock, Incumbent};
use common::{MockDrainable, connect_with_retry};
fn make_active_session() -> (
tempfile::TempDir,
std::os::unix::net::UnixStream,
HandoffId,
thread::JoinHandle<handoff::Result<()>>,
) {
let temp = tempfile::tempdir().unwrap();
let sock_path = temp.path().join("control.sock");
let data_dir = temp.path().join("data");
let lock = DataDirLock::acquire(&data_dir).unwrap();
let incumbent = Incumbent::bind_cold_start(&sock_path, lock).unwrap();
let drainable = MockDrainable::default();
let server_thread = thread::spawn(move || incumbent.serve(drainable));
let mut stream = connect_with_retry(&sock_path);
let (_v, hello) = read_message(&mut stream).unwrap();
assert!(matches!(hello, Message::Hello { .. }));
let handoff_id = HandoffId::new();
write_message(
&mut stream,
PROTO_MAX,
&Message::HelloAck {
proto_version_chosen: PROTO_MAX,
handoff_id,
},
)
.unwrap();
(temp, stream, handoff_id, server_thread)
}
#[test]
fn commit_before_seal_closes_session() {
let (_temp, mut stream, handoff_id, _thread) = make_active_session();
write_message(
&mut stream,
PROTO_MAX,
&Message::PrepareHandoff {
handoff_id,
successor_pid: 9999,
deadline_ms: 5000,
drain_grace_ms: 1000,
},
)
.unwrap();
let (_v, drained) = read_message(&mut stream).unwrap();
assert!(matches!(drained, Message::Drained { .. }));
write_message(&mut stream, PROTO_MAX, &Message::Commit { handoff_id }).unwrap();
assert!(read_message(&mut stream).is_err());
}
#[test]
fn seal_request_with_wrong_id_closes_session() {
let (_temp, mut stream, handoff_id, _thread) = make_active_session();
write_message(
&mut stream,
PROTO_MAX,
&Message::PrepareHandoff {
handoff_id,
successor_pid: 9999,
deadline_ms: 5000,
drain_grace_ms: 1000,
},
)
.unwrap();
let (_v, drained) = read_message(&mut stream).unwrap();
assert!(matches!(drained, Message::Drained { .. }));
let other_id = HandoffId::new();
write_message(
&mut stream,
PROTO_MAX,
&Message::SealRequest {
handoff_id: other_id,
},
)
.unwrap();
assert!(read_message(&mut stream).is_err());
}
#[test]
fn commit_with_wrong_id_closes_session() {
let (_temp, mut stream, handoff_id, _thread) = make_active_session();
write_message(
&mut stream,
PROTO_MAX,
&Message::PrepareHandoff {
handoff_id,
successor_pid: 9999,
deadline_ms: 5000,
drain_grace_ms: 1000,
},
)
.unwrap();
let (_v, _) = read_message(&mut stream).unwrap();
write_message(&mut stream, PROTO_MAX, &Message::SealRequest { handoff_id }).unwrap();
let (_v, sealed) = read_message(&mut stream).unwrap();
assert!(matches!(sealed, Message::SealComplete { .. }));
let other_id = HandoffId::new();
write_message(
&mut stream,
PROTO_MAX,
&Message::Commit {
handoff_id: other_id,
},
)
.unwrap();
assert!(read_message(&mut stream).is_err());
}
#[test]
fn wrong_direction_frames_close_session_but_keep_serving() {
let illegal: Vec<Message> = vec![
Message::Hello {
role: Side::Successor,
pid: 1,
build_id: Vec::new(),
proto_min: PROTO_MAX,
proto_max: PROTO_MAX,
capabilities: Default::default(),
},
Message::HelloAck {
proto_version_chosen: PROTO_MAX,
handoff_id: HandoffId::new(),
},
Message::Drained {
open_conns_remaining: 0,
accept_closed: true,
},
Message::SealComplete {
handoff_id: HandoffId::new(),
last_revision_per_shard: vec![1],
data_dir_fingerprint: [0u8; 32],
},
Message::SealFailed {
handoff_id: HandoffId::new(),
error: "x".into(),
partial_state: String::new(),
},
Message::SealProgress {
shards_sealed: 1,
shards_total: 2,
last_revision: 0,
},
Message::Begin {
handoff_id: HandoffId::new(),
},
Message::Ready {
handoff_id: HandoffId::new(),
listening_on: vec!["tcp".into()],
healthz_ok: true,
advertised_revision_per_shard: vec![1],
},
];
for msg in illegal {
let label = std::mem::discriminant(&msg);
let (_temp, mut stream, _id, _thread) = make_active_session();
write_message(&mut stream, PROTO_MAX, &msg).unwrap();
assert!(
read_message(&mut stream).is_err(),
"session must close after illegal {label:?}, but read succeeded"
);
drop(stream);
}
}
#[test]
fn heartbeat_is_idempotent_in_any_state() {
let (_temp, mut stream, handoff_id, _thread) = make_active_session();
for _ in 0..3 {
write_message(&mut stream, PROTO_MAX, &Message::Heartbeat { ts_ms: 0 }).unwrap();
let (_v, echoed) = read_message(&mut stream).unwrap();
assert!(matches!(echoed, Message::Heartbeat { .. }));
}
write_message(
&mut stream,
PROTO_MAX,
&Message::PrepareHandoff {
handoff_id,
successor_pid: 9999,
deadline_ms: 5000,
drain_grace_ms: 1000,
},
)
.unwrap();
let (_v, _) = read_message(&mut stream).unwrap();
for _ in 0..3 {
write_message(&mut stream, PROTO_MAX, &Message::Heartbeat { ts_ms: 1 }).unwrap();
let (_v, echoed) = read_message(&mut stream).unwrap();
assert!(matches!(echoed, Message::Heartbeat { .. }));
}
write_message(&mut stream, PROTO_MAX, &Message::SealRequest { handoff_id }).unwrap();
let (_v, sealed) = read_message(&mut stream).unwrap();
assert!(matches!(sealed, Message::SealComplete { .. }));
write_message(&mut stream, PROTO_MAX, &Message::Commit { handoff_id }).unwrap();
thread::sleep(Duration::from_millis(50));
}