1use crate::error::{ClusterError, Result};
22use crate::wire::WIRE_VERSION;
23use nodedb_raft::message::{
24 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
25 RequestVoteRequest, RequestVoteResponse,
26};
27
28pub const HEADER_SIZE: usize = 10;
30
31const MAX_RPC_PAYLOAD_SIZE: u32 = 64 * 1024 * 1024;
35
36const RPC_APPEND_ENTRIES_REQ: u8 = 1;
38const RPC_APPEND_ENTRIES_RESP: u8 = 2;
39const RPC_REQUEST_VOTE_REQ: u8 = 3;
40const RPC_REQUEST_VOTE_RESP: u8 = 4;
41const RPC_INSTALL_SNAPSHOT_REQ: u8 = 5;
42const RPC_INSTALL_SNAPSHOT_RESP: u8 = 6;
43const RPC_JOIN_REQ: u8 = 7;
44const RPC_JOIN_RESP: u8 = 8;
45const RPC_PING: u8 = 9;
46const RPC_PONG: u8 = 10;
47const RPC_TOPOLOGY_UPDATE: u8 = 11;
48const RPC_TOPOLOGY_ACK: u8 = 12;
49const RPC_FORWARD_REQ: u8 = 13;
50const RPC_FORWARD_RESP: u8 = 14;
51const RPC_VSHARD_ENVELOPE: u8 = 15;
52
53#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
60pub struct ForwardRequest {
61 pub sql: String,
63 pub tenant_id: u32,
65 pub deadline_remaining_ms: u64,
67 pub trace_id: u64,
69}
70
71#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
73pub struct ForwardResponse {
74 pub success: bool,
76 pub payloads: Vec<Vec<u8>>,
79 pub error_message: String,
81}
82
83#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
85pub struct PingRequest {
86 pub sender_id: u64,
87 pub topology_version: u64,
89}
90
91#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
93pub struct PongResponse {
94 pub responder_id: u64,
95 pub topology_version: u64,
96}
97
98#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
100pub struct TopologyUpdate {
101 pub version: u64,
102 pub nodes: Vec<JoinNodeInfo>,
103}
104
105#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
107pub struct TopologyAck {
108 pub responder_id: u64,
109 pub accepted_version: u64,
110}
111
112#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
114pub struct JoinRequest {
115 pub node_id: u64,
116 pub listen_addr: String,
118}
119
120#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
122pub struct JoinResponse {
123 pub success: bool,
124 pub error: String,
125 pub nodes: Vec<JoinNodeInfo>,
127 pub vshard_to_group: Vec<u64>,
129 pub groups: Vec<JoinGroupInfo>,
131}
132
133#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
135pub struct JoinNodeInfo {
136 pub node_id: u64,
137 pub addr: String,
138 pub state: u8,
140 pub raft_groups: Vec<u64>,
141}
142
143#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
145pub struct JoinGroupInfo {
146 pub group_id: u64,
147 pub leader: u64,
148 pub members: Vec<u64>,
149}
150
151#[derive(Debug, Clone)]
155pub enum RaftRpc {
156 AppendEntriesRequest(AppendEntriesRequest),
158 AppendEntriesResponse(AppendEntriesResponse),
159 RequestVoteRequest(RequestVoteRequest),
160 RequestVoteResponse(RequestVoteResponse),
161 InstallSnapshotRequest(InstallSnapshotRequest),
162 InstallSnapshotResponse(InstallSnapshotResponse),
163 JoinRequest(JoinRequest),
165 JoinResponse(JoinResponse),
166 Ping(PingRequest),
168 Pong(PongResponse),
169 TopologyUpdate(TopologyUpdate),
171 TopologyAck(TopologyAck),
172 ForwardRequest(ForwardRequest),
174 ForwardResponse(ForwardResponse),
175 VShardEnvelope(Vec<u8>), }
180
181impl RaftRpc {
182 fn rpc_type(&self) -> u8 {
183 match self {
184 Self::AppendEntriesRequest(_) => RPC_APPEND_ENTRIES_REQ,
185 Self::AppendEntriesResponse(_) => RPC_APPEND_ENTRIES_RESP,
186 Self::RequestVoteRequest(_) => RPC_REQUEST_VOTE_REQ,
187 Self::RequestVoteResponse(_) => RPC_REQUEST_VOTE_RESP,
188 Self::InstallSnapshotRequest(_) => RPC_INSTALL_SNAPSHOT_REQ,
189 Self::InstallSnapshotResponse(_) => RPC_INSTALL_SNAPSHOT_RESP,
190 Self::JoinRequest(_) => RPC_JOIN_REQ,
191 Self::JoinResponse(_) => RPC_JOIN_RESP,
192 Self::Ping(_) => RPC_PING,
193 Self::Pong(_) => RPC_PONG,
194 Self::TopologyUpdate(_) => RPC_TOPOLOGY_UPDATE,
195 Self::TopologyAck(_) => RPC_TOPOLOGY_ACK,
196 Self::ForwardRequest(_) => RPC_FORWARD_REQ,
197 Self::ForwardResponse(_) => RPC_FORWARD_RESP,
198 Self::VShardEnvelope(_) => RPC_VSHARD_ENVELOPE,
199 }
200 }
201}
202
203pub fn encode(rpc: &RaftRpc) -> Result<Vec<u8>> {
205 let payload = serialize_payload(rpc)?;
206 let payload_len: u32 = payload.len().try_into().map_err(|_| ClusterError::Codec {
207 detail: format!("payload too large: {} bytes", payload.len()),
208 })?;
209
210 let crc = crc32c::crc32c(&payload);
211
212 let mut frame = Vec::with_capacity(HEADER_SIZE + payload.len());
213 frame.push(WIRE_VERSION as u8);
215 frame.push(rpc.rpc_type());
216 frame.extend_from_slice(&payload_len.to_le_bytes());
217 frame.extend_from_slice(&crc.to_le_bytes());
218 frame.extend_from_slice(&payload);
219
220 Ok(frame)
221}
222
223pub fn decode(data: &[u8]) -> Result<RaftRpc> {
225 if data.len() < HEADER_SIZE {
226 return Err(ClusterError::Codec {
227 detail: format!("frame too short: {} bytes, need {HEADER_SIZE}", data.len()),
228 });
229 }
230
231 let version = data[0];
232 if version != WIRE_VERSION as u8 {
233 return Err(ClusterError::Codec {
234 detail: format!("unsupported wire version: {version}, expected {WIRE_VERSION}"),
235 });
236 }
237
238 let rpc_type = data[1];
239 let payload_len = u32::from_le_bytes([data[2], data[3], data[4], data[5]]);
240 let expected_crc = u32::from_le_bytes([data[6], data[7], data[8], data[9]]);
241
242 if payload_len > MAX_RPC_PAYLOAD_SIZE {
243 return Err(ClusterError::Codec {
244 detail: format!("payload length {payload_len} exceeds maximum {MAX_RPC_PAYLOAD_SIZE}"),
245 });
246 }
247
248 let expected_total = HEADER_SIZE + payload_len as usize;
249 if data.len() < expected_total {
250 return Err(ClusterError::Codec {
251 detail: format!(
252 "frame truncated: got {} bytes, expected {expected_total}",
253 data.len()
254 ),
255 });
256 }
257
258 let payload = &data[HEADER_SIZE..expected_total];
259
260 let actual_crc = crc32c::crc32c(payload);
261 if actual_crc != expected_crc {
262 return Err(ClusterError::Codec {
263 detail: format!(
264 "CRC32C mismatch: expected {expected_crc:#010x}, got {actual_crc:#010x}"
265 ),
266 });
267 }
268
269 deserialize_payload(rpc_type, payload)
270}
271
272pub fn frame_size(header: &[u8; HEADER_SIZE]) -> Result<usize> {
275 let payload_len = u32::from_le_bytes([header[2], header[3], header[4], header[5]]);
276 if payload_len > MAX_RPC_PAYLOAD_SIZE {
277 return Err(ClusterError::Codec {
278 detail: format!("payload length {payload_len} exceeds maximum {MAX_RPC_PAYLOAD_SIZE}"),
279 });
280 }
281 Ok(HEADER_SIZE + payload_len as usize)
282}
283
284fn serialize_payload(rpc: &RaftRpc) -> Result<Vec<u8>> {
287 let bytes = match rpc {
288 RaftRpc::AppendEntriesRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
289 RaftRpc::AppendEntriesResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
290 RaftRpc::RequestVoteRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
291 RaftRpc::RequestVoteResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
292 RaftRpc::InstallSnapshotRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
293 RaftRpc::InstallSnapshotResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
294 RaftRpc::JoinRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
295 RaftRpc::JoinResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
296 RaftRpc::Ping(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
297 RaftRpc::Pong(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
298 RaftRpc::TopologyUpdate(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
299 RaftRpc::TopologyAck(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
300 RaftRpc::ForwardRequest(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
301 RaftRpc::ForwardResponse(msg) => rkyv::to_bytes::<rkyv::rancor::Error>(msg),
302 RaftRpc::VShardEnvelope(bytes) => return Ok(bytes.clone()), };
304 bytes.map(|b| b.to_vec()).map_err(|e| ClusterError::Codec {
305 detail: format!("rkyv serialize failed: {e}"),
306 })
307}
308
309fn deserialize_payload(rpc_type: u8, payload: &[u8]) -> Result<RaftRpc> {
310 let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(payload.len());
313 aligned.extend_from_slice(payload);
314
315 match rpc_type {
316 RPC_APPEND_ENTRIES_REQ => {
317 let msg = rkyv::from_bytes::<AppendEntriesRequest, rkyv::rancor::Error>(&aligned)
318 .map_err(|e| ClusterError::Codec {
319 detail: format!("rkyv deserialize AppendEntriesRequest: {e}"),
320 })?;
321 Ok(RaftRpc::AppendEntriesRequest(msg))
322 }
323 RPC_APPEND_ENTRIES_RESP => {
324 let msg = rkyv::from_bytes::<AppendEntriesResponse, rkyv::rancor::Error>(&aligned)
325 .map_err(|e| ClusterError::Codec {
326 detail: format!("rkyv deserialize AppendEntriesResponse: {e}"),
327 })?;
328 Ok(RaftRpc::AppendEntriesResponse(msg))
329 }
330 RPC_REQUEST_VOTE_REQ => {
331 let msg = rkyv::from_bytes::<RequestVoteRequest, rkyv::rancor::Error>(&aligned)
332 .map_err(|e| ClusterError::Codec {
333 detail: format!("rkyv deserialize RequestVoteRequest: {e}"),
334 })?;
335 Ok(RaftRpc::RequestVoteRequest(msg))
336 }
337 RPC_REQUEST_VOTE_RESP => {
338 let msg = rkyv::from_bytes::<RequestVoteResponse, rkyv::rancor::Error>(&aligned)
339 .map_err(|e| ClusterError::Codec {
340 detail: format!("rkyv deserialize RequestVoteResponse: {e}"),
341 })?;
342 Ok(RaftRpc::RequestVoteResponse(msg))
343 }
344 RPC_INSTALL_SNAPSHOT_REQ => {
345 let msg = rkyv::from_bytes::<InstallSnapshotRequest, rkyv::rancor::Error>(&aligned)
346 .map_err(|e| ClusterError::Codec {
347 detail: format!("rkyv deserialize InstallSnapshotRequest: {e}"),
348 })?;
349 Ok(RaftRpc::InstallSnapshotRequest(msg))
350 }
351 RPC_INSTALL_SNAPSHOT_RESP => {
352 let msg = rkyv::from_bytes::<InstallSnapshotResponse, rkyv::rancor::Error>(&aligned)
353 .map_err(|e| ClusterError::Codec {
354 detail: format!("rkyv deserialize InstallSnapshotResponse: {e}"),
355 })?;
356 Ok(RaftRpc::InstallSnapshotResponse(msg))
357 }
358 RPC_JOIN_REQ => {
359 let msg =
360 rkyv::from_bytes::<JoinRequest, rkyv::rancor::Error>(&aligned).map_err(|e| {
361 ClusterError::Codec {
362 detail: format!("rkyv deserialize JoinRequest: {e}"),
363 }
364 })?;
365 Ok(RaftRpc::JoinRequest(msg))
366 }
367 RPC_JOIN_RESP => {
368 let msg =
369 rkyv::from_bytes::<JoinResponse, rkyv::rancor::Error>(&aligned).map_err(|e| {
370 ClusterError::Codec {
371 detail: format!("rkyv deserialize JoinResponse: {e}"),
372 }
373 })?;
374 Ok(RaftRpc::JoinResponse(msg))
375 }
376 RPC_PING => {
377 let msg =
378 rkyv::from_bytes::<PingRequest, rkyv::rancor::Error>(&aligned).map_err(|e| {
379 ClusterError::Codec {
380 detail: format!("rkyv deserialize PingRequest: {e}"),
381 }
382 })?;
383 Ok(RaftRpc::Ping(msg))
384 }
385 RPC_PONG => {
386 let msg =
387 rkyv::from_bytes::<PongResponse, rkyv::rancor::Error>(&aligned).map_err(|e| {
388 ClusterError::Codec {
389 detail: format!("rkyv deserialize PongResponse: {e}"),
390 }
391 })?;
392 Ok(RaftRpc::Pong(msg))
393 }
394 RPC_TOPOLOGY_UPDATE => {
395 let msg =
396 rkyv::from_bytes::<TopologyUpdate, rkyv::rancor::Error>(&aligned).map_err(|e| {
397 ClusterError::Codec {
398 detail: format!("rkyv deserialize TopologyUpdate: {e}"),
399 }
400 })?;
401 Ok(RaftRpc::TopologyUpdate(msg))
402 }
403 RPC_TOPOLOGY_ACK => {
404 let msg =
405 rkyv::from_bytes::<TopologyAck, rkyv::rancor::Error>(&aligned).map_err(|e| {
406 ClusterError::Codec {
407 detail: format!("rkyv deserialize TopologyAck: {e}"),
408 }
409 })?;
410 Ok(RaftRpc::TopologyAck(msg))
411 }
412 RPC_FORWARD_REQ => {
413 let msg =
414 rkyv::from_bytes::<ForwardRequest, rkyv::rancor::Error>(&aligned).map_err(|e| {
415 ClusterError::Codec {
416 detail: format!("rkyv deserialize ForwardRequest: {e}"),
417 }
418 })?;
419 Ok(RaftRpc::ForwardRequest(msg))
420 }
421 RPC_FORWARD_RESP => {
422 let msg = rkyv::from_bytes::<ForwardResponse, rkyv::rancor::Error>(&aligned).map_err(
423 |e| ClusterError::Codec {
424 detail: format!("rkyv deserialize ForwardResponse: {e}"),
425 },
426 )?;
427 Ok(RaftRpc::ForwardResponse(msg))
428 }
429 RPC_VSHARD_ENVELOPE => {
430 Ok(RaftRpc::VShardEnvelope(payload.to_vec()))
432 }
433 _ => Err(ClusterError::Codec {
434 detail: format!("unknown rpc_type: {rpc_type}"),
435 }),
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use nodedb_raft::message::LogEntry;
443
444 #[test]
445 fn roundtrip_append_entries_request() {
446 let req = AppendEntriesRequest {
447 term: 5,
448 leader_id: 1,
449 prev_log_index: 99,
450 prev_log_term: 4,
451 entries: vec![
452 LogEntry {
453 term: 5,
454 index: 100,
455 data: b"put x=1".to_vec(),
456 },
457 LogEntry {
458 term: 5,
459 index: 101,
460 data: b"put y=2".to_vec(),
461 },
462 ],
463 leader_commit: 98,
464 group_id: 7,
465 };
466
467 let rpc = RaftRpc::AppendEntriesRequest(req.clone());
468 let encoded = encode(&rpc).unwrap();
469 let decoded = decode(&encoded).unwrap();
470
471 match decoded {
472 RaftRpc::AppendEntriesRequest(d) => {
473 assert_eq!(d.term, req.term);
474 assert_eq!(d.leader_id, req.leader_id);
475 assert_eq!(d.prev_log_index, req.prev_log_index);
476 assert_eq!(d.prev_log_term, req.prev_log_term);
477 assert_eq!(d.entries.len(), 2);
478 assert_eq!(d.entries[0].data, b"put x=1");
479 assert_eq!(d.entries[1].data, b"put y=2");
480 assert_eq!(d.leader_commit, req.leader_commit);
481 assert_eq!(d.group_id, req.group_id);
482 }
483 other => panic!("expected AppendEntriesRequest, got {other:?}"),
484 }
485 }
486
487 #[test]
488 fn roundtrip_append_entries_heartbeat() {
489 let req = AppendEntriesRequest {
490 term: 3,
491 leader_id: 1,
492 prev_log_index: 10,
493 prev_log_term: 2,
494 entries: vec![],
495 leader_commit: 8,
496 group_id: 0,
497 };
498
499 let rpc = RaftRpc::AppendEntriesRequest(req);
500 let encoded = encode(&rpc).unwrap();
501 let decoded = decode(&encoded).unwrap();
502
503 match decoded {
504 RaftRpc::AppendEntriesRequest(d) => {
505 assert!(d.entries.is_empty());
506 assert_eq!(d.term, 3);
507 }
508 other => panic!("expected heartbeat, got {other:?}"),
509 }
510 }
511
512 #[test]
513 fn roundtrip_append_entries_response() {
514 let resp = AppendEntriesResponse {
515 term: 5,
516 success: true,
517 last_log_index: 100,
518 };
519
520 let rpc = RaftRpc::AppendEntriesResponse(resp);
521 let encoded = encode(&rpc).unwrap();
522 let decoded = decode(&encoded).unwrap();
523
524 match decoded {
525 RaftRpc::AppendEntriesResponse(d) => {
526 assert_eq!(d.term, 5);
527 assert!(d.success);
528 assert_eq!(d.last_log_index, 100);
529 }
530 other => panic!("expected AppendEntriesResponse, got {other:?}"),
531 }
532 }
533
534 #[test]
535 fn roundtrip_request_vote_request() {
536 let req = RequestVoteRequest {
537 term: 10,
538 candidate_id: 3,
539 last_log_index: 200,
540 last_log_term: 9,
541 group_id: 42,
542 };
543
544 let rpc = RaftRpc::RequestVoteRequest(req);
545 let encoded = encode(&rpc).unwrap();
546 let decoded = decode(&encoded).unwrap();
547
548 match decoded {
549 RaftRpc::RequestVoteRequest(d) => {
550 assert_eq!(d.term, 10);
551 assert_eq!(d.candidate_id, 3);
552 assert_eq!(d.last_log_index, 200);
553 assert_eq!(d.last_log_term, 9);
554 assert_eq!(d.group_id, 42);
555 }
556 other => panic!("expected RequestVoteRequest, got {other:?}"),
557 }
558 }
559
560 #[test]
561 fn roundtrip_request_vote_response() {
562 let resp = RequestVoteResponse {
563 term: 10,
564 vote_granted: true,
565 };
566
567 let rpc = RaftRpc::RequestVoteResponse(resp);
568 let encoded = encode(&rpc).unwrap();
569 let decoded = decode(&encoded).unwrap();
570
571 match decoded {
572 RaftRpc::RequestVoteResponse(d) => {
573 assert_eq!(d.term, 10);
574 assert!(d.vote_granted);
575 }
576 other => panic!("expected RequestVoteResponse, got {other:?}"),
577 }
578 }
579
580 #[test]
581 fn roundtrip_install_snapshot_request() {
582 let data: Vec<u8> = [0xDE, 0xAD, 0xBE, 0xEF]
583 .iter()
584 .copied()
585 .cycle()
586 .take(1024)
587 .collect();
588 let req = InstallSnapshotRequest {
589 term: 7,
590 leader_id: 1,
591 last_included_index: 500,
592 last_included_term: 6,
593 offset: 0,
594 data: data.clone(),
595 done: false,
596 group_id: 3,
597 };
598
599 let rpc = RaftRpc::InstallSnapshotRequest(req);
600 let encoded = encode(&rpc).unwrap();
601 let decoded = decode(&encoded).unwrap();
602
603 match decoded {
604 RaftRpc::InstallSnapshotRequest(d) => {
605 assert_eq!(d.term, 7);
606 assert_eq!(d.leader_id, 1);
607 assert_eq!(d.last_included_index, 500);
608 assert_eq!(d.last_included_term, 6);
609 assert_eq!(d.offset, 0);
610 assert_eq!(d.data, data);
611 assert!(!d.done);
612 assert_eq!(d.group_id, 3);
613 }
614 other => panic!("expected InstallSnapshotRequest, got {other:?}"),
615 }
616 }
617
618 #[test]
619 fn roundtrip_install_snapshot_final_chunk() {
620 let req = InstallSnapshotRequest {
621 term: 7,
622 leader_id: 1,
623 last_included_index: 500,
624 last_included_term: 6,
625 offset: 4096,
626 data: vec![0xFF; 128],
627 done: true,
628 group_id: 3,
629 };
630
631 let rpc = RaftRpc::InstallSnapshotRequest(req);
632 let encoded = encode(&rpc).unwrap();
633 let decoded = decode(&encoded).unwrap();
634
635 match decoded {
636 RaftRpc::InstallSnapshotRequest(d) => {
637 assert!(d.done);
638 assert_eq!(d.offset, 4096);
639 }
640 other => panic!("expected InstallSnapshotRequest, got {other:?}"),
641 }
642 }
643
644 #[test]
645 fn roundtrip_install_snapshot_response() {
646 let resp = InstallSnapshotResponse { term: 7 };
647
648 let rpc = RaftRpc::InstallSnapshotResponse(resp);
649 let encoded = encode(&rpc).unwrap();
650 let decoded = decode(&encoded).unwrap();
651
652 match decoded {
653 RaftRpc::InstallSnapshotResponse(d) => {
654 assert_eq!(d.term, 7);
655 }
656 other => panic!("expected InstallSnapshotResponse, got {other:?}"),
657 }
658 }
659
660 #[test]
661 fn crc_corruption_detected() {
662 let rpc = RaftRpc::RequestVoteResponse(RequestVoteResponse {
663 term: 1,
664 vote_granted: false,
665 });
666 let mut encoded = encode(&rpc).unwrap();
667
668 if let Some(last) = encoded.last_mut() {
670 *last ^= 0x01;
671 }
672
673 let err = decode(&encoded).unwrap_err();
674 assert!(err.to_string().contains("CRC32C mismatch"), "{err}");
675 }
676
677 #[test]
678 fn version_mismatch_rejected() {
679 let rpc = RaftRpc::RequestVoteResponse(RequestVoteResponse {
680 term: 1,
681 vote_granted: false,
682 });
683 let mut encoded = encode(&rpc).unwrap();
684
685 encoded[0] = 99;
687
688 let err = decode(&encoded).unwrap_err();
689 assert!(
690 err.to_string().contains("unsupported wire version"),
691 "{err}"
692 );
693 }
694
695 #[test]
696 fn truncated_frame_rejected() {
697 let err = decode(&[1, 2, 3]).unwrap_err();
698 assert!(err.to_string().contains("frame too short"), "{err}");
699 }
700
701 #[test]
702 fn unknown_rpc_type_rejected() {
703 let rpc = RaftRpc::RequestVoteResponse(RequestVoteResponse {
704 term: 1,
705 vote_granted: false,
706 });
707 let mut encoded = encode(&rpc).unwrap();
708
709 encoded[1] = 255;
711
712 let err = decode(&encoded).unwrap_err();
719 assert!(err.to_string().contains("unknown rpc_type"), "{err}");
720 }
721
722 #[test]
723 fn payload_too_large_rejected() {
724 let mut frame = vec![0u8; HEADER_SIZE];
726 frame[0] = WIRE_VERSION as u8;
727 frame[1] = RPC_APPEND_ENTRIES_REQ;
728 let huge: u32 = MAX_RPC_PAYLOAD_SIZE + 1;
729 frame[2..6].copy_from_slice(&huge.to_le_bytes());
730
731 let err = decode(&frame).unwrap_err();
732 assert!(err.to_string().contains("exceeds maximum"), "{err}");
733 }
734
735 #[test]
736 fn frame_size_helper() {
737 let rpc = RaftRpc::AppendEntriesResponse(AppendEntriesResponse {
738 term: 1,
739 success: true,
740 last_log_index: 5,
741 });
742 let encoded = encode(&rpc).unwrap();
743
744 let header: [u8; HEADER_SIZE] = encoded[..HEADER_SIZE].try_into().unwrap();
745 let size = frame_size(&header).unwrap();
746 assert_eq!(size, encoded.len());
747 }
748
749 #[test]
750 fn large_snapshot_roundtrip() {
751 let data = vec![0xAB; 1024 * 1024];
753 let req = InstallSnapshotRequest {
754 term: 100,
755 leader_id: 5,
756 last_included_index: 999_999,
757 last_included_term: 99,
758 offset: 0,
759 data: data.clone(),
760 done: false,
761 group_id: 0,
762 };
763
764 let rpc = RaftRpc::InstallSnapshotRequest(req);
765 let encoded = encode(&rpc).unwrap();
766 let decoded = decode(&encoded).unwrap();
767
768 match decoded {
769 RaftRpc::InstallSnapshotRequest(d) => {
770 assert_eq!(d.data.len(), 1024 * 1024);
771 assert_eq!(d.data, data);
772 }
773 other => panic!("expected InstallSnapshotRequest, got {other:?}"),
774 }
775 }
776
777 #[test]
778 fn roundtrip_join_request() {
779 let req = JoinRequest {
780 node_id: 42,
781 listen_addr: "10.0.0.5:9400".into(),
782 };
783
784 let rpc = RaftRpc::JoinRequest(req);
785 let encoded = encode(&rpc).unwrap();
786 let decoded = decode(&encoded).unwrap();
787
788 match decoded {
789 RaftRpc::JoinRequest(d) => {
790 assert_eq!(d.node_id, 42);
791 assert_eq!(d.listen_addr, "10.0.0.5:9400");
792 }
793 other => panic!("expected JoinRequest, got {other:?}"),
794 }
795 }
796
797 #[test]
798 fn roundtrip_join_response() {
799 let resp = JoinResponse {
800 success: true,
801 error: String::new(),
802 nodes: vec![
803 JoinNodeInfo {
804 node_id: 1,
805 addr: "10.0.0.1:9400".into(),
806 state: 1,
807 raft_groups: vec![0, 1],
808 },
809 JoinNodeInfo {
810 node_id: 2,
811 addr: "10.0.0.2:9400".into(),
812 state: 1,
813 raft_groups: vec![0, 1],
814 },
815 ],
816 vshard_to_group: (0..1024u64).map(|i| i % 4).collect(),
817 groups: vec![JoinGroupInfo {
818 group_id: 0,
819 leader: 1,
820 members: vec![1, 2],
821 }],
822 };
823
824 let rpc = RaftRpc::JoinResponse(resp);
825 let encoded = encode(&rpc).unwrap();
826 let decoded = decode(&encoded).unwrap();
827
828 match decoded {
829 RaftRpc::JoinResponse(d) => {
830 assert!(d.success);
831 assert_eq!(d.nodes.len(), 2);
832 assert_eq!(d.vshard_to_group.len(), 1024);
833 assert_eq!(d.groups.len(), 1);
834 assert_eq!(d.groups[0].leader, 1);
835 }
836 other => panic!("expected JoinResponse, got {other:?}"),
837 }
838 }
839}