1use crate::node::NodeId;
9use crate::raft::{
10 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
11 VoteRequest, VoteResponse,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum MessageType {
24 VoteRequest,
25 VoteResponse,
26 AppendEntries,
27 AppendEntriesResponse,
28 InstallSnapshot,
29 InstallSnapshotResponse,
30 ClientRequest,
31 ClientResponse,
32 Heartbeat,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Message {
42 pub message_type: MessageType,
43 pub from: NodeId,
44 pub to: NodeId,
45 pub term: u64,
46 pub payload: MessagePayload,
47 pub timestamp: u64,
48}
49
50impl Message {
51 pub fn new(
53 message_type: MessageType,
54 from: NodeId,
55 to: NodeId,
56 term: u64,
57 payload: MessagePayload,
58 ) -> Self {
59 Self {
60 message_type,
61 from,
62 to,
63 term,
64 payload,
65 timestamp: current_timestamp(),
66 }
67 }
68
69 pub fn vote_request(from: NodeId, to: NodeId, request: VoteRequest) -> Self {
71 Self::new(
72 MessageType::VoteRequest,
73 from,
74 to,
75 request.term,
76 MessagePayload::VoteRequest(request),
77 )
78 }
79
80 pub fn vote_response(from: NodeId, to: NodeId, response: VoteResponse) -> Self {
82 Self::new(
83 MessageType::VoteResponse,
84 from,
85 to,
86 response.term,
87 MessagePayload::VoteResponse(response),
88 )
89 }
90
91 pub fn append_entries(from: NodeId, to: NodeId, request: AppendEntriesRequest) -> Self {
93 Self::new(
94 MessageType::AppendEntries,
95 from,
96 to,
97 request.term,
98 MessagePayload::AppendEntries(request),
99 )
100 }
101
102 pub fn append_entries_response(
104 from: NodeId,
105 to: NodeId,
106 response: AppendEntriesResponse,
107 ) -> Self {
108 Self::new(
109 MessageType::AppendEntriesResponse,
110 from,
111 to,
112 response.term,
113 MessagePayload::AppendEntriesResponse(response),
114 )
115 }
116
117 pub fn heartbeat(from: NodeId, to: NodeId, term: u64) -> Self {
119 Self::new(
120 MessageType::Heartbeat,
121 from,
122 to,
123 term,
124 MessagePayload::Heartbeat,
125 )
126 }
127
128 pub fn to_bytes(&self) -> Vec<u8> {
130 serde_json::to_vec(self).unwrap_or_default()
131 }
132
133 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
135 serde_json::from_slice(bytes).ok()
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum MessagePayload {
146 VoteRequest(VoteRequest),
147 VoteResponse(VoteResponse),
148 AppendEntries(AppendEntriesRequest),
149 AppendEntriesResponse(AppendEntriesResponse),
150 InstallSnapshot(InstallSnapshotRequest),
151 InstallSnapshotResponse(InstallSnapshotResponse),
152 ClientRequest(ClientRequest),
153 ClientResponse(ClientResponse),
154 Heartbeat,
155 Empty,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ClientRequest {
165 pub request_id: String,
166 pub operation: ClientOperation,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum ClientOperation {
172 Get { key: String },
173 Set { key: String, value: Vec<u8> },
174 Delete { key: String },
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClientResponse {
180 pub request_id: String,
181 pub success: bool,
182 pub value: Option<Vec<u8>>,
183 pub error: Option<String>,
184 pub leader_hint: Option<NodeId>,
185}
186
187impl ClientResponse {
188 pub fn success(request_id: String, value: Option<Vec<u8>>) -> Self {
189 Self {
190 request_id,
191 success: true,
192 value,
193 error: None,
194 leader_hint: None,
195 }
196 }
197
198 pub fn error(request_id: String, error: impl Into<String>) -> Self {
199 Self {
200 request_id,
201 success: false,
202 value: None,
203 error: Some(error.into()),
204 leader_hint: None,
205 }
206 }
207
208 pub fn not_leader(request_id: String, leader: Option<NodeId>) -> Self {
209 Self {
210 request_id,
211 success: false,
212 value: None,
213 error: Some("Not the leader".to_string()),
214 leader_hint: leader,
215 }
216 }
217}
218
219pub trait Transport: Send + Sync {
225 fn send(&self, message: Message) -> Result<(), TransportError>;
227
228 fn recv(&self) -> Result<Message, TransportError>;
230
231 fn try_recv(&self) -> Option<Message>;
233
234 fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>>;
236}
237
238#[derive(Debug, Clone)]
244pub enum TransportError {
245 ConnectionFailed(String),
246 Timeout,
247 Disconnected,
248 SerializationError(String),
249 ChannelFull,
250 Unknown(String),
251}
252
253impl std::fmt::Display for TransportError {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 match self {
256 Self::ConnectionFailed(addr) => write!(f, "Connection failed: {}", addr),
257 Self::Timeout => write!(f, "Timeout"),
258 Self::Disconnected => write!(f, "Disconnected"),
259 Self::SerializationError(e) => write!(f, "Serialization error: {}", e),
260 Self::ChannelFull => write!(f, "Channel full"),
261 Self::Unknown(e) => write!(f, "Unknown error: {}", e),
262 }
263 }
264}
265
266impl std::error::Error for TransportError {}
267
268pub struct InMemoryTransport {
274 node_id: NodeId,
275 inboxes: Arc<RwLock<HashMap<NodeId, Vec<Message>>>>,
276}
277
278impl InMemoryTransport {
279 pub fn new_network(nodes: &[NodeId]) -> HashMap<NodeId, Self> {
281 let inboxes = Arc::new(RwLock::new(HashMap::new()));
282
283 for node in nodes {
284 inboxes.write().unwrap().insert(node.clone(), Vec::new());
285 }
286
287 nodes
288 .iter()
289 .map(|id| {
290 (
291 id.clone(),
292 Self {
293 node_id: id.clone(),
294 inboxes: Arc::clone(&inboxes),
295 },
296 )
297 })
298 .collect()
299 }
300
301 pub fn new(node_id: NodeId) -> Self {
303 let inboxes = Arc::new(RwLock::new(HashMap::new()));
304 inboxes.write().unwrap().insert(node_id.clone(), Vec::new());
305 Self { node_id, inboxes }
306 }
307}
308
309impl Transport for InMemoryTransport {
310 fn send(&self, message: Message) -> Result<(), TransportError> {
311 let mut inboxes = self.inboxes.write().unwrap();
312 if let Some(inbox) = inboxes.get_mut(&message.to) {
313 inbox.push(message);
314 Ok(())
315 } else {
316 Err(TransportError::ConnectionFailed(message.to.to_string()))
317 }
318 }
319
320 fn recv(&self) -> Result<Message, TransportError> {
321 loop {
322 if let Some(msg) = self.try_recv() {
323 return Ok(msg);
324 }
325 std::thread::sleep(std::time::Duration::from_millis(1));
326 }
327 }
328
329 fn try_recv(&self) -> Option<Message> {
330 let mut inboxes = self.inboxes.write().unwrap();
331 if let Some(inbox) = inboxes.get_mut(&self.node_id) {
332 if !inbox.is_empty() {
333 return Some(inbox.remove(0));
334 }
335 }
336 None
337 }
338
339 fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>> {
340 peers
341 .iter()
342 .map(|peer| {
343 let mut msg = message.clone();
344 msg.to = peer.clone();
345 self.send(msg)
346 })
347 .collect()
348 }
349}
350
351pub struct ConnectionPool {
357 connections: RwLock<HashMap<NodeId, ConnectionState>>,
358 max_connections: usize,
359}
360
361#[derive(Debug, Clone)]
363pub struct ConnectionState {
364 pub node_id: NodeId,
365 pub address: String,
366 pub connected: bool,
367 pub last_activity: u64,
368 pub retry_count: u32,
369}
370
371impl ConnectionPool {
372 pub fn new(max_connections: usize) -> Self {
374 Self {
375 connections: RwLock::new(HashMap::new()),
376 max_connections,
377 }
378 }
379
380 pub fn add(&self, node_id: NodeId, address: String) {
382 let mut conns = self.connections.write().unwrap();
383 if conns.len() < self.max_connections {
384 conns.insert(
385 node_id.clone(),
386 ConnectionState {
387 node_id,
388 address,
389 connected: false,
390 last_activity: current_timestamp(),
391 retry_count: 0,
392 },
393 );
394 }
395 }
396
397 pub fn remove(&self, node_id: &NodeId) {
399 self.connections.write().unwrap().remove(node_id);
400 }
401
402 pub fn get(&self, node_id: &NodeId) -> Option<ConnectionState> {
404 self.connections.read().unwrap().get(node_id).cloned()
405 }
406
407 pub fn mark_connected(&self, node_id: &NodeId) {
409 if let Some(conn) = self.connections.write().unwrap().get_mut(node_id) {
410 conn.connected = true;
411 conn.last_activity = current_timestamp();
412 conn.retry_count = 0;
413 }
414 }
415
416 pub fn mark_disconnected(&self, node_id: &NodeId) {
418 if let Some(conn) = self.connections.write().unwrap().get_mut(node_id) {
419 conn.connected = false;
420 conn.retry_count += 1;
421 }
422 }
423
424 pub fn connected_nodes(&self) -> Vec<NodeId> {
426 self.connections
427 .read()
428 .unwrap()
429 .values()
430 .filter(|c| c.connected)
431 .map(|c| c.node_id.clone())
432 .collect()
433 }
434
435 pub fn len(&self) -> usize {
437 self.connections.read().unwrap().len()
438 }
439
440 pub fn is_empty(&self) -> bool {
442 self.len() == 0
443 }
444}
445
446fn current_timestamp() -> u64 {
447 std::time::SystemTime::now()
448 .duration_since(std::time::UNIX_EPOCH)
449 .map(|d| d.as_millis() as u64)
450 .unwrap_or(0)
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_message_serialization() {
463 let request = VoteRequest {
464 term: 1,
465 candidate_id: NodeId::new("node1"),
466 last_log_index: 0,
467 last_log_term: 0,
468 };
469
470 let msg = Message::vote_request(
471 NodeId::new("node1"),
472 NodeId::new("node2"),
473 request,
474 );
475
476 let bytes = msg.to_bytes();
477 let restored = Message::from_bytes(&bytes).unwrap();
478
479 assert_eq!(restored.message_type, MessageType::VoteRequest);
480 assert_eq!(restored.from.as_str(), "node1");
481 assert_eq!(restored.to.as_str(), "node2");
482 }
483
484 #[test]
485 fn test_in_memory_transport() {
486 let nodes = vec![NodeId::new("node1"), NodeId::new("node2")];
487 let transports = InMemoryTransport::new_network(&nodes);
488
489 let t1 = &transports[&NodeId::new("node1")];
490 let t2 = &transports[&NodeId::new("node2")];
491
492 let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node2"), 1);
493 t1.send(msg).unwrap();
494
495 let received = t2.try_recv().unwrap();
496 assert_eq!(received.message_type, MessageType::Heartbeat);
497 assert_eq!(received.from.as_str(), "node1");
498 }
499
500 #[test]
501 fn test_broadcast() {
502 let nodes = vec![
503 NodeId::new("node1"),
504 NodeId::new("node2"),
505 NodeId::new("node3"),
506 ];
507 let transports = InMemoryTransport::new_network(&nodes);
508
509 let t1 = &transports[&NodeId::new("node1")];
510
511 let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node1"), 1);
512 let peers = vec![NodeId::new("node2"), NodeId::new("node3")];
513 let results = t1.broadcast(msg, &peers);
514
515 assert!(results.iter().all(|r| r.is_ok()));
516 }
517
518 #[test]
519 fn test_connection_pool() {
520 let pool = ConnectionPool::new(10);
521
522 pool.add(NodeId::new("node1"), "127.0.0.1:5000".to_string());
523 pool.add(NodeId::new("node2"), "127.0.0.1:5001".to_string());
524
525 assert_eq!(pool.len(), 2);
526
527 pool.mark_connected(&NodeId::new("node1"));
528 let connected = pool.connected_nodes();
529 assert_eq!(connected.len(), 1);
530 assert_eq!(connected[0].as_str(), "node1");
531
532 pool.mark_disconnected(&NodeId::new("node1"));
533 let state = pool.get(&NodeId::new("node1")).unwrap();
534 assert!(!state.connected);
535 assert_eq!(state.retry_count, 1);
536 }
537
538 #[test]
539 fn test_client_response() {
540 let success = ClientResponse::success("req1".to_string(), Some(b"value".to_vec()));
541 assert!(success.success);
542 assert_eq!(success.value, Some(b"value".to_vec()));
543
544 let error = ClientResponse::error("req2".to_string(), "failed");
545 assert!(!error.success);
546 assert_eq!(error.error, Some("failed".to_string()));
547
548 let not_leader = ClientResponse::not_leader("req3".to_string(), Some(NodeId::new("leader")));
549 assert!(!not_leader.success);
550 assert_eq!(not_leader.leader_hint, Some(NodeId::new("leader")));
551 }
552}