use crate::types::Rank;
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize, Debug, Clone, PartialEq)]
pub enum NexarMessage {
Hello {
protocol_version: u16,
capabilities: u64,
cluster_token: Vec<u8>,
listen_addr: String,
},
Welcome {
rank: Rank,
world_size: u32,
peers: Vec<(Rank, String)>,
ca_cert: Vec<u8>,
node_cert: Vec<u8>,
node_key: Vec<u8>,
},
Barrier { epoch: u64, comm_id: u64 },
BarrierAck { epoch: u64, comm_id: u64 },
Heartbeat { timestamp_ns: u64 },
NodeJoined { rank: Rank, addr: String },
NodeLeft { rank: Rank },
Rpc {
req_id: u64,
fn_id: u16,
payload: Vec<u8>,
},
RpcResponse { req_id: u64, payload: Vec<u8> },
Data {
tag: u32,
src_rank: Rank,
payload: Vec<u8>,
},
RdmaEndpoint {
lid: u16,
qpn: u32,
psn: u32,
gid: Vec<u8>,
},
SplitRequest { color: u32, key: u32 },
RecoveryVote { epoch: u64, dead_ranks: Vec<Rank> },
RecoveryAgreement { epoch: u64, dead_ranks: Vec<Rank> },
ElasticCheckpoint { epoch: u64 },
ElasticCheckpointAck {
epoch: u64,
joining: Vec<(Rank, String)>,
leaving: Vec<Rank>,
new_world_size: u32,
},
Relay {
src_rank: Rank,
final_dest: Rank,
tag: u64,
payload: Vec<u8>,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hello_roundtrip() {
let msg = NexarMessage::Hello {
protocol_version: 1,
capabilities: 0xFF,
cluster_token: vec![],
listen_addr: String::new(),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&msg).unwrap();
let deserialized: NexarMessage =
rkyv::from_bytes::<NexarMessage, rkyv::rancor::Error>(&bytes).unwrap();
assert_eq!(msg, deserialized);
}
#[test]
fn test_welcome_roundtrip() {
let msg = NexarMessage::Welcome {
rank: 3,
world_size: 8,
peers: vec![(0, "127.0.0.1:5000".into()), (1, "127.0.0.1:5001".into())],
ca_cert: vec![1, 2, 3],
node_cert: vec![4, 5, 6],
node_key: vec![7, 8, 9],
};
let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&msg).unwrap();
let deserialized: NexarMessage =
rkyv::from_bytes::<NexarMessage, rkyv::rancor::Error>(&bytes).unwrap();
assert_eq!(msg, deserialized);
}
#[test]
fn test_all_variants_roundtrip() {
let messages = vec![
NexarMessage::Hello {
protocol_version: 1,
capabilities: 0,
cluster_token: vec![],
listen_addr: String::new(),
},
NexarMessage::Welcome {
rank: 0,
world_size: 1,
peers: vec![],
ca_cert: vec![10],
node_cert: vec![20],
node_key: vec![30],
},
NexarMessage::Barrier {
epoch: 42,
comm_id: 0,
},
NexarMessage::BarrierAck {
epoch: 42,
comm_id: 0,
},
NexarMessage::Heartbeat {
timestamp_ns: 123456789,
},
NexarMessage::NodeJoined {
rank: 5,
addr: "10.0.0.5:9000".into(),
},
NexarMessage::NodeLeft { rank: 2 },
NexarMessage::Rpc {
req_id: 1,
fn_id: 100,
payload: vec![1, 2, 3],
},
NexarMessage::RpcResponse {
req_id: 1,
payload: vec![4, 5, 6],
},
NexarMessage::Data {
tag: 7,
src_rank: 0,
payload: vec![0xFF; 64],
},
NexarMessage::RdmaEndpoint {
lid: 1,
qpn: 42,
psn: 100,
gid: vec![0; 16],
},
NexarMessage::SplitRequest { color: 0, key: 1 },
NexarMessage::RecoveryVote {
epoch: 1,
dead_ranks: vec![2, 3],
},
NexarMessage::RecoveryAgreement {
epoch: 1,
dead_ranks: vec![2, 3],
},
NexarMessage::ElasticCheckpoint { epoch: 5 },
NexarMessage::ElasticCheckpointAck {
epoch: 5,
joining: vec![(4, "127.0.0.1:9000".into())],
leaving: vec![2],
new_world_size: 4,
},
NexarMessage::Relay {
src_rank: 0,
final_dest: 5,
tag: 42,
payload: vec![1, 2, 3, 4],
},
];
for msg in messages {
let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&msg).unwrap();
let back: NexarMessage =
rkyv::from_bytes::<NexarMessage, rkyv::rancor::Error>(&bytes).unwrap();
assert_eq!(msg, back, "roundtrip failed for {msg:?}");
}
}
}