Skip to main content

aegis_replication/
transport.rs

1//! Aegis Replication Transport
2//!
3//! Network transport layer for Raft message passing.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::node::NodeId;
9use crate::raft::{
10    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
11    VoteRequest, VoteResponse,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17// =============================================================================
18// Message Type
19// =============================================================================
20
21/// Types of messages in the Raft protocol.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum MessageType {
24    VoteRequest,
25    VoteResponse,
26    AppendEntries,
27    AppendEntriesResponse,
28    InstallSnapshot,
29    InstallSnapshotResponse,
30    ClientRequest,
31    ClientResponse,
32    Heartbeat,
33}
34
35// =============================================================================
36// Message
37// =============================================================================
38
39/// A message in the Raft protocol.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Message {
42    pub message_type: MessageType,
43    pub from: NodeId,
44    pub to: NodeId,
45    pub term: u64,
46    pub payload: MessagePayload,
47    pub timestamp: u64,
48}
49
50impl Message {
51    /// Create a new message.
52    pub fn new(
53        message_type: MessageType,
54        from: NodeId,
55        to: NodeId,
56        term: u64,
57        payload: MessagePayload,
58    ) -> Self {
59        Self {
60            message_type,
61            from,
62            to,
63            term,
64            payload,
65            timestamp: current_timestamp(),
66        }
67    }
68
69    /// Create a vote request message.
70    pub fn vote_request(from: NodeId, to: NodeId, request: VoteRequest) -> Self {
71        Self::new(
72            MessageType::VoteRequest,
73            from,
74            to,
75            request.term,
76            MessagePayload::VoteRequest(request),
77        )
78    }
79
80    /// Create a vote response message.
81    pub fn vote_response(from: NodeId, to: NodeId, response: VoteResponse) -> Self {
82        Self::new(
83            MessageType::VoteResponse,
84            from,
85            to,
86            response.term,
87            MessagePayload::VoteResponse(response),
88        )
89    }
90
91    /// Create an append entries message.
92    pub fn append_entries(from: NodeId, to: NodeId, request: AppendEntriesRequest) -> Self {
93        Self::new(
94            MessageType::AppendEntries,
95            from,
96            to,
97            request.term,
98            MessagePayload::AppendEntries(request),
99        )
100    }
101
102    /// Create an append entries response message.
103    pub fn append_entries_response(
104        from: NodeId,
105        to: NodeId,
106        response: AppendEntriesResponse,
107    ) -> Self {
108        Self::new(
109            MessageType::AppendEntriesResponse,
110            from,
111            to,
112            response.term,
113            MessagePayload::AppendEntriesResponse(response),
114        )
115    }
116
117    /// Create a heartbeat message.
118    pub fn heartbeat(from: NodeId, to: NodeId, term: u64) -> Self {
119        Self::new(
120            MessageType::Heartbeat,
121            from,
122            to,
123            term,
124            MessagePayload::Heartbeat,
125        )
126    }
127
128    /// Serialize the message.
129    pub fn to_bytes(&self) -> Vec<u8> {
130        serde_json::to_vec(self).unwrap_or_default()
131    }
132
133    /// Deserialize a message.
134    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
135        serde_json::from_slice(bytes).ok()
136    }
137}
138
139// =============================================================================
140// Message Payload
141// =============================================================================
142
143/// Payload of a Raft message.
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum MessagePayload {
146    VoteRequest(VoteRequest),
147    VoteResponse(VoteResponse),
148    AppendEntries(AppendEntriesRequest),
149    AppendEntriesResponse(AppendEntriesResponse),
150    InstallSnapshot(InstallSnapshotRequest),
151    InstallSnapshotResponse(InstallSnapshotResponse),
152    ClientRequest(ClientRequest),
153    ClientResponse(ClientResponse),
154    Heartbeat,
155    Empty,
156}
157
158// =============================================================================
159// Client Request/Response
160// =============================================================================
161
162/// A client request.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ClientRequest {
165    pub request_id: String,
166    pub operation: ClientOperation,
167}
168
169/// Client operations.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum ClientOperation {
172    Get { key: String },
173    Set { key: String, value: Vec<u8> },
174    Delete { key: String },
175}
176
177/// A client response.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClientResponse {
180    pub request_id: String,
181    pub success: bool,
182    pub value: Option<Vec<u8>>,
183    pub error: Option<String>,
184    pub leader_hint: Option<NodeId>,
185}
186
187impl ClientResponse {
188    pub fn success(request_id: String, value: Option<Vec<u8>>) -> Self {
189        Self {
190            request_id,
191            success: true,
192            value,
193            error: None,
194            leader_hint: None,
195        }
196    }
197
198    pub fn error(request_id: String, error: impl Into<String>) -> Self {
199        Self {
200            request_id,
201            success: false,
202            value: None,
203            error: Some(error.into()),
204            leader_hint: None,
205        }
206    }
207
208    pub fn not_leader(request_id: String, leader: Option<NodeId>) -> Self {
209        Self {
210            request_id,
211            success: false,
212            value: None,
213            error: Some("Not the leader".to_string()),
214            leader_hint: leader,
215        }
216    }
217}
218
219// =============================================================================
220// Transport Trait
221// =============================================================================
222
223/// Transport layer for Raft communication.
224pub trait Transport: Send + Sync {
225    /// Send a message to a node.
226    fn send(&self, message: Message) -> Result<(), TransportError>;
227
228    /// Receive a message (blocking).
229    fn recv(&self) -> Result<Message, TransportError>;
230
231    /// Try to receive a message (non-blocking).
232    fn try_recv(&self) -> Option<Message>;
233
234    /// Broadcast a message to all peers.
235    fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>>;
236}
237
238// =============================================================================
239// Transport Error
240// =============================================================================
241
242/// Errors that can occur during transport.
243#[derive(Debug, Clone)]
244pub enum TransportError {
245    ConnectionFailed(String),
246    Timeout,
247    Disconnected,
248    SerializationError(String),
249    ChannelFull,
250    Unknown(String),
251}
252
253impl std::fmt::Display for TransportError {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            Self::ConnectionFailed(addr) => write!(f, "Connection failed: {}", addr),
257            Self::Timeout => write!(f, "Timeout"),
258            Self::Disconnected => write!(f, "Disconnected"),
259            Self::SerializationError(e) => write!(f, "Serialization error: {}", e),
260            Self::ChannelFull => write!(f, "Channel full"),
261            Self::Unknown(e) => write!(f, "Unknown error: {}", e),
262        }
263    }
264}
265
266impl std::error::Error for TransportError {}
267
268// =============================================================================
269// In-Memory Transport (for testing)
270// =============================================================================
271
272/// In-memory transport for testing.
273pub struct InMemoryTransport {
274    node_id: NodeId,
275    inboxes: Arc<RwLock<HashMap<NodeId, Vec<Message>>>>,
276}
277
278impl InMemoryTransport {
279    /// Create a new in-memory transport network.
280    pub fn new_network(nodes: &[NodeId]) -> HashMap<NodeId, Self> {
281        let inboxes = Arc::new(RwLock::new(HashMap::new()));
282
283        for node in nodes {
284            inboxes
285                .write()
286                .expect("transport inboxes lock poisoned")
287                .insert(node.clone(), Vec::new());
288        }
289
290        nodes
291            .iter()
292            .map(|id| {
293                (
294                    id.clone(),
295                    Self {
296                        node_id: id.clone(),
297                        inboxes: Arc::clone(&inboxes),
298                    },
299                )
300            })
301            .collect()
302    }
303
304    /// Create a single transport (for testing).
305    pub fn new(node_id: NodeId) -> Self {
306        let inboxes = Arc::new(RwLock::new(HashMap::new()));
307        inboxes
308            .write()
309            .expect("transport inboxes lock poisoned")
310            .insert(node_id.clone(), Vec::new());
311        Self { node_id, inboxes }
312    }
313}
314
315impl Transport for InMemoryTransport {
316    fn send(&self, message: Message) -> Result<(), TransportError> {
317        let mut inboxes = self
318            .inboxes
319            .write()
320            .expect("transport inboxes lock poisoned");
321        if let Some(inbox) = inboxes.get_mut(&message.to) {
322            inbox.push(message);
323            Ok(())
324        } else {
325            Err(TransportError::ConnectionFailed(message.to.to_string()))
326        }
327    }
328
329    fn recv(&self) -> Result<Message, TransportError> {
330        loop {
331            if let Some(msg) = self.try_recv() {
332                return Ok(msg);
333            }
334            std::thread::sleep(std::time::Duration::from_millis(1));
335        }
336    }
337
338    fn try_recv(&self) -> Option<Message> {
339        let mut inboxes = self
340            .inboxes
341            .write()
342            .expect("transport inboxes lock poisoned");
343        if let Some(inbox) = inboxes.get_mut(&self.node_id) {
344            if !inbox.is_empty() {
345                return Some(inbox.remove(0));
346            }
347        }
348        None
349    }
350
351    fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>> {
352        peers
353            .iter()
354            .map(|peer| {
355                let mut msg = message.clone();
356                msg.to = peer.clone();
357                self.send(msg)
358            })
359            .collect()
360    }
361}
362
363// =============================================================================
364// Connection Pool
365// =============================================================================
366
367/// Connection pool for managing peer connections.
368pub struct ConnectionPool {
369    connections: RwLock<HashMap<NodeId, ConnectionState>>,
370    max_connections: usize,
371}
372
373/// State of a connection.
374#[derive(Debug, Clone)]
375pub struct ConnectionState {
376    pub node_id: NodeId,
377    pub address: String,
378    pub connected: bool,
379    pub last_activity: u64,
380    pub retry_count: u32,
381}
382
383impl ConnectionPool {
384    /// Create a new connection pool.
385    pub fn new(max_connections: usize) -> Self {
386        Self {
387            connections: RwLock::new(HashMap::new()),
388            max_connections,
389        }
390    }
391
392    /// Add a connection.
393    pub fn add(&self, node_id: NodeId, address: String) {
394        let mut conns = self
395            .connections
396            .write()
397            .expect("connection pool lock poisoned");
398        if conns.len() < self.max_connections {
399            conns.insert(
400                node_id.clone(),
401                ConnectionState {
402                    node_id,
403                    address,
404                    connected: false,
405                    last_activity: current_timestamp(),
406                    retry_count: 0,
407                },
408            );
409        }
410    }
411
412    /// Remove a connection.
413    pub fn remove(&self, node_id: &NodeId) {
414        self.connections
415            .write()
416            .expect("connection pool lock poisoned")
417            .remove(node_id);
418    }
419
420    /// Get a connection.
421    pub fn get(&self, node_id: &NodeId) -> Option<ConnectionState> {
422        self.connections
423            .read()
424            .expect("connection pool lock poisoned")
425            .get(node_id)
426            .cloned()
427    }
428
429    /// Mark a connection as connected.
430    pub fn mark_connected(&self, node_id: &NodeId) {
431        if let Some(conn) = self
432            .connections
433            .write()
434            .expect("connection pool lock poisoned")
435            .get_mut(node_id)
436        {
437            conn.connected = true;
438            conn.last_activity = current_timestamp();
439            conn.retry_count = 0;
440        }
441    }
442
443    /// Mark a connection as disconnected.
444    pub fn mark_disconnected(&self, node_id: &NodeId) {
445        if let Some(conn) = self
446            .connections
447            .write()
448            .expect("connection pool lock poisoned")
449            .get_mut(node_id)
450        {
451            conn.connected = false;
452            conn.retry_count += 1;
453        }
454    }
455
456    /// Get all connected nodes.
457    pub fn connected_nodes(&self) -> Vec<NodeId> {
458        self.connections
459            .read()
460            .expect("connection pool lock poisoned")
461            .values()
462            .filter(|c| c.connected)
463            .map(|c| c.node_id.clone())
464            .collect()
465    }
466
467    /// Get connection count.
468    pub fn len(&self) -> usize {
469        self.connections
470            .read()
471            .expect("connection pool lock poisoned")
472            .len()
473    }
474
475    /// Check if pool is empty.
476    pub fn is_empty(&self) -> bool {
477        self.len() == 0
478    }
479}
480
481fn current_timestamp() -> u64 {
482    std::time::SystemTime::now()
483        .duration_since(std::time::UNIX_EPOCH)
484        .map(|d| d.as_millis() as u64)
485        .unwrap_or(0)
486}
487
488// =============================================================================
489// Tests
490// =============================================================================
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_message_serialization() {
498        let request = VoteRequest {
499            term: 1,
500            candidate_id: NodeId::new("node1"),
501            last_log_index: 0,
502            last_log_term: 0,
503        };
504
505        let msg = Message::vote_request(NodeId::new("node1"), NodeId::new("node2"), request);
506
507        let bytes = msg.to_bytes();
508        let restored = Message::from_bytes(&bytes).unwrap();
509
510        assert_eq!(restored.message_type, MessageType::VoteRequest);
511        assert_eq!(restored.from.as_str(), "node1");
512        assert_eq!(restored.to.as_str(), "node2");
513    }
514
515    #[test]
516    fn test_in_memory_transport() {
517        let nodes = vec![NodeId::new("node1"), NodeId::new("node2")];
518        let transports = InMemoryTransport::new_network(&nodes);
519
520        let t1 = &transports[&NodeId::new("node1")];
521        let t2 = &transports[&NodeId::new("node2")];
522
523        let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node2"), 1);
524        t1.send(msg).unwrap();
525
526        let received = t2.try_recv().unwrap();
527        assert_eq!(received.message_type, MessageType::Heartbeat);
528        assert_eq!(received.from.as_str(), "node1");
529    }
530
531    #[test]
532    fn test_broadcast() {
533        let nodes = vec![
534            NodeId::new("node1"),
535            NodeId::new("node2"),
536            NodeId::new("node3"),
537        ];
538        let transports = InMemoryTransport::new_network(&nodes);
539
540        let t1 = &transports[&NodeId::new("node1")];
541
542        let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node1"), 1);
543        let peers = vec![NodeId::new("node2"), NodeId::new("node3")];
544        let results = t1.broadcast(msg, &peers);
545
546        assert!(results.iter().all(|r| r.is_ok()));
547    }
548
549    #[test]
550    fn test_connection_pool() {
551        let pool = ConnectionPool::new(10);
552
553        pool.add(NodeId::new("node1"), "127.0.0.1:5000".to_string());
554        pool.add(NodeId::new("node2"), "127.0.0.1:5001".to_string());
555
556        assert_eq!(pool.len(), 2);
557
558        pool.mark_connected(&NodeId::new("node1"));
559        let connected = pool.connected_nodes();
560        assert_eq!(connected.len(), 1);
561        assert_eq!(connected[0].as_str(), "node1");
562
563        pool.mark_disconnected(&NodeId::new("node1"));
564        let state = pool.get(&NodeId::new("node1")).unwrap();
565        assert!(!state.connected);
566        assert_eq!(state.retry_count, 1);
567    }
568
569    #[test]
570    fn test_client_response() {
571        let success = ClientResponse::success("req1".to_string(), Some(b"value".to_vec()));
572        assert!(success.success);
573        assert_eq!(success.value, Some(b"value".to_vec()));
574
575        let error = ClientResponse::error("req2".to_string(), "failed");
576        assert!(!error.success);
577        assert_eq!(error.error, Some("failed".to_string()));
578
579        let not_leader =
580            ClientResponse::not_leader("req3".to_string(), Some(NodeId::new("leader")));
581        assert!(!not_leader.success);
582        assert_eq!(not_leader.leader_hint, Some(NodeId::new("leader")));
583    }
584}