nexar 0.1.2

Distributed runtime with QUIC transport, stream-multiplexed messaging, and built-in collectives
Documentation
use crate::types::Rank;

/// Control messages exchanged between nexar nodes.
///
/// Tensor data does NOT flow through this enum. Bulk tensor transfers use
/// dedicated QUIC unidirectional streams with a minimal binary header
/// followed by raw bytes — avoiding rkyv overhead on large payloads.
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize, Debug, Clone, PartialEq)]
pub enum NexarMessage {
    /// Initial handshake from worker to seed.
    Hello {
        protocol_version: u16,
        capabilities: u64,
        /// Pre-shared cluster token for bootstrap authentication.
        /// Empty = no authentication required.
        cluster_token: Vec<u8>,
        /// Mesh listener address advertised by this node.
        /// Empty string means seed should use `conn.remote_address()` as fallback.
        listen_addr: String,
    },

    /// Seed's response with rank assignment, peer list, and mTLS credentials.
    Welcome {
        rank: Rank,
        world_size: u32,
        /// `(rank, socket_addr_string)` for each peer.
        peers: Vec<(Rank, String)>,
        /// DER-encoded cluster CA certificate (trust anchor for mesh mTLS).
        ca_cert: Vec<u8>,
        /// DER-encoded leaf certificate for this node, signed by the cluster CA.
        node_cert: Vec<u8>,
        /// DER-encoded private key for this node's leaf certificate.
        node_key: Vec<u8>,
    },

    /// Barrier request: all ranks must reach this epoch before proceeding.
    Barrier { epoch: u64, comm_id: u64 },

    /// Barrier acknowledgement from coordinator.
    BarrierAck { epoch: u64, comm_id: u64 },

    /// Periodic heartbeat for failure detection.
    Heartbeat { timestamp_ns: u64 },

    /// Notification that a new node has joined the cluster.
    NodeJoined { rank: Rank, addr: String },

    /// Notification that a node has left (or been detected as failed).
    NodeLeft { rank: Rank },

    /// Remote procedure call request.
    Rpc {
        req_id: u64,
        fn_id: u16,
        payload: Vec<u8>,
    },

    /// Response to an RPC request.
    RpcResponse { req_id: u64, payload: Vec<u8> },

    /// P2P data envelope for tagged point-to-point messaging.
    Data {
        tag: u32,
        src_rank: Rank,
        payload: Vec<u8>,
    },

    /// RDMA endpoint exchange for mesh formation (feature = "rdma").
    /// Carries the IB QP endpoint info needed for RDMA handshake.
    RdmaEndpoint {
        /// Local ID (LID) of the HCA port.
        lid: u16,
        /// Queue Pair Number.
        qpn: u32,
        /// Packet Sequence Number.
        psn: u32,
        /// Global ID (GID) as 16 bytes (IPv6 format).
        gid: Vec<u8>,
    },

    /// Communicator split request: carries (color, key) for group formation.
    SplitRequest { color: u32, key: u32 },

    /// Recovery vote: a survivor sends its view of dead ranks to the leader.
    RecoveryVote { epoch: u64, dead_ranks: Vec<Rank> },

    /// Recovery agreement: the leader broadcasts the agreed dead set to all survivors.
    RecoveryAgreement { epoch: u64, dead_ranks: Vec<Rank> },

    /// Elastic checkpoint: this rank is at a safe point for resize.
    /// New nodes send this to rank 0 after connecting to all existing peers.
    /// Existing nodes send this when the training loop calls elastic_checkpoint().
    ElasticCheckpoint { epoch: u64 },

    /// Elastic checkpoint ack: sent by rank 0 when all ranks (old + new) checked in.
    /// Contains the full delta: who's joining, who's leaving, new world size.
    ElasticCheckpointAck {
        epoch: u64,
        joining: Vec<(Rank, String)>,
        leaving: Vec<Rank>,
        new_world_size: u32,
    },

