praborrow_lease/
network.rs

1//! Network transport for Raft consensus.
2//!
3//! Provides abstract network interface and implementations for consensus messaging.
4
5use crate::raft::{LogEntry, LogIndex, NodeId, Snapshot, Term};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11
12/// Maximum UDP packet size (theoretical max is 65507, but we use a safer limit)
13pub const MAX_PACKET_SIZE: usize = 4096;
14
15/// Default read timeout for UDP sockets
16pub const DEFAULT_READ_TIMEOUT: Duration = Duration::from_millis(100);
17
18/// Maximum backoff duration for supervisor restart
19pub const MAX_BACKOFF: Duration = Duration::from_secs(5);
20
21/// Initial backoff duration for supervisor restart
22pub const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
23
24/// Configuration for network transport
25#[derive(Debug, Clone)]
26pub struct NetworkConfig {
27    /// Size of the receive buffer in bytes
28    pub buffer_size: usize,
29    /// Read timeout for the socket
30    pub read_timeout: Duration,
31    /// Initial backoff for supervisor restart
32    pub initial_backoff: Duration,
33    /// Maximum backoff for supervisor restart
34    pub max_backoff: Duration,
35    /// Connection timeout
36    pub connect_timeout: Duration,
37    /// Request timeout
38    pub request_timeout: Duration,
39}
40
41impl Default for NetworkConfig {
42    fn default() -> Self {
43        Self {
44            buffer_size: MAX_PACKET_SIZE,
45            read_timeout: DEFAULT_READ_TIMEOUT,
46            initial_backoff: INITIAL_BACKOFF,
47            max_backoff: MAX_BACKOFF,
48            connect_timeout: Duration::from_secs(5),
49            request_timeout: Duration::from_secs(10),
50        }
51    }
52}
53
54impl NetworkConfig {
55    /// Creates a new network configuration with custom values
56    pub fn new(buffer_size: usize, read_timeout: Duration) -> Self {
57        Self {
58            buffer_size: buffer_size.min(MAX_PACKET_SIZE),
59            read_timeout,
60            ..Default::default()
61        }
62    }
63
64    /// Validates the configuration.
65    pub fn validate(&self) -> Result<(), String> {
66        if self.buffer_size == 0 {
67            return Err("Buffer size must be positive".to_string());
68        }
69        if self.read_timeout.is_zero() {
70            return Err("Read timeout must be non-zero".to_string());
71        }
72        Ok(())
73    }
74
75    /// Returns a new builder for configuration.
76    pub fn builder() -> NetworkConfigBuilder {
77        NetworkConfigBuilder::default()
78    }
79}
80
81/// Builder for NetworkConfig.
82#[derive(Default)]
83pub struct NetworkConfigBuilder {
84    buffer_size: Option<usize>,
85    read_timeout: Option<Duration>,
86    initial_backoff: Option<Duration>,
87    max_backoff: Option<Duration>,
88    connect_timeout: Option<Duration>,
89    request_timeout: Option<Duration>,
90}
91
92impl NetworkConfigBuilder {
93    pub fn buffer_size(mut self, size: usize) -> Self {
94        self.buffer_size = Some(size);
95        self
96    }
97
98    pub fn read_timeout(mut self, timeout: Duration) -> Self {
99        self.read_timeout = Some(timeout);
100        self
101    }
102
103    pub fn initial_backoff(mut self, backoff: Duration) -> Self {
104        self.initial_backoff = Some(backoff);
105        self
106    }
107
108    pub fn max_backoff(mut self, backoff: Duration) -> Self {
109        self.max_backoff = Some(backoff);
110        self
111    }
112
113    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
114        self.connect_timeout = Some(timeout);
115        self
116    }
117
118    pub fn request_timeout(mut self, timeout: Duration) -> Self {
119        self.request_timeout = Some(timeout);
120        self
121    }
122
123    pub fn build(self) -> Result<NetworkConfig, String> {
124        let config = NetworkConfig {
125            buffer_size: self
126                .buffer_size
127                .unwrap_or(MAX_PACKET_SIZE)
128                .min(MAX_PACKET_SIZE),
129            read_timeout: self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT),
130            initial_backoff: self.initial_backoff.unwrap_or(INITIAL_BACKOFF),
131            max_backoff: self.max_backoff.unwrap_or(MAX_BACKOFF),
132            connect_timeout: self.connect_timeout.unwrap_or(Duration::from_secs(5)),
133            request_timeout: self.request_timeout.unwrap_or(Duration::from_secs(10)),
134        };
135        config.validate()?;
136        Ok(config)
137    }
138}
139
140// ============================================================================
141// RAFT RPC MESSAGES
142// ============================================================================
143
144/// Full Raft RPC message types (per Raft paper §5-7)
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub enum RaftMessage<T> {
147    // ===== RequestVote RPC (§5.2) =====
148    /// Invoked by candidates to gather votes
149    RequestVote {
150        /// Candidate's term
151        term: Term,
152        /// Candidate requesting vote
153        candidate_id: NodeId,
154        /// Index of candidate's last log entry
155        last_log_index: LogIndex,
156        /// Term of candidate's last log entry
157        last_log_term: Term,
158    },
159
160    /// Response to RequestVote RPC
161    RequestVoteResponse {
162        /// Current term, for candidate to update itself
163        term: Term,
164        /// True means candidate received vote
165        vote_granted: bool,
166        /// Responder's node ID
167        from_id: NodeId,
168    },
169
170    // ===== AppendEntries RPC (§5.3) =====
171    /// Invoked by leader to replicate log entries; also used as heartbeat
172    AppendEntries {
173        /// Leader's term
174        term: Term,
175        /// So follower can redirect clients
176        leader_id: NodeId,
177        /// Index of log entry immediately preceding new ones
178        prev_log_index: LogIndex,
179        /// Term of prev_log_index entry
180        prev_log_term: Term,
181        /// Log entries to store (empty for heartbeat)
182        entries: Vec<LogEntry<T>>,
183        /// Leader's commit index
184        leader_commit: LogIndex,
185    },
186
187    /// Response to AppendEntries RPC
188    AppendEntriesResponse {
189        /// Current term, for leader to update itself
190        term: Term,
191        /// True if follower contained entry matching prev_log_index and prev_log_term
192        success: bool,
193        /// The index of the last entry replicated (for updating match_index)
194        match_index: LogIndex,
195        /// Responder's node ID
196        from_id: NodeId,
197    },
198
199    // ===== InstallSnapshot RPC (§7) =====
200    /// Invoked by leader to send chunks of a snapshot to a follower
201    InstallSnapshot {
202        /// Leader's term
203        term: Term,
204        /// So follower can redirect clients
205        leader_id: NodeId,
206        /// Snapshot data
207        snapshot: Snapshot<T>,
208    },
209
210    /// Response to InstallSnapshot RPC
211    InstallSnapshotResponse {
212        /// Current term, for leader to update itself
213        term: Term,
214        /// True if snapshot was accepted
215        success: bool,
216        /// Responder's node ID
217        from_id: NodeId,
218    },
219}
220
221/// Legacy packet type for backward compatibility
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub enum Packet {
224    VoteRequest {
225        term: Term,
226        candidate_id: NodeId,
227    },
228    VoteResponse {
229        term: Term,
230        vote_granted: bool,
231    },
232    Heartbeat {
233        leader_id: NodeId,
234        term: Term,
235    },
236    /// Configuration change (membership change)
237    ConfigChange {
238        /// Type of change: "add" or "remove"
239        change_type: String,
240        /// Address of the peer
241        peer_address: String,
242        /// Node ID
243        node_id: NodeId,
244    },
245}
246
247// ============================================================================
248// NETWORK TRAIT
249// ============================================================================
250
251/// Peer information for network transport
252#[derive(Debug, Clone)]
253pub struct PeerInfo {
254    pub id: NodeId,
255    pub address: String,
256}
257
258/// Abstract network interface for Raft consensus.
259///
260/// Handles sending/receiving Raft RPC messages to/from peers.
261#[async_trait]
262pub trait RaftNetwork<T: Send + Sync + Clone>: Send + Sync {
263    /// Sends a RequestVote RPC to a specific peer.
264    #[allow(clippy::too_many_arguments)]
265    async fn send_request_vote(
266        &self,
267        peer_id: NodeId,
268        term: Term,
269        candidate_id: NodeId,
270        last_log_index: LogIndex,
271        last_log_term: Term,
272    ) -> Result<Option<RaftMessage<T>>, NetworkError>;
273
274    /// Sends an AppendEntries RPC to a specific peer.
275    #[allow(clippy::too_many_arguments)]
276    async fn send_append_entries(
277        &self,
278        peer_id: NodeId,
279        term: Term,
280        leader_id: NodeId,
281        prev_log_index: LogIndex,
282        prev_log_term: Term,
283        entries: Vec<LogEntry<T>>,
284        leader_commit: LogIndex,
285    ) -> Result<Option<RaftMessage<T>>, NetworkError>;
286
287    /// Sends an InstallSnapshot RPC to a specific peer.
288    async fn send_install_snapshot(
289        &self,
290        peer_id: NodeId,
291        term: Term,
292        leader_id: NodeId,
293        snapshot: Snapshot<T>,
294    ) -> Result<Option<RaftMessage<T>>, NetworkError>;
295
296    /// Receives the next incoming RPC message.
297    async fn receive(&self) -> Result<RaftMessage<T>, NetworkError>;
298
299    /// Responds to an incoming RPC.
300    async fn respond(&self, to: NodeId, message: RaftMessage<T>) -> Result<(), NetworkError>;
301
302    /// Gets the list of peer IDs.
303    fn peer_ids(&self) -> Vec<NodeId>;
304
305    /// Updates the peer list.
306    async fn update_peers(&self, peers: Vec<PeerInfo>) -> Result<(), NetworkError>;
307}
308
309/// Network errors
310#[derive(Debug, thiserror::Error)]
311pub enum NetworkError {
312    #[error("Connection failed: {0}")]
313    ConnectionFailed(#[source] Box<dyn std::error::Error + Send + Sync>),
314    #[error("Timeout")]
315    Timeout,
316    #[error("Serialization error: {0}")]
317    SerializationError(String),
318    #[error("Peer not found: {0}")]
319    PeerNotFound(NodeId),
320    #[error("Transport error: {0}")]
321    TransportError(String),
322}
323
324/// Legacy ConsensusNetwork trait for backward compatibility
325#[async_trait]
326pub trait ConsensusNetwork: Send + Sync {
327    /// Broadcast a RequestVote RPC to all peers.
328    async fn broadcast_vote_request(&self, term: Term, candidate_id: NodeId) -> Result<(), String>;
329
330    /// Send a Heartbeat (empty AppendEntries) to all peers.
331    async fn send_heartbeat(&self, leader_id: NodeId, term: Term) -> Result<(), String>;
332
333    /// Receive the next packet from the network.
334    async fn receive(&self) -> Result<Packet, String>;
335
336    /// Update the list of peers (for dynamic membership).
337    async fn update_peers(&self, peers: Vec<String>) -> Result<(), String>;
338}
339
340// ============================================================================
341// IN-MEMORY NETWORK (for testing)
342// ============================================================================
343
344/// In-memory network transport for testing.
345pub struct InMemoryNetwork<T> {
346    _node_id: NodeId,
347    peers: Arc<tokio::sync::RwLock<HashMap<NodeId, PeerInfo>>>,
348    inbox: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<RaftMessage<T>>>>,
349    outboxes: Arc<tokio::sync::RwLock<HashMap<NodeId, tokio::sync::mpsc::Sender<RaftMessage<T>>>>>,
350}
351
352impl<T: Send + Sync + Clone + 'static> InMemoryNetwork<T> {
353    /// Creates a new in-memory network node.
354    pub fn new(node_id: NodeId, inbox_rx: tokio::sync::mpsc::Receiver<RaftMessage<T>>) -> Self {
355        Self {
356            _node_id: node_id,
357            peers: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
358            inbox: Arc::new(tokio::sync::Mutex::new(inbox_rx)),
359            outboxes: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
360        }
361    }
362
363    /// Registers a peer's outbox for sending messages.
364    pub async fn register_peer(
365        &self,
366        peer_id: NodeId,
367        address: String,
368        sender: tokio::sync::mpsc::Sender<RaftMessage<T>>,
369    ) {
370        self.peers.write().await.insert(
371            peer_id,
372            PeerInfo {
373                id: peer_id,
374                address,
375            },
376        );
377        self.outboxes.write().await.insert(peer_id, sender);
378    }
379}
380
381#[async_trait]
382impl<T: Send + Sync + Clone + serde::Serialize + serde::de::DeserializeOwned + 'static>
383    RaftNetwork<T> for InMemoryNetwork<T>
384{
385    async fn send_request_vote(
386        &self,
387        peer_id: NodeId,
388        term: Term,
389        candidate_id: NodeId,
390        last_log_index: LogIndex,
391        last_log_term: Term,
392    ) -> Result<Option<RaftMessage<T>>, NetworkError> {
393        let outboxes = self.outboxes.read().await;
394        let sender = outboxes
395            .get(&peer_id)
396            .ok_or(NetworkError::PeerNotFound(peer_id))?;
397
398        sender
399            .send(RaftMessage::RequestVote {
400                term,
401                candidate_id,
402                last_log_index,
403                last_log_term,
404            })
405            .await
406            .map_err(|e| NetworkError::TransportError(e.to_string()))?;
407
408        Ok(None) // Response comes via inbox
409    }
410
411    async fn send_append_entries(
412        &self,
413        peer_id: NodeId,
414        term: Term,
415        leader_id: NodeId,
416        prev_log_index: LogIndex,
417        prev_log_term: Term,
418        entries: Vec<LogEntry<T>>,
419        leader_commit: LogIndex,
420    ) -> Result<Option<RaftMessage<T>>, NetworkError> {
421        let outboxes = self.outboxes.read().await;
422        let sender = outboxes
423            .get(&peer_id)
424            .ok_or(NetworkError::PeerNotFound(peer_id))?;
425
426        sender
427            .send(RaftMessage::AppendEntries {
428                term,
429                leader_id,
430                prev_log_index,
431                prev_log_term,
432                entries,
433                leader_commit,
434            })
435            .await
436            .map_err(|e| NetworkError::TransportError(e.to_string()))?;
437
438        Ok(None)
439    }
440
441    async fn send_install_snapshot(
442        &self,
443        peer_id: NodeId,
444        term: Term,
445        leader_id: NodeId,
446        snapshot: Snapshot<T>,
447    ) -> Result<Option<RaftMessage<T>>, NetworkError> {
448        let outboxes = self.outboxes.read().await;
449        let sender = outboxes
450            .get(&peer_id)
451            .ok_or(NetworkError::PeerNotFound(peer_id))?;
452
453        sender
454            .send(RaftMessage::InstallSnapshot {
455                term,
456                leader_id,
457                snapshot,
458            })
459            .await
460            .map_err(|e| NetworkError::TransportError(e.to_string()))?;
461
462        Ok(None)
463    }
464
465    async fn receive(&self) -> Result<RaftMessage<T>, NetworkError> {
466        let mut inbox = self.inbox.lock().await;
467        inbox
468            .recv()
469            .await
470            .ok_or(NetworkError::TransportError("Channel closed".into()))
471    }
472
473    async fn respond(&self, to: NodeId, message: RaftMessage<T>) -> Result<(), NetworkError> {
474        let outboxes = self.outboxes.read().await;
475        let sender = outboxes.get(&to).ok_or(NetworkError::PeerNotFound(to))?;
476
477        sender
478            .send(message)
479            .await
480            .map_err(|e| NetworkError::TransportError(e.to_string()))
481    }
482
483    fn peer_ids(&self) -> Vec<NodeId> {
484        // Blocking read - only for simple cases
485        Vec::new()
486    }
487
488    async fn update_peers(&self, peers: Vec<PeerInfo>) -> Result<(), NetworkError> {
489        let mut peer_map = self.peers.write().await;
490        peer_map.clear();
491        for peer in peers {
492            peer_map.insert(peer.id, peer);
493        }
494        Ok(())
495    }
496}
497
498// ============================================================================
499// UDP TRANSPORT (legacy)
500// ============================================================================
501
502#[cfg(feature = "net")]
503pub mod udp {
504    use super::*;
505    use tokio::net::UdpSocket;
506    use tokio::sync::RwLock;
507
508    /// UDP-based network transport for consensus algorithms.
509    pub struct UdpTransport {
510        socket: Arc<UdpSocket>,
511        peers: Arc<RwLock<Vec<String>>>,
512        config: NetworkConfig,
513    }
514
515    impl std::fmt::Debug for UdpTransport {
516        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517            f.debug_struct("UdpTransport")
518                .field("peers", &self.peers)
519                .field("config", &self.config)
520                .finish_non_exhaustive()
521        }
522    }
523
524    impl UdpTransport {
525        /// Creates a new UDP transport with default configuration.
526        pub async fn new(bind_addr: &str, peers: Vec<String>) -> Result<Self, String> {
527            Self::with_config(bind_addr, peers, NetworkConfig::default()).await
528        }
529
530        /// Creates a new UDP transport with custom configuration.
531        pub async fn with_config(
532            bind_addr: &str,
533            peers: Vec<String>,
534            config: NetworkConfig,
535        ) -> Result<Self, String> {
536            let socket = UdpSocket::bind(bind_addr)
537                .await
538                .map_err(|e| format!("Failed to bind UDP socket to '{}': {}", bind_addr, e))?;
539
540            tracing::info!(
541                bind_addr = bind_addr,
542                peer_count = peers.len(),
543                buffer_size = config.buffer_size,
544                "Async UDP transport initialized"
545            );
546
547            Ok(Self {
548                socket: Arc::new(socket),
549                peers: Arc::new(RwLock::new(peers)),
550                config,
551            })
552        }
553    }
554
555    #[async_trait]
556    impl ConsensusNetwork for UdpTransport {
557        async fn broadcast_vote_request(
558            &self,
559            term: Term,
560            candidate_id: NodeId,
561        ) -> Result<(), String> {
562            let packet = Packet::VoteRequest { term, candidate_id };
563            let serialized = serde_json::to_vec(&packet).map_err(|e| e.to_string())?;
564
565
566            let peers = self.peers.read().await;
567            let mut last_error = None;
568            let mut success_count = 0;
569
570            for peer in peers.iter() {
571                match self.socket.send_to(&serialized, peer).await {
572                    Ok(_) => success_count += 1,
573                    Err(e) => {
574                        tracing::warn!("Failed to send to {}: {}", peer, e);
575                        last_error = Some(e);
576                    }
577                }
578            }
579            
580            if success_count == 0 && !peers.is_empty() {
581                if let Some(e) = last_error {
582                    return Err(format!("Failed to broadcast vote request to any peer: {}", e));
583                }
584            }
585            Ok(())
586        }
587
588        async fn send_heartbeat(&self, leader_id: NodeId, term: Term) -> Result<(), String> {
589            let packet = Packet::Heartbeat { leader_id, term };
590            let serialized = serde_json::to_vec(&packet).map_err(|e| e.to_string())?;
591
592            let peers = self.peers.read().await;
593            let mut last_error = None;
594            let mut success_count = 0;
595
596            for peer in peers.iter() {
597                match self.socket.send_to(&serialized, peer).await {
598                    Ok(_) => success_count += 1,
599                    Err(e) => {
600                         // Heartbeat failures are common, debug log only
601                        tracing::debug!("Failed to heartbeat {}: {}", peer, e);
602                        last_error = Some(e);
603                    }
604                }
605            }
606             if success_count == 0 && !peers.is_empty() {
607                if let Some(e) = last_error {
608                     return Err(format!("Failed to send heartbeats to any peer: {}", e));
609                }
610            }
611            Ok(())
612        }
613
614        async fn receive(&self) -> Result<Packet, String> {
615            let mut buf = vec![0u8; self.config.buffer_size];
616
617            loop {
618                match tokio::time::timeout(
619                    self.config.read_timeout,
620                    self.socket.recv_from(&mut buf),
621                )
622                .await
623                {
624                    Ok(io_result) => match io_result {
625                        Ok((amt, _src)) => {
626                            let packet: Packet =
627                                serde_json::from_slice(&buf[..amt]).map_err(|e| e.to_string())?;
628                            return Ok(packet);
629                        }
630                        Err(e) => {
631                            tracing::error!("UDP receive IO error: {}", e);
632                            tokio::time::sleep(Duration::from_millis(50)).await;
633                        }
634                    },
635                    Err(_) => {
636                        // Timeout - continue loop
637                    }
638                }
639            }
640        }
641
642        async fn update_peers(&self, new_peers: Vec<String>) -> Result<(), String> {
643            let mut peers = self.peers.write().await;
644            *peers = new_peers;
645            tracing::info!(peer_count = peers.len(), "Updated ConsensusNetwork peers");
646            Ok(())
647        }
648    }
649}
650
651// ============================================================================
652// TESTS
653// ============================================================================
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_network_config_default() {
661        let config = NetworkConfig::default();
662        assert_eq!(config.buffer_size, MAX_PACKET_SIZE);
663        assert_eq!(config.read_timeout, DEFAULT_READ_TIMEOUT);
664    }
665
666    #[test]
667    fn test_network_config_clamps_buffer_size() {
668        let config = NetworkConfig::new(100_000, Duration::from_millis(50));
669        assert_eq!(config.buffer_size, MAX_PACKET_SIZE);
670    }
671
672    #[test]
673    fn test_packet_serialization() {
674        let packet = Packet::VoteRequest {
675            term: 1,
676            candidate_id: 42,
677        };
678        let serialized = serde_json::to_vec(&packet).unwrap();
679        let deserialized: Packet = serde_json::from_slice(&serialized).unwrap();
680
681        match deserialized {
682            Packet::VoteRequest { term, candidate_id } => {
683                assert_eq!(term, 1);
684                assert_eq!(candidate_id, 42);
685            }
686            _ => panic!("Wrong packet type"),
687        }
688    }
689
690    #[cfg(feature = "net")]
691    mod udp_tests {
692        use super::super::udp::UdpTransport;
693        use super::super::*;
694
695        #[tokio::test]
696        async fn test_invalid_bind_address() {
697            let result = UdpTransport::new("not-an-address", vec![]).await;
698            assert!(result.is_err());
699        }
700
701        #[tokio::test]
702        async fn test_valid_bind_address() {
703            let result = UdpTransport::new("127.0.0.1:0", vec![]).await;
704            assert!(result.is_ok());
705        }
706    }
707}