aegis_replication/
transport.rs

1//! Aegis Replication Transport
2//!
3//! Network transport layer for Raft message passing.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::node::NodeId;
9use crate::raft::{
10    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
11    VoteRequest, VoteResponse,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17// =============================================================================
18// Message Type
19// =============================================================================
20
21/// Types of messages in the Raft protocol.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum MessageType {
24    VoteRequest,
25    VoteResponse,
26    AppendEntries,
27    AppendEntriesResponse,
28    InstallSnapshot,
29    InstallSnapshotResponse,
30    ClientRequest,
31    ClientResponse,
32    Heartbeat,
33}
34
35// =============================================================================
36// Message
37// =============================================================================
38
39/// A message in the Raft protocol.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Message {
42    pub message_type: MessageType,
43    pub from: NodeId,
44    pub to: NodeId,
45    pub term: u64,
46    pub payload: MessagePayload,
47    pub timestamp: u64,
48}
49
50impl Message {
51    /// Create a new message.
52    pub fn new(
53        message_type: MessageType,
54        from: NodeId,
55        to: NodeId,
56        term: u64,
57        payload: MessagePayload,
58    ) -> Self {
59        Self {
60            message_type,
61            from,
62            to,
63            term,
64            payload,
65            timestamp: current_timestamp(),
66        }
67    }
68
69    /// Create a vote request message.
70    pub fn vote_request(from: NodeId, to: NodeId, request: VoteRequest) -> Self {
71        Self::new(
72            MessageType::VoteRequest,
73            from,
74            to,
75            request.term,
76            MessagePayload::VoteRequest(request),
77        )
78    }
79
80    /// Create a vote response message.
81    pub fn vote_response(from: NodeId, to: NodeId, response: VoteResponse) -> Self {
82        Self::new(
83            MessageType::VoteResponse,
84            from,
85            to,
86            response.term,
87            MessagePayload::VoteResponse(response),
88        )
89    }
90
91    /// Create an append entries message.
92    pub fn append_entries(from: NodeId, to: NodeId, request: AppendEntriesRequest) -> Self {
93        Self::new(
94            MessageType::AppendEntries,
95            from,
96            to,
97            request.term,
98            MessagePayload::AppendEntries(request),
99        )
100    }
101
102    /// Create an append entries response message.
103    pub fn append_entries_response(
104        from: NodeId,
105        to: NodeId,
106        response: AppendEntriesResponse,
107    ) -> Self {
108        Self::new(
109            MessageType::AppendEntriesResponse,
110            from,
111            to,
112            response.term,
113            MessagePayload::AppendEntriesResponse(response),
114        )
115    }
116
117    /// Create a heartbeat message.
118    pub fn heartbeat(from: NodeId, to: NodeId, term: u64) -> Self {
119        Self::new(
120            MessageType::Heartbeat,
121            from,
122            to,
123            term,
124            MessagePayload::Heartbeat,
125        )
126    }
127
128    /// Serialize the message.
129    pub fn to_bytes(&self) -> Vec<u8> {
130        serde_json::to_vec(self).unwrap_or_default()
131    }
132
133    /// Deserialize a message.
134    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
135        serde_json::from_slice(bytes).ok()
136    }
137}
138
139// =============================================================================
140// Message Payload
141// =============================================================================
142
143/// Payload of a Raft message.
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum MessagePayload {
146    VoteRequest(VoteRequest),
147    VoteResponse(VoteResponse),
148    AppendEntries(AppendEntriesRequest),
149    AppendEntriesResponse(AppendEntriesResponse),
150    InstallSnapshot(InstallSnapshotRequest),
151    InstallSnapshotResponse(InstallSnapshotResponse),
152    ClientRequest(ClientRequest),
153    ClientResponse(ClientResponse),
154    Heartbeat,
155    Empty,
156}
157
158// =============================================================================
159// Client Request/Response
160// =============================================================================
161
162/// A client request.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ClientRequest {
165    pub request_id: String,
166    pub operation: ClientOperation,
167}
168
169/// Client operations.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum ClientOperation {
172    Get { key: String },
173    Set { key: String, value: Vec<u8> },
174    Delete { key: String },
175}
176
177/// A client response.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClientResponse {
180    pub request_id: String,
181    pub success: bool,
182    pub value: Option<Vec<u8>>,
183    pub error: Option<String>,
184    pub leader_hint: Option<NodeId>,
185}
186
187impl ClientResponse {
188    pub fn success(request_id: String, value: Option<Vec<u8>>) -> Self {
189        Self {
190            request_id,
191            success: true,
192            value,
193            error: None,
194            leader_hint: None,
195        }
196    }
197
198    pub fn error(request_id: String, error: impl Into<String>) -> Self {
199        Self {
200            request_id,
201            success: false,
202            value: None,
203            error: Some(error.into()),
204            leader_hint: None,
205        }
206    }
207
208    pub fn not_leader(request_id: String, leader: Option<NodeId>) -> Self {
209        Self {
210            request_id,
211            success: false,
212            value: None,
213            error: Some("Not the leader".to_string()),
214            leader_hint: leader,
215        }
216    }
217}
218
219// =============================================================================
220// Transport Trait
221// =============================================================================
222
223/// Transport layer for Raft communication.
224pub trait Transport: Send + Sync {
225    /// Send a message to a node.
226    fn send(&self, message: Message) -> Result<(), TransportError>;
227
228    /// Receive a message (blocking).
229    fn recv(&self) -> Result<Message, TransportError>;
230
231    /// Try to receive a message (non-blocking).
232    fn try_recv(&self) -> Option<Message>;
233
234    /// Broadcast a message to all peers.
235    fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>>;
236}
237
238// =============================================================================
239// Transport Error
240// =============================================================================
241
242/// Errors that can occur during transport.
243#[derive(Debug, Clone)]
244pub enum TransportError {
245    ConnectionFailed(String),
246    Timeout,
247    Disconnected,
248    SerializationError(String),
249    ChannelFull,
250    Unknown(String),
251}
252
253impl std::fmt::Display for TransportError {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            Self::ConnectionFailed(addr) => write!(f, "Connection failed: {}", addr),
257            Self::Timeout => write!(f, "Timeout"),
258            Self::Disconnected => write!(f, "Disconnected"),
259            Self::SerializationError(e) => write!(f, "Serialization error: {}", e),
260            Self::ChannelFull => write!(f, "Channel full"),
261            Self::Unknown(e) => write!(f, "Unknown error: {}", e),
262        }
263    }
264}
265
266impl std::error::Error for TransportError {}
267
268// =============================================================================
269// In-Memory Transport (for testing)
270// =============================================================================
271
272/// In-memory transport for testing.
273pub struct InMemoryTransport {
274    node_id: NodeId,
275    inboxes: Arc<RwLock<HashMap<NodeId, Vec<Message>>>>,
276}
277
278impl InMemoryTransport {
279    /// Create a new in-memory transport network.
280    pub fn new_network(nodes: &[NodeId]) -> HashMap<NodeId, Self> {
281        let inboxes = Arc::new(RwLock::new(HashMap::new()));
282
283        for node in nodes {
284            inboxes.write().unwrap().insert(node.clone(), Vec::new());
285        }
286
287        nodes
288            .iter()
289            .map(|id| {
290                (
291                    id.clone(),
292                    Self {
293                        node_id: id.clone(),
294                        inboxes: Arc::clone(&inboxes),
295                    },
296                )
297            })
298            .collect()
299    }
300
301    /// Create a single transport (for testing).
302    pub fn new(node_id: NodeId) -> Self {
303        let inboxes = Arc::new(RwLock::new(HashMap::new()));
304        inboxes.write().unwrap().insert(node_id.clone(), Vec::new());
305        Self { node_id, inboxes }
306    }
307}
308
309impl Transport for InMemoryTransport {
310    fn send(&self, message: Message) -> Result<(), TransportError> {
311        let mut inboxes = self.inboxes.write().unwrap();
312        if let Some(inbox) = inboxes.get_mut(&message.to) {
313            inbox.push(message);
314            Ok(())
315        } else {
316            Err(TransportError::ConnectionFailed(message.to.to_string()))
317        }
318    }
319
320    fn recv(&self) -> Result<Message, TransportError> {
321        loop {
322            if let Some(msg) = self.try_recv() {
323                return Ok(msg);
324            }
325            std::thread::sleep(std::time::Duration::from_millis(1));
326        }
327    }
328
329    fn try_recv(&self) -> Option<Message> {
330        let mut inboxes = self.inboxes.write().unwrap();
331        if let Some(inbox) = inboxes.get_mut(&self.node_id) {
332            if !inbox.is_empty() {
333                return Some(inbox.remove(0));
334            }
335        }
336        None
337    }
338
339    fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>> {
340        peers
341            .iter()
342            .map(|peer| {
343                let mut msg = message.clone();
344                msg.to = peer.clone();
345                self.send(msg)
346            })
347            .collect()
348    }
349}
350
351// =============================================================================
352// Connection Pool
353// =============================================================================
354
355/// Connection pool for managing peer connections.
356pub struct ConnectionPool {
357    connections: RwLock<HashMap<NodeId, ConnectionState>>,
358    max_connections: usize,
359}
360
361/// State of a connection.
362#[derive(Debug, Clone)]
363pub struct ConnectionState {
364    pub node_id: NodeId,
365    pub address: String,
366    pub connected: bool,
367    pub last_activity: u64,
368    pub retry_count: u32,
369}
370
371impl ConnectionPool {
372    /// Create a new connection pool.
373    pub fn new(max_connections: usize) -> Self {
374        Self {
375            connections: RwLock::new(HashMap::new()),
376            max_connections,
377        }
378    }
379
380    /// Add a connection.
381    pub fn add(&self, node_id: NodeId, address: String) {
382        let mut conns = self.connections.write().unwrap();
383        if conns.len() < self.max_connections {
384            conns.insert(
385                node_id.clone(),
386                ConnectionState {
387                    node_id,
388                    address,
389                    connected: false,
390                    last_activity: current_timestamp(),
391                    retry_count: 0,
392                },
393            );
394        }
395    }
396
397    /// Remove a connection.
398    pub fn remove(&self, node_id: &NodeId) {
399        self.connections.write().unwrap().remove(node_id);
400    }
401
402    /// Get a connection.
403    pub fn get(&self, node_id: &NodeId) -> Option<ConnectionState> {
404        self.connections.read().unwrap().get(node_id).cloned()
405    }
406
407    /// Mark a connection as connected.
408    pub fn mark_connected(&self, node_id: &NodeId) {
409        if let Some(conn) = self.connections.write().unwrap().get_mut(node_id) {
410            conn.connected = true;
411            conn.last_activity = current_timestamp();
412            conn.retry_count = 0;
413        }
414    }
415
416    /// Mark a connection as disconnected.
417    pub fn mark_disconnected(&self, node_id: &NodeId) {
418        if let Some(conn) = self.connections.write().unwrap().get_mut(node_id) {
419            conn.connected = false;
420            conn.retry_count += 1;
421        }
422    }
423
424    /// Get all connected nodes.
425    pub fn connected_nodes(&self) -> Vec<NodeId> {
426        self.connections
427            .read()
428            .unwrap()
429            .values()
430            .filter(|c| c.connected)
431            .map(|c| c.node_id.clone())
432            .collect()
433    }
434
435    /// Get connection count.
436    pub fn len(&self) -> usize {
437        self.connections.read().unwrap().len()
438    }
439
440    /// Check if pool is empty.
441    pub fn is_empty(&self) -> bool {
442        self.len() == 0
443    }
444}
445
446fn current_timestamp() -> u64 {
447    std::time::SystemTime::now()
448        .duration_since(std::time::UNIX_EPOCH)
449        .map(|d| d.as_millis() as u64)
450        .unwrap_or(0)
451}
452
453// =============================================================================
454// Tests
455// =============================================================================
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_message_serialization() {
463        let request = VoteRequest {
464            term: 1,
465            candidate_id: NodeId::new("node1"),
466            last_log_index: 0,
467            last_log_term: 0,
468        };
469
470        let msg = Message::vote_request(
471            NodeId::new("node1"),
472            NodeId::new("node2"),
473            request,
474        );
475
476        let bytes = msg.to_bytes();
477        let restored = Message::from_bytes(&bytes).unwrap();
478
479        assert_eq!(restored.message_type, MessageType::VoteRequest);
480        assert_eq!(restored.from.as_str(), "node1");
481        assert_eq!(restored.to.as_str(), "node2");
482    }
483
484    #[test]
485    fn test_in_memory_transport() {
486        let nodes = vec![NodeId::new("node1"), NodeId::new("node2")];
487        let transports = InMemoryTransport::new_network(&nodes);
488
489        let t1 = &transports[&NodeId::new("node1")];
490        let t2 = &transports[&NodeId::new("node2")];
491
492        let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node2"), 1);
493        t1.send(msg).unwrap();
494
495        let received = t2.try_recv().unwrap();
496        assert_eq!(received.message_type, MessageType::Heartbeat);
497        assert_eq!(received.from.as_str(), "node1");
498    }
499
500    #[test]
501    fn test_broadcast() {
502        let nodes = vec![
503            NodeId::new("node1"),
504            NodeId::new("node2"),
505            NodeId::new("node3"),
506        ];
507        let transports = InMemoryTransport::new_network(&nodes);
508
509        let t1 = &transports[&NodeId::new("node1")];
510
511        let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node1"), 1);
512        let peers = vec![NodeId::new("node2"), NodeId::new("node3")];
513        let results = t1.broadcast(msg, &peers);
514
515        assert!(results.iter().all(|r| r.is_ok()));
516    }
517
518    #[test]
519    fn test_connection_pool() {
520        let pool = ConnectionPool::new(10);
521
522        pool.add(NodeId::new("node1"), "127.0.0.1:5000".to_string());
523        pool.add(NodeId::new("node2"), "127.0.0.1:5001".to_string());
524
525        assert_eq!(pool.len(), 2);
526
527        pool.mark_connected(&NodeId::new("node1"));
528        let connected = pool.connected_nodes();
529        assert_eq!(connected.len(), 1);
530        assert_eq!(connected[0].as_str(), "node1");
531
532        pool.mark_disconnected(&NodeId::new("node1"));
533        let state = pool.get(&NodeId::new("node1")).unwrap();
534        assert!(!state.connected);
535        assert_eq!(state.retry_count, 1);
536    }
537
538    #[test]
539    fn test_client_response() {
540        let success = ClientResponse::success("req1".to_string(), Some(b"value".to_vec()));
541        assert!(success.success);
542        assert_eq!(success.value, Some(b"value".to_vec()));
543
544        let error = ClientResponse::error("req2".to_string(), "failed");
545        assert!(!error.success);
546        assert_eq!(error.error, Some("failed".to_string()));
547
548        let not_leader = ClientResponse::not_leader("req3".to_string(), Some(NodeId::new("leader")));
549        assert!(!not_leader.success);
550        assert_eq!(not_leader.leader_hint, Some(NodeId::new("leader")));
551    }
552}