Skip to main content

nodedb_cluster/
rpc_codec.rs

1//! Raft RPC binary codec.
2//!
3//! Encodes/decodes all Raft RPC messages into a compact binary wire format
4//! using rkyv (zero-copy deserialization). Every frame includes a CRC32C
5//! integrity checksum and a version field for protocol evolution.
6//!
7//! Wire layout (8-byte header + payload):
8//!
9//! ```text
10//! ┌─────────┬──────────┬────────────┬──────────┬─────────────────────┐
11//! │ version │ rpc_type │ payload_len│ crc32c   │ rkyv payload bytes  │
12//! │  1 byte │  1 byte  │  4 bytes   │ 4 bytes  │  payload_len bytes  │
13//! └─────────┴──────────┴────────────┴──────────┴─────────────────────┘
14//! ```
15//!
16//! - `version`: Wire protocol version (currently `1`).
17//! - `rpc_type`: Discriminant for [`RaftRpc`] variant.
18//! - `payload_len`: Little-endian u32, byte count of the rkyv payload.
19//! - `crc32c`: CRC32C over the rkyv payload bytes only.
20
21use crate::error::{ClusterError, Result};
22use crate::wire::WIRE_VERSION;
23use nodedb_raft::message::{
24    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
25    RequestVoteRequest, RequestVoteResponse,
26};
27
28/// Header size in bytes: version(1) + rpc_type(1) + payload_len(4) + crc32c(4).
29pub const HEADER_SIZE: usize = 10;
30
31/// Maximum RPC message payload size (64 MiB). Distinct from WAL's MAX_WAL_PAYLOAD_SIZE.
32///
33/// Prevents degenerate allocations from corrupt frames.
34const MAX_RPC_PAYLOAD_SIZE: u32 = 64 * 1024 * 1024;
35
36/// RPC type discriminants.
37const RPC_APPEND_ENTRIES_REQ: u8 = 1;
38const RPC_APPEND_ENTRIES_RESP: u8 = 2;
39const RPC_REQUEST_VOTE_REQ: u8 = 3;
40const RPC_REQUEST_VOTE_RESP: u8 = 4;
41const RPC_INSTALL_SNAPSHOT_REQ: u8 = 5;
42const RPC_INSTALL_SNAPSHOT_RESP: u8 = 6;
43const RPC_JOIN_REQ: u8 = 7;
44const RPC_JOIN_RESP: u8 = 8;
45const RPC_PING: u8 = 9;
46const RPC_PONG: u8 = 10;
47const RPC_TOPOLOGY_UPDATE: u8 = 11;
48const RPC_TOPOLOGY_ACK: u8 = 12;
49const RPC_FORWARD_REQ: u8 = 13;
50const RPC_FORWARD_RESP: u8 = 14;
51const RPC_VSHARD_ENVELOPE: u8 = 15;
52
53// ── Cluster management wire types ───────────────────────────────────
54
55/// Forward a SQL query to the leader node for a vShard.
56///
57/// Used when a client connects to a non-leader node. The receiving node
58/// re-plans and executes the SQL locally against its Data Plane.
59#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
60pub struct ForwardRequest {
61    /// The SQL statement to execute.
62    pub sql: String,
63    /// Tenant ID (authenticated on the originating node, trusted here).
64    pub tenant_id: u32,
65    /// Milliseconds remaining until the client's deadline.
66    pub deadline_remaining_ms: u64,
67    /// Distributed trace ID for observability.
68    pub trace_id: u64,
69}
70
71/// Response to a forwarded SQL query.
72#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
73pub struct ForwardResponse {
74    /// True if the query succeeded.
75    pub success: bool,
76    /// Result payloads — one per result set produced by the query.
77    /// Each payload is the raw bytes from the Data Plane response.
78    pub payloads: Vec<Vec<u8>>,
79    /// Non-empty if success=false.
80    pub error_message: String,
81}
82
83/// Health check ping.
84#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
85pub struct PingRequest {
86    pub sender_id: u64,
87    /// Sender's current topology version — lets the responder detect staleness.
88    pub topology_version: u64,
89}
90
91/// Health check pong.
92#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
93pub struct PongResponse {
94    pub responder_id: u64,
95    pub topology_version: u64,
96}
97
98/// Push topology update to a peer.
99#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
100pub struct TopologyUpdate {
101    pub version: u64,
102    pub nodes: Vec<JoinNodeInfo>,
103}
104
105/// Acknowledgement of a topology update.
106#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
107pub struct TopologyAck {
108    pub responder_id: u64,
109    pub accepted_version: u64,
110}
111
112/// Request to join an existing cluster.
113#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
114pub struct JoinRequest {
115    pub node_id: u64,
116    /// Listen address for Raft RPCs (e.g. "10.0.0.5:9400").
117    pub listen_addr: String,
118}
119
120/// Response to a join request — carries full cluster state.
121#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
122pub struct JoinResponse {
123    pub success: bool,
124    pub error: String,
125    /// All nodes in the cluster.
126    pub nodes: Vec<JoinNodeInfo>,
127    /// vShard → Raft group mapping (1024 entries).
128    pub vshard_to_group: Vec<u64>,
129    /// Raft group membership.
130    pub groups: Vec<JoinGroupInfo>,
131}
132
133/// Node info in the join response wire format.
134#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
135pub struct JoinNodeInfo {
136    pub node_id: u64,
137    pub addr: String,
138    /// NodeState as u8 (0=Joining, 1=Active, 2=Draining, 3=Decommissioned).
139    pub state: u8,
140    pub raft_groups: Vec<u64>,
141}
142
143/// Raft group membership in the join response wire format.
144#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
145pub struct JoinGroupInfo {
146    pub group_id: u64,
147    pub leader: u64,
148    pub members: Vec<u64>,
149}
150
151// ── RPC enum ────────────────────────────────────────────────────────
152
153/// An RPC message — Raft consensus or cluster management.
154#[derive(Debug, Clone)]
155pub enum RaftRpc {
156    // Raft consensus
157    AppendEntriesRequest(AppendEntriesRequest),
158    AppendEntriesResponse(AppendEntriesResponse),
159    RequestVoteRequest(RequestVoteRequest),
160    RequestVoteResponse(RequestVoteResponse),
161    InstallSnapshotRequest(InstallSnapshotRequest),
162    InstallSnapshotResponse(InstallSnapshotResponse),
163    // Cluster management
164    JoinRequest(JoinRequest),
165    JoinResponse(JoinResponse),
166    // Health check
167    Ping(PingRequest),
168    Pong(PongResponse),
169    // Topology broadcast
170    TopologyUpdate(TopologyUpdate),
171    TopologyAck(TopologyAck),
172    // Query forwarding
173    ForwardRequest(ForwardRequest),
174    ForwardResponse(ForwardResponse),
175    // VShardEnvelope — carries graph BSP, timeseries scatter-gather, migration,
176    // retention, and archival messages. The inner VShardMessageType determines
177    // the handler.
178    VShardEnvelope(Vec<u8>), // Serialized VShardEnvelope bytes.
179}
180
181impl RaftRpc {
182    fn rpc_type(&self) -> u8 {
183        match self {
184            Self::AppendEntriesRequest(_) => RPC_APPEND_ENTRIES_REQ,
185            Self::AppendEntriesResponse(_) => RPC_APPEND_ENTRIES_RESP,
186            Self::RequestVoteRequest(_) => RPC_REQUEST_VOTE_REQ,
187            Self::RequestVoteResponse(_) => RPC_REQUEST_VOTE_RESP,
188            Self::InstallSnapshotRequest(_) => RPC_INSTALL_SNAPSHOT_REQ,
189            Self::InstallSnapshotResponse(_) => RPC_INSTALL_SNAPSHOT_RESP,
190            Self::JoinRequest(_) => RPC_JOIN_REQ,
191            Self::JoinResponse(_) => RPC_JOIN_RESP,
192            Self::Ping(_) => RPC_PING,
193            Self::Pong(_) => RPC_PONG,
194            Self::TopologyUpdate(_) => RPC_TOPOLOGY_UPDATE,
195            Self::TopologyAck(_) => RPC_TOPOLOGY_ACK,
196            Self::ForwardRequest(_) => RPC_FORWARD_REQ,
197            Self::ForwardResponse(_) => RPC_FORWARD_RESP,
198            Self::VShardEnvelope(_) => RPC_VSHARD_ENVELOPE,
199        }
200    }
201}
202
203/// Encode a [`RaftRpc`] into a framed binary message.
204pub fn encode(rpc: &RaftRpc) -> Result<Vec<u8>> {
205    let payload = serialize_payload(rpc)?;
206    let payload_len: u32 = payload.len().try_into().map_err(|_| ClusterError::Codec {
207        detail: format!("payload too large: {} bytes", payload.len()),
208    })?;
209
210    let crc = crc32c::crc32c(&payload);
211
212    let mut frame = Vec::with_capacity(HEADER_SIZE + payload.len());
213    // Version field is 1 byte on the wire (see header diagram); narrowing cast is intentional.
214    frame.push(WIRE_VERSION as u8);
215    frame.push(rpc.rpc_type());
216    frame.extend_from_slice(&payload_len.to_le_bytes());
217    frame.extend_from_slice(&crc.to_le_bytes());
218    frame.extend_from_slice(&payload);
219
220    Ok(frame)
221}
222
223/// Decode a framed binary message into a [`RaftRpc`].
224pub fn decode(data: &[u8]) -> Result<RaftRpc> {
225    if data.len() < HEADER_SIZE {
226        return Err(ClusterError::Codec {
227            detail: format!("frame too short: {} bytes, need {HEADER_SIZE}", data.len()),
228        });
229    }
230
231    let version = data[0];
232    if version != WIRE_VERSION as u8 {
233        return Err(ClusterError::Codec {
234            detail: format!("unsupported wire version: {version}, expected {WIRE_VERSION}"),
235        });
236    }
237
238    let rpc_type = data[1];
239    let payload_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]);
240    let expected_crc = u32::from_le_bytes([data[6], data[7], data[8], data[9]]);
241
242    if payload_len > MAX_RPC_PAYLOAD_SIZE {
243        return Err(ClusterError::Codec {
244            detail: format!("payload length {payload_len} exceeds maximum {MAX_RPC_PAYLOAD_SIZE}"),
245        });
246    }
247
248    let expected_total = HEADER_SIZE + payload_len as usize;
249    if data.len() < expected_total {
250        return Err(ClusterError::Codec {
251            detail: format!(
252                "frame truncated: got {} bytes, expected {expected_total}",
253                data.len()
254            ),
255        });
256    }
257
258    let payload = &data[HEADER_SIZE..expected_total];
259
260    let actual_crc = crc32c::crc32c(payload);
261    if actual_crc != expected_crc {
262        return Err(ClusterError::Codec {
263            detail: format!(
264                "CRC32C mismatch: expected {expected_crc:#010x}, got {actual_crc:#010x}"
265            ),
266        });
267    }
268
269    deserialize_payload(rpc_type, payload)
270}
271
272/// Return the total frame size for a buffer that starts with a valid header.
273/// Useful for stream framing — read the header, then read the remaining payload.
274pub fn frame_size(header: &[u8; HEADER_SIZE]) -> Result<usize> {
275    let payload_len = u32::from_le_bytes([header[2], header[3], header[4], header[5]]);
276    if payload_len > MAX_RPC_PAYLOAD_SIZE {
277        return Err(ClusterError::Codec {
278            detail: format!("payload length {payload_len} exceeds maximum {MAX_RPC_PAYLOAD_SIZE}"),
279        });
280    }
281    Ok(HEADER_SIZE + payload_len as usize)
282}
283
284// ── Serialization helpers ───────────────────────────────────────────
285
286fn serialize_payload(rpc: &RaftRpc) -> Result<Vec<u8>> {
287    let bytes = match rpc {
288        RaftRpc::AppendEntriesRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
289        RaftRpc::AppendEntriesResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
290        RaftRpc::RequestVoteRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
291        RaftRpc::RequestVoteResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
292        RaftRpc::InstallSnapshotRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
293        RaftRpc::InstallSnapshotResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
294        RaftRpc::JoinRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
295        RaftRpc::JoinResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
296        RaftRpc::Ping(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
297        RaftRpc::Pong(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
298        RaftRpc::TopologyUpdate(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
299        RaftRpc::TopologyAck(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
300        RaftRpc::ForwardRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
301        RaftRpc::ForwardResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
302        RaftRpc::VShardEnvelope(bytes) => return Ok(bytes.clone()), // Already serialized.
303    };
304    bytes.map(|b| b.to_vec()).map_err(|e| ClusterError::Codec {
305        detail: format!("rkyv serialize failed: {e}"),
306    })
307}
308
309fn deserialize_payload(rpc_type: u8, payload: &[u8]) -> Result<RaftRpc> {
310    // rkyv requires aligned data for zero-copy access. Network-received slices
311    // are not guaranteed to be aligned, so copy into an AlignedVec first.
312    let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(payload.len());
313    aligned.extend_from_slice(payload);
314
315    match rpc_type {
316        RPC_APPEND_ENTRIES_REQ => {
317            let msg = rkyv::from_bytes::<AppendEntriesRequest, rkyv::rancor::Error>(&aligned)
318                .map_err(|e| ClusterError::Codec {
319                    detail: format!("rkyv deserialize AppendEntriesRequest: {e}"),
320                })?;
321            Ok(RaftRpc::AppendEntriesRequest(msg))
322        }
323        RPC_APPEND_ENTRIES_RESP => {
324            let msg = rkyv::from_bytes::<AppendEntriesResponse, rkyv::rancor::Error>(&aligned)
325                .map_err(|e| ClusterError::Codec {
326                    detail: format!("rkyv deserialize AppendEntriesResponse: {e}"),
327                })?;
328            Ok(RaftRpc::AppendEntriesResponse(msg))
329        }
330        RPC_REQUEST_VOTE_REQ => {
331            let msg = rkyv::from_bytes::<RequestVoteRequest, rkyv::rancor::Error>(&aligned)
332                .map_err(|e| ClusterError::Codec {
333                    detail: format!("rkyv deserialize RequestVoteRequest: {e}"),
334                })?;
335            Ok(RaftRpc::RequestVoteRequest(msg))
336        }
337        RPC_REQUEST_VOTE_RESP => {
338            let msg = rkyv::from_bytes::<RequestVoteResponse, rkyv::rancor::Error>(&aligned)
339                .map_err(|e| ClusterError::Codec {
340                    detail: format!("rkyv deserialize RequestVoteResponse: {e}"),
341                })?;
342            Ok(RaftRpc::RequestVoteResponse(msg))
343        }
344        RPC_INSTALL_SNAPSHOT_REQ => {
345            let msg = rkyv::from_bytes::<InstallSnapshotRequest, rkyv::rancor::Error>(&aligned)
346                .map_err(|e| ClusterError::Codec {
347                    detail: format!("rkyv deserialize InstallSnapshotRequest: {e}"),
348                })?;
349            Ok(RaftRpc::InstallSnapshotRequest(msg))
350        }
351        RPC_INSTALL_SNAPSHOT_RESP => {
352            let msg = rkyv::from_bytes::<InstallSnapshotResponse, rkyv::rancor::Error>(&aligned)
353                .map_err(|e| ClusterError::Codec {
354                    detail: format!("rkyv deserialize InstallSnapshotResponse: {e}"),
355                })?;
356            Ok(RaftRpc::InstallSnapshotResponse(msg))
357        }
358        RPC_JOIN_REQ => {
359            let msg =
360                rkyv::from_bytes::<JoinRequest, rkyv::rancor::Error>(&aligned).map_err(|e| {
361                    ClusterError::Codec {
362                        detail: format!("rkyv deserialize JoinRequest: {e}"),
363                    }
364                })?;
365            Ok(RaftRpc::JoinRequest(msg))
366        }
367        RPC_JOIN_RESP => {
368            let msg =
369                rkyv::from_bytes::<JoinResponse, rkyv::rancor::Error>(&aligned).map_err(|e| {
370                    ClusterError::Codec {
371                        detail: format!("rkyv deserialize JoinResponse: {e}"),
372                    }
373                })?;
374            Ok(RaftRpc::JoinResponse(msg))
375        }
376        RPC_PING => {
377            let msg =
378                rkyv::from_bytes::<PingRequest, rkyv::rancor::Error>(&aligned).map_err(|e| {
379                    ClusterError::Codec {
380                        detail: format!("rkyv deserialize PingRequest: {e}"),
381                    }
382                })?;
383            Ok(RaftRpc::Ping(msg))
384        }
385        RPC_PONG => {
386            let msg =
387                rkyv::from_bytes::<PongResponse, rkyv::rancor::Error>(&aligned).map_err(|e| {
388                    ClusterError::Codec {
389                        detail: format!("rkyv deserialize PongResponse: {e}"),
390                    }
391                })?;
392            Ok(RaftRpc::Pong(msg))
393        }
394        RPC_TOPOLOGY_UPDATE => {
395            let msg =
396                rkyv::from_bytes::<TopologyUpdate, rkyv::rancor::Error>(&aligned).map_err(|e| {
397                    ClusterError::Codec {
398                        detail: format!("rkyv deserialize TopologyUpdate: {e}"),
399                    }
400                })?;
401            Ok(RaftRpc::TopologyUpdate(msg))
402        }
403        RPC_TOPOLOGY_ACK => {
404            let msg =
405                rkyv::from_bytes::<TopologyAck, rkyv::rancor::Error>(&aligned).map_err(|e| {
406                    ClusterError::Codec {
407                        detail: format!("rkyv deserialize TopologyAck: {e}"),
408                    }
409                })?;
410            Ok(RaftRpc::TopologyAck(msg))
411        }
412        RPC_FORWARD_REQ => {
413            let msg =
414                rkyv::from_bytes::<ForwardRequest, rkyv::rancor::Error>(&aligned).map_err(|e| {
415                    ClusterError::Codec {
416                        detail: format!("rkyv deserialize ForwardRequest: {e}"),
417                    }
418                })?;
419            Ok(RaftRpc::ForwardRequest(msg))
420        }
421        RPC_FORWARD_RESP => {
422            let msg = rkyv::from_bytes::<ForwardResponse, rkyv::rancor::Error>(&aligned).map_err(
423                |e| ClusterError::Codec {
424                    detail: format!("rkyv deserialize ForwardResponse: {e}"),
425                },
426            )?;
427            Ok(RaftRpc::ForwardResponse(msg))
428        }
429        RPC_VSHARD_ENVELOPE => {
430            // VShardEnvelope is already in its own binary format — pass through raw.
431            Ok(RaftRpc::VShardEnvelope(payload.to_vec()))
432        }
433        _ => Err(ClusterError::Codec {
434            detail: format!("unknown rpc_type: {rpc_type}"),
435        }),
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use nodedb_raft::message::LogEntry;
443
444    #[test]
445    fn roundtrip_append_entries_request() {
446        let req = AppendEntriesRequest {
447            term: 5,
448            leader_id: 1,
449            prev_log_index: 99,
450            prev_log_term: 4,
451            entries: vec![
452                LogEntry {
453                    term: 5,
454                    index: 100,
455                    data: b"put x=1".to_vec(),
456                },
457                LogEntry {
458                    term: 5,
459                    index: 101,
460                    data: b"put y=2".to_vec(),
461                },
462            ],
463            leader_commit: 98,
464            group_id: 7,
465        };
466
467        let rpc = RaftRpc::AppendEntriesRequest(req.clone());
468        let encoded = encode(&rpc).unwrap();
469        let decoded = decode(&encoded).unwrap();
470
471        match decoded {
472            RaftRpc::AppendEntriesRequest(d) => {
473                assert_eq!(d.term, req.term);
474                assert_eq!(d.leader_id, req.leader_id);
475                assert_eq!(d.prev_log_index, req.prev_log_index);
476                assert_eq!(d.prev_log_term, req.prev_log_term);
477                assert_eq!(d.entries.len(), 2);
478                assert_eq!(d.entries[0].data, b"put x=1");
479                assert_eq!(d.entries[1].data, b"put y=2");
480                assert_eq!(d.leader_commit, req.leader_commit);
481                assert_eq!(d.group_id, req.group_id);
482            }
483            other => panic!("expected AppendEntriesRequest, got {other:?}"),
484        }
485    }
486
487    #[test]
488    fn roundtrip_append_entries_heartbeat() {
489        let req = AppendEntriesRequest {
490            term: 3,
491            leader_id: 1,
492            prev_log_index: 10,
493            prev_log_term: 2,
494            entries: vec![],
495            leader_commit: 8,
496            group_id: 0,
497        };
498
499        let rpc = RaftRpc::AppendEntriesRequest(req);
500        let encoded = encode(&rpc).unwrap();
501        let decoded = decode(&encoded).unwrap();
502
503        match decoded {
504            RaftRpc::AppendEntriesRequest(d) => {
505                assert!(d.entries.is_empty());
506                assert_eq!(d.term, 3);
507            }
508            other => panic!("expected heartbeat, got {other:?}"),
509        }
510    }
511
512    #[test]
513    fn roundtrip_append_entries_response() {
514        let resp = AppendEntriesResponse {
515            term: 5,
516            success: true,
517            last_log_index: 100,
518        };
519
520        let rpc = RaftRpc::AppendEntriesResponse(resp);
521        let encoded = encode(&rpc).unwrap();
522        let decoded = decode(&encoded).unwrap();
523
524        match decoded {
525            RaftRpc::AppendEntriesResponse(d) => {
526                assert_eq!(d.term, 5);
527                assert!(d.success);
528                assert_eq!(d.last_log_index, 100);
529            }
530            other => panic!("expected AppendEntriesResponse, got {other:?}"),
531        }
532    }
533
534    #[test]
535    fn roundtrip_request_vote_request() {
536        let req = RequestVoteRequest {
537            term: 10,
538            candidate_id: 3,
539            last_log_index: 200,
540            last_log_term: 9,
541            group_id: 42,
542        };
543
544        let rpc = RaftRpc::RequestVoteRequest(req);
545        let encoded = encode(&rpc).unwrap();
546        let decoded = decode(&encoded).unwrap();
547
548        match decoded {
549            RaftRpc::RequestVoteRequest(d) => {
550                assert_eq!(d.term, 10);
551                assert_eq!(d.candidate_id, 3);
552                assert_eq!(d.last_log_index, 200);
553                assert_eq!(d.last_log_term, 9);
554                assert_eq!(d.group_id, 42);
555            }
556            other => panic!("expected RequestVoteRequest, got {other:?}"),
557        }
558    }
559
560    #[test]
561    fn roundtrip_request_vote_response() {
562        let resp = RequestVoteResponse {
563            term: 10,
564            vote_granted: true,
565        };
566
567        let rpc = RaftRpc::RequestVoteResponse(resp);
568        let encoded = encode(&rpc).unwrap();
569        let decoded = decode(&encoded).unwrap();
570
571        match decoded {
572            RaftRpc::RequestVoteResponse(d) => {
573                assert_eq!(d.term, 10);
574                assert!(d.vote_granted);
575            }
576            other => panic!("expected RequestVoteResponse, got {other:?}"),
577        }
578    }
579
580    #[test]
581    fn roundtrip_install_snapshot_request() {
582        let data: Vec<u8> = [0xDE, 0xAD, 0xBE, 0xEF]
583            .iter()
584            .copied()
585            .cycle()
586            .take(1024)
587            .collect();
588        let req = InstallSnapshotRequest {
589            term: 7,
590            leader_id: 1,
591            last_included_index: 500,
592            last_included_term: 6,
593            offset: 0,
594            data: data.clone(),
595            done: false,
596            group_id: 3,
597        };
598
599        let rpc = RaftRpc::InstallSnapshotRequest(req);
600        let encoded = encode(&rpc).unwrap();
601        let decoded = decode(&encoded).unwrap();
602
603        match decoded {
604            RaftRpc::InstallSnapshotRequest(d) => {
605                assert_eq!(d.term, 7);
606                assert_eq!(d.leader_id, 1);
607                assert_eq!(d.last_included_index, 500);
608                assert_eq!(d.last_included_term, 6);
609                assert_eq!(d.offset, 0);
610                assert_eq!(d.data, data);
611                assert!(!d.done);
612                assert_eq!(d.group_id, 3);
613            }
614            other => panic!("expected InstallSnapshotRequest, got {other:?}"),
615        }
616    }
617
618    #[test]
619    fn roundtrip_install_snapshot_final_chunk() {
620        let req = InstallSnapshotRequest {
621            term: 7,
622            leader_id: 1,
623            last_included_index: 500,
624            last_included_term: 6,
625            offset: 4096,
626            data: vec![0xFF; 128],
627            done: true,
628            group_id: 3,
629        };
630
631        let rpc = RaftRpc::InstallSnapshotRequest(req);
632        let encoded = encode(&rpc).unwrap();
633        let decoded = decode(&encoded).unwrap();
634
635        match decoded {
636            RaftRpc::InstallSnapshotRequest(d) => {
637                assert!(d.done);
638                assert_eq!(d.offset, 4096);
639            }
640            other => panic!("expected InstallSnapshotRequest, got {other:?}"),
641        }
642    }
643
644    #[test]
645    fn roundtrip_install_snapshot_response() {
646        let resp = InstallSnapshotResponse { term: 7 };
647
648        let rpc = RaftRpc::InstallSnapshotResponse(resp);
649        let encoded = encode(&rpc).unwrap();
650        let decoded = decode(&encoded).unwrap();
651
652        match decoded {
653            RaftRpc::InstallSnapshotResponse(d) => {
654                assert_eq!(d.term, 7);
655            }
656            other => panic!("expected InstallSnapshotResponse, got {other:?}"),
657        }
658    }
659
660    #[test]
661    fn crc_corruption_detected() {
662        let rpc = RaftRpc::RequestVoteResponse(RequestVoteResponse {
663            term: 1,
664            vote_granted: false,
665        });
666        let mut encoded = encode(&rpc).unwrap();
667
668        // Flip a bit in the payload.
669        if let Some(last) = encoded.last_mut() {
670            *last ^= 0x01;
671        }
672
673        let err = decode(&encoded).unwrap_err();
674        assert!(err.to_string().contains("CRC32C mismatch"), "{err}");
675    }
676
677    #[test]
678    fn version_mismatch_rejected() {
679        let rpc = RaftRpc::RequestVoteResponse(RequestVoteResponse {
680            term: 1,
681            vote_granted: false,
682        });
683        let mut encoded = encode(&rpc).unwrap();
684
685        // Set version to 99.
686        encoded[0] = 99;
687
688        let err = decode(&encoded).unwrap_err();
689        assert!(
690            err.to_string().contains("unsupported wire version"),
691            "{err}"
692        );
693    }
694
695    #[test]
696    fn truncated_frame_rejected() {
697        let err = decode(&[1, 2, 3]).unwrap_err();
698        assert!(err.to_string().contains("frame too short"), "{err}");
699    }
700
701    #[test]
702    fn unknown_rpc_type_rejected() {
703        let rpc = RaftRpc::RequestVoteResponse(RequestVoteResponse {
704            term: 1,
705            vote_granted: false,
706        });
707        let mut encoded = encode(&rpc).unwrap();
708
709        // Set rpc_type to 255.
710        encoded[1] = 255;
711
712        // CRC will mismatch because we didn't change payload — but the rpc_type
713        // byte is in the header, not covered by CRC. The decode will fail on
714        // unknown rpc_type after CRC passes. Actually, CRC only covers payload,
715        // so the type corruption is caught by the type discriminant check.
716        // However, the CRC is still valid (payload unchanged), so we get the
717        // unknown type error.
718        let err = decode(&encoded).unwrap_err();
719        assert!(err.to_string().contains("unknown rpc_type"), "{err}");
720    }
721
722    #[test]
723    fn payload_too_large_rejected() {
724        // Craft a header claiming a massive payload.
725        let mut frame = vec![0u8; HEADER_SIZE];
726        frame[0] = WIRE_VERSION as u8;
727        frame[1] = RPC_APPEND_ENTRIES_REQ;
728        let huge: u32 = MAX_RPC_PAYLOAD_SIZE + 1;
729        frame[2..6].copy_from_slice(&huge.to_le_bytes());
730
731        let err = decode(&frame).unwrap_err();
732        assert!(err.to_string().contains("exceeds maximum"), "{err}");
733    }
734
735    #[test]
736    fn frame_size_helper() {
737        let rpc = RaftRpc::AppendEntriesResponse(AppendEntriesResponse {
738            term: 1,
739            success: true,
740            last_log_index: 5,
741        });
742        let encoded = encode(&rpc).unwrap();
743
744        let header: [u8; HEADER_SIZE] = encoded[..HEADER_SIZE].try_into().unwrap();
745        let size = frame_size(&header).unwrap();
746        assert_eq!(size, encoded.len());
747    }
748
749    #[test]
750    fn large_snapshot_roundtrip() {
751        // 1 MiB snapshot chunk.
752        let data = vec![0xAB; 1024 * 1024];
753        let req = InstallSnapshotRequest {
754            term: 100,
755            leader_id: 5,
756            last_included_index: 999_999,
757            last_included_term: 99,
758            offset: 0,
759            data: data.clone(),
760            done: false,
761            group_id: 0,
762        };
763
764        let rpc = RaftRpc::InstallSnapshotRequest(req);
765        let encoded = encode(&rpc).unwrap();
766        let decoded = decode(&encoded).unwrap();
767
768        match decoded {
769            RaftRpc::InstallSnapshotRequest(d) => {
770                assert_eq!(d.data.len(), 1024 * 1024);
771                assert_eq!(d.data, data);
772            }
773            other => panic!("expected InstallSnapshotRequest, got {other:?}"),
774        }
775    }
776
777    #[test]
778    fn roundtrip_join_request() {
779        let req = JoinRequest {
780            node_id: 42,
781            listen_addr: "10.0.0.5:9400".into(),
782        };
783
784        let rpc = RaftRpc::JoinRequest(req);
785        let encoded = encode(&rpc).unwrap();
786        let decoded = decode(&encoded).unwrap();
787
788        match decoded {
789            RaftRpc::JoinRequest(d) => {
790                assert_eq!(d.node_id, 42);
791                assert_eq!(d.listen_addr, "10.0.0.5:9400");
792            }
793            other => panic!("expected JoinRequest, got {other:?}"),
794        }
795    }
796
797    #[test]
798    fn roundtrip_join_response() {
799        let resp = JoinResponse {
800            success: true,
801            error: String::new(),
802            nodes: vec![
803                JoinNodeInfo {
804                    node_id: 1,
805                    addr: "10.0.0.1:9400".into(),
806                    state: 1,
807                    raft_groups: vec![0, 1],
808                },
809                JoinNodeInfo {
810                    node_id: 2,
811                    addr: "10.0.0.2:9400".into(),
812                    state: 1,
813                    raft_groups: vec![0, 1],
814                },
815            ],
816            vshard_to_group: (0..1024u64).map(|i| i % 4).collect(),
817            groups: vec![JoinGroupInfo {
818                group_id: 0,
819                leader: 1,
820                members: vec![1, 2],
821            }],
822        };
823
824        let rpc = RaftRpc::JoinResponse(resp);
825        let encoded = encode(&rpc).unwrap();
826        let decoded = decode(&encoded).unwrap();
827
828        match decoded {
829            RaftRpc::JoinResponse(d) => {
830                assert!(d.success);
831                assert_eq!(d.nodes.len(), 2);
832                assert_eq!(d.vshard_to_group.len(), 1024);
833                assert_eq!(d.groups.len(), 1);
834                assert_eq!(d.groups[0].leader, 1);
835            }
836            other => panic!("expected JoinResponse, got {other:?}"),
837        }
838    }
839}