rabia_engine/network/
tcp.rs

1//! TCP/IP Network Implementation for Rabia Consensus
2//!
3//! This module provides production-ready TCP networking for the Rabia consensus protocol.
4//! It supports:
5//! - Connection management and pooling
6//! - Message framing and serialization
7//! - Node discovery and dynamic topology
8//! - Fault tolerance and automatic reconnection
9//! - Performance optimizations for high throughput
10
11use async_trait::async_trait;
12use bytes::{BufMut, Bytes, BytesMut};
13// use futures_util::SinkExt;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::io::{ReadHalf, WriteHalf};
21use tokio::net::{TcpListener, TcpStream};
22use tokio::sync::{mpsc, Mutex, RwLock};
23use tokio::time::{sleep, timeout};
24use tracing::{debug, error, info, warn};
25
26use rabia_core::{
27    messages::ProtocolMessage, network::NetworkTransport, NodeId, RabiaError, Result,
28};
29
30/// Configuration for TCP networking
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct TcpNetworkConfig {
33    /// Local address to bind to
34    pub bind_addr: SocketAddr,
35    /// Known peer addresses for initial connection
36    pub peer_addresses: HashMap<NodeId, SocketAddr>,
37    /// Connection timeout
38    pub connection_timeout: Duration,
39    /// Keep-alive interval
40    pub keepalive_interval: Duration,
41    /// Maximum message size (in bytes)
42    pub max_message_size: usize,
43    /// Connection retry settings
44    pub retry_config: RetryConfig,
45    /// Buffer sizes
46    pub buffer_config: BufferConfig,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RetryConfig {
51    /// Maximum number of connection attempts
52    pub max_attempts: usize,
53    /// Base delay between retry attempts
54    pub base_delay: Duration,
55    /// Maximum delay between retry attempts
56    pub max_delay: Duration,
57    /// Exponential backoff multiplier
58    pub backoff_multiplier: f64,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct BufferConfig {
63    /// TCP read buffer size
64    pub read_buffer_size: usize,
65    /// TCP write buffer size  
66    pub write_buffer_size: usize,
67    /// Message queue size per connection
68    pub message_queue_size: usize,
69}
70
71impl Default for TcpNetworkConfig {
72    fn default() -> Self {
73        // Adjust timeouts for Windows in CI environments
74        let (connection_timeout, keepalive_interval) =
75            if std::env::var("CI").is_ok() && cfg!(windows) {
76                (Duration::from_secs(30), Duration::from_secs(60))
77            } else {
78                (Duration::from_secs(10), Duration::from_secs(30))
79            };
80
81        Self {
82            bind_addr: "127.0.0.1:0".parse().unwrap(),
83            peer_addresses: HashMap::new(),
84            connection_timeout,
85            keepalive_interval,
86            max_message_size: 16 * 1024 * 1024, // 16MB
87            retry_config: RetryConfig::default(),
88            buffer_config: BufferConfig::default(),
89        }
90    }
91}
92
93impl Default for RetryConfig {
94    fn default() -> Self {
95        Self {
96            max_attempts: 5,
97            base_delay: Duration::from_millis(100),
98            max_delay: Duration::from_secs(30),
99            backoff_multiplier: 2.0,
100        }
101    }
102}
103
104impl Default for BufferConfig {
105    fn default() -> Self {
106        Self {
107            read_buffer_size: 64 * 1024,  // 64KB
108            write_buffer_size: 64 * 1024, // 64KB
109            message_queue_size: 1000,
110        }
111    }
112}
113
114/// Message frame structure for TCP transport
115#[derive(Debug)]
116struct MessageFrame {
117    /// Length of the message payload
118    length: u32,
119    /// Message payload
120    payload: Bytes,
121}
122
123impl MessageFrame {
124    /// Maximum frame size (length field + max payload)
125    const MAX_FRAME_SIZE: usize = 4 + 16 * 1024 * 1024; // 4 bytes + 16MB
126
127    /// Create a new message frame
128    fn new(payload: Bytes) -> Result<Self> {
129        if payload.len() > Self::MAX_FRAME_SIZE - 4 {
130            return Err(RabiaError::network(format!(
131                "Message too large: {} bytes",
132                payload.len()
133            )));
134        }
135
136        Ok(Self {
137            length: payload.len() as u32,
138            payload,
139        })
140    }
141
142    /// Serialize frame to bytes
143    fn to_bytes(&self) -> Bytes {
144        let mut buf = BytesMut::with_capacity(4 + self.payload.len());
145        buf.put_u32_le(self.length);
146        buf.put_slice(&self.payload);
147        buf.freeze()
148    }
149
150    /// Deserialize frame from bytes
151    async fn from_stream<R>(reader: &mut R) -> Result<Self>
152    where
153        R: AsyncReadExt + Unpin,
154    {
155        // Read length field
156        let length = reader
157            .read_u32_le()
158            .await
159            .map_err(|e| RabiaError::network(format!("Failed to read frame length: {}", e)))?;
160
161        if length as usize > Self::MAX_FRAME_SIZE - 4 {
162            return Err(RabiaError::network(format!(
163                "Frame too large: {} bytes",
164                length
165            )));
166        }
167
168        // Read payload
169        let mut payload = vec![0u8; length as usize];
170        reader
171            .read_exact(&mut payload)
172            .await
173            .map_err(|e| RabiaError::network(format!("Failed to read frame payload: {}", e)))?;
174
175        Ok(Self {
176            length,
177            payload: Bytes::from(payload),
178        })
179    }
180}
181
182/// Connection state information
183#[derive(Debug)]
184struct ConnectionInfo {
185    node_id: NodeId,
186    #[allow(dead_code)]
187    addr: SocketAddr,
188    reader: Arc<Mutex<ReadHalf<TcpStream>>>,
189    writer: Arc<Mutex<WriteHalf<TcpStream>>>,
190    last_seen: Instant,
191    outbound_queue: mpsc::UnboundedSender<ProtocolMessage>,
192    #[allow(dead_code)]
193    is_outbound: bool,
194}
195
196/// TCP Network implementation
197pub struct TcpNetwork {
198    /// This node's ID
199    node_id: NodeId,
200    /// Configuration
201    config: TcpNetworkConfig,
202    /// TCP listener for incoming connections
203    #[allow(dead_code)]
204    listener: Option<TcpListener>,
205    /// Active connections by node ID
206    connections: Arc<RwLock<HashMap<NodeId, Arc<ConnectionInfo>>>>,
207    /// Address to node ID mapping
208    addr_to_node: Arc<RwLock<HashMap<SocketAddr, NodeId>>>,
209    /// Incoming message queue
210    message_rx: Arc<Mutex<mpsc::UnboundedReceiver<(NodeId, ProtocolMessage)>>>,
211    message_tx: mpsc::UnboundedSender<(NodeId, ProtocolMessage)>,
212    /// Shutdown signal
213    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
214    shutdown_rx: Arc<Mutex<Option<mpsc::Receiver<()>>>>,
215}
216
217impl TcpNetwork {
218    /// Create a new TCP network instance
219    pub async fn new(node_id: NodeId, config: TcpNetworkConfig) -> Result<Self> {
220        let (message_tx, message_rx) = mpsc::unbounded_channel();
221        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
222
223        let mut network = Self {
224            node_id,
225            config,
226            listener: None,
227            connections: Arc::new(RwLock::new(HashMap::new())),
228            addr_to_node: Arc::new(RwLock::new(HashMap::new())),
229            message_rx: Arc::new(Mutex::new(message_rx)),
230            message_tx,
231            shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
232            shutdown_rx: Arc::new(Mutex::new(Some(shutdown_rx))),
233        };
234
235        // Start TCP listener
236        network.start_listener().await?;
237
238        // Start connection manager
239        network.start_connection_manager().await;
240
241        info!(
242            "TCP network started for node {} on {}",
243            node_id, network.config.bind_addr
244        );
245
246        Ok(network)
247    }
248
249    /// Start TCP listener for incoming connections
250    async fn start_listener(&mut self) -> Result<()> {
251        let listener = TcpListener::bind(&self.config.bind_addr)
252            .await
253            .map_err(|e| {
254                RabiaError::network(format!(
255                    "Failed to bind to {}: {}",
256                    self.config.bind_addr, e
257                ))
258            })?;
259
260        let actual_addr = listener
261            .local_addr()
262            .map_err(|e| RabiaError::network(format!("Failed to get local address: {}", e)))?;
263
264        info!("TCP listener bound to {}", actual_addr);
265        self.config.bind_addr = actual_addr;
266
267        // Spawn listener task
268        let connections = self.connections.clone();
269        let addr_to_node = self.addr_to_node.clone();
270        let message_tx = self.message_tx.clone();
271        let node_id = self.node_id;
272        let config = self.config.clone();
273
274        tokio::spawn(async move {
275            Self::accept_connections(
276                listener,
277                node_id,
278                config,
279                connections,
280                addr_to_node,
281                message_tx,
282            )
283            .await;
284        });
285
286        Ok(())
287    }
288
289    /// Accept incoming TCP connections
290    async fn accept_connections(
291        listener: TcpListener,
292        node_id: NodeId,
293        config: TcpNetworkConfig,
294        connections: Arc<RwLock<HashMap<NodeId, Arc<ConnectionInfo>>>>,
295        addr_to_node: Arc<RwLock<HashMap<SocketAddr, NodeId>>>,
296        message_tx: mpsc::UnboundedSender<(NodeId, ProtocolMessage)>,
297    ) {
298        loop {
299            match listener.accept().await {
300                Ok((stream, addr)) => {
301                    debug!("Accepted connection from {}", addr);
302
303                    let connections = connections.clone();
304                    let addr_to_node = addr_to_node.clone();
305                    let message_tx = message_tx.clone();
306                    let config = config.clone();
307
308                    tokio::spawn(async move {
309                        if let Err(e) = Self::handle_inbound_connection(
310                            stream,
311                            addr,
312                            node_id,
313                            config,
314                            connections,
315                            addr_to_node,
316                            message_tx,
317                        )
318                        .await
319                        {
320                            warn!("Failed to handle inbound connection from {}: {}", addr, e);
321                        }
322                    });
323                }
324                Err(e) => {
325                    error!("Failed to accept connection: {}", e);
326                    sleep(Duration::from_millis(100)).await;
327                }
328            }
329        }
330    }
331
332    /// Handle an inbound TCP connection
333    async fn handle_inbound_connection(
334        mut stream: TcpStream,
335        addr: SocketAddr,
336        local_node_id: NodeId,
337        config: TcpNetworkConfig,
338        connections: Arc<RwLock<HashMap<NodeId, Arc<ConnectionInfo>>>>,
339        addr_to_node: Arc<RwLock<HashMap<SocketAddr, NodeId>>>,
340        message_tx: mpsc::UnboundedSender<(NodeId, ProtocolMessage)>,
341    ) -> Result<()> {
342        // Perform handshake to identify the peer
343        let peer_node_id = Self::perform_inbound_handshake(&mut stream, local_node_id).await?;
344
345        info!(
346            "Established inbound connection from {} ({})",
347            peer_node_id, addr
348        );
349
350        // Create connection info
351        let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
352        let (read_half, write_half) = tokio::io::split(stream);
353        let connection_info = Arc::new(ConnectionInfo {
354            node_id: peer_node_id,
355            addr,
356            reader: Arc::new(Mutex::new(read_half)),
357            writer: Arc::new(Mutex::new(write_half)),
358            last_seen: Instant::now(),
359            outbound_queue: outbound_tx,
360            is_outbound: false,
361        });
362
363        // Register connection
364        {
365            let mut connections = connections.write().await;
366            connections.insert(peer_node_id, connection_info.clone());
367        }
368        {
369            let mut addr_to_node = addr_to_node.write().await;
370            addr_to_node.insert(addr, peer_node_id);
371        }
372
373        // Start connection handler
374        tokio::spawn(Self::run_connection_handler(
375            connection_info,
376            outbound_rx,
377            message_tx,
378            config,
379        ));
380
381        Ok(())
382    }
383
384    /// Perform handshake for inbound connection
385    async fn perform_inbound_handshake(
386        stream: &mut TcpStream,
387        local_node_id: NodeId,
388    ) -> Result<NodeId> {
389        // Simple handshake protocol:
390        // 1. Peer sends their node ID
391        // 2. We send our node ID back
392        // 3. Connection is established
393
394        // Read peer's node ID
395        let frame = MessageFrame::from_stream(stream).await?;
396        let peer_node_id: NodeId = bincode::deserialize(&frame.payload).map_err(|e| {
397            RabiaError::network(format!("Failed to deserialize peer node ID: {}", e))
398        })?;
399
400        // Send our node ID
401        let our_id_bytes = bincode::serialize(&local_node_id)
402            .map_err(|e| RabiaError::network(format!("Failed to serialize node ID: {}", e)))?;
403        let response_frame = MessageFrame::new(Bytes::from(our_id_bytes))?;
404
405        stream
406            .write_all(&response_frame.to_bytes())
407            .await
408            .map_err(|e| {
409                RabiaError::network(format!("Failed to send handshake response: {}", e))
410            })?;
411
412        Ok(peer_node_id)
413    }
414
415    /// Connect to a peer node
416    pub async fn connect_to_peer(&self, peer_node_id: NodeId, addr: SocketAddr) -> Result<()> {
417        // Check if already connected
418        {
419            let connections = self.connections.read().await;
420            if connections.contains_key(&peer_node_id) {
421                debug!("Already connected to peer {}", peer_node_id);
422                return Ok(());
423            }
424        }
425
426        info!("Connecting to peer {} at {}", peer_node_id, addr);
427
428        // Attempt connection with retries
429        let mut attempts = 0;
430        let mut delay = self.config.retry_config.base_delay;
431
432        while attempts < self.config.retry_config.max_attempts {
433            match timeout(self.config.connection_timeout, TcpStream::connect(&addr)).await {
434                Ok(Ok(mut stream)) => {
435                    // Perform outbound handshake
436                    if let Err(e) = self
437                        .perform_outbound_handshake(&mut stream, peer_node_id)
438                        .await
439                    {
440                        warn!("Handshake failed with {}: {}", peer_node_id, e);
441                        attempts += 1;
442                        sleep(delay).await;
443                        delay = Duration::min(
444                            Duration::from_millis(
445                                (delay.as_millis() as f64
446                                    * self.config.retry_config.backoff_multiplier)
447                                    as u64,
448                            ),
449                            self.config.retry_config.max_delay,
450                        );
451                        continue;
452                    }
453
454                    info!(
455                        "Successfully connected to peer {} at {}",
456                        peer_node_id, addr
457                    );
458
459                    // Create connection info
460                    let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
461                    let (read_half, write_half) = tokio::io::split(stream);
462                    let connection_info = Arc::new(ConnectionInfo {
463                        node_id: peer_node_id,
464                        addr,
465                        reader: Arc::new(Mutex::new(read_half)),
466                        writer: Arc::new(Mutex::new(write_half)),
467                        last_seen: Instant::now(),
468                        outbound_queue: outbound_tx,
469                        is_outbound: true,
470                    });
471
472                    // Register connection
473                    {
474                        let mut connections = self.connections.write().await;
475                        connections.insert(peer_node_id, connection_info.clone());
476                    }
477                    {
478                        let mut addr_to_node = self.addr_to_node.write().await;
479                        addr_to_node.insert(addr, peer_node_id);
480                    }
481
482                    // Start connection handler
483                    let message_tx = self.message_tx.clone();
484                    let config = self.config.clone();
485
486                    tokio::spawn(Self::run_connection_handler(
487                        connection_info,
488                        outbound_rx,
489                        message_tx,
490                        config,
491                    ));
492
493                    return Ok(());
494                }
495                Ok(Err(e)) => {
496                    warn!(
497                        "Connection attempt {} to {} failed: {}",
498                        attempts + 1,
499                        addr,
500                        e
501                    );
502                }
503                Err(_) => {
504                    warn!("Connection attempt {} to {} timed out", attempts + 1, addr);
505                }
506            }
507
508            attempts += 1;
509            if attempts < self.config.retry_config.max_attempts {
510                sleep(delay).await;
511                delay = Duration::min(
512                    Duration::from_millis(
513                        (delay.as_millis() as f64 * self.config.retry_config.backoff_multiplier)
514                            as u64,
515                    ),
516                    self.config.retry_config.max_delay,
517                );
518            }
519        }
520
521        Err(RabiaError::network(format!(
522            "Failed to connect to {} after {} attempts",
523            addr, attempts
524        )))
525    }
526
527    /// Perform handshake for outbound connection
528    async fn perform_outbound_handshake(
529        &self,
530        stream: &mut TcpStream,
531        expected_peer_id: NodeId,
532    ) -> Result<()> {
533        // Send our node ID
534        let our_id_bytes = bincode::serialize(&self.node_id)
535            .map_err(|e| RabiaError::network(format!("Failed to serialize node ID: {}", e)))?;
536        let handshake_frame = MessageFrame::new(Bytes::from(our_id_bytes))?;
537
538        stream
539            .write_all(&handshake_frame.to_bytes())
540            .await
541            .map_err(|e| RabiaError::network(format!("Failed to send handshake: {}", e)))?;
542
543        // Read peer's response
544        let frame = MessageFrame::from_stream(stream).await?;
545        let peer_node_id: NodeId = bincode::deserialize(&frame.payload).map_err(|e| {
546            RabiaError::network(format!("Failed to deserialize peer response: {}", e))
547        })?;
548
549        if peer_node_id != expected_peer_id {
550            return Err(RabiaError::network(format!(
551                "Node ID mismatch: expected {}, got {}",
552                expected_peer_id, peer_node_id
553            )));
554        }
555
556        Ok(())
557    }
558
559    /// Run the connection handler for a specific connection
560    async fn run_connection_handler(
561        connection: Arc<ConnectionInfo>,
562        mut outbound_rx: mpsc::UnboundedReceiver<ProtocolMessage>,
563        message_tx: mpsc::UnboundedSender<(NodeId, ProtocolMessage)>,
564        _config: TcpNetworkConfig,
565    ) {
566        let node_id = connection.node_id;
567        info!("Starting connection handler for {}", node_id);
568
569        // Create separate handles for reading and writing
570        let stream_read = connection.reader.clone();
571        let stream_write = connection.writer.clone();
572
573        // Spawn reader task
574        let message_tx_clone = message_tx.clone();
575        let reader_handle = tokio::spawn(async move {
576            loop {
577                let frame_result = {
578                    let mut stream_guard = stream_read.lock().await;
579                    MessageFrame::from_stream(&mut *stream_guard).await
580                };
581
582                match frame_result {
583                    Ok(frame) => match bincode::deserialize::<ProtocolMessage>(&frame.payload) {
584                        Ok(message) => {
585                            if let Err(e) = message_tx_clone.send((node_id, message)) {
586                                debug!("Failed to send message to queue: {}", e);
587                                break;
588                            }
589                        }
590                        Err(e) => {
591                            warn!("Failed to deserialize message from {}: {}", node_id, e);
592                        }
593                    },
594                    Err(e) => {
595                        debug!("Connection to {} closed: {}", node_id, e);
596                        break;
597                    }
598                }
599            }
600        });
601
602        // Spawn writer task
603        let writer_handle = tokio::spawn(async move {
604            while let Some(message) = outbound_rx.recv().await {
605                match bincode::serialize(&message) {
606                    Ok(serialized) => {
607                        let frame = match MessageFrame::new(Bytes::from(serialized)) {
608                            Ok(frame) => frame,
609                            Err(e) => {
610                                warn!("Failed to create frame for message to {}: {}", node_id, e);
611                                continue;
612                            }
613                        };
614
615                        let write_result = {
616                            let mut stream_guard = stream_write.lock().await;
617                            stream_guard.write_all(&frame.to_bytes()).await
618                        };
619
620                        if let Err(e) = write_result {
621                            debug!("Failed to write to {}: {}", node_id, e);
622                            break;
623                        }
624                    }
625                    Err(e) => {
626                        warn!("Failed to serialize message to {}: {}", node_id, e);
627                    }
628                }
629            }
630        });
631
632        // Wait for either task to complete (indicating connection closed)
633        tokio::select! {
634            _ = reader_handle => {
635                debug!("Reader task for {} completed", node_id);
636            }
637            _ = writer_handle => {
638                debug!("Writer task for {} completed", node_id);
639            }
640        }
641
642        info!("Connection handler for {} stopped", node_id);
643    }
644
645    /// Start connection manager for automatic peer connections
646    async fn start_connection_manager(&self) {
647        let peer_addresses = self.config.peer_addresses.clone();
648        let connections = self.connections.clone();
649
650        // Connect to known peers
651        for (peer_id, addr) in peer_addresses {
652            let network = self.clone();
653            tokio::spawn(async move {
654                if let Err(e) = network.connect_to_peer(peer_id, addr).await {
655                    warn!("Failed to connect to peer {} at {}: {}", peer_id, addr, e);
656                }
657            });
658        }
659
660        // Start periodic connection health checks
661        let connections_clone = connections.clone();
662        let keepalive_interval = self.config.keepalive_interval;
663
664        tokio::spawn(async move {
665            let mut interval = tokio::time::interval(keepalive_interval);
666            loop {
667                interval.tick().await;
668
669                let connections = connections_clone.read().await;
670                let now = Instant::now();
671
672                for (node_id, connection) in connections.iter() {
673                    let elapsed = now.duration_since(connection.last_seen);
674                    if elapsed > keepalive_interval * 2 {
675                        warn!(
676                            "Connection to {} appears stale (last seen {:?} ago)",
677                            node_id, elapsed
678                        );
679                        // Future enhancement: implement connection health check and reconnection logic
680                    }
681                }
682            }
683        });
684    }
685
686    /// Get the local bind address
687    pub fn local_addr(&self) -> SocketAddr {
688        self.config.bind_addr
689    }
690
691    /// Get this node's ID
692    pub fn node_id(&self) -> NodeId {
693        self.node_id
694    }
695
696    /// Add a known peer address for automatic connection
697    pub async fn add_peer(&mut self, node_id: NodeId, addr: SocketAddr) {
698        self.config.peer_addresses.insert(node_id, addr);
699
700        // Attempt immediate connection
701        if let Err(e) = self.connect_to_peer(node_id, addr).await {
702            warn!(
703                "Failed to connect to newly added peer {} at {}: {}",
704                node_id, addr, e
705            );
706        }
707    }
708
709    /// Remove a peer
710    pub async fn remove_peer(&mut self, node_id: NodeId) {
711        self.config.peer_addresses.remove(&node_id);
712
713        // Close existing connection
714        let mut connections = self.connections.write().await;
715        if let Some(_connection) = connections.remove(&node_id) {
716            info!("Removed connection to peer {}", node_id);
717            // Connection will be closed when the handler tasks detect the closed stream
718        }
719    }
720
721    /// Shutdown the network
722    pub async fn shutdown(&self) {
723        info!("Shutting down TCP network");
724
725        if let Some(shutdown_tx) = self.shutdown_tx.lock().await.as_ref() {
726            let _ = shutdown_tx.send(()).await;
727        }
728
729        // Close all connections
730        let connections = self.connections.read().await;
731        for (node_id, _) in connections.iter() {
732            debug!("Closing connection to {}", node_id);
733        }
734    }
735}
736
737impl Clone for TcpNetwork {
738    fn clone(&self) -> Self {
739        Self {
740            node_id: self.node_id,
741            config: self.config.clone(),
742            listener: None, // Don't clone the listener
743            connections: self.connections.clone(),
744            addr_to_node: self.addr_to_node.clone(),
745            message_rx: self.message_rx.clone(),
746            message_tx: self.message_tx.clone(),
747            shutdown_tx: self.shutdown_tx.clone(),
748            shutdown_rx: self.shutdown_rx.clone(),
749        }
750    }
751}
752
753#[async_trait]
754impl NetworkTransport for TcpNetwork {
755    async fn send_to(&self, target: NodeId, message: ProtocolMessage) -> Result<()> {
756        let connections = self.connections.read().await;
757
758        if let Some(connection) = connections.get(&target) {
759            connection.outbound_queue.send(message).map_err(|_| {
760                RabiaError::network(format!("Failed to queue message to {}", target))
761            })?;
762            Ok(())
763        } else {
764            Err(RabiaError::network(format!(
765                "No connection to node {}",
766                target
767            )))
768        }
769    }
770
771    async fn broadcast(&self, message: ProtocolMessage, exclude: Option<NodeId>) -> Result<()> {
772        let connections = self.connections.read().await;
773        let mut failed_nodes = Vec::new();
774
775        for (node_id, connection) in connections.iter() {
776            if Some(*node_id) != exclude
777                && *node_id != self.node_id
778                && connection.outbound_queue.send(message.clone()).is_err()
779            {
780                failed_nodes.push(*node_id);
781            }
782        }
783
784        if !failed_nodes.is_empty() {
785            warn!("Failed to broadcast to nodes: {:?}", failed_nodes);
786        }
787
788        Ok(())
789    }
790
791    async fn receive(&mut self) -> Result<(NodeId, ProtocolMessage)> {
792        let mut rx = self.message_rx.lock().await;
793
794        match rx.recv().await {
795            Some((from, message)) => Ok((from, message)),
796            None => Err(RabiaError::network("Message channel closed")),
797        }
798    }
799
800    async fn get_connected_nodes(&self) -> Result<HashSet<NodeId>> {
801        let connections = self.connections.read().await;
802        Ok(connections.keys().copied().collect())
803    }
804
805    async fn is_connected(&self, node_id: NodeId) -> Result<bool> {
806        let connections = self.connections.read().await;
807        Ok(connections.contains_key(&node_id))
808    }
809
810    async fn disconnect(&mut self) -> Result<()> {
811        self.shutdown().await;
812        Ok(())
813    }
814
815    async fn reconnect(&mut self) -> Result<()> {
816        // Attempt to reconnect to all known peers
817        let peer_addresses = self.config.peer_addresses.clone();
818
819        for (peer_id, addr) in peer_addresses {
820            if let Err(e) = self.connect_to_peer(peer_id, addr).await {
821                warn!("Failed to reconnect to peer {} at {}: {}", peer_id, addr, e);
822            }
823        }
824
825        Ok(())
826    }
827}
828
829#[cfg(test)]
830mod tests {
831    use super::*;
832    use std::time::Duration;
833    use tokio::time::sleep;
834
835    #[tokio::test]
836    async fn test_tcp_network_creation() {
837        let node_id = NodeId::new();
838        let config = TcpNetworkConfig::default();
839
840        let network = TcpNetwork::new(node_id, config).await.unwrap();
841        assert_eq!(network.node_id, node_id);
842        assert!(network.local_addr().port() > 0);
843    }
844
845    #[tokio::test]
846    async fn test_message_frame() {
847        let payload = Bytes::from("test message");
848        let frame = MessageFrame::new(payload.clone()).unwrap();
849
850        assert_eq!(frame.length, payload.len() as u32);
851        assert_eq!(frame.payload, payload);
852
853        let serialized = frame.to_bytes();
854        assert!(serialized.len() == 4 + payload.len());
855    }
856
857    #[tokio::test]
858    async fn test_peer_connection() {
859        let node1_id = NodeId::new();
860        let node2_id = NodeId::new();
861
862        let config1 = TcpNetworkConfig {
863            bind_addr: "127.0.0.1:0".parse().unwrap(),
864            ..Default::default()
865        };
866        let config2 = TcpNetworkConfig {
867            bind_addr: "127.0.0.1:0".parse().unwrap(),
868            ..Default::default()
869        };
870
871        let network1 = TcpNetwork::new(node1_id, config1).await.unwrap();
872        let network2 = TcpNetwork::new(node2_id, config2).await.unwrap();
873
874        let _addr1 = network1.local_addr();
875        let addr2 = network2.local_addr();
876
877        // Connect network1 to network2
878        network1.connect_to_peer(node2_id, addr2).await.unwrap();
879
880        // Give some time for connection to establish
881        sleep(Duration::from_millis(100)).await;
882
883        // Check connections
884        assert!(network1.is_connected(node2_id).await.unwrap());
885        assert!(network2.is_connected(node1_id).await.unwrap());
886
887        // Cleanup
888        network1.shutdown().await;
889        network2.shutdown().await;
890    }
891}