1use crate::node::NodeId;
9use crate::raft::{
10 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
11 VoteRequest, VoteResponse,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum MessageType {
24 VoteRequest,
25 VoteResponse,
26 AppendEntries,
27 AppendEntriesResponse,
28 InstallSnapshot,
29 InstallSnapshotResponse,
30 ClientRequest,
31 ClientResponse,
32 Heartbeat,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Message {
42 pub message_type: MessageType,
43 pub from: NodeId,
44 pub to: NodeId,
45 pub term: u64,
46 pub payload: MessagePayload,
47 pub timestamp: u64,
48}
49
50impl Message {
51 pub fn new(
53 message_type: MessageType,
54 from: NodeId,
55 to: NodeId,
56 term: u64,
57 payload: MessagePayload,
58 ) -> Self {
59 Self {
60 message_type,
61 from,
62 to,
63 term,
64 payload,
65 timestamp: current_timestamp(),
66 }
67 }
68
69 pub fn vote_request(from: NodeId, to: NodeId, request: VoteRequest) -> Self {
71 Self::new(
72 MessageType::VoteRequest,
73 from,
74 to,
75 request.term,
76 MessagePayload::VoteRequest(request),
77 )
78 }
79
80 pub fn vote_response(from: NodeId, to: NodeId, response: VoteResponse) -> Self {
82 Self::new(
83 MessageType::VoteResponse,
84 from,
85 to,
86 response.term,
87 MessagePayload::VoteResponse(response),
88 )
89 }
90
91 pub fn append_entries(from: NodeId, to: NodeId, request: AppendEntriesRequest) -> Self {
93 Self::new(
94 MessageType::AppendEntries,
95 from,
96 to,
97 request.term,
98 MessagePayload::AppendEntries(request),
99 )
100 }
101
102 pub fn append_entries_response(
104 from: NodeId,
105 to: NodeId,
106 response: AppendEntriesResponse,
107 ) -> Self {
108 Self::new(
109 MessageType::AppendEntriesResponse,
110 from,
111 to,
112 response.term,
113 MessagePayload::AppendEntriesResponse(response),
114 )
115 }
116
117 pub fn heartbeat(from: NodeId, to: NodeId, term: u64) -> Self {
119 Self::new(
120 MessageType::Heartbeat,
121 from,
122 to,
123 term,
124 MessagePayload::Heartbeat,
125 )
126 }
127
128 pub fn to_bytes(&self) -> Vec<u8> {
130 serde_json::to_vec(self).unwrap_or_default()
131 }
132
133 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
135 serde_json::from_slice(bytes).ok()
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum MessagePayload {
146 VoteRequest(VoteRequest),
147 VoteResponse(VoteResponse),
148 AppendEntries(AppendEntriesRequest),
149 AppendEntriesResponse(AppendEntriesResponse),
150 InstallSnapshot(InstallSnapshotRequest),
151 InstallSnapshotResponse(InstallSnapshotResponse),
152 ClientRequest(ClientRequest),
153 ClientResponse(ClientResponse),
154 Heartbeat,
155 Empty,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ClientRequest {
165 pub request_id: String,
166 pub operation: ClientOperation,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum ClientOperation {
172 Get { key: String },
173 Set { key: String, value: Vec<u8> },
174 Delete { key: String },
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClientResponse {
180 pub request_id: String,
181 pub success: bool,
182 pub value: Option<Vec<u8>>,
183 pub error: Option<String>,
184 pub leader_hint: Option<NodeId>,
185}
186
187impl ClientResponse {
188 pub fn success(request_id: String, value: Option<Vec<u8>>) -> Self {
189 Self {
190 request_id,
191 success: true,
192 value,
193 error: None,
194 leader_hint: None,
195 }
196 }
197
198 pub fn error(request_id: String, error: impl Into<String>) -> Self {
199 Self {
200 request_id,
201 success: false,
202 value: None,
203 error: Some(error.into()),
204 leader_hint: None,
205 }
206 }
207
208 pub fn not_leader(request_id: String, leader: Option<NodeId>) -> Self {
209 Self {
210 request_id,
211 success: false,
212 value: None,
213 error: Some("Not the leader".to_string()),
214 leader_hint: leader,
215 }
216 }
217}
218
219pub trait Transport: Send + Sync {
225 fn send(&self, message: Message) -> Result<(), TransportError>;
227
228 fn recv(&self) -> Result<Message, TransportError>;
230
231 fn try_recv(&self) -> Option<Message>;
233
234 fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>>;
236}
237
238#[derive(Debug, Clone)]
244pub enum TransportError {
245 ConnectionFailed(String),
246 Timeout,
247 Disconnected,
248 SerializationError(String),
249 ChannelFull,
250 Unknown(String),
251}
252
253impl std::fmt::Display for TransportError {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 match self {
256 Self::ConnectionFailed(addr) => write!(f, "Connection failed: {}", addr),
257 Self::Timeout => write!(f, "Timeout"),
258 Self::Disconnected => write!(f, "Disconnected"),
259 Self::SerializationError(e) => write!(f, "Serialization error: {}", e),
260 Self::ChannelFull => write!(f, "Channel full"),
261 Self::Unknown(e) => write!(f, "Unknown error: {}", e),
262 }
263 }
264}
265
266impl std::error::Error for TransportError {}
267
268pub struct InMemoryTransport {
274 node_id: NodeId,
275 inboxes: Arc<RwLock<HashMap<NodeId, Vec<Message>>>>,
276}
277
278impl InMemoryTransport {
279 pub fn new_network(nodes: &[NodeId]) -> HashMap<NodeId, Self> {
281 let inboxes = Arc::new(RwLock::new(HashMap::new()));
282
283 for node in nodes {
284 inboxes
285 .write()
286 .expect("transport inboxes lock poisoned")
287 .insert(node.clone(), Vec::new());
288 }
289
290 nodes
291 .iter()
292 .map(|id| {
293 (
294 id.clone(),
295 Self {
296 node_id: id.clone(),
297 inboxes: Arc::clone(&inboxes),
298 },
299 )
300 })
301 .collect()
302 }
303
304 pub fn new(node_id: NodeId) -> Self {
306 let inboxes = Arc::new(RwLock::new(HashMap::new()));
307 inboxes
308 .write()
309 .expect("transport inboxes lock poisoned")
310 .insert(node_id.clone(), Vec::new());
311 Self { node_id, inboxes }
312 }
313}
314
315impl Transport for InMemoryTransport {
316 fn send(&self, message: Message) -> Result<(), TransportError> {
317 let mut inboxes = self
318 .inboxes
319 .write()
320 .expect("transport inboxes lock poisoned");
321 if let Some(inbox) = inboxes.get_mut(&message.to) {
322 inbox.push(message);
323 Ok(())
324 } else {
325 Err(TransportError::ConnectionFailed(message.to.to_string()))
326 }
327 }
328
329 fn recv(&self) -> Result<Message, TransportError> {
330 loop {
331 if let Some(msg) = self.try_recv() {
332 return Ok(msg);
333 }
334 std::thread::sleep(std::time::Duration::from_millis(1));
335 }
336 }
337
338 fn try_recv(&self) -> Option<Message> {
339 let mut inboxes = self
340 .inboxes
341 .write()
342 .expect("transport inboxes lock poisoned");
343 if let Some(inbox) = inboxes.get_mut(&self.node_id) {
344 if !inbox.is_empty() {
345 return Some(inbox.remove(0));
346 }
347 }
348 None
349 }
350
351 fn broadcast(&self, message: Message, peers: &[NodeId]) -> Vec<Result<(), TransportError>> {
352 peers
353 .iter()
354 .map(|peer| {
355 let mut msg = message.clone();
356 msg.to = peer.clone();
357 self.send(msg)
358 })
359 .collect()
360 }
361}
362
363pub struct ConnectionPool {
369 connections: RwLock<HashMap<NodeId, ConnectionState>>,
370 max_connections: usize,
371}
372
373#[derive(Debug, Clone)]
375pub struct ConnectionState {
376 pub node_id: NodeId,
377 pub address: String,
378 pub connected: bool,
379 pub last_activity: u64,
380 pub retry_count: u32,
381}
382
383impl ConnectionPool {
384 pub fn new(max_connections: usize) -> Self {
386 Self {
387 connections: RwLock::new(HashMap::new()),
388 max_connections,
389 }
390 }
391
392 pub fn add(&self, node_id: NodeId, address: String) {
394 let mut conns = self
395 .connections
396 .write()
397 .expect("connection pool lock poisoned");
398 if conns.len() < self.max_connections {
399 conns.insert(
400 node_id.clone(),
401 ConnectionState {
402 node_id,
403 address,
404 connected: false,
405 last_activity: current_timestamp(),
406 retry_count: 0,
407 },
408 );
409 }
410 }
411
412 pub fn remove(&self, node_id: &NodeId) {
414 self.connections
415 .write()
416 .expect("connection pool lock poisoned")
417 .remove(node_id);
418 }
419
420 pub fn get(&self, node_id: &NodeId) -> Option<ConnectionState> {
422 self.connections
423 .read()
424 .expect("connection pool lock poisoned")
425 .get(node_id)
426 .cloned()
427 }
428
429 pub fn mark_connected(&self, node_id: &NodeId) {
431 if let Some(conn) = self
432 .connections
433 .write()
434 .expect("connection pool lock poisoned")
435 .get_mut(node_id)
436 {
437 conn.connected = true;
438 conn.last_activity = current_timestamp();
439 conn.retry_count = 0;
440 }
441 }
442
443 pub fn mark_disconnected(&self, node_id: &NodeId) {
445 if let Some(conn) = self
446 .connections
447 .write()
448 .expect("connection pool lock poisoned")
449 .get_mut(node_id)
450 {
451 conn.connected = false;
452 conn.retry_count += 1;
453 }
454 }
455
456 pub fn connected_nodes(&self) -> Vec<NodeId> {
458 self.connections
459 .read()
460 .expect("connection pool lock poisoned")
461 .values()
462 .filter(|c| c.connected)
463 .map(|c| c.node_id.clone())
464 .collect()
465 }
466
467 pub fn len(&self) -> usize {
469 self.connections
470 .read()
471 .expect("connection pool lock poisoned")
472 .len()
473 }
474
475 pub fn is_empty(&self) -> bool {
477 self.len() == 0
478 }
479}
480
481fn current_timestamp() -> u64 {
482 std::time::SystemTime::now()
483 .duration_since(std::time::UNIX_EPOCH)
484 .map(|d| d.as_millis() as u64)
485 .unwrap_or(0)
486}
487
488#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_message_serialization() {
498 let request = VoteRequest {
499 term: 1,
500 candidate_id: NodeId::new("node1"),
501 last_log_index: 0,
502 last_log_term: 0,
503 };
504
505 let msg = Message::vote_request(NodeId::new("node1"), NodeId::new("node2"), request);
506
507 let bytes = msg.to_bytes();
508 let restored = Message::from_bytes(&bytes).unwrap();
509
510 assert_eq!(restored.message_type, MessageType::VoteRequest);
511 assert_eq!(restored.from.as_str(), "node1");
512 assert_eq!(restored.to.as_str(), "node2");
513 }
514
515 #[test]
516 fn test_in_memory_transport() {
517 let nodes = vec![NodeId::new("node1"), NodeId::new("node2")];
518 let transports = InMemoryTransport::new_network(&nodes);
519
520 let t1 = &transports[&NodeId::new("node1")];
521 let t2 = &transports[&NodeId::new("node2")];
522
523 let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node2"), 1);
524 t1.send(msg).unwrap();
525
526 let received = t2.try_recv().unwrap();
527 assert_eq!(received.message_type, MessageType::Heartbeat);
528 assert_eq!(received.from.as_str(), "node1");
529 }
530
531 #[test]
532 fn test_broadcast() {
533 let nodes = vec![
534 NodeId::new("node1"),
535 NodeId::new("node2"),
536 NodeId::new("node3"),
537 ];
538 let transports = InMemoryTransport::new_network(&nodes);
539
540 let t1 = &transports[&NodeId::new("node1")];
541
542 let msg = Message::heartbeat(NodeId::new("node1"), NodeId::new("node1"), 1);
543 let peers = vec![NodeId::new("node2"), NodeId::new("node3")];
544 let results = t1.broadcast(msg, &peers);
545
546 assert!(results.iter().all(|r| r.is_ok()));
547 }
548
549 #[test]
550 fn test_connection_pool() {
551 let pool = ConnectionPool::new(10);
552
553 pool.add(NodeId::new("node1"), "127.0.0.1:5000".to_string());
554 pool.add(NodeId::new("node2"), "127.0.0.1:5001".to_string());
555
556 assert_eq!(pool.len(), 2);
557
558 pool.mark_connected(&NodeId::new("node1"));
559 let connected = pool.connected_nodes();
560 assert_eq!(connected.len(), 1);
561 assert_eq!(connected[0].as_str(), "node1");
562
563 pool.mark_disconnected(&NodeId::new("node1"));
564 let state = pool.get(&NodeId::new("node1")).unwrap();
565 assert!(!state.connected);
566 assert_eq!(state.retry_count, 1);
567 }
568
569 #[test]
570 fn test_client_response() {
571 let success = ClientResponse::success("req1".to_string(), Some(b"value".to_vec()));
572 assert!(success.success);
573 assert_eq!(success.value, Some(b"value".to_vec()));
574
575 let error = ClientResponse::error("req2".to_string(), "failed");
576 assert!(!error.success);
577 assert_eq!(error.error, Some("failed".to_string()));
578
579 let not_leader =
580 ClientResponse::not_leader("req3".to_string(), Some(NodeId::new("leader")));
581 assert!(!not_leader.success);
582 assert_eq!(not_leader.leader_hint, Some(NodeId::new("leader")));
583 }
584}