ipfrs_storage/
transport.rs

1//! Network transport abstraction for distributed RAFT.
2//!
3//! Provides a generic transport layer for RAFT node communication,
4//! enabling multi-node clusters with different network backends.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use ipfrs_storage::transport::{Transport, InMemoryTransport, Message};
10//!
11//! let transport = InMemoryTransport::new();
12//! let msg = Message::AppendEntries { /* ... */ };
13//! transport.send(target_node, msg).await?;
14//! ```
15
16use crate::raft::{
17    AppendEntriesRequest, AppendEntriesResponse, NodeId, RequestVoteRequest, RequestVoteResponse,
18};
19use async_trait::async_trait;
20use dashmap::DashMap;
21use ipfrs_core::{Error, Result};
22use serde::{Deserialize, Serialize};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{mpsc, RwLock};
28
29/// Network message types for RAFT communication
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum Message {
32    /// AppendEntries RPC request
33    AppendEntries(AppendEntriesRequest),
34    /// AppendEntries RPC response
35    AppendEntriesResponse(AppendEntriesResponse),
36    /// RequestVote RPC request
37    RequestVote(RequestVoteRequest),
38    /// RequestVote RPC response
39    RequestVoteResponse(RequestVoteResponse),
40}
41
42/// Transport trait for network communication between RAFT nodes
43#[async_trait]
44pub trait Transport: Send + Sync {
45    /// Send a message to a specific node
46    async fn send(&self, target: NodeId, message: Message) -> Result<()>;
47
48    /// Receive the next message for this node
49    async fn recv(&self) -> Result<(NodeId, Message)>;
50
51    /// Get local node ID
52    fn node_id(&self) -> NodeId;
53
54    /// Close the transport and clean up resources
55    async fn close(&self) -> Result<()>;
56}
57
58/// In-memory transport for testing and local development
59///
60/// Provides a zero-copy, high-performance transport for running
61/// multiple RAFT nodes in the same process.
62pub struct InMemoryTransport {
63    /// Local node ID
64    node_id: NodeId,
65    /// Shared message registry for all nodes
66    registry: Arc<DashMap<NodeId, mpsc::UnboundedSender<(NodeId, Message)>>>,
67    /// Receiver for incoming messages
68    rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(NodeId, Message)>>>,
69}
70
71impl InMemoryTransport {
72    /// Create a new in-memory transport for a node
73    ///
74    /// # Arguments
75    /// * `node_id` - Unique identifier for this node
76    /// * `registry` - Shared registry of all nodes in the cluster
77    pub fn new(
78        node_id: NodeId,
79        registry: Arc<DashMap<NodeId, mpsc::UnboundedSender<(NodeId, Message)>>>,
80    ) -> Self {
81        let (tx, rx) = mpsc::unbounded_channel();
82        registry.insert(node_id, tx);
83
84        Self {
85            node_id,
86            registry,
87            rx: Arc::new(tokio::sync::Mutex::new(rx)),
88        }
89    }
90
91    /// Create a new registry for a cluster
92    pub fn new_registry() -> Arc<DashMap<NodeId, mpsc::UnboundedSender<(NodeId, Message)>>> {
93        Arc::new(DashMap::new())
94    }
95}
96
97#[async_trait]
98impl Transport for InMemoryTransport {
99    async fn send(&self, target: NodeId, message: Message) -> Result<()> {
100        if let Some(tx) = self.registry.get(&target) {
101            tx.send((self.node_id, message))
102                .map_err(|_| Error::Network("Failed to send message".into()))?;
103            Ok(())
104        } else {
105            Err(Error::Network(format!("Node {} not found", target.0)))
106        }
107    }
108
109    async fn recv(&self) -> Result<(NodeId, Message)> {
110        let mut rx = self.rx.lock().await;
111        rx.recv()
112            .await
113            .ok_or_else(|| Error::Network("Transport closed".into()))
114    }
115
116    fn node_id(&self) -> NodeId {
117        self.node_id
118    }
119
120    async fn close(&self) -> Result<()> {
121        self.registry.remove(&self.node_id);
122        Ok(())
123    }
124}
125
126/// TCP-based transport for real network communication
127///
128/// Provides a production-ready transport for RAFT clusters
129/// running across multiple machines.
130pub struct TcpTransport {
131    /// Local node ID
132    node_id: NodeId,
133    /// Local listening address
134    listen_addr: SocketAddr,
135    /// Mapping of node IDs to their addresses
136    peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
137    /// Receiver for incoming messages
138    rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(NodeId, Message)>>>,
139    /// Sender for incoming messages (used by listener)
140    tx: mpsc::UnboundedSender<(NodeId, Message)>,
141    /// Transport configuration
142    config: TransportConfig,
143    /// Shutdown signal
144    shutdown: Arc<RwLock<bool>>,
145}
146
147impl TcpTransport {
148    /// Create a new TCP transport
149    ///
150    /// # Arguments
151    /// * `node_id` - Unique identifier for this node
152    /// * `listen_addr` - Address to listen on for incoming connections
153    /// * `peer_addrs` - Map of node IDs to their addresses
154    /// * `config` - Transport configuration
155    pub async fn new(
156        node_id: NodeId,
157        listen_addr: SocketAddr,
158        peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
159        config: TransportConfig,
160    ) -> Result<Self> {
161        let (tx, rx) = mpsc::unbounded_channel();
162        let shutdown = Arc::new(RwLock::new(false));
163
164        let transport = Self {
165            node_id,
166            listen_addr,
167            peer_addrs,
168            rx: Arc::new(tokio::sync::Mutex::new(rx)),
169            tx,
170            config,
171            shutdown,
172        };
173
174        // Start listener task and get the actual bound address
175        transport.start_listener().await
176    }
177
178    /// Start listening for incoming connections
179    async fn start_listener(self) -> Result<Self> {
180        let listener = TcpListener::bind(self.listen_addr)
181            .await
182            .map_err(|e| Error::Network(format!("Failed to bind: {e}")))?;
183
184        // Get the actual bound address (important when using port 0)
185        let actual_addr = listener
186            .local_addr()
187            .map_err(|e| Error::Network(format!("Failed to get local address: {e}")))?;
188
189        let tx = self.tx.clone();
190        let max_size = self.config.max_message_size;
191        let shutdown = self.shutdown.clone();
192
193        tokio::spawn(async move {
194            loop {
195                // Check shutdown signal
196                if *shutdown.read().await {
197                    break;
198                }
199
200                match listener.accept().await {
201                    Ok((mut stream, _)) => {
202                        let tx = tx.clone();
203                        tokio::spawn(async move {
204                            if let Err(e) = Self::handle_connection(&mut stream, tx, max_size).await
205                            {
206                                tracing::warn!("Connection error: {}", e);
207                            }
208                        });
209                    }
210                    Err(e) => {
211                        tracing::error!("Accept error: {}", e);
212                    }
213                }
214            }
215        });
216
217        Ok(Self {
218            listen_addr: actual_addr,
219            ..self
220        })
221    }
222
223    /// Handle an incoming connection
224    async fn handle_connection(
225        stream: &mut TcpStream,
226        tx: mpsc::UnboundedSender<(NodeId, Message)>,
227        max_size: usize,
228    ) -> Result<()> {
229        // Read message length (4 bytes)
230        let len = stream
231            .read_u32()
232            .await
233            .map_err(|e| Error::Network(format!("Failed to read length: {e}")))?
234            as usize;
235
236        if len > max_size {
237            return Err(Error::Network(format!(
238                "Message too large: {len} > {max_size}"
239            )));
240        }
241
242        // Read message data
243        let mut buf = vec![0u8; len];
244        stream
245            .read_exact(&mut buf)
246            .await
247            .map_err(|e| Error::Network(format!("Failed to read message: {e}")))?;
248
249        // Deserialize message
250        let (sender_id, message): (NodeId, Message) =
251            oxicode::serde::decode_owned_from_slice(&buf, oxicode::config::standard())
252                .map(|(v, _)| v)
253                .map_err(|e| Error::Network(format!("Failed to deserialize: {e}")))?;
254
255        // Send to receiver channel
256        tx.send((sender_id, message))
257            .map_err(|_| Error::Network("Channel closed".into()))?;
258
259        Ok(())
260    }
261
262    /// Send a message to a peer with retry logic
263    async fn send_to_peer(&self, target: NodeId, message: Message) -> Result<()> {
264        let addr = self
265            .peer_addrs
266            .get(&target)
267            .ok_or_else(|| Error::Network(format!("Node {} not found", target.0)))?
268            .value()
269            .to_owned();
270
271        // Serialize message with sender ID
272        let data =
273            oxicode::serde::encode_to_vec(&(self.node_id, message), oxicode::config::standard())
274                .map_err(|e| Error::Network(format!("Failed to serialize: {e}")))?;
275
276        if data.len() > self.config.max_message_size {
277            return Err(Error::Network(format!(
278                "Message too large: {} > {}",
279                data.len(),
280                self.config.max_message_size
281            )));
282        }
283
284        // Retry with exponential backoff
285        let mut attempt = 0;
286        let mut last_error = None;
287
288        while attempt <= self.config.max_retries {
289            match self.send_with_timeout(addr, &data).await {
290                Ok(_) => return Ok(()),
291                Err(e) => {
292                    last_error = Some(e);
293                    attempt += 1;
294
295                    if attempt <= self.config.max_retries {
296                        // Exponential backoff: 100ms, 200ms, 400ms, etc.
297                        let backoff_ms = 100 * (1 << (attempt - 1));
298                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
299                    }
300                }
301            }
302        }
303
304        Err(last_error.unwrap_or_else(|| Error::Network("Send failed".into())))
305    }
306
307    /// Send data with timeout (single attempt)
308    async fn send_with_timeout(&self, addr: SocketAddr, data: &[u8]) -> Result<()> {
309        let connect_timeout = std::time::Duration::from_millis(self.config.connect_timeout_ms);
310        let mut stream = tokio::time::timeout(connect_timeout, TcpStream::connect(addr))
311            .await
312            .map_err(|_| Error::Network("Connection timeout".into()))?
313            .map_err(|e| Error::Network(format!("Failed to connect: {e}")))?;
314
315        // Write message length (4 bytes) + data
316        stream
317            .write_u32(data.len() as u32)
318            .await
319            .map_err(|e| Error::Network(format!("Failed to write length: {e}")))?;
320
321        stream
322            .write_all(data)
323            .await
324            .map_err(|e| Error::Network(format!("Failed to write data: {e}")))?;
325
326        stream
327            .flush()
328            .await
329            .map_err(|e| Error::Network(format!("Failed to flush: {e}")))?;
330
331        Ok(())
332    }
333}
334
335#[async_trait]
336impl Transport for TcpTransport {
337    async fn send(&self, target: NodeId, message: Message) -> Result<()> {
338        self.send_to_peer(target, message).await
339    }
340
341    async fn recv(&self) -> Result<(NodeId, Message)> {
342        let mut rx = self.rx.lock().await;
343        rx.recv()
344            .await
345            .ok_or_else(|| Error::Network("Transport closed".into()))
346    }
347
348    fn node_id(&self) -> NodeId {
349        self.node_id
350    }
351
352    async fn close(&self) -> Result<()> {
353        *self.shutdown.write().await = true;
354        Ok(())
355    }
356}
357
358/// Configuration for network transports
359#[derive(Debug, Clone)]
360pub struct TransportConfig {
361    /// Maximum message size in bytes
362    pub max_message_size: usize,
363    /// Connection timeout in milliseconds
364    pub connect_timeout_ms: u64,
365    /// Request timeout in milliseconds
366    pub request_timeout_ms: u64,
367    /// Maximum number of retry attempts
368    pub max_retries: usize,
369}
370
371impl Default for TransportConfig {
372    fn default() -> Self {
373        Self {
374            max_message_size: 10 * 1024 * 1024, // 10MB
375            connect_timeout_ms: 5000,           // 5 seconds
376            request_timeout_ms: 10000,          // 10 seconds
377            max_retries: 3,
378        }
379    }
380}
381
382/// QUIC-based transport for encrypted, multiplexed communication
383///
384/// Provides a high-performance transport with built-in encryption,
385/// connection multiplexing, and 0-RTT support.
386#[cfg(feature = "quic")]
387pub struct QuicTransport {
388    /// Local node ID
389    node_id: NodeId,
390    /// QUIC endpoint
391    endpoint: Arc<quinn::Endpoint>,
392    /// Mapping of node IDs to their addresses
393    peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
394    /// Receiver for incoming messages
395    rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(NodeId, Message)>>>,
396    /// Sender for incoming messages (used by listener)
397    tx: mpsc::UnboundedSender<(NodeId, Message)>,
398    /// Transport configuration
399    config: TransportConfig,
400    /// Shutdown signal
401    shutdown: Arc<RwLock<bool>>,
402}
403
404#[cfg(feature = "quic")]
405impl QuicTransport {
406    /// Create a new QUIC transport
407    ///
408    /// # Arguments
409    /// * `node_id` - Unique identifier for this node
410    /// * `listen_addr` - Address to listen on for incoming connections
411    /// * `peer_addrs` - Map of node IDs to their addresses
412    /// * `config` - Transport configuration
413    #[allow(clippy::unused_async)]
414    pub async fn new(
415        node_id: NodeId,
416        listen_addr: SocketAddr,
417        peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
418        config: TransportConfig,
419    ) -> Result<Self> {
420        let (tx, rx) = mpsc::unbounded_channel();
421        let shutdown = Arc::new(RwLock::new(false));
422
423        // Generate self-signed certificate
424        let cert = generate_self_signed_cert()?;
425        let server_config = configure_server(cert.clone())?;
426        let client_config = configure_client()?;
427
428        // Create QUIC endpoint
429        let mut endpoint = quinn::Endpoint::server(server_config, listen_addr)
430            .map_err(|e| Error::Network(format!("Failed to create endpoint: {e}")))?;
431
432        endpoint.set_default_client_config(client_config);
433
434        let transport = Self {
435            node_id,
436            endpoint: Arc::new(endpoint),
437            peer_addrs,
438            rx: Arc::new(tokio::sync::Mutex::new(rx)),
439            tx,
440            config,
441            shutdown,
442        };
443
444        // Start listener task
445        transport.start_listener();
446
447        Ok(transport)
448    }
449
450    /// Start listening for incoming connections
451    fn start_listener(&self) {
452        let endpoint = self.endpoint.clone();
453        let tx = self.tx.clone();
454        let max_size = self.config.max_message_size;
455        let shutdown = self.shutdown.clone();
456
457        tokio::spawn(async move {
458            loop {
459                // Check shutdown signal
460                if *shutdown.read().await {
461                    break;
462                }
463
464                // Accept incoming connection
465                match endpoint.accept().await {
466                    Some(incoming) => {
467                        let tx = tx.clone();
468                        tokio::spawn(async move {
469                            if let Err(e) = Self::handle_connection(incoming, tx, max_size).await {
470                                tracing::warn!("QUIC connection error: {}", e);
471                            }
472                        });
473                    }
474                    None => {
475                        // Endpoint closed
476                        break;
477                    }
478                }
479            }
480        });
481    }
482
483    /// Handle an incoming QUIC connection
484    async fn handle_connection(
485        incoming: quinn::Incoming,
486        tx: mpsc::UnboundedSender<(NodeId, Message)>,
487        max_size: usize,
488    ) -> Result<()> {
489        let connection = incoming
490            .await
491            .map_err(|e| Error::Network(format!("Failed to establish connection: {e}")))?;
492
493        // Accept bi-directional stream
494        let (_send, mut recv) = connection
495            .accept_bi()
496            .await
497            .map_err(|e| Error::Network(format!("Failed to accept stream: {e}")))?;
498
499        // Read message length (4 bytes)
500        let mut len_buf = [0u8; 4];
501        recv.read_exact(&mut len_buf)
502            .await
503            .map_err(|e| Error::Network(format!("Failed to read length: {e}")))?;
504        let len = u32::from_be_bytes(len_buf) as usize;
505
506        if len > max_size {
507            return Err(Error::Network(format!(
508                "Message too large: {len} > {max_size}"
509            )));
510        }
511
512        // Read message data
513        let mut buf = vec![0u8; len];
514        recv.read_exact(&mut buf)
515            .await
516            .map_err(|e| Error::Network(format!("Failed to read message: {e}")))?;
517
518        // Deserialize message
519        let (sender_id, message): (NodeId, Message) =
520            oxicode::serde::decode_owned_from_slice(&buf, oxicode::config::standard())
521                .map(|(v, _)| v)
522                .map_err(|e| Error::Network(format!("Failed to deserialize: {e}")))?;
523
524        // Send to receiver channel
525        tx.send((sender_id, message))
526            .map_err(|_| Error::Network("Channel closed".into()))?;
527
528        Ok(())
529    }
530
531    /// Send a message to a peer with retry logic
532    async fn send_to_peer(&self, target: NodeId, message: Message) -> Result<()> {
533        let addr = self
534            .peer_addrs
535            .get(&target)
536            .ok_or_else(|| Error::Network(format!("Node {} not found", target.0)))?
537            .value()
538            .to_owned();
539
540        // Serialize message with sender ID
541        let data =
542            oxicode::serde::encode_to_vec(&(self.node_id, message), oxicode::config::standard())
543                .map_err(|e| Error::Network(format!("Failed to serialize: {e}")))?;
544
545        if data.len() > self.config.max_message_size {
546            return Err(Error::Network(format!(
547                "Message too large: {} > {}",
548                data.len(),
549                self.config.max_message_size
550            )));
551        }
552
553        // Retry with exponential backoff
554        let mut attempt = 0;
555        let mut last_error = None;
556
557        while attempt <= self.config.max_retries {
558            match self.send_with_timeout(addr, &data).await {
559                Ok(_) => return Ok(()),
560                Err(e) => {
561                    last_error = Some(e);
562                    attempt += 1;
563
564                    if attempt <= self.config.max_retries {
565                        // Exponential backoff: 100ms, 200ms, 400ms, etc.
566                        let backoff_ms = 100 * (1 << (attempt - 1));
567                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
568                    }
569                }
570            }
571        }
572
573        Err(last_error.unwrap_or_else(|| Error::Network("Send failed".into())))
574    }
575
576    /// Send data with timeout (single attempt)
577    async fn send_with_timeout(&self, addr: SocketAddr, data: &[u8]) -> Result<()> {
578        let connect_timeout = std::time::Duration::from_millis(self.config.connect_timeout_ms);
579
580        let connecting = self
581            .endpoint
582            .connect(addr, "localhost")
583            .map_err(|e| Error::Network(format!("Failed to initiate connection: {e}")))?;
584
585        let connection = tokio::time::timeout(connect_timeout, connecting)
586            .await
587            .map_err(|_| Error::Network("Connection timeout".into()))?
588            .map_err(|e| Error::Network(format!("Failed to establish connection: {e}")))?;
589
590        // Open bi-directional stream
591        let (mut send, _recv) = connection
592            .open_bi()
593            .await
594            .map_err(|e| Error::Network(format!("Failed to open stream: {e}")))?;
595
596        // Write message length (4 bytes) + data
597        send.write_all(&(data.len() as u32).to_be_bytes())
598            .await
599            .map_err(|e| Error::Network(format!("Failed to write length: {e}")))?;
600
601        send.write_all(data)
602            .await
603            .map_err(|e| Error::Network(format!("Failed to write data: {e}")))?;
604
605        send.finish()
606            .map_err(|e| Error::Network(format!("Failed to finish stream: {e}")))?;
607
608        Ok(())
609    }
610
611    /// Get the local address
612    pub fn local_addr(&self) -> Result<SocketAddr> {
613        self.endpoint
614            .local_addr()
615            .map_err(|e| Error::Network(format!("Failed to get local address: {e}")))
616    }
617}
618
619#[cfg(feature = "quic")]
620#[async_trait]
621impl Transport for QuicTransport {
622    async fn send(&self, target: NodeId, message: Message) -> Result<()> {
623        self.send_to_peer(target, message).await
624    }
625
626    async fn recv(&self) -> Result<(NodeId, Message)> {
627        let mut rx = self.rx.lock().await;
628        rx.recv()
629            .await
630            .ok_or_else(|| Error::Network("Transport closed".into()))
631    }
632
633    fn node_id(&self) -> NodeId {
634        self.node_id
635    }
636
637    async fn close(&self) -> Result<()> {
638        *self.shutdown.write().await = true;
639        self.endpoint.close(0u32.into(), b"Shutdown");
640        Ok(())
641    }
642}
643
644/// Generate a self-signed certificate for QUIC
645#[cfg(feature = "quic")]
646fn generate_self_signed_cert() -> Result<rustls::pki_types::CertificateDer<'static>> {
647    let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
648        .map_err(|e| Error::Network(format!("Failed to generate certificate: {e}")))?;
649
650    let cert_der = cert.cert.der().to_vec();
651    Ok(rustls::pki_types::CertificateDer::from(cert_der))
652}
653
654/// Configure QUIC server
655#[cfg(feature = "quic")]
656fn configure_server(
657    _cert: rustls::pki_types::CertificateDer<'static>,
658) -> Result<quinn::ServerConfig> {
659    let cert_gen = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
660        .map_err(|e| Error::Network(format!("Failed to generate certificate: {e}")))?;
661
662    let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8(
663        rustls::pki_types::PrivatePkcs8KeyDer::from(cert_gen.signing_key.serialize_der()),
664    );
665    let cert_der = cert_gen.cert.der().to_vec();
666    let cert_chain = vec![rustls::pki_types::CertificateDer::from(cert_der)];
667
668    let mut server_crypto = rustls::ServerConfig::builder()
669        .with_no_client_auth()
670        .with_single_cert(cert_chain, key_der)
671        .map_err(|e| Error::Network(format!("Failed to configure server: {e}")))?;
672
673    server_crypto.alpn_protocols = vec![b"ipfrs-raft".to_vec()];
674
675    let server_config = quinn::ServerConfig::with_crypto(Arc::new(
676        quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
677            .map_err(|e| Error::Network(format!("Failed to create QUIC server config: {e}")))?,
678    ));
679
680    Ok(server_config)
681}
682
683/// Configure QUIC client
684#[cfg(feature = "quic")]
685fn configure_client() -> Result<quinn::ClientConfig> {
686    // Accept any certificate (for testing with self-signed certs)
687    let mut client_crypto = rustls::ClientConfig::builder()
688        .dangerous()
689        .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
690        .with_no_client_auth();
691
692    client_crypto.alpn_protocols = vec![b"ipfrs-raft".to_vec()];
693
694    let client_config = quinn::ClientConfig::new(Arc::new(
695        quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)
696            .map_err(|e| Error::Network(format!("Failed to create QUIC client config: {e}")))?,
697    ));
698
699    Ok(client_config)
700}
701
702/// Skip server certificate verification (for testing only)
703#[cfg(feature = "quic")]
704#[derive(Debug)]
705struct SkipServerVerification;
706
707#[cfg(feature = "quic")]
708impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
709    fn verify_server_cert(
710        &self,
711        _end_entity: &rustls::pki_types::CertificateDer<'_>,
712        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
713        _server_name: &rustls::pki_types::ServerName<'_>,
714        _ocsp_response: &[u8],
715        _now: rustls::pki_types::UnixTime,
716    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
717        Ok(rustls::client::danger::ServerCertVerified::assertion())
718    }
719
720    fn verify_tls12_signature(
721        &self,
722        _message: &[u8],
723        _cert: &rustls::pki_types::CertificateDer<'_>,
724        _dss: &rustls::DigitallySignedStruct,
725    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
726        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
727    }
728
729    fn verify_tls13_signature(
730        &self,
731        _message: &[u8],
732        _cert: &rustls::pki_types::CertificateDer<'_>,
733        _dss: &rustls::DigitallySignedStruct,
734    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
735        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
736    }
737
738    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
739        vec![
740            rustls::SignatureScheme::RSA_PKCS1_SHA256,
741            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
742            rustls::SignatureScheme::ED25519,
743        ]
744    }
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750
751    #[tokio::test]
752    async fn test_in_memory_transport_send_recv() {
753        let registry = InMemoryTransport::new_registry();
754        let transport1 = InMemoryTransport::new(NodeId(1), registry.clone());
755        let transport2 = InMemoryTransport::new(NodeId(2), registry);
756
757        // Send message from node 1 to node 2
758        let request = RequestVoteRequest {
759            term: crate::raft::Term(1),
760            candidate_id: NodeId(1),
761            last_log_index: crate::raft::LogIndex(0),
762            last_log_term: crate::raft::Term(0),
763        };
764        let message = Message::RequestVote(request);
765
766        transport1.send(NodeId(2), message.clone()).await.unwrap();
767
768        // Receive message at node 2
769        let (sender, received) = transport2.recv().await.unwrap();
770        assert_eq!(sender, NodeId(1));
771        matches!(received, Message::RequestVote(_));
772    }
773
774    #[tokio::test]
775    async fn test_in_memory_transport_node_not_found() {
776        let registry = InMemoryTransport::new_registry();
777        let transport = InMemoryTransport::new(NodeId(1), registry);
778
779        let request = RequestVoteRequest {
780            term: crate::raft::Term(1),
781            candidate_id: NodeId(1),
782            last_log_index: crate::raft::LogIndex(0),
783            last_log_term: crate::raft::Term(0),
784        };
785        let message = Message::RequestVote(request);
786
787        // Try to send to non-existent node
788        let result = transport.send(NodeId(999), message).await;
789        assert!(result.is_err());
790    }
791
792    #[tokio::test]
793    async fn test_transport_close() {
794        let registry = InMemoryTransport::new_registry();
795        let transport = InMemoryTransport::new(NodeId(1), registry.clone());
796
797        assert!(registry.contains_key(&NodeId(1)));
798
799        transport.close().await.unwrap();
800
801        assert!(!registry.contains_key(&NodeId(1)));
802    }
803
804    #[tokio::test]
805    async fn test_bidirectional_communication() {
806        let registry = InMemoryTransport::new_registry();
807        let transport1 = InMemoryTransport::new(NodeId(1), registry.clone());
808        let transport2 = InMemoryTransport::new(NodeId(2), registry);
809
810        // Node 1 sends RequestVote to Node 2
811        let vote_request = RequestVoteRequest {
812            term: crate::raft::Term(1),
813            candidate_id: NodeId(1),
814            last_log_index: crate::raft::LogIndex(0),
815            last_log_term: crate::raft::Term(0),
816        };
817        transport1
818            .send(NodeId(2), Message::RequestVote(vote_request))
819            .await
820            .unwrap();
821
822        // Node 2 receives and responds
823        let (sender, _msg) = transport2.recv().await.unwrap();
824        assert_eq!(sender, NodeId(1));
825
826        let vote_response = RequestVoteResponse {
827            term: crate::raft::Term(1),
828            vote_granted: true,
829        };
830        transport2
831            .send(NodeId(1), Message::RequestVoteResponse(vote_response))
832            .await
833            .unwrap();
834
835        // Node 1 receives response
836        let (sender, received) = transport1.recv().await.unwrap();
837        assert_eq!(sender, NodeId(2));
838        matches!(received, Message::RequestVoteResponse(_));
839    }
840
841    #[tokio::test]
842    async fn test_tcp_transport_send_recv() {
843        let peer_addrs1 = Arc::new(DashMap::new());
844        let peer_addrs2 = Arc::new(DashMap::new());
845
846        let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
847        let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
848
849        let config = TransportConfig::default();
850
851        let transport1 = TcpTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
852            .await
853            .unwrap();
854
855        let transport2 = TcpTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
856            .await
857            .unwrap();
858
859        // Register peers
860        peer_addrs1.insert(NodeId(2), transport2.listen_addr);
861        peer_addrs2.insert(NodeId(1), transport1.listen_addr);
862
863        // Send message from node 1 to node 2
864        let request = RequestVoteRequest {
865            term: crate::raft::Term(1),
866            candidate_id: NodeId(1),
867            last_log_index: crate::raft::LogIndex(0),
868            last_log_term: crate::raft::Term(0),
869        };
870        let message = Message::RequestVote(request);
871
872        transport1.send(NodeId(2), message).await.unwrap();
873
874        // Receive message at node 2
875        let (sender, received) = transport2.recv().await.unwrap();
876        assert_eq!(sender, NodeId(1));
877        matches!(received, Message::RequestVote(_));
878
879        // Cleanup
880        transport1.close().await.unwrap();
881        transport2.close().await.unwrap();
882    }
883
884    #[tokio::test]
885    async fn test_tcp_transport_bidirectional() {
886        let peer_addrs1 = Arc::new(DashMap::new());
887        let peer_addrs2 = Arc::new(DashMap::new());
888
889        let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
890        let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
891
892        let config = TransportConfig::default();
893
894        let transport1 = TcpTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
895            .await
896            .unwrap();
897
898        let transport2 = TcpTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
899            .await
900            .unwrap();
901
902        // Register peers
903        peer_addrs1.insert(NodeId(2), transport2.listen_addr);
904        peer_addrs2.insert(NodeId(1), transport1.listen_addr);
905
906        // Node 1 sends to Node 2
907        let vote_request = RequestVoteRequest {
908            term: crate::raft::Term(1),
909            candidate_id: NodeId(1),
910            last_log_index: crate::raft::LogIndex(0),
911            last_log_term: crate::raft::Term(0),
912        };
913        transport1
914            .send(NodeId(2), Message::RequestVote(vote_request))
915            .await
916            .unwrap();
917
918        // Node 2 receives
919        let (sender, _msg) = transport2.recv().await.unwrap();
920        assert_eq!(sender, NodeId(1));
921
922        // Node 2 responds
923        let vote_response = RequestVoteResponse {
924            term: crate::raft::Term(1),
925            vote_granted: true,
926        };
927        transport2
928            .send(NodeId(1), Message::RequestVoteResponse(vote_response))
929            .await
930            .unwrap();
931
932        // Node 1 receives response
933        let (sender, received) = transport1.recv().await.unwrap();
934        assert_eq!(sender, NodeId(2));
935        matches!(received, Message::RequestVoteResponse(_));
936
937        // Cleanup
938        transport1.close().await.unwrap();
939        transport2.close().await.unwrap();
940    }
941
942    #[cfg(feature = "quic")]
943    #[tokio::test]
944    #[ignore] // QUIC tests need timing refinement
945    async fn test_quic_transport_send_recv() {
946        // Install default crypto provider for rustls
947        let _ = rustls::crypto::ring::default_provider().install_default();
948
949        let peer_addrs1 = Arc::new(DashMap::new());
950        let peer_addrs2 = Arc::new(DashMap::new());
951
952        let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
953        let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
954
955        let config = TransportConfig::default();
956
957        let transport1 = QuicTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
958            .await
959            .unwrap();
960
961        let transport2 = QuicTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
962            .await
963            .unwrap();
964
965        let addr1_actual = transport1.local_addr().unwrap();
966        let addr2_actual = transport2.local_addr().unwrap();
967
968        // Register peers
969        peer_addrs1.insert(NodeId(2), addr2_actual);
970        peer_addrs2.insert(NodeId(1), addr1_actual);
971
972        // Give the listeners time to start
973        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
974
975        // Send message from node 1 to node 2
976        let request = RequestVoteRequest {
977            term: crate::raft::Term(1),
978            candidate_id: NodeId(1),
979            last_log_index: crate::raft::LogIndex(0),
980            last_log_term: crate::raft::Term(0),
981        };
982        let message = Message::RequestVote(request);
983
984        transport1.send(NodeId(2), message).await.unwrap();
985
986        // Receive message at node 2
987        let (sender, received) = transport2.recv().await.unwrap();
988        assert_eq!(sender, NodeId(1));
989        matches!(received, Message::RequestVote(_));
990
991        // Cleanup
992        transport1.close().await.unwrap();
993        transport2.close().await.unwrap();
994    }
995
996    #[cfg(feature = "quic")]
997    #[tokio::test]
998    #[ignore] // QUIC tests need timing refinement
999    async fn test_quic_transport_bidirectional() {
1000        // Install default crypto provider for rustls
1001        let _ = rustls::crypto::ring::default_provider().install_default();
1002
1003        let peer_addrs1 = Arc::new(DashMap::new());
1004        let peer_addrs2 = Arc::new(DashMap::new());
1005
1006        let addr1: SocketAddr = "127.0.0.1:0".parse().unwrap();
1007        let addr2: SocketAddr = "127.0.0.1:0".parse().unwrap();
1008
1009        let config = TransportConfig::default();
1010
1011        let transport1 = QuicTransport::new(NodeId(1), addr1, peer_addrs1.clone(), config.clone())
1012            .await
1013            .unwrap();
1014
1015        let transport2 = QuicTransport::new(NodeId(2), addr2, peer_addrs2.clone(), config)
1016            .await
1017            .unwrap();
1018
1019        let addr1_actual = transport1.local_addr().unwrap();
1020        let addr2_actual = transport2.local_addr().unwrap();
1021
1022        // Register peers
1023        peer_addrs1.insert(NodeId(2), addr2_actual);
1024        peer_addrs2.insert(NodeId(1), addr1_actual);
1025
1026        // Give the listeners time to start
1027        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1028
1029        // Node 1 sends to Node 2
1030        let vote_request = RequestVoteRequest {
1031            term: crate::raft::Term(1),
1032            candidate_id: NodeId(1),
1033            last_log_index: crate::raft::LogIndex(0),
1034            last_log_term: crate::raft::Term(0),
1035        };
1036        transport1
1037            .send(NodeId(2), Message::RequestVote(vote_request))
1038            .await
1039            .unwrap();
1040
1041        // Node 2 receives
1042        let (sender, _msg) = transport2.recv().await.unwrap();
1043        assert_eq!(sender, NodeId(1));
1044
1045        // Node 2 responds
1046        let vote_response = RequestVoteResponse {
1047            term: crate::raft::Term(1),
1048            vote_granted: true,
1049        };
1050        transport2
1051            .send(NodeId(1), Message::RequestVoteResponse(vote_response))
1052            .await
1053            .unwrap();
1054
1055        // Node 1 receives response
1056        let (sender, received) = transport1.recv().await.unwrap();
1057        assert_eq!(sender, NodeId(2));
1058        matches!(received, Message::RequestVoteResponse(_));
1059
1060        // Cleanup
1061        transport1.close().await.unwrap();
1062        transport2.close().await.unwrap();
1063    }
1064}