1use crate::raft::{LogEntry, LogIndex, NodeId, Snapshot, Term};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11
12pub const MAX_PACKET_SIZE: usize = 4096;
14
15pub const DEFAULT_READ_TIMEOUT: Duration = Duration::from_millis(100);
17
18pub const MAX_BACKOFF: Duration = Duration::from_secs(5);
20
21pub const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
23
24#[derive(Debug, Clone)]
26pub struct NetworkConfig {
27 pub buffer_size: usize,
29 pub read_timeout: Duration,
31 pub initial_backoff: Duration,
33 pub max_backoff: Duration,
35 pub connect_timeout: Duration,
37 pub request_timeout: Duration,
39}
40
41impl Default for NetworkConfig {
42 fn default() -> Self {
43 Self {
44 buffer_size: MAX_PACKET_SIZE,
45 read_timeout: DEFAULT_READ_TIMEOUT,
46 initial_backoff: INITIAL_BACKOFF,
47 max_backoff: MAX_BACKOFF,
48 connect_timeout: Duration::from_secs(5),
49 request_timeout: Duration::from_secs(10),
50 }
51 }
52}
53
54impl NetworkConfig {
55 pub fn new(buffer_size: usize, read_timeout: Duration) -> Self {
57 Self {
58 buffer_size: buffer_size.min(MAX_PACKET_SIZE),
59 read_timeout,
60 ..Default::default()
61 }
62 }
63
64 pub fn validate(&self) -> Result<(), String> {
66 if self.buffer_size == 0 {
67 return Err("Buffer size must be positive".to_string());
68 }
69 if self.read_timeout.is_zero() {
70 return Err("Read timeout must be non-zero".to_string());
71 }
72 Ok(())
73 }
74
75 pub fn builder() -> NetworkConfigBuilder {
77 NetworkConfigBuilder::default()
78 }
79}
80
81#[derive(Default)]
83pub struct NetworkConfigBuilder {
84 buffer_size: Option<usize>,
85 read_timeout: Option<Duration>,
86 initial_backoff: Option<Duration>,
87 max_backoff: Option<Duration>,
88 connect_timeout: Option<Duration>,
89 request_timeout: Option<Duration>,
90}
91
92impl NetworkConfigBuilder {
93 pub fn buffer_size(mut self, size: usize) -> Self {
94 self.buffer_size = Some(size);
95 self
96 }
97
98 pub fn read_timeout(mut self, timeout: Duration) -> Self {
99 self.read_timeout = Some(timeout);
100 self
101 }
102
103 pub fn initial_backoff(mut self, backoff: Duration) -> Self {
104 self.initial_backoff = Some(backoff);
105 self
106 }
107
108 pub fn max_backoff(mut self, backoff: Duration) -> Self {
109 self.max_backoff = Some(backoff);
110 self
111 }
112
113 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
114 self.connect_timeout = Some(timeout);
115 self
116 }
117
118 pub fn request_timeout(mut self, timeout: Duration) -> Self {
119 self.request_timeout = Some(timeout);
120 self
121 }
122
123 pub fn build(self) -> Result<NetworkConfig, String> {
124 let config = NetworkConfig {
125 buffer_size: self
126 .buffer_size
127 .unwrap_or(MAX_PACKET_SIZE)
128 .min(MAX_PACKET_SIZE),
129 read_timeout: self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT),
130 initial_backoff: self.initial_backoff.unwrap_or(INITIAL_BACKOFF),
131 max_backoff: self.max_backoff.unwrap_or(MAX_BACKOFF),
132 connect_timeout: self.connect_timeout.unwrap_or(Duration::from_secs(5)),
133 request_timeout: self.request_timeout.unwrap_or(Duration::from_secs(10)),
134 };
135 config.validate()?;
136 Ok(config)
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
146pub enum RaftMessage<T> {
147 RequestVote {
150 term: Term,
152 candidate_id: NodeId,
154 last_log_index: LogIndex,
156 last_log_term: Term,
158 },
159
160 RequestVoteResponse {
162 term: Term,
164 vote_granted: bool,
166 from_id: NodeId,
168 },
169
170 AppendEntries {
173 term: Term,
175 leader_id: NodeId,
177 prev_log_index: LogIndex,
179 prev_log_term: Term,
181 entries: Vec<LogEntry<T>>,
183 leader_commit: LogIndex,
185 },
186
187 AppendEntriesResponse {
189 term: Term,
191 success: bool,
193 match_index: LogIndex,
195 from_id: NodeId,
197 },
198
199 InstallSnapshot {
202 term: Term,
204 leader_id: NodeId,
206 snapshot: Snapshot<T>,
208 },
209
210 InstallSnapshotResponse {
212 term: Term,
214 success: bool,
216 from_id: NodeId,
218 },
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223pub enum Packet {
224 VoteRequest {
225 term: Term,
226 candidate_id: NodeId,
227 },
228 VoteResponse {
229 term: Term,
230 vote_granted: bool,
231 },
232 Heartbeat {
233 leader_id: NodeId,
234 term: Term,
235 },
236 ConfigChange {
238 change_type: String,
240 peer_address: String,
242 node_id: NodeId,
244 },
245}
246
247#[derive(Debug, Clone)]
253pub struct PeerInfo {
254 pub id: NodeId,
255 pub address: String,
256}
257
258#[async_trait]
262pub trait RaftNetwork<T: Send + Sync + Clone>: Send + Sync {
263 #[allow(clippy::too_many_arguments)]
265 async fn send_request_vote(
266 &self,
267 peer_id: NodeId,
268 term: Term,
269 candidate_id: NodeId,
270 last_log_index: LogIndex,
271 last_log_term: Term,
272 ) -> Result<Option<RaftMessage<T>>, NetworkError>;
273
274 #[allow(clippy::too_many_arguments)]
276 async fn send_append_entries(
277 &self,
278 peer_id: NodeId,
279 term: Term,
280 leader_id: NodeId,
281 prev_log_index: LogIndex,
282 prev_log_term: Term,
283 entries: Vec<LogEntry<T>>,
284 leader_commit: LogIndex,
285 ) -> Result<Option<RaftMessage<T>>, NetworkError>;
286
287 async fn send_install_snapshot(
289 &self,
290 peer_id: NodeId,
291 term: Term,
292 leader_id: NodeId,
293 snapshot: Snapshot<T>,
294 ) -> Result<Option<RaftMessage<T>>, NetworkError>;
295
296 async fn receive(&self) -> Result<RaftMessage<T>, NetworkError>;
298
299 async fn respond(&self, to: NodeId, message: RaftMessage<T>) -> Result<(), NetworkError>;
301
302 fn peer_ids(&self) -> Vec<NodeId>;
304
305 async fn update_peers(&self, peers: Vec<PeerInfo>) -> Result<(), NetworkError>;
307}
308
309#[derive(Debug, thiserror::Error)]
311pub enum NetworkError {
312 #[error("Connection failed: {0}")]
313 ConnectionFailed(#[source] Box<dyn std::error::Error + Send + Sync>),
314 #[error("Timeout")]
315 Timeout,
316 #[error("Serialization error: {0}")]
317 SerializationError(String),
318 #[error("Peer not found: {0}")]
319 PeerNotFound(NodeId),
320 #[error("Transport error: {0}")]
321 TransportError(String),
322}
323
324#[async_trait]
326pub trait ConsensusNetwork: Send + Sync {
327 async fn broadcast_vote_request(&self, term: Term, candidate_id: NodeId) -> Result<(), String>;
329
330 async fn send_heartbeat(&self, leader_id: NodeId, term: Term) -> Result<(), String>;
332
333 async fn receive(&self) -> Result<Packet, String>;
335
336 async fn update_peers(&self, peers: Vec<String>) -> Result<(), String>;
338}
339
340pub struct InMemoryNetwork<T> {
346 _node_id: NodeId,
347 peers: Arc<tokio::sync::RwLock<HashMap<NodeId, PeerInfo>>>,
348 inbox: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<RaftMessage<T>>>>,
349 outboxes: Arc<tokio::sync::RwLock<HashMap<NodeId, tokio::sync::mpsc::Sender<RaftMessage<T>>>>>,
350}
351
352impl<T: Send + Sync + Clone + 'static> InMemoryNetwork<T> {
353 pub fn new(node_id: NodeId, inbox_rx: tokio::sync::mpsc::Receiver<RaftMessage<T>>) -> Self {
355 Self {
356 _node_id: node_id,
357 peers: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
358 inbox: Arc::new(tokio::sync::Mutex::new(inbox_rx)),
359 outboxes: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
360 }
361 }
362
363 pub async fn register_peer(
365 &self,
366 peer_id: NodeId,
367 address: String,
368 sender: tokio::sync::mpsc::Sender<RaftMessage<T>>,
369 ) {
370 self.peers.write().await.insert(
371 peer_id,
372 PeerInfo {
373 id: peer_id,
374 address,
375 },
376 );
377 self.outboxes.write().await.insert(peer_id, sender);
378 }
379}
380
381#[async_trait]
382impl<T: Send + Sync + Clone + serde::Serialize + serde::de::DeserializeOwned + 'static>
383 RaftNetwork<T> for InMemoryNetwork<T>
384{
385 async fn send_request_vote(
386 &self,
387 peer_id: NodeId,
388 term: Term,
389 candidate_id: NodeId,
390 last_log_index: LogIndex,
391 last_log_term: Term,
392 ) -> Result<Option<RaftMessage<T>>, NetworkError> {
393 let outboxes = self.outboxes.read().await;
394 let sender = outboxes
395 .get(&peer_id)
396 .ok_or(NetworkError::PeerNotFound(peer_id))?;
397
398 sender
399 .send(RaftMessage::RequestVote {
400 term,
401 candidate_id,
402 last_log_index,
403 last_log_term,
404 })
405 .await
406 .map_err(|e| NetworkError::TransportError(e.to_string()))?;
407
408 Ok(None) }
410
411 async fn send_append_entries(
412 &self,
413 peer_id: NodeId,
414 term: Term,
415 leader_id: NodeId,
416 prev_log_index: LogIndex,
417 prev_log_term: Term,
418 entries: Vec<LogEntry<T>>,
419 leader_commit: LogIndex,
420 ) -> Result<Option<RaftMessage<T>>, NetworkError> {
421 let outboxes = self.outboxes.read().await;
422 let sender = outboxes
423 .get(&peer_id)
424 .ok_or(NetworkError::PeerNotFound(peer_id))?;
425
426 sender
427 .send(RaftMessage::AppendEntries {
428 term,
429 leader_id,
430 prev_log_index,
431 prev_log_term,
432 entries,
433 leader_commit,
434 })
435 .await
436 .map_err(|e| NetworkError::TransportError(e.to_string()))?;
437
438 Ok(None)
439 }
440
441 async fn send_install_snapshot(
442 &self,
443 peer_id: NodeId,
444 term: Term,
445 leader_id: NodeId,
446 snapshot: Snapshot<T>,
447 ) -> Result<Option<RaftMessage<T>>, NetworkError> {
448 let outboxes = self.outboxes.read().await;
449 let sender = outboxes
450 .get(&peer_id)
451 .ok_or(NetworkError::PeerNotFound(peer_id))?;
452
453 sender
454 .send(RaftMessage::InstallSnapshot {
455 term,
456 leader_id,
457 snapshot,
458 })
459 .await
460 .map_err(|e| NetworkError::TransportError(e.to_string()))?;
461
462 Ok(None)
463 }
464
465 async fn receive(&self) -> Result<RaftMessage<T>, NetworkError> {
466 let mut inbox = self.inbox.lock().await;
467 inbox
468 .recv()
469 .await
470 .ok_or(NetworkError::TransportError("Channel closed".into()))
471 }
472
473 async fn respond(&self, to: NodeId, message: RaftMessage<T>) -> Result<(), NetworkError> {
474 let outboxes = self.outboxes.read().await;
475 let sender = outboxes.get(&to).ok_or(NetworkError::PeerNotFound(to))?;
476
477 sender
478 .send(message)
479 .await
480 .map_err(|e| NetworkError::TransportError(e.to_string()))
481 }
482
483 fn peer_ids(&self) -> Vec<NodeId> {
484 Vec::new()
486 }
487
488 async fn update_peers(&self, peers: Vec<PeerInfo>) -> Result<(), NetworkError> {
489 let mut peer_map = self.peers.write().await;
490 peer_map.clear();
491 for peer in peers {
492 peer_map.insert(peer.id, peer);
493 }
494 Ok(())
495 }
496}
497
498#[cfg(feature = "net")]
503pub mod udp {
504 use super::*;
505 use tokio::net::UdpSocket;
506 use tokio::sync::RwLock;
507
508 pub struct UdpTransport {
510 socket: Arc<UdpSocket>,
511 peers: Arc<RwLock<Vec<String>>>,
512 config: NetworkConfig,
513 }
514
515 impl std::fmt::Debug for UdpTransport {
516 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517 f.debug_struct("UdpTransport")
518 .field("peers", &self.peers)
519 .field("config", &self.config)
520 .finish_non_exhaustive()
521 }
522 }
523
524 impl UdpTransport {
525 pub async fn new(bind_addr: &str, peers: Vec<String>) -> Result<Self, String> {
527 Self::with_config(bind_addr, peers, NetworkConfig::default()).await
528 }
529
530 pub async fn with_config(
532 bind_addr: &str,
533 peers: Vec<String>,
534 config: NetworkConfig,
535 ) -> Result<Self, String> {
536 let socket = UdpSocket::bind(bind_addr)
537 .await
538 .map_err(|e| format!("Failed to bind UDP socket to '{}': {}", bind_addr, e))?;
539
540 tracing::info!(
541 bind_addr = bind_addr,
542 peer_count = peers.len(),
543 buffer_size = config.buffer_size,
544 "Async UDP transport initialized"
545 );
546
547 Ok(Self {
548 socket: Arc::new(socket),
549 peers: Arc::new(RwLock::new(peers)),
550 config,
551 })
552 }
553 }
554
555 #[async_trait]
556 impl ConsensusNetwork for UdpTransport {
557 async fn broadcast_vote_request(
558 &self,
559 term: Term,
560 candidate_id: NodeId,
561 ) -> Result<(), String> {
562 let packet = Packet::VoteRequest { term, candidate_id };
563 let serialized = serde_json::to_vec(&packet).map_err(|e| e.to_string())?;
564
565
566 let peers = self.peers.read().await;
567 let mut last_error = None;
568 let mut success_count = 0;
569
570 for peer in peers.iter() {
571 match self.socket.send_to(&serialized, peer).await {
572 Ok(_) => success_count += 1,
573 Err(e) => {
574 tracing::warn!("Failed to send to {}: {}", peer, e);
575 last_error = Some(e);
576 }
577 }
578 }
579
580 if success_count == 0 && !peers.is_empty() {
581 if let Some(e) = last_error {
582 return Err(format!("Failed to broadcast vote request to any peer: {}", e));
583 }
584 }
585 Ok(())
586 }
587
588 async fn send_heartbeat(&self, leader_id: NodeId, term: Term) -> Result<(), String> {
589 let packet = Packet::Heartbeat { leader_id, term };
590 let serialized = serde_json::to_vec(&packet).map_err(|e| e.to_string())?;
591
592 let peers = self.peers.read().await;
593 let mut last_error = None;
594 let mut success_count = 0;
595
596 for peer in peers.iter() {
597 match self.socket.send_to(&serialized, peer).await {
598 Ok(_) => success_count += 1,
599 Err(e) => {
600 tracing::debug!("Failed to heartbeat {}: {}", peer, e);
602 last_error = Some(e);
603 }
604 }
605 }
606 if success_count == 0 && !peers.is_empty() {
607 if let Some(e) = last_error {
608 return Err(format!("Failed to send heartbeats to any peer: {}", e));
609 }
610 }
611 Ok(())
612 }
613
614 async fn receive(&self) -> Result<Packet, String> {
615 let mut buf = vec![0u8; self.config.buffer_size];
616
617 loop {
618 match tokio::time::timeout(
619 self.config.read_timeout,
620 self.socket.recv_from(&mut buf),
621 )
622 .await
623 {
624 Ok(io_result) => match io_result {
625 Ok((amt, _src)) => {
626 let packet: Packet =
627 serde_json::from_slice(&buf[..amt]).map_err(|e| e.to_string())?;
628 return Ok(packet);
629 }
630 Err(e) => {
631 tracing::error!("UDP receive IO error: {}", e);
632 tokio::time::sleep(Duration::from_millis(50)).await;
633 }
634 },
635 Err(_) => {
636 }
638 }
639 }
640 }
641
642 async fn update_peers(&self, new_peers: Vec<String>) -> Result<(), String> {
643 let mut peers = self.peers.write().await;
644 *peers = new_peers;
645 tracing::info!(peer_count = peers.len(), "Updated ConsensusNetwork peers");
646 Ok(())
647 }
648 }
649}
650
651#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_network_config_default() {
661 let config = NetworkConfig::default();
662 assert_eq!(config.buffer_size, MAX_PACKET_SIZE);
663 assert_eq!(config.read_timeout, DEFAULT_READ_TIMEOUT);
664 }
665
666 #[test]
667 fn test_network_config_clamps_buffer_size() {
668 let config = NetworkConfig::new(100_000, Duration::from_millis(50));
669 assert_eq!(config.buffer_size, MAX_PACKET_SIZE);
670 }
671
672 #[test]
673 fn test_packet_serialization() {
674 let packet = Packet::VoteRequest {
675 term: 1,
676 candidate_id: 42,
677 };
678 let serialized = serde_json::to_vec(&packet).unwrap();
679 let deserialized: Packet = serde_json::from_slice(&serialized).unwrap();
680
681 match deserialized {
682 Packet::VoteRequest { term, candidate_id } => {
683 assert_eq!(term, 1);
684 assert_eq!(candidate_id, 42);
685 }
686 _ => panic!("Wrong packet type"),
687 }
688 }
689
690 #[cfg(feature = "net")]
691 mod udp_tests {
692 use super::super::udp::UdpTransport;
693 use super::super::*;
694
695 #[tokio::test]
696 async fn test_invalid_bind_address() {
697 let result = UdpTransport::new("not-an-address", vec![]).await;
698 assert!(result.is_err());
699 }
700
701 #[tokio::test]
702 async fn test_valid_bind_address() {
703 let result = UdpTransport::new("127.0.0.1:0", vec![]).await;
704 assert!(result.is_ok());
705 }
706 }
707}