1use crate::raft::{
17 AppendEntriesRequest, AppendEntriesResponse, NodeId, RequestVoteRequest, RequestVoteResponse,
18};
19use async_trait::async_trait;
20use dashmap::DashMap;
21use ipfrs_core::{Error, Result};
22use serde::{Deserialize, Serialize};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{mpsc, RwLock};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum Message {
32 AppendEntries(AppendEntriesRequest),
34 AppendEntriesResponse(AppendEntriesResponse),
36 RequestVote(RequestVoteRequest),
38 RequestVoteResponse(RequestVoteResponse),
40}
41
42#[async_trait]
44pub trait Transport: Send + Sync {
45 async fn send(&self, target: NodeId, message: Message) -> Result<()>;
47
48 async fn recv(&self) -> Result<(NodeId, Message)>;
50
51 fn node_id(&self) -> NodeId;
53
54 async fn close(&self) -> Result<()>;
56}
57
58pub struct InMemoryTransport {
63 node_id: NodeId,
65 registry: Arc<DashMap<NodeId, mpsc::UnboundedSender<(NodeId, Message)>>>,
67 rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(NodeId, Message)>>>,
69}
70
71impl InMemoryTransport {
72 pub fn new(
78 node_id: NodeId,
79 registry: Arc<DashMap<NodeId, mpsc::UnboundedSender<(NodeId, Message)>>>,
80 ) -> Self {
81 let (tx, rx) = mpsc::unbounded_channel();
82 registry.insert(node_id, tx);
83
84 Self {
85 node_id,
86 registry,
87 rx: Arc::new(tokio::sync::Mutex::new(rx)),
88 }
89 }
90
91 pub fn new_registry() -> Arc<DashMap<NodeId, mpsc::UnboundedSender<(NodeId, Message)>>> {
93 Arc::new(DashMap::new())
94 }
95}
96
97#[async_trait]
98impl Transport for InMemoryTransport {
99 async fn send(&self, target: NodeId, message: Message) -> Result<()> {
100 if let Some(tx) = self.registry.get(&target) {
101 tx.send((self.node_id, message))
102 .map_err(|_| Error::Network("Failed to send message".into()))?;
103 Ok(())
104 } else {
105 Err(Error::Network(format!("Node {} not found", target.0)))
106 }
107 }
108
109 async fn recv(&self) -> Result<(NodeId, Message)> {
110 let mut rx = self.rx.lock().await;
111 rx.recv()
112 .await
113 .ok_or_else(|| Error::Network("Transport closed".into()))
114 }
115
116 fn node_id(&self) -> NodeId {
117 self.node_id
118 }
119
120 async fn close(&self) -> Result<()> {
121 self.registry.remove(&self.node_id);
122 Ok(())
123 }
124}
125
126pub struct TcpTransport {
131 node_id: NodeId,
133 listen_addr: SocketAddr,
135 peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
137 rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(NodeId, Message)>>>,
139 tx: mpsc::UnboundedSender<(NodeId, Message)>,
141 config: TransportConfig,
143 shutdown: Arc<RwLock<bool>>,
145}
146
147impl TcpTransport {
148 pub async fn new(
156 node_id: NodeId,
157 listen_addr: SocketAddr,
158 peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
159 config: TransportConfig,
160 ) -> Result<Self> {
161 let (tx, rx) = mpsc::unbounded_channel();
162 let shutdown = Arc::new(RwLock::new(false));
163
164 let transport = Self {
165 node_id,
166 listen_addr,
167 peer_addrs,
168 rx: Arc::new(tokio::sync::Mutex::new(rx)),
169 tx,
170 config,
171 shutdown,
172 };
173
174 transport.start_listener().await
176 }
177
178 async fn start_listener(self) -> Result<Self> {
180 let listener = TcpListener::bind(self.listen_addr)
181 .await
182 .map_err(|e| Error::Network(format!("Failed to bind: {e}")))?;
183
184 let actual_addr = listener
186 .local_addr()
187 .map_err(|e| Error::Network(format!("Failed to get local address: {e}")))?;
188
189 let tx = self.tx.clone();
190 let max_size = self.config.max_message_size;
191 let shutdown = self.shutdown.clone();
192
193 tokio::spawn(async move {
194 loop {
195 if *shutdown.read().await {
197 break;
198 }
199
200 match listener.accept().await {
201 Ok((mut stream, _)) => {
202 let tx = tx.clone();
203 tokio::spawn(async move {
204 if let Err(e) = Self::handle_connection(&mut stream, tx, max_size).await
205 {
206 tracing::warn!("Connection error: {}", e);
207 }
208 });
209 }
210 Err(e) => {
211 tracing::error!("Accept error: {}", e);
212 }
213 }
214 }
215 });
216
217 Ok(Self {
218 listen_addr: actual_addr,
219 ..self
220 })
221 }
222
223 async fn handle_connection(
225 stream: &mut TcpStream,
226 tx: mpsc::UnboundedSender<(NodeId, Message)>,
227 max_size: usize,
228 ) -> Result<()> {
229 let len = stream
231 .read_u32()
232 .await
233 .map_err(|e| Error::Network(format!("Failed to read length: {e}")))?
234 as usize;
235
236 if len > max_size {
237 return Err(Error::Network(format!(
238 "Message too large: {len} > {max_size}"
239 )));
240 }
241
242 let mut buf = vec![0u8; len];
244 stream
245 .read_exact(&mut buf)
246 .await
247 .map_err(|e| Error::Network(format!("Failed to read message: {e}")))?;
248
249 let (sender_id, message): (NodeId, Message) =
251 oxicode::serde::decode_owned_from_slice(&buf, oxicode::config::standard())
252 .map(|(v, _)| v)
253 .map_err(|e| Error::Network(format!("Failed to deserialize: {e}")))?;
254
255 tx.send((sender_id, message))
257 .map_err(|_| Error::Network("Channel closed".into()))?;
258
259 Ok(())
260 }
261
262 async fn send_to_peer(&self, target: NodeId, message: Message) -> Result<()> {
264 let addr = self
265 .peer_addrs
266 .get(&target)
267 .ok_or_else(|| Error::Network(format!("Node {} not found", target.0)))?
268 .value()
269 .to_owned();
270
271 let data =
273 oxicode::serde::encode_to_vec(&(self.node_id, message), oxicode::config::standard())
274 .map_err(|e| Error::Network(format!("Failed to serialize: {e}")))?;
275
276 if data.len() > self.config.max_message_size {
277 return Err(Error::Network(format!(
278 "Message too large: {} > {}",
279 data.len(),
280 self.config.max_message_size
281 )));
282 }
283
284 let mut attempt = 0;
286 let mut last_error = None;
287
288 while attempt <= self.config.max_retries {
289 match self.send_with_timeout(addr, &data).await {
290 Ok(_) => return Ok(()),
291 Err(e) => {
292 last_error = Some(e);
293 attempt += 1;
294
295 if attempt <= self.config.max_retries {
296 let backoff_ms = 100 * (1 << (attempt - 1));
298 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
299 }
300 }
301 }
302 }
303
304 Err(last_error.unwrap_or_else(|| Error::Network("Send failed".into())))
305 }
306
307 async fn send_with_timeout(&self, addr: SocketAddr, data: &[u8]) -> Result<()> {
309 let connect_timeout = std::time::Duration::from_millis(self.config.connect_timeout_ms);
310 let mut stream = tokio::time::timeout(connect_timeout, TcpStream::connect(addr))
311 .await
312 .map_err(|_| Error::Network("Connection timeout".into()))?
313 .map_err(|e| Error::Network(format!("Failed to connect: {e}")))?;
314
315 stream
317 .write_u32(data.len() as u32)
318 .await
319 .map_err(|e| Error::Network(format!("Failed to write length: {e}")))?;
320
321 stream
322 .write_all(data)
323 .await
324 .map_err(|e| Error::Network(format!("Failed to write data: {e}")))?;
325
326 stream
327 .flush()
328 .await
329 .map_err(|e| Error::Network(format!("Failed to flush: {e}")))?;
330
331 Ok(())
332 }
333}
334
335#[async_trait]
336impl Transport for TcpTransport {
337 async fn send(&self, target: NodeId, message: Message) -> Result<()> {
338 self.send_to_peer(target, message).await
339 }
340
341 async fn recv(&self) -> Result<(NodeId, Message)> {
342 let mut rx = self.rx.lock().await;
343 rx.recv()
344 .await
345 .ok_or_else(|| Error::Network("Transport closed".into()))
346 }
347
348 fn node_id(&self) -> NodeId {
349 self.node_id
350 }
351
352 async fn close(&self) -> Result<()> {
353 *self.shutdown.write().await = true;
354 Ok(())
355 }
356}
357
358#[derive(Debug, Clone)]
360pub struct TransportConfig {
361 pub max_message_size: usize,
363 pub connect_timeout_ms: u64,
365 pub request_timeout_ms: u64,
367 pub max_retries: usize,
369}
370
371impl Default for TransportConfig {
372 fn default() -> Self {
373 Self {
374 max_message_size: 10 * 1024 * 1024, connect_timeout_ms: 5000, request_timeout_ms: 10000, max_retries: 3,
378 }
379 }
380}
381
382#[cfg(feature = "quic")]
387pub struct QuicTransport {
388 node_id: NodeId,
390 endpoint: Arc<quinn::Endpoint>,
392 peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
394 rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(NodeId, Message)>>>,
396 tx: mpsc::UnboundedSender<(NodeId, Message)>,
398 config: TransportConfig,
400 shutdown: Arc<RwLock<bool>>,
402}
403
404#[cfg(feature = "quic")]
405impl QuicTransport {
406 #[allow(clippy::unused_async)]
414 pub async fn new(
415 node_id: NodeId,
416 listen_addr: SocketAddr,
417 peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
418 config: TransportConfig,
419 ) -> Result<Self> {
420 let (tx, rx) = mpsc::unbounded_channel();
421 let shutdown = Arc::new(RwLock::new(false));
422
423 let cert = generate_self_signed_cert()?;
425 let server_config = configure_server(cert.clone())?;
426 let client_config = configure_client()?;
427
428 let mut endpoint = quinn::Endpoint::server(server_config, listen_addr)
430 .map_err(|e| Error::Network(format!("Failed to create endpoint: {e}")))?;
431
432 endpoint.set_default_client_config(client_config);
433
434 let transport = Self {
435 node_id,
436 endpoint: Arc::new(endpoint),
437 peer_addrs,
438 rx: Arc::new(tokio::sync::Mutex::new(rx)),
439 tx,
440 config,
441 shutdown,
442 };
443
444 transport.start_listener();
446
447 Ok(transport)
448 }
449
450 fn start_listener(&self) {
452 let endpoint = self.endpoint.clone();
453 let tx = self.tx.clone();
454 let max_size = self.config.max_message_size;
455 let shutdown = self.shutdown.clone();
456
457 tokio::spawn(async move {
458 loop {
459 if *shutdown.read().await {
461 break;
462 }
463
464 match endpoint.accept().await {
466 Some(incoming) => {
467 let tx = tx.clone();
468 tokio::spawn(async move {
469 if let Err(e) = Self::handle_connection(incoming, tx, max_size).await {
470 tracing::warn!("QUIC connection error: {}", e);
471 }
472 });
473 }
474 None => {
475 break;
477 }
478 }
479 }
480 });
481 }
482
483 async fn handle_connection(
485 incoming: quinn::Incoming,
486 tx: mpsc::UnboundedSender<(NodeId, Message)>,
487 max_size: usize,
488 ) -> Result<()> {
489 let connection = incoming
490 .await
491 .map_err(|e| Error::Network(format!("Failed to establish connection: {e}")))?;
492
493 let (_send, mut recv) = connection
495 .accept_bi()
496 .await
497 .map_err(|e| Error::Network(format!("Failed to accept stream: {e}")))?;
498
499 let mut len_buf = [0u8; 4];
501 recv.read_exact(&mut len_buf)
502 .await
503 .map_err(|e| Error::Network(format!("Failed to read length: {e}")))?;
504 let len = u32::from_be_bytes(len_buf) as usize;
505
506 if len > max_size {
507 return Err(Error::Network(format!(
508 "Message too large: {len} > {max_size}"
509 )));
510 }
511
512 let mut buf = vec![0u8; len];
514 recv.read_exact(&mut buf)
515 .await
516 .map_err(|e| Error::Network(format!("Failed to read message: {e}")))?;
517
518 let (sender_id, message): (NodeId, Message) =
520 oxicode::serde::decode_owned_from_slice(&buf, oxicode::config::standard())
521 .map(|(v, _)| v)
522 .map_err(|e| Error::Network(format!("Failed to deserialize: {e}")))?;
523
524 tx.send((sender_id, message))
526 .map_err(|_| Error::Network("Channel closed".into()))?;
527
528 Ok(())
529 }
530
531 async fn send_to_peer(&self, target: NodeId, message: Message) -> Result<()> {
533 let addr = self
534 .peer_addrs
535 .get(&target)
536 .ok_or_else(|| Error::Network(format!("Node {} not found", target.0)))?
537 .value()
538 .to_owned();
539
540 let data =
542 oxicode::serde::encode_to_vec(&(self.node_id, message), oxicode::config::standard())
543 .map_err(|e| Error::Network(format!("Failed to serialize: {e}")))?;
544
545 if data.len() > self.config.max_message_size {
546 return Err(Error::Network(format!(
547 "Message too large: {} > {}",
548 data.len(),
549 self.config.max_message_size
550 )));
551 }
552
553 let mut attempt = 0;
555 let mut last_error = None;
556
557 while attempt <= self.config.max_retries {
558 match self.send_with_timeout(addr, &data).await {
559 Ok(_) => return Ok(()),
560 Err(e) => {
561 last_error = Some(e);
562 attempt += 1;
563
564 if attempt <= self.config.max_retries {
565 let backoff_ms = 100 * (1 << (attempt - 1));
567 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
568 }
569 }
570 }
571 }
572
573 Err(last_error.unwrap_or_else(|| Error::Network("Send failed".into())))
574 }
575
576 async fn send_with_timeout(&self, addr: SocketAddr, data: &[u8]) -> Result<()> {
578 let connect_timeout = std::time::Duration::from_millis(self.config.connect_timeout_ms);
579
580 let connecting = self
581 .endpoint
582 .connect(addr, "localhost")
583 .map_err(|e| Error::Network(format!("Failed to initiate connection: {e}")))?;
584
585 let connection = tokio::time::timeout(connect_timeout, connecting)
586 .await
587 .map_err(|_| Error::Network("Connection timeout".into()))?
588 .map_err(|e| Error::Network(format!("Failed to establish connection: {e}")))?;
589
590 let (mut send, _recv) = connection
592 .open_bi()
593 .await
594 .map_err(|e| Error::Network(format!("Failed to open stream: {e}")))?;
595
596 send.write_all(&(data.len() as u32).to_be_bytes())
598 .await
599 .map_err(|e| Error::Network(format!("Failed to write length: {e}")))?;
600
601 send.write_all(data)
602 .await
603 .map_err(|e| Error::Network(format!("Failed to write data: {e}")))?;
604
605 send.finish()
606 .map_err(|e| Error::Network(format!("Failed to finish stream: {e}")))?;
607
608 Ok(())
609 }
610
611 pub fn local_addr(&self) -> Result<SocketAddr> {
613 self.endpoint
614 .local_addr()
615 .map_err(|e| Error::Network(format!("Failed to get local address: {e}")))
616 }
617}
618
619#[cfg(feature = "quic")]
620#[async_trait]
621impl Transport for QuicTransport {
622 async fn send(&self, target: NodeId, message: Message) -> Result<()> {
623 self.send_to_peer(target, message).await
624 }
625
626 async fn recv(&self) -> Result<(NodeId, Message)> {
627 let mut rx = self.rx.lock().await;
628 rx.recv()
629 .await
630 .ok_or_else(|| Error::Network("Transport closed".into()))
631 }
632
633 fn node_id(&self) -> NodeId {
634 self.node_id
635 }
636
637 async fn close(&self) -> Result<()> {
638 *self.shutdown.write().await = true;
639 self.endpoint.close(0u32.into(), b"Shutdown");
640 Ok(())
641 }
642}
643
644#[cfg(feature = "quic")]
646fn generate_self_signed_cert() -> Result<rustls::pki_types::CertificateDer<'static>> {
647 let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
648 .map_err(|e| Error::Network(format!("Failed to generate certificate: {e}")))?;
649
650 let cert_der = cert.cert.der().to_vec();
651 Ok(rustls::pki_types::CertificateDer::from(cert_der))
652}
653
654#[cfg(feature = "quic")]
656fn configure_server(
657 _cert: rustls::pki_types::CertificateDer<'static>,
658) -> Result<quinn::ServerConfig> {
659 let cert_gen = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
660 .map_err(|e| Error::Network(format!("Failed to generate certificate: {e}")))?;
661
662 let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8(
663 rustls::pki_types::PrivatePkcs8KeyDer::from(cert_gen.signing_key.serialize_der()),
664 );
665 let cert_der = cert_gen.cert.der().to_vec();
666 let cert_chain = vec![rustls::pki_types::CertificateDer::from(cert_der)];
667
668 let mut server_crypto = rustls::ServerConfig::builder()
669 .with_no_client_auth()
670 .with_single_cert(cert_chain, key_der)
671 .map_err(|e| Error::Network(format!("Failed to configure server: {e}")))?;
672
673 server_crypto.alpn_protocols = vec![b"ipfrs-raft".to_vec()];
674
675 let server_config = quinn::ServerConfig::with_crypto(Arc::new(
676 quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
677 .map_err(|e| Error::Network(format!("Failed to create QUIC server config: {e}")))?,
678 ));
679
680 Ok(server_config)
681}
682
683#[cfg(feature = "quic")]
685fn configure_client() -> Result<quinn::ClientConfig> {
686 let mut client_crypto = rustls::ClientConfig::builder()
688 .dangerous()
689 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
690 .with_no_client_auth();
691
692 client_crypto.alpn_protocols = vec![b"ipfrs-raft".to_vec()];
693
694 let client_config = quinn::ClientConfig::new(Arc::new(
695 quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)
696 .map_err(|e| Error::Network(format!("Failed to create QUIC client config: {e}")))?,
697 ));
698
699 Ok(client_config)
700}
701
702#[cfg(feature = "quic")]
704#[derive(Debug)]
705struct SkipServerVerification;
706
707#[cfg(feature = "quic")]
708impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
709 fn verify_server_cert(
710 &self,
711 _end_entity: &rustls::pki_types::CertificateDer<'_>,
712 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
713 _server_name: &rustls::pki_types::ServerName<'_>,
714 _ocsp_response: &[u8],
715 _now: rustls::pki_types::UnixTime,
716 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
717 Ok(rustls::client::danger::ServerCertVerified::assertion())
718 }
719
720 fn verify_tls12_signature(
721 &self,
722 _message: &[u8],
723 _cert: &rustls::pki_types::CertificateDer<'_>,
724 _dss: &rustls::DigitallySignedStruct,
725 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
726 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
727 }
728
729 fn verify_tls13_signature(
730 &self,
731 _message: &[u8],
732 _cert: &rustls::pki_types::CertificateDer<'_>,
733 _dss: &rustls::DigitallySignedStruct,
734 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
735 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
736 }
737
738 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
739 vec![
740 rustls::SignatureScheme::RSA_PKCS1_SHA256,
741 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
742 rustls::SignatureScheme::ED25519,
743 ]
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[tokio::test]
752 async fn test_in_memory_transport_send_recv() {
753 let registry = InMemoryTransport::new_registry();
754 let transport1 = InMemoryTransport::new(NodeId(1), registry.clone());
755 let transport2 = InMemoryTransport::new(NodeId(2), registry);
756
757 let request = RequestVoteRequest {
759 term: crate::raft::Term(1),
760 candidate_id: NodeId(1),
761 last_log_index: crate::raft::LogIndex(0),
762 last_log_term: crate::raft::Term(0),
763 };
764 let message = Message::RequestVote(request);
765
766 transport1.send(NodeId(2), message.clone()).await.unwrap();
767
768 let (sender, received) = transport2.recv().await.unwrap();
770 assert_eq!(sender, NodeId(1));
771 matches!(received, Message::RequestVote(_));
772 }
773
774 #[tokio::test]
775 async fn test_in_memory_transport_node_not_found() {
776 let registry = InMemoryTransport::new_registry();
777 let transport = InMemoryTransport::new(NodeId(1), registry);
778
779 let request = RequestVoteRequest {
780 term: crate::raft::Term(1),
781 candidate_id: NodeId(1),
782 last_log_index: crate::raft::LogIndex(0),
783 last_log_term: crate::raft::Term(0),
784 };
785 let message = Message::RequestVote(request);
786
787 let result = transport.send(NodeId(999), message).await;
789 assert!(result.is_err());
790 }
791
792 #[tokio::test]
793 async fn test_transport_close() {
794 let registry = InMemoryTransport::new_registry();
795 let transport = InMemoryTransport::new(NodeId(1), registry.clone());
796
797 assert!(registry.contains_key(&NodeId(1)));
798
799 transport.close().await.unwrap();
800
801 assert!(!registry.contains_key(&NodeId(1)));
802 }
803
804 #[tokio::test]
805 async fn test_bidirectional_communication() {
806 let registry = InMemoryTransport::new_registry();
807 let transport1 = InMemoryTransport::new(NodeId(1), registry.clone());
808 let transport2 = InMemoryTransport::new(NodeId(2), registry);
809
810 let vote_request = RequestVoteRequest {
812 term: crate::raft::Term(1),
813 candidate_id: NodeId(1),
814 last_log_index: crate::raft::LogIndex(0),
815 last_log_term: crate::raft::Term(0),
816 };
817 transport1
818 .send(NodeId(2), Message::RequestVote(vote_request))
819 .await
820 .unwrap();
821
822 let (sender, _msg) = transport2.recv().await.unwrap();
824 assert_eq!(sender, NodeId(1));
825
826 let vote_response = RequestVoteResponse {
827 term: crate::raft::Term(1),
828 vote_granted: true,
829 };
830 transport2
831 .send(NodeId(1), Message::RequestVoteResponse(vote_response))
832 .await
833 .unwrap();
834
835 let (sender, received) = transport1.recv().await.unwrap();
837 assert_eq!(sender, NodeId(2));
838 matches!(received, Message::RequestVoteResponse(_));
839 }
840
841 #[tokio::test]
842 async fn test_tcp_transport_send_recv() {
843 let peer_addrs1 = Arc::new(DashMap::new());
844 let peer_addrs2 = Arc::new(DashMap::new());
845
846 let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
847 let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
848
849 let config = TransportConfig::default();
850
851 let transport1 = TcpTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
852 .await
853 .unwrap();
854
855 let transport2 = TcpTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
856 .await
857 .unwrap();
858
859 peer_addrs1.insert(NodeId(2), transport2.listen_addr);
861 peer_addrs2.insert(NodeId(1), transport1.listen_addr);
862
863 let request = RequestVoteRequest {
865 term: crate::raft::Term(1),
866 candidate_id: NodeId(1),
867 last_log_index: crate::raft::LogIndex(0),
868 last_log_term: crate::raft::Term(0),
869 };
870 let message = Message::RequestVote(request);
871
872 transport1.send(NodeId(2), message).await.unwrap();
873
874 let (sender, received) = transport2.recv().await.unwrap();
876 assert_eq!(sender, NodeId(1));
877 matches!(received, Message::RequestVote(_));
878
879 transport1.close().await.unwrap();
881 transport2.close().await.unwrap();
882 }
883
884 #[tokio::test]
885 async fn test_tcp_transport_bidirectional() {
886 let peer_addrs1 = Arc::new(DashMap::new());
887 let peer_addrs2 = Arc::new(DashMap::new());
888
889 let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
890 let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
891
892 let config = TransportConfig::default();
893
894 let transport1 = TcpTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
895 .await
896 .unwrap();
897
898 let transport2 = TcpTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
899 .await
900 .unwrap();
901
902 peer_addrs1.insert(NodeId(2), transport2.listen_addr);
904 peer_addrs2.insert(NodeId(1), transport1.listen_addr);
905
906 let vote_request = RequestVoteRequest {
908 term: crate::raft::Term(1),
909 candidate_id: NodeId(1),
910 last_log_index: crate::raft::LogIndex(0),
911 last_log_term: crate::raft::Term(0),
912 };
913 transport1
914 .send(NodeId(2), Message::RequestVote(vote_request))
915 .await
916 .unwrap();
917
918 let (sender, _msg) = transport2.recv().await.unwrap();
920 assert_eq!(sender, NodeId(1));
921
922 let vote_response = RequestVoteResponse {
924 term: crate::raft::Term(1),
925 vote_granted: true,
926 };
927 transport2
928 .send(NodeId(1), Message::RequestVoteResponse(vote_response))
929 .await
930 .unwrap();
931
932 let (sender, received) = transport1.recv().await.unwrap();
934 assert_eq!(sender, NodeId(2));
935 matches!(received, Message::RequestVoteResponse(_));
936
937 transport1.close().await.unwrap();
939 transport2.close().await.unwrap();
940 }
941
942 #[cfg(feature = "quic")]
943 #[tokio::test]
944 #[ignore] async fn test_quic_transport_send_recv() {
946 let _ = rustls::crypto::ring::default_provider().install_default();
948
949 let peer_addrs1 = Arc::new(DashMap::new());
950 let peer_addrs2 = Arc::new(DashMap::new());
951
952 let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
953 let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
954
955 let config = TransportConfig::default();
956
957 let transport1 = QuicTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
958 .await
959 .unwrap();
960
961 let transport2 = QuicTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
962 .await
963 .unwrap();
964
965 let addr1_actual = transport1.local_addr().unwrap();
966 let addr2_actual = transport2.local_addr().unwrap();
967
968 peer_addrs1.insert(NodeId(2), addr2_actual);
970 peer_addrs2.insert(NodeId(1), addr1_actual);
971
972 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
974
975 let request = RequestVoteRequest {
977 term: crate::raft::Term(1),
978 candidate_id: NodeId(1),
979 last_log_index: crate::raft::LogIndex(0),
980 last_log_term: crate::raft::Term(0),
981 };
982 let message = Message::RequestVote(request);
983
984 transport1.send(NodeId(2), message).await.unwrap();
985
986 let (sender, received) = transport2.recv().await.unwrap();
988 assert_eq!(sender, NodeId(1));
989 matches!(received, Message::RequestVote(_));
990
991 transport1.close().await.unwrap();
993 transport2.close().await.unwrap();
994 }
995
996 #[cfg(feature = "quic")]
997 #[tokio::test]
998 #[ignore] async fn test_quic_transport_bidirectional() {
1000 let _ = rustls::crypto::ring::default_provider().install_default();
1002
1003 let peer_addrs1 = Arc::new(DashMap::new());
1004 let peer_addrs2 = Arc::new(DashMap::new());
1005
1006 let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
1007 let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
1008
1009 let config = TransportConfig::default();
1010
1011 let transport1 = QuicTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
1012 .await
1013 .unwrap();
1014
1015 let transport2 = QuicTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
1016 .await
1017 .unwrap();
1018
1019 let addr1_actual = transport1.local_addr().unwrap();
1020 let addr2_actual = transport2.local_addr().unwrap();
1021
1022 peer_addrs1.insert(NodeId(2), addr2_actual);
1024 peer_addrs2.insert(NodeId(1), addr1_actual);
1025
1026 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1028
1029 let vote_request = RequestVoteRequest {
1031 term: crate::raft::Term(1),
1032 candidate_id: NodeId(1),
1033 last_log_index: crate::raft::LogIndex(0),
1034 last_log_term: crate::raft::Term(0),
1035 };
1036 transport1
1037 .send(NodeId(2), Message::RequestVote(vote_request))
1038 .await
1039 .unwrap();
1040
1041 let (sender, _msg) = transport2.recv().await.unwrap();
1043 assert_eq!(sender, NodeId(1));
1044
1045 let vote_response = RequestVoteResponse {
1047 term: crate::raft::Term(1),
1048 vote_granted: true,
1049 };
1050 transport2
1051 .send(NodeId(1), Message::RequestVoteResponse(vote_response))
1052 .await
1053 .unwrap();
1054
1055 let (sender, received) = transport1.recv().await.unwrap();
1057 assert_eq!(sender, NodeId(2));
1058 matches!(received, Message::RequestVoteResponse(_));
1059
1060 transport1.close().await.unwrap();
1062 transport2.close().await.unwrap();
1063 }
1064}