Skip to main content

moloch_net/
transport.rs

1//! Transport layer for Moloch network.
2//!
3//! Provides TCP connections with TLS 1.3 encryption, connection pooling,
4//! and automatic reconnection handling.
5
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14use tokio::sync::RwLock;
15use tracing::error;
16
17use crate::protocol::{
18    generate_message_id, DisconnectReason, HelloAckMessage, HelloMessage, Message, MessageCodec,
19    PeerId, ProtocolVersion,
20};
21#[cfg(test)]
22use crate::protocol::{PingMessage, StatusMessage};
23use moloch_core::crypto::{PublicKey, SecretKey};
24
25/// Network configuration.
26#[derive(Debug, Clone)]
27pub struct NetworkConfig {
28    /// Address to listen on.
29    pub listen_addr: SocketAddr,
30    /// Chain ID for network separation.
31    pub chain_id: String,
32    /// Node's secret key for identity.
33    pub node_key: SecretKey,
34    /// TLS configuration.
35    pub tls: TlsConfig,
36    /// Maximum number of connections.
37    pub max_connections: usize,
38    /// Connection timeout.
39    pub connection_timeout: Duration,
40    /// Handshake timeout.
41    pub handshake_timeout: Duration,
42    /// Keep-alive interval (ping frequency).
43    pub keepalive_interval: Duration,
44    /// Maximum time without response before disconnecting.
45    pub idle_timeout: Duration,
46    /// Reconnection backoff (initial delay).
47    pub reconnect_delay: Duration,
48    /// Maximum reconnection attempts.
49    pub max_reconnect_attempts: usize,
50    /// Message codec configuration.
51    pub max_message_size: usize,
52}
53
54impl NetworkConfig {
55    /// Create a new network config builder.
56    pub fn builder() -> NetworkConfigBuilder {
57        NetworkConfigBuilder::default()
58    }
59
60    /// Get the node's public key.
61    pub fn node_pubkey(&self) -> PublicKey {
62        self.node_key.public_key()
63    }
64
65    /// Get the peer ID for this node.
66    pub fn peer_id(&self) -> PeerId {
67        PeerId::new(self.node_pubkey())
68    }
69}
70
71/// Builder for NetworkConfig.
72#[derive(Debug, Default)]
73pub struct NetworkConfigBuilder {
74    listen_addr: Option<SocketAddr>,
75    chain_id: Option<String>,
76    node_key: Option<SecretKey>,
77    tls: Option<TlsConfig>,
78    max_connections: Option<usize>,
79    connection_timeout: Option<Duration>,
80    handshake_timeout: Option<Duration>,
81    keepalive_interval: Option<Duration>,
82    idle_timeout: Option<Duration>,
83    reconnect_delay: Option<Duration>,
84    max_reconnect_attempts: Option<usize>,
85    max_message_size: Option<usize>,
86}
87
88impl NetworkConfigBuilder {
89    /// Set the listen address.
90    pub fn listen_addr(mut self, addr: SocketAddr) -> Self {
91        self.listen_addr = Some(addr);
92        self
93    }
94
95    /// Set the listen address from a string.
96    pub fn listen_addr_str(mut self, addr: &str) -> Result<Self, std::net::AddrParseError> {
97        self.listen_addr = Some(addr.parse()?);
98        Ok(self)
99    }
100
101    /// Set the chain ID.
102    pub fn chain_id(mut self, chain_id: impl Into<String>) -> Self {
103        self.chain_id = Some(chain_id.into());
104        self
105    }
106
107    /// Set the node key.
108    pub fn node_key(mut self, key: SecretKey) -> Self {
109        self.node_key = Some(key);
110        self
111    }
112
113    /// Set TLS configuration.
114    pub fn tls(mut self, tls: TlsConfig) -> Self {
115        self.tls = Some(tls);
116        self
117    }
118
119    /// Set maximum connections.
120    pub fn max_connections(mut self, max: usize) -> Self {
121        self.max_connections = Some(max);
122        self
123    }
124
125    /// Set connection timeout.
126    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
127        self.connection_timeout = Some(timeout);
128        self
129    }
130
131    /// Set handshake timeout.
132    pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
133        self.handshake_timeout = Some(timeout);
134        self
135    }
136
137    /// Set keepalive interval.
138    pub fn keepalive_interval(mut self, interval: Duration) -> Self {
139        self.keepalive_interval = Some(interval);
140        self
141    }
142
143    /// Set idle timeout.
144    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
145        self.idle_timeout = Some(timeout);
146        self
147    }
148
149    /// Build the config.
150    pub fn build(self) -> Result<NetworkConfig, TransportError> {
151        let node_key = self.node_key.unwrap_or_else(SecretKey::generate);
152
153        Ok(NetworkConfig {
154            listen_addr: self
155                .listen_addr
156                .ok_or_else(|| TransportError::Config("listen_addr is required".into()))?,
157            chain_id: self
158                .chain_id
159                .ok_or_else(|| TransportError::Config("chain_id is required".into()))?,
160            node_key,
161            tls: self.tls.unwrap_or_default(),
162            max_connections: self.max_connections.unwrap_or(100),
163            connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(10)),
164            handshake_timeout: self.handshake_timeout.unwrap_or(Duration::from_secs(5)),
165            keepalive_interval: self.keepalive_interval.unwrap_or(Duration::from_secs(30)),
166            idle_timeout: self.idle_timeout.unwrap_or(Duration::from_secs(120)),
167            reconnect_delay: self.reconnect_delay.unwrap_or(Duration::from_secs(1)),
168            max_reconnect_attempts: self.max_reconnect_attempts.unwrap_or(5),
169            max_message_size: self
170                .max_message_size
171                .unwrap_or(MessageCodec::DEFAULT_MAX_SIZE),
172        })
173    }
174}
175
176/// TLS configuration.
177#[derive(Debug, Clone)]
178pub struct TlsConfig {
179    /// Enable TLS.
180    pub enabled: bool,
181    /// Certificate in DER format.
182    pub cert: Option<Vec<u8>>,
183    /// Private key in PKCS#8 DER format.
184    pub key: Option<Vec<u8>>,
185    /// Skip certificate verification (for testing only!).
186    pub skip_verify: bool,
187}
188
189impl Default for TlsConfig {
190    fn default() -> Self {
191        Self {
192            enabled: true,
193            cert: None,
194            key: None,
195            skip_verify: false,
196        }
197    }
198}
199
200impl TlsConfig {
201    /// Create a self-signed certificate for testing.
202    pub fn generate_self_signed(common_name: &str) -> Result<Self, TransportError> {
203        use rcgen::{generate_simple_self_signed, CertifiedKey};
204
205        let subject_alt_names = vec![common_name.to_string()];
206        let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names)
207            .map_err(|e| {
208                TransportError::Tls(format!("failed to generate self-signed cert: {}", e))
209            })?;
210
211        Ok(Self {
212            enabled: true,
213            cert: Some(cert.der().to_vec()),
214            key: Some(key_pair.serialize_der()),
215            skip_verify: false,
216        })
217    }
218}
219
220/// Transport errors.
221#[derive(Debug, thiserror::Error)]
222pub enum TransportError {
223    #[error("I/O error: {0}")]
224    Io(#[from] std::io::Error),
225
226    #[error("TLS error: {0}")]
227    Tls(String),
228
229    #[error("configuration error: {0}")]
230    Config(String),
231
232    #[error("connection failed: {0}")]
233    ConnectionFailed(String),
234
235    #[error("handshake failed: {0}")]
236    HandshakeFailed(String),
237
238    #[error("protocol mismatch: {0}")]
239    ProtocolMismatch(String),
240
241    #[error("chain ID mismatch: expected {expected}, got {got}")]
242    ChainMismatch { expected: String, got: String },
243
244    #[error("connection closed: {0:?}")]
245    ConnectionClosed(DisconnectReason),
246
247    #[error("timeout: {0}")]
248    Timeout(String),
249
250    #[error("too many connections")]
251    TooManyConnections,
252
253    #[error("duplicate connection")]
254    DuplicateConnection,
255
256    #[error("message codec error: {0}")]
257    Codec(#[from] crate::protocol::CodecError),
258
259    #[error("serialization error: {0}")]
260    Serialization(#[from] Box<bincode::ErrorKind>),
261
262    #[error("peer not found: {0}")]
263    PeerNotFound(String),
264}
265
266/// A network connection with a peer.
267#[derive(Debug)]
268pub struct Connection {
269    /// Unique connection ID.
270    pub id: ConnectionId,
271    /// Remote peer ID.
272    pub peer_id: PeerId,
273    /// Remote address.
274    pub remote_addr: SocketAddr,
275    /// Connection state.
276    pub state: ConnectionState,
277    /// When the connection was established.
278    pub connected_at: DateTime<Utc>,
279    /// Last activity timestamp.
280    pub last_activity: DateTime<Utc>,
281    /// Number of messages sent.
282    pub messages_sent: u64,
283    /// Number of messages received.
284    pub messages_received: u64,
285    /// Round-trip latency (from ping/pong).
286    pub latency: Option<Duration>,
287    /// Is this an outbound connection?
288    pub outbound: bool,
289}
290
291/// Unique connection identifier.
292pub type ConnectionId = u64;
293
294/// Connection state.
295#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum ConnectionState {
297    /// Connection is being established.
298    Connecting,
299    /// TLS handshake in progress.
300    TlsHandshake,
301    /// Protocol handshake in progress.
302    Handshaking,
303    /// Connection is active.
304    Active,
305    /// Connection is closing.
306    Closing,
307    /// Connection is closed.
308    Closed,
309}
310
311impl std::fmt::Display for ConnectionState {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        match self {
314            ConnectionState::Connecting => write!(f, "connecting"),
315            ConnectionState::TlsHandshake => write!(f, "tls_handshake"),
316            ConnectionState::Handshaking => write!(f, "handshaking"),
317            ConnectionState::Active => write!(f, "active"),
318            ConnectionState::Closing => write!(f, "closing"),
319            ConnectionState::Closed => write!(f, "closed"),
320        }
321    }
322}
323
324/// Connection pool managing multiple peer connections.
325#[derive(Debug)]
326pub struct ConnectionPool {
327    /// Network configuration.
328    config: Arc<NetworkConfig>,
329    /// Active connections by peer ID.
330    connections: RwLock<HashMap<PeerId, Connection>>,
331    /// Connection count by address (for deduplication).
332    by_address: RwLock<HashMap<SocketAddr, PeerId>>,
333    /// Next connection ID.
334    next_id: std::sync::atomic::AtomicU64,
335}
336
337impl ConnectionPool {
338    /// Create a new connection pool.
339    pub fn new(config: NetworkConfig) -> Self {
340        Self {
341            config: Arc::new(config),
342            connections: RwLock::new(HashMap::new()),
343            by_address: RwLock::new(HashMap::new()),
344            next_id: std::sync::atomic::AtomicU64::new(1),
345        }
346    }
347
348    /// Get the network configuration.
349    pub fn config(&self) -> &NetworkConfig {
350        &self.config
351    }
352
353    /// Generate a new connection ID.
354    fn next_connection_id(&self) -> ConnectionId {
355        self.next_id
356            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
357    }
358
359    /// Get the number of active connections.
360    pub async fn connection_count(&self) -> usize {
361        self.connections.read().await.len()
362    }
363
364    /// Check if we're connected to a peer.
365    pub async fn is_connected(&self, peer_id: &PeerId) -> bool {
366        let conns = self.connections.read().await;
367        conns
368            .get(peer_id)
369            .map(|c| c.state == ConnectionState::Active)
370            .unwrap_or(false)
371    }
372
373    /// Get connection info for a peer.
374    pub async fn get_connection(&self, peer_id: &PeerId) -> Option<Connection> {
375        let conns = self.connections.read().await;
376        // Clone the connection info (not the actual stream)
377        conns.get(peer_id).map(|c| Connection {
378            id: c.id,
379            peer_id: c.peer_id.clone(),
380            remote_addr: c.remote_addr,
381            state: c.state,
382            connected_at: c.connected_at,
383            last_activity: c.last_activity,
384            messages_sent: c.messages_sent,
385            messages_received: c.messages_received,
386            latency: c.latency,
387            outbound: c.outbound,
388        })
389    }
390
391    /// Get all peer IDs.
392    pub async fn peer_ids(&self) -> Vec<PeerId> {
393        self.connections.read().await.keys().cloned().collect()
394    }
395
396    /// Add a connection to the pool.
397    pub async fn add_connection(
398        &self,
399        peer_id: PeerId,
400        remote_addr: SocketAddr,
401        outbound: bool,
402    ) -> Result<ConnectionId, TransportError> {
403        let mut conns = self.connections.write().await;
404        let mut by_addr = self.by_address.write().await;
405
406        // Check limits
407        if conns.len() >= self.config.max_connections {
408            return Err(TransportError::TooManyConnections);
409        }
410
411        // Check for duplicate
412        if conns.contains_key(&peer_id) {
413            return Err(TransportError::DuplicateConnection);
414        }
415
416        let conn_id = self.next_connection_id();
417        let now = Utc::now();
418
419        let conn = Connection {
420            id: conn_id,
421            peer_id: peer_id.clone(),
422            remote_addr,
423            state: ConnectionState::Active,
424            connected_at: now,
425            last_activity: now,
426            messages_sent: 0,
427            messages_received: 0,
428            latency: None,
429            outbound,
430        };
431
432        conns.insert(peer_id.clone(), conn);
433        by_addr.insert(remote_addr, peer_id);
434
435        Ok(conn_id)
436    }
437
438    /// Remove a connection from the pool.
439    pub async fn remove_connection(&self, peer_id: &PeerId) -> Option<Connection> {
440        let mut conns = self.connections.write().await;
441        let mut by_addr = self.by_address.write().await;
442
443        if let Some(conn) = conns.remove(peer_id) {
444            by_addr.remove(&conn.remote_addr);
445            Some(conn)
446        } else {
447            None
448        }
449    }
450
451    /// Update connection state.
452    pub async fn update_state(&self, peer_id: &PeerId, state: ConnectionState) {
453        let mut conns = self.connections.write().await;
454        if let Some(conn) = conns.get_mut(peer_id) {
455            conn.state = state;
456        }
457    }
458
459    /// Record message sent.
460    pub async fn record_sent(&self, peer_id: &PeerId) {
461        let mut conns = self.connections.write().await;
462        if let Some(conn) = conns.get_mut(peer_id) {
463            conn.messages_sent += 1;
464            conn.last_activity = Utc::now();
465        }
466    }
467
468    /// Record message received.
469    pub async fn record_received(&self, peer_id: &PeerId) {
470        let mut conns = self.connections.write().await;
471        if let Some(conn) = conns.get_mut(peer_id) {
472            conn.messages_received += 1;
473            conn.last_activity = Utc::now();
474        }
475    }
476
477    /// Update latency measurement.
478    pub async fn update_latency(&self, peer_id: &PeerId, latency: Duration) {
479        let mut conns = self.connections.write().await;
480        if let Some(conn) = conns.get_mut(peer_id) {
481            conn.latency = Some(latency);
482        }
483    }
484
485    /// Get connections that have been idle too long.
486    pub async fn get_idle_connections(&self, max_idle: Duration) -> Vec<PeerId> {
487        let conns = self.connections.read().await;
488        let now = Utc::now();
489
490        conns
491            .iter()
492            .filter(|(_, conn)| {
493                let idle_time = now.signed_duration_since(conn.last_activity);
494                idle_time.num_milliseconds() > max_idle.as_millis() as i64
495            })
496            .map(|(peer_id, _)| peer_id.clone())
497            .collect()
498    }
499}
500
501/// Transport layer for network communication.
502#[derive(Debug)]
503pub struct Transport {
504    /// Network configuration.
505    config: Arc<NetworkConfig>,
506    /// Connection pool.
507    pool: Arc<ConnectionPool>,
508    /// Message codec.
509    codec: MessageCodec,
510}
511
512impl Transport {
513    /// Create a new transport.
514    pub fn new(config: NetworkConfig) -> Self {
515        let codec = MessageCodec::with_max_size(config.max_message_size);
516        let config = Arc::new(config);
517        let pool = Arc::new(ConnectionPool::new((*config).clone()));
518
519        Self {
520            config,
521            pool,
522            codec,
523        }
524    }
525
526    /// Get the connection pool.
527    pub fn pool(&self) -> &Arc<ConnectionPool> {
528        &self.pool
529    }
530
531    /// Get the network configuration.
532    pub fn config(&self) -> &NetworkConfig {
533        &self.config
534    }
535
536    /// Create a Hello message for handshaking.
537    pub fn create_hello(
538        &self,
539        height: Option<u64>,
540        head_hash: Option<moloch_core::block::BlockHash>,
541    ) -> HelloMessage {
542        let timestamp = Utc::now();
543        let message_bytes = format!(
544            "{}:{}:{}",
545            self.config.chain_id,
546            height.unwrap_or(0),
547            timestamp.timestamp_millis()
548        );
549        let signature = self.config.node_key.sign(message_bytes.as_bytes());
550
551        HelloMessage {
552            id: generate_message_id(),
553            version: ProtocolVersion::CURRENT,
554            chain_id: self.config.chain_id.clone(),
555            node_key: self.config.node_pubkey(),
556            height,
557            head_hash,
558            timestamp,
559            signature,
560        }
561    }
562
563    /// Create a HelloAck response.
564    pub fn create_hello_ack(
565        &self,
566        request_id: u64,
567        height: Option<u64>,
568        head_hash: Option<moloch_core::block::BlockHash>,
569    ) -> HelloAckMessage {
570        let timestamp = Utc::now();
571        let message_bytes = format!(
572            "ack:{}:{}:{}",
573            self.config.chain_id,
574            height.unwrap_or(0),
575            timestamp.timestamp_millis()
576        );
577        let signature = self.config.node_key.sign(message_bytes.as_bytes());
578
579        HelloAckMessage {
580            request_id,
581            version: ProtocolVersion::CURRENT,
582            chain_id: self.config.chain_id.clone(),
583            node_key: self.config.node_pubkey(),
584            height,
585            head_hash,
586            timestamp,
587            signature,
588        }
589    }
590
591    /// Validate a Hello message.
592    pub fn validate_hello(&self, hello: &HelloMessage) -> Result<(), TransportError> {
593        // Check protocol version
594        if !hello.version.is_compatible_with(&ProtocolVersion::CURRENT) {
595            return Err(TransportError::ProtocolMismatch(format!(
596                "incompatible protocol version: {}",
597                hello.version
598            )));
599        }
600
601        // Check chain ID
602        if hello.chain_id != self.config.chain_id {
603            return Err(TransportError::ChainMismatch {
604                expected: self.config.chain_id.clone(),
605                got: hello.chain_id.clone(),
606            });
607        }
608
609        // Verify signature (proves key ownership)
610        let message_bytes = format!(
611            "{}:{}:{}",
612            hello.chain_id,
613            hello.height.unwrap_or(0),
614            hello.timestamp.timestamp_millis()
615        );
616
617        hello
618            .node_key
619            .verify(message_bytes.as_bytes(), &hello.signature)
620            .map_err(|_| TransportError::HandshakeFailed("invalid signature".into()))?;
621
622        Ok(())
623    }
624
625    /// Validate a HelloAck message.
626    pub fn validate_hello_ack(&self, ack: &HelloAckMessage) -> Result<(), TransportError> {
627        // Check protocol version
628        if !ack.version.is_compatible_with(&ProtocolVersion::CURRENT) {
629            return Err(TransportError::ProtocolMismatch(format!(
630                "incompatible protocol version: {}",
631                ack.version
632            )));
633        }
634
635        // Check chain ID
636        if ack.chain_id != self.config.chain_id {
637            return Err(TransportError::ChainMismatch {
638                expected: self.config.chain_id.clone(),
639                got: ack.chain_id.clone(),
640            });
641        }
642
643        // Verify signature
644        let message_bytes = format!(
645            "ack:{}:{}:{}",
646            ack.chain_id,
647            ack.height.unwrap_or(0),
648            ack.timestamp.timestamp_millis()
649        );
650
651        ack.node_key
652            .verify(message_bytes.as_bytes(), &ack.signature)
653            .map_err(|_| TransportError::HandshakeFailed("invalid signature".into()))?;
654
655        Ok(())
656    }
657
658    /// Encode a message for sending.
659    pub fn encode_message(&self, message: &Message) -> Result<Vec<u8>, TransportError> {
660        Ok(self.codec.encode(message)?)
661    }
662
663    /// Decode a received message.
664    pub fn decode_message(&self, data: &[u8]) -> Result<Message, TransportError> {
665        Ok(self.codec.decode(data)?)
666    }
667
668    /// Read a message from an async reader.
669    pub async fn read_message<R: AsyncRead + Unpin>(
670        &self,
671        reader: &mut R,
672    ) -> Result<Message, TransportError> {
673        // Read length prefix
674        let mut len_buf = [0u8; 4];
675        reader.read_exact(&mut len_buf).await?;
676        let length = u32::from_be_bytes(len_buf) as usize;
677
678        if length > self.config.max_message_size {
679            return Err(TransportError::Codec(
680                crate::protocol::CodecError::MessageTooLarge {
681                    size: length,
682                    max: self.config.max_message_size,
683                },
684            ));
685        }
686
687        // Read payload
688        let mut payload = vec![0u8; length];
689        reader.read_exact(&mut payload).await?;
690
691        let message = bincode::deserialize(&payload)?;
692        Ok(message)
693    }
694
695    /// Write a message to an async writer.
696    pub async fn write_message<W: AsyncWrite + Unpin>(
697        &self,
698        writer: &mut W,
699        message: &Message,
700    ) -> Result<(), TransportError> {
701        let frame = self.encode_message(message)?;
702        writer.write_all(&frame).await?;
703        writer.flush().await?;
704        Ok(())
705    }
706}
707
708/// Statistics for the transport layer.
709#[derive(Debug, Clone, Default, Serialize, Deserialize)]
710pub struct TransportStats {
711    /// Total connections established.
712    pub connections_established: u64,
713    /// Total connections closed.
714    pub connections_closed: u64,
715    /// Total messages sent.
716    pub messages_sent: u64,
717    /// Total messages received.
718    pub messages_received: u64,
719    /// Total bytes sent.
720    pub bytes_sent: u64,
721    /// Total bytes received.
722    pub bytes_received: u64,
723    /// Current active connections.
724    pub active_connections: usize,
725    /// Failed connection attempts.
726    pub connection_failures: u64,
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732
733    fn test_config() -> NetworkConfig {
734        NetworkConfig::builder()
735            .listen_addr_str("127.0.0.1:0")
736            .unwrap()
737            .chain_id("moloch-test")
738            .node_key(SecretKey::generate())
739            .tls(TlsConfig {
740                enabled: false,
741                ..Default::default()
742            })
743            .build()
744            .unwrap()
745    }
746
747    #[test]
748    fn test_network_config_builder() {
749        let config = test_config();
750        assert_eq!(config.chain_id, "moloch-test");
751        assert_eq!(config.max_connections, 100);
752    }
753
754    #[test]
755    fn test_network_config_builder_missing_fields() {
756        let result = NetworkConfig::builder().build();
757        assert!(result.is_err());
758    }
759
760    #[test]
761    fn test_tls_config_self_signed() {
762        let config = TlsConfig::generate_self_signed("localhost").unwrap();
763        assert!(config.enabled);
764        assert!(config.cert.is_some());
765        assert!(config.key.is_some());
766    }
767
768    #[test]
769    fn test_connection_state_display() {
770        assert_eq!(format!("{}", ConnectionState::Active), "active");
771        assert_eq!(format!("{}", ConnectionState::Connecting), "connecting");
772    }
773
774    #[tokio::test]
775    async fn test_connection_pool_add_remove() {
776        let config = test_config();
777        let pool = ConnectionPool::new(config);
778
779        let key = SecretKey::generate();
780        let peer_id = PeerId::new(key.public_key());
781        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
782
783        // Add connection
784        let conn_id = pool
785            .add_connection(peer_id.clone(), addr, true)
786            .await
787            .unwrap();
788        assert!(conn_id > 0);
789        assert_eq!(pool.connection_count().await, 1);
790        assert!(pool.is_connected(&peer_id).await);
791
792        // Remove connection
793        let removed = pool.remove_connection(&peer_id).await;
794        assert!(removed.is_some());
795        assert_eq!(pool.connection_count().await, 0);
796        assert!(!pool.is_connected(&peer_id).await);
797    }
798
799    #[tokio::test]
800    async fn test_connection_pool_duplicate() {
801        let config = test_config();
802        let pool = ConnectionPool::new(config);
803
804        let key = SecretKey::generate();
805        let peer_id = PeerId::new(key.public_key());
806        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
807
808        pool.add_connection(peer_id.clone(), addr, true)
809            .await
810            .unwrap();
811
812        // Try to add duplicate
813        let result = pool.add_connection(peer_id, addr, true).await;
814        assert!(matches!(result, Err(TransportError::DuplicateConnection)));
815    }
816
817    #[tokio::test]
818    async fn test_connection_pool_max_connections() {
819        let mut config = test_config();
820        config.max_connections = 2;
821        let pool = ConnectionPool::new(config);
822
823        // Add two connections
824        for i in 0..2 {
825            let key = SecretKey::generate();
826            let peer_id = PeerId::new(key.public_key());
827            let addr: SocketAddr = format!("127.0.0.1:800{}", i).parse().unwrap();
828            pool.add_connection(peer_id, addr, true).await.unwrap();
829        }
830
831        // Third should fail
832        let key = SecretKey::generate();
833        let peer_id = PeerId::new(key.public_key());
834        let addr: SocketAddr = "127.0.0.1:8002".parse().unwrap();
835        let result = pool.add_connection(peer_id, addr, true).await;
836        assert!(matches!(result, Err(TransportError::TooManyConnections)));
837    }
838
839    #[tokio::test]
840    async fn test_connection_pool_stats() {
841        let config = test_config();
842        let pool = ConnectionPool::new(config);
843
844        let key = SecretKey::generate();
845        let peer_id = PeerId::new(key.public_key());
846        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
847
848        pool.add_connection(peer_id.clone(), addr, true)
849            .await
850            .unwrap();
851
852        // Record activity
853        pool.record_sent(&peer_id).await;
854        pool.record_sent(&peer_id).await;
855        pool.record_received(&peer_id).await;
856        pool.update_latency(&peer_id, Duration::from_millis(50))
857            .await;
858
859        let conn = pool.get_connection(&peer_id).await.unwrap();
860        assert_eq!(conn.messages_sent, 2);
861        assert_eq!(conn.messages_received, 1);
862        assert_eq!(conn.latency, Some(Duration::from_millis(50)));
863    }
864
865    #[test]
866    fn test_transport_hello_creation() {
867        let config = test_config();
868        let transport = Transport::new(config.clone());
869
870        let hello = transport.create_hello(Some(100), None);
871
872        assert_eq!(hello.chain_id, "moloch-test");
873        assert_eq!(hello.height, Some(100));
874        assert!(hello.version.is_compatible_with(&ProtocolVersion::CURRENT));
875    }
876
877    #[test]
878    fn test_transport_hello_validation() {
879        let config = test_config();
880        let transport = Transport::new(config);
881
882        let hello = transport.create_hello(Some(100), None);
883        let result = transport.validate_hello(&hello);
884        assert!(result.is_ok());
885    }
886
887    #[test]
888    fn test_transport_hello_wrong_chain() {
889        let config = test_config();
890        let transport = Transport::new(config);
891
892        let mut hello = transport.create_hello(Some(100), None);
893        hello.chain_id = "wrong-chain".into();
894
895        let result = transport.validate_hello(&hello);
896        assert!(matches!(result, Err(TransportError::ChainMismatch { .. })));
897    }
898
899    #[test]
900    fn test_transport_hello_ack() {
901        let config = test_config();
902        let transport = Transport::new(config);
903
904        let hello = transport.create_hello(Some(100), None);
905        let ack = transport.create_hello_ack(hello.id, Some(50), None);
906
907        assert_eq!(ack.request_id, hello.id);
908        assert!(transport.validate_hello_ack(&ack).is_ok());
909    }
910
911    #[tokio::test]
912    async fn test_transport_message_roundtrip() {
913        let config = test_config();
914        let transport = Transport::new(config);
915
916        let message = Message::Status(StatusMessage {
917            height: Some(100),
918            head_hash: None,
919            peer_count: 5,
920            syncing: false,
921            timestamp: Utc::now(),
922        });
923
924        let encoded = transport.encode_message(&message).unwrap();
925        let decoded = transport.decode_message(&encoded).unwrap();
926
927        match decoded {
928            Message::Status(s) => {
929                assert_eq!(s.height, Some(100));
930                assert_eq!(s.peer_count, 5);
931            }
932            _ => panic!("wrong message type"),
933        }
934    }
935
936    #[tokio::test]
937    async fn test_transport_async_message_io() {
938        use tokio::io::duplex;
939
940        let config = test_config();
941        let transport = Transport::new(config);
942
943        let (mut client, mut server) = duplex(1024);
944
945        let message = Message::Ping(PingMessage {
946            id: 42,
947            timestamp: Utc::now(),
948        });
949
950        // Write message
951        transport
952            .write_message(&mut client, &message)
953            .await
954            .unwrap();
955
956        // Read message
957        let received = transport.read_message(&mut server).await.unwrap();
958
959        match received {
960            Message::Ping(p) => assert_eq!(p.id, 42),
961            _ => panic!("wrong message type"),
962        }
963    }
964
965    #[tokio::test]
966    async fn test_get_idle_connections() {
967        let mut config = test_config();
968        config.idle_timeout = Duration::from_millis(100);
969        let pool = ConnectionPool::new(config);
970
971        let key = SecretKey::generate();
972        let peer_id = PeerId::new(key.public_key());
973        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
974
975        pool.add_connection(peer_id.clone(), addr, true)
976            .await
977            .unwrap();
978
979        // Initially not idle
980        let idle = pool.get_idle_connections(Duration::from_millis(100)).await;
981        assert!(idle.is_empty());
982
983        // Wait and check again
984        tokio::time::sleep(Duration::from_millis(150)).await;
985        let idle = pool.get_idle_connections(Duration::from_millis(100)).await;
986        assert_eq!(idle.len(), 1);
987        assert_eq!(idle[0], peer_id);
988    }
989}