    /// Relay message for sparse topologies: forwarded through intermediate hops.
    ///
    /// `tag == 0` means the payload is a serialized `NexarMessage` (control).
    /// `tag > 0` means the payload is raw collective data for that tag.
    Relay {
        src_rank: Rank,
        final_dest: Rank,
        tag: u64,
        payload: Vec<u8>,
    },
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_hello_roundtrip() {
        let msg = NexarMessage::Hello {
            protocol_version: 1,
            capabilities: 0xFF,
            cluster_token: vec![],
            listen_addr: String::new(),
        };
        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&msg).unwrap();
        let deserialized: NexarMessage =
            rkyv::from_bytes::<NexarMessage, rkyv::rancor::Error>(&bytes).unwrap();
        assert_eq!(msg, deserialized);
    }

    #[test]
    fn test_welcome_roundtrip() {
        let msg = NexarMessage::Welcome {
            rank: 3,
            world_size: 8,
            peers: vec![(0, "127.0.0.1:5000".into()), (1, "127.0.0.1:5001".into())],
            ca_cert: vec![1, 2, 3],
            node_cert: vec![4, 5, 6],
            node_key: vec![7, 8, 9],
        };
        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&msg).unwrap();
        let deserialized: NexarMessage =
            rkyv::from_bytes::<NexarMessage, rkyv::rancor::Error>(&bytes).unwrap();
        assert_eq!(msg, deserialized);
    }

    #[test]
    fn test_all_variants_roundtrip() {
        let messages = vec![
            NexarMessage::Hello {
                protocol_version: 1,
                capabilities: 0,
                cluster_token: vec![],
                listen_addr: String::new(),
            },
            NexarMessage::Welcome {
                rank: 0,
                world_size: 1,
                peers: vec![],
                ca_cert: vec![10],
                node_cert: vec![20],
                node_key: vec![30],
            },
            NexarMessage::Barrier {
                epoch: 42,
                comm_id: 0,
            },
            NexarMessage::BarrierAck {
                epoch: 42,
                comm_id: 0,
            },
            NexarMessage::Heartbeat {
                timestamp_ns: 123456789,
            },
            NexarMessage::NodeJoined {
                rank: 5,
                addr: "10.0.0.5:9000".into(),
            },
            NexarMessage::NodeLeft { rank: 2 },
            NexarMessage::Rpc {
                req_id: 1,
                fn_id: 100,
                payload: vec![1, 2, 3],
            },
            NexarMessage::RpcResponse {
                req_id: 1,
                payload: vec![4, 5, 6],
            },
            NexarMessage::Data {
                tag: 7,
                src_rank: 0,
                payload: vec![0xFF; 64],
            },
            NexarMessage::RdmaEndpoint {
                lid: 1,
                qpn: 42,
                psn: 100,
                gid: vec![0; 16],
            },
            NexarMessage::SplitRequest { color: 0, key: 1 },
            NexarMessage::RecoveryVote {
                epoch: 1,
                dead_ranks: vec![2, 3],
            },
            NexarMessage::RecoveryAgreement {
                epoch: 1,
                dead_ranks: vec![2, 3],
            },
            NexarMessage::ElasticCheckpoint { epoch: 5 },
            NexarMessage::ElasticCheckpointAck {
                epoch: 5,
                joining: vec![(4, "127.0.0.1:9000".into())],
                leaving: vec![2],
                new_world_size: 4,
            },
            NexarMessage::Relay {
                src_rank: 0,
                final_dest: 5,
                tag: 42,
                payload: vec![1, 2, 3, 4],
            },
        ];

        for msg in messages {
            let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&msg).unwrap();
            let back: NexarMessage =
                rkyv::from_bytes::<NexarMessage, rkyv::rancor::Error>(&bytes).unwrap();
            assert_eq!(msg, back, "roundtrip failed for {msg:?}");
        }
    }
}