1use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14use tokio::sync::RwLock;
15use tracing::error;
16
17use crate::protocol::{
18 generate_message_id, DisconnectReason, HelloAckMessage, HelloMessage, Message, MessageCodec,
19 PeerId, ProtocolVersion,
20};
21#[cfg(test)]
22use crate::protocol::{PingMessage, StatusMessage};
23use moloch_core::crypto::{PublicKey, SecretKey};
24
25#[derive(Debug, Clone)]
27pub struct NetworkConfig {
28 pub listen_addr: SocketAddr,
30 pub chain_id: String,
32 pub node_key: SecretKey,
34 pub tls: TlsConfig,
36 pub max_connections: usize,
38 pub connection_timeout: Duration,
40 pub handshake_timeout: Duration,
42 pub keepalive_interval: Duration,
44 pub idle_timeout: Duration,
46 pub reconnect_delay: Duration,
48 pub max_reconnect_attempts: usize,
50 pub max_message_size: usize,
52}
53
54impl NetworkConfig {
55 pub fn builder() -> NetworkConfigBuilder {
57 NetworkConfigBuilder::default()
58 }
59
60 pub fn node_pubkey(&self) -> PublicKey {
62 self.node_key.public_key()
63 }
64
65 pub fn peer_id(&self) -> PeerId {
67 PeerId::new(self.node_pubkey())
68 }
69}
70
71#[derive(Debug, Default)]
73pub struct NetworkConfigBuilder {
74 listen_addr: Option<SocketAddr>,
75 chain_id: Option<String>,
76 node_key: Option<SecretKey>,
77 tls: Option<TlsConfig>,
78 max_connections: Option<usize>,
79 connection_timeout: Option<Duration>,
80 handshake_timeout: Option<Duration>,
81 keepalive_interval: Option<Duration>,
82 idle_timeout: Option<Duration>,
83 reconnect_delay: Option<Duration>,
84 max_reconnect_attempts: Option<usize>,
85 max_message_size: Option<usize>,
86}
87
88impl NetworkConfigBuilder {
89 pub fn listen_addr(mut self, addr: SocketAddr) -> Self {
91 self.listen_addr = Some(addr);
92 self
93 }
94
95 pub fn listen_addr_str(mut self, addr: &str) -> Result<Self, std::net::AddrParseError> {
97 self.listen_addr = Some(addr.parse()?);
98 Ok(self)
99 }
100
101 pub fn chain_id(mut self, chain_id: impl Into<String>) -> Self {
103 self.chain_id = Some(chain_id.into());
104 self
105 }
106
107 pub fn node_key(mut self, key: SecretKey) -> Self {
109 self.node_key = Some(key);
110 self
111 }
112
113 pub fn tls(mut self, tls: TlsConfig) -> Self {
115 self.tls = Some(tls);
116 self
117 }
118
119 pub fn max_connections(mut self, max: usize) -> Self {
121 self.max_connections = Some(max);
122 self
123 }
124
125 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
127 self.connection_timeout = Some(timeout);
128 self
129 }
130
131 pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
133 self.handshake_timeout = Some(timeout);
134 self
135 }
136
137 pub fn keepalive_interval(mut self, interval: Duration) -> Self {
139 self.keepalive_interval = Some(interval);
140 self
141 }
142
143 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
145 self.idle_timeout = Some(timeout);
146 self
147 }
148
149 pub fn build(self) -> Result<NetworkConfig, TransportError> {
151 let node_key = self.node_key.unwrap_or_else(SecretKey::generate);
152
153 Ok(NetworkConfig {
154 listen_addr: self
155 .listen_addr
156 .ok_or_else(|| TransportError::Config("listen_addr is required".into()))?,
157 chain_id: self
158 .chain_id
159 .ok_or_else(|| TransportError::Config("chain_id is required".into()))?,
160 node_key,
161 tls: self.tls.unwrap_or_default(),
162 max_connections: self.max_connections.unwrap_or(100),
163 connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(10)),
164 handshake_timeout: self.handshake_timeout.unwrap_or(Duration::from_secs(5)),
165 keepalive_interval: self.keepalive_interval.unwrap_or(Duration::from_secs(30)),
166 idle_timeout: self.idle_timeout.unwrap_or(Duration::from_secs(120)),
167 reconnect_delay: self.reconnect_delay.unwrap_or(Duration::from_secs(1)),
168 max_reconnect_attempts: self.max_reconnect_attempts.unwrap_or(5),
169 max_message_size: self
170 .max_message_size
171 .unwrap_or(MessageCodec::DEFAULT_MAX_SIZE),
172 })
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct TlsConfig {
179 pub enabled: bool,
181 pub cert: Option<Vec<u8>>,
183 pub key: Option<Vec<u8>>,
185 pub skip_verify: bool,
187}
188
189impl Default for TlsConfig {
190 fn default() -> Self {
191 Self {
192 enabled: true,
193 cert: None,
194 key: None,
195 skip_verify: false,
196 }
197 }
198}
199
200impl TlsConfig {
201 pub fn generate_self_signed(common_name: &str) -> Result<Self, TransportError> {
203 use rcgen::{generate_simple_self_signed, CertifiedKey};
204
205 let subject_alt_names = vec![common_name.to_string()];
206 let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names)
207 .map_err(|e| {
208 TransportError::Tls(format!("failed to generate self-signed cert: {}", e))
209 })?;
210
211 Ok(Self {
212 enabled: true,
213 cert: Some(cert.der().to_vec()),
214 key: Some(key_pair.serialize_der()),
215 skip_verify: false,
216 })
217 }
218}
219
220#[derive(Debug, thiserror::Error)]
222pub enum TransportError {
223 #[error("I/O error: {0}")]
224 Io(#[from] std::io::Error),
225
226 #[error("TLS error: {0}")]
227 Tls(String),
228
229 #[error("configuration error: {0}")]
230 Config(String),
231
232 #[error("connection failed: {0}")]
233 ConnectionFailed(String),
234
235 #[error("handshake failed: {0}")]
236 HandshakeFailed(String),
237
238 #[error("protocol mismatch: {0}")]
239 ProtocolMismatch(String),
240
241 #[error("chain ID mismatch: expected {expected}, got {got}")]
242 ChainMismatch { expected: String, got: String },
243
244 #[error("connection closed: {0:?}")]
245 ConnectionClosed(DisconnectReason),
246
247 #[error("timeout: {0}")]
248 Timeout(String),
249
250 #[error("too many connections")]
251 TooManyConnections,
252
253 #[error("duplicate connection")]
254 DuplicateConnection,
255
256 #[error("message codec error: {0}")]
257 Codec(#[from] crate::protocol::CodecError),
258
259 #[error("serialization error: {0}")]
260 Serialization(#[from] Box<bincode::ErrorKind>),
261
262 #[error("peer not found: {0}")]
263 PeerNotFound(String),
264}
265
266#[derive(Debug)]
268pub struct Connection {
269 pub id: ConnectionId,
271 pub peer_id: PeerId,
273 pub remote_addr: SocketAddr,
275 pub state: ConnectionState,
277 pub connected_at: DateTime<Utc>,
279 pub last_activity: DateTime<Utc>,
281 pub messages_sent: u64,
283 pub messages_received: u64,
285 pub latency: Option<Duration>,
287 pub outbound: bool,
289}
290
291pub type ConnectionId = u64;
293
294#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum ConnectionState {
297 Connecting,
299 TlsHandshake,
301 Handshaking,
303 Active,
305 Closing,
307 Closed,
309}
310
311impl std::fmt::Display for ConnectionState {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 match self {
314 ConnectionState::Connecting => write!(f, "connecting"),
315 ConnectionState::TlsHandshake => write!(f, "tls_handshake"),
316 ConnectionState::Handshaking => write!(f, "handshaking"),
317 ConnectionState::Active => write!(f, "active"),
318 ConnectionState::Closing => write!(f, "closing"),
319 ConnectionState::Closed => write!(f, "closed"),
320 }
321 }
322}
323
324#[derive(Debug)]
326pub struct ConnectionPool {
327 config: Arc<NetworkConfig>,
329 connections: RwLock<HashMap<PeerId, Connection>>,
331 by_address: RwLock<HashMap<SocketAddr, PeerId>>,
333 next_id: std::sync::atomic::AtomicU64,
335}
336
337impl ConnectionPool {
338 pub fn new(config: NetworkConfig) -> Self {
340 Self {
341 config: Arc::new(config),
342 connections: RwLock::new(HashMap::new()),
343 by_address: RwLock::new(HashMap::new()),
344 next_id: std::sync::atomic::AtomicU64::new(1),
345 }
346 }
347
348 pub fn config(&self) -> &NetworkConfig {
350 &self.config
351 }
352
353 fn next_connection_id(&self) -> ConnectionId {
355 self.next_id
356 .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
357 }
358
359 pub async fn connection_count(&self) -> usize {
361 self.connections.read().await.len()
362 }
363
364 pub async fn is_connected(&self, peer_id: &PeerId) -> bool {
366 let conns = self.connections.read().await;
367 conns
368 .get(peer_id)
369 .map(|c| c.state == ConnectionState::Active)
370 .unwrap_or(false)
371 }
372
373 pub async fn get_connection(&self, peer_id: &PeerId) -> Option<Connection> {
375 let conns = self.connections.read().await;
376 conns.get(peer_id).map(|c| Connection {
378 id: c.id,
379 peer_id: c.peer_id.clone(),
380 remote_addr: c.remote_addr,
381 state: c.state,
382 connected_at: c.connected_at,
383 last_activity: c.last_activity,
384 messages_sent: c.messages_sent,
385 messages_received: c.messages_received,
386 latency: c.latency,
387 outbound: c.outbound,
388 })
389 }
390
391 pub async fn peer_ids(&self) -> Vec<PeerId> {
393 self.connections.read().await.keys().cloned().collect()
394 }
395
396 pub async fn add_connection(
398 &self,
399 peer_id: PeerId,
400 remote_addr: SocketAddr,
401 outbound: bool,
402 ) -> Result<ConnectionId, TransportError> {
403 let mut conns = self.connections.write().await;
404 let mut by_addr = self.by_address.write().await;
405
406 if conns.len() >= self.config.max_connections {
408 return Err(TransportError::TooManyConnections);
409 }
410
411 if conns.contains_key(&peer_id) {
413 return Err(TransportError::DuplicateConnection);
414 }
415
416 let conn_id = self.next_connection_id();
417 let now = Utc::now();
418
419 let conn = Connection {
420 id: conn_id,
421 peer_id: peer_id.clone(),
422 remote_addr,
423 state: ConnectionState::Active,
424 connected_at: now,
425 last_activity: now,
426 messages_sent: 0,
427 messages_received: 0,
428 latency: None,
429 outbound,
430 };
431
432 conns.insert(peer_id.clone(), conn);
433 by_addr.insert(remote_addr, peer_id);
434
435 Ok(conn_id)
436 }
437
438 pub async fn remove_connection(&self, peer_id: &PeerId) -> Option<Connection> {
440 let mut conns = self.connections.write().await;
441 let mut by_addr = self.by_address.write().await;
442
443 if let Some(conn) = conns.remove(peer_id) {
444 by_addr.remove(&conn.remote_addr);
445 Some(conn)
446 } else {
447 None
448 }
449 }
450
451 pub async fn update_state(&self, peer_id: &PeerId, state: ConnectionState) {
453 let mut conns = self.connections.write().await;
454 if let Some(conn) = conns.get_mut(peer_id) {
455 conn.state = state;
456 }
457 }
458
459 pub async fn record_sent(&self, peer_id: &PeerId) {
461 let mut conns = self.connections.write().await;
462 if let Some(conn) = conns.get_mut(peer_id) {
463 conn.messages_sent += 1;
464 conn.last_activity = Utc::now();
465 }
466 }
467
468 pub async fn record_received(&self, peer_id: &PeerId) {
470 let mut conns = self.connections.write().await;
471 if let Some(conn) = conns.get_mut(peer_id) {
472 conn.messages_received += 1;
473 conn.last_activity = Utc::now();
474 }
475 }
476
477 pub async fn update_latency(&self, peer_id: &PeerId, latency: Duration) {
479 let mut conns = self.connections.write().await;
480 if let Some(conn) = conns.get_mut(peer_id) {
481 conn.latency = Some(latency);
482 }
483 }
484
485 pub async fn get_idle_connections(&self, max_idle: Duration) -> Vec<PeerId> {
487 let conns = self.connections.read().await;
488 let now = Utc::now();
489
490 conns
491 .iter()
492 .filter(|(_, conn)| {
493 let idle_time = now.signed_duration_since(conn.last_activity);
494 idle_time.num_milliseconds() > max_idle.as_millis() as i64
495 })
496 .map(|(peer_id, _)| peer_id.clone())
497 .collect()
498 }
499}
500
501#[derive(Debug)]
503pub struct Transport {
504 config: Arc<NetworkConfig>,
506 pool: Arc<ConnectionPool>,
508 codec: MessageCodec,
510}
511
512impl Transport {
513 pub fn new(config: NetworkConfig) -> Self {
515 let codec = MessageCodec::with_max_size(config.max_message_size);
516 let config = Arc::new(config);
517 let pool = Arc::new(ConnectionPool::new((*config).clone()));
518
519 Self {
520 config,
521 pool,
522 codec,
523 }
524 }
525
526 pub fn pool(&self) -> &Arc<ConnectionPool> {
528 &self.pool
529 }
530
531 pub fn config(&self) -> &NetworkConfig {
533 &self.config
534 }
535
536 pub fn create_hello(
538 &self,
539 height: Option<u64>,
540 head_hash: Option<moloch_core::block::BlockHash>,
541 ) -> HelloMessage {
542 let timestamp = Utc::now();
543 let message_bytes = format!(
544 "{}:{}:{}",
545 self.config.chain_id,
546 height.unwrap_or(0),
547 timestamp.timestamp_millis()
548 );
549 let signature = self.config.node_key.sign(message_bytes.as_bytes());
550
551 HelloMessage {
552 id: generate_message_id(),
553 version: ProtocolVersion::CURRENT,
554 chain_id: self.config.chain_id.clone(),
555 node_key: self.config.node_pubkey(),
556 height,
557 head_hash,
558 timestamp,
559 signature,
560 }
561 }
562
563 pub fn create_hello_ack(
565 &self,
566 request_id: u64,
567 height: Option<u64>,
568 head_hash: Option<moloch_core::block::BlockHash>,
569 ) -> HelloAckMessage {
570 let timestamp = Utc::now();
571 let message_bytes = format!(
572 "ack:{}:{}:{}",
573 self.config.chain_id,
574 height.unwrap_or(0),
575 timestamp.timestamp_millis()
576 );
577 let signature = self.config.node_key.sign(message_bytes.as_bytes());
578
579 HelloAckMessage {
580 request_id,
581 version: ProtocolVersion::CURRENT,
582 chain_id: self.config.chain_id.clone(),
583 node_key: self.config.node_pubkey(),
584 height,
585 head_hash,
586 timestamp,
587 signature,
588 }
589 }
590
591 pub fn validate_hello(&self, hello: &HelloMessage) -> Result<(), TransportError> {
593 if !hello.version.is_compatible_with(&ProtocolVersion::CURRENT) {
595 return Err(TransportError::ProtocolMismatch(format!(
596 "incompatible protocol version: {}",
597 hello.version
598 )));
599 }
600
601 if hello.chain_id != self.config.chain_id {
603 return Err(TransportError::ChainMismatch {
604 expected: self.config.chain_id.clone(),
605 got: hello.chain_id.clone(),
606 });
607 }
608
609 let message_bytes = format!(
611 "{}:{}:{}",
612 hello.chain_id,
613 hello.height.unwrap_or(0),
614 hello.timestamp.timestamp_millis()
615 );
616
617 hello
618 .node_key
619 .verify(message_bytes.as_bytes(), &hello.signature)
620 .map_err(|_| TransportError::HandshakeFailed("invalid signature".into()))?;
621
622 Ok(())
623 }
624
625 pub fn validate_hello_ack(&self, ack: &HelloAckMessage) -> Result<(), TransportError> {
627 if !ack.version.is_compatible_with(&ProtocolVersion::CURRENT) {
629 return Err(TransportError::ProtocolMismatch(format!(
630 "incompatible protocol version: {}",
631 ack.version
632 )));
633 }
634
635 if ack.chain_id != self.config.chain_id {
637 return Err(TransportError::ChainMismatch {
638 expected: self.config.chain_id.clone(),
639 got: ack.chain_id.clone(),
640 });
641 }
642
643 let message_bytes = format!(
645 "ack:{}:{}:{}",
646 ack.chain_id,
647 ack.height.unwrap_or(0),
648 ack.timestamp.timestamp_millis()
649 );
650
651 ack.node_key
652 .verify(message_bytes.as_bytes(), &ack.signature)
653 .map_err(|_| TransportError::HandshakeFailed("invalid signature".into()))?;
654
655 Ok(())
656 }
657
658 pub fn encode_message(&self, message: &Message) -> Result<Vec<u8>, TransportError> {
660 Ok(self.codec.encode(message)?)
661 }
662
663 pub fn decode_message(&self, data: &[u8]) -> Result<Message, TransportError> {
665 Ok(self.codec.decode(data)?)
666 }
667
668 pub async fn read_message<R: AsyncRead + Unpin>(
670 &self,
671 reader: &mut R,
672 ) -> Result<Message, TransportError> {
673 let mut len_buf = [0u8; 4];
675 reader.read_exact(&mut len_buf).await?;
676 let length = u32::from_be_bytes(len_buf) as usize;
677
678 if length > self.config.max_message_size {
679 return Err(TransportError::Codec(
680 crate::protocol::CodecError::MessageTooLarge {
681 size: length,
682 max: self.config.max_message_size,
683 },
684 ));
685 }
686
687 let mut payload = vec![0u8; length];
689 reader.read_exact(&mut payload).await?;
690
691 let message = bincode::deserialize(&payload)?;
692 Ok(message)
693 }
694
695 pub async fn write_message<W: AsyncWrite + Unpin>(
697 &self,
698 writer: &mut W,
699 message: &Message,
700 ) -> Result<(), TransportError> {
701 let frame = self.encode_message(message)?;
702 writer.write_all(&frame).await?;
703 writer.flush().await?;
704 Ok(())
705 }
706}
707
708#[derive(Debug, Clone, Default, Serialize, Deserialize)]
710pub struct TransportStats {
711 pub connections_established: u64,
713 pub connections_closed: u64,
715 pub messages_sent: u64,
717 pub messages_received: u64,
719 pub bytes_sent: u64,
721 pub bytes_received: u64,
723 pub active_connections: usize,
725 pub connection_failures: u64,
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732
733 fn test_config() -> NetworkConfig {
734 NetworkConfig::builder()
735 .listen_addr_str("127.0.0.1:0")
736 .unwrap()
737 .chain_id("moloch-test")
738 .node_key(SecretKey::generate())
739 .tls(TlsConfig {
740 enabled: false,
741 ..Default::default()
742 })
743 .build()
744 .unwrap()
745 }
746
747 #[test]
748 fn test_network_config_builder() {
749 let config = test_config();
750 assert_eq!(config.chain_id, "moloch-test");
751 assert_eq!(config.max_connections, 100);
752 }
753
754 #[test]
755 fn test_network_config_builder_missing_fields() {
756 let result = NetworkConfig::builder().build();
757 assert!(result.is_err());
758 }
759
760 #[test]
761 fn test_tls_config_self_signed() {
762 let config = TlsConfig::generate_self_signed("localhost").unwrap();
763 assert!(config.enabled);
764 assert!(config.cert.is_some());
765 assert!(config.key.is_some());
766 }
767
768 #[test]
769 fn test_connection_state_display() {
770 assert_eq!(format!("{}", ConnectionState::Active), "active");
771 assert_eq!(format!("{}", ConnectionState::Connecting), "connecting");
772 }
773
774 #[tokio::test]
775 async fn test_connection_pool_add_remove() {
776 let config = test_config();
777 let pool = ConnectionPool::new(config);
778
779 let key = SecretKey::generate();
780 let peer_id = PeerId::new(key.public_key());
781 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
782
783 let conn_id = pool
785 .add_connection(peer_id.clone(), addr, true)
786 .await
787 .unwrap();
788 assert!(conn_id > 0);
789 assert_eq!(pool.connection_count().await, 1);
790 assert!(pool.is_connected(&peer_id).await);
791
792 let removed = pool.remove_connection(&peer_id).await;
794 assert!(removed.is_some());
795 assert_eq!(pool.connection_count().await, 0);
796 assert!(!pool.is_connected(&peer_id).await);
797 }
798
799 #[tokio::test]
800 async fn test_connection_pool_duplicate() {
801 let config = test_config();
802 let pool = ConnectionPool::new(config);
803
804 let key = SecretKey::generate();
805 let peer_id = PeerId::new(key.public_key());
806 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
807
808 pool.add_connection(peer_id.clone(), addr, true)
809 .await
810 .unwrap();
811
812 let result = pool.add_connection(peer_id, addr, true).await;
814 assert!(matches!(result, Err(TransportError::DuplicateConnection)));
815 }
816
817 #[tokio::test]
818 async fn test_connection_pool_max_connections() {
819 let mut config = test_config();
820 config.max_connections = 2;
821 let pool = ConnectionPool::new(config);
822
823 for i in 0..2 {
825 let key = SecretKey::generate();
826 let peer_id = PeerId::new(key.public_key());
827 let addr: SocketAddr = format!("127.0.0.1:800{}", i).parse().unwrap();
828 pool.add_connection(peer_id, addr, true).await.unwrap();
829 }
830
831 let key = SecretKey::generate();
833 let peer_id = PeerId::new(key.public_key());
834 let addr: SocketAddr = "127.0.0.1:8002".parse().unwrap();
835 let result = pool.add_connection(peer_id, addr, true).await;
836 assert!(matches!(result, Err(TransportError::TooManyConnections)));
837 }
838
839 #[tokio::test]
840 async fn test_connection_pool_stats() {
841 let config = test_config();
842 let pool = ConnectionPool::new(config);
843
844 let key = SecretKey::generate();
845 let peer_id = PeerId::new(key.public_key());
846 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
847
848 pool.add_connection(peer_id.clone(), addr, true)
849 .await
850 .unwrap();
851
852 pool.record_sent(&peer_id).await;
854 pool.record_sent(&peer_id).await;
855 pool.record_received(&peer_id).await;
856 pool.update_latency(&peer_id, Duration::from_millis(50))
857 .await;
858
859 let conn = pool.get_connection(&peer_id).await.unwrap();
860 assert_eq!(conn.messages_sent, 2);
861 assert_eq!(conn.messages_received, 1);
862 assert_eq!(conn.latency, Some(Duration::from_millis(50)));
863 }
864
865 #[test]
866 fn test_transport_hello_creation() {
867 let config = test_config();
868 let transport = Transport::new(config.clone());
869
870 let hello = transport.create_hello(Some(100), None);
871
872 assert_eq!(hello.chain_id, "moloch-test");
873 assert_eq!(hello.height, Some(100));
874 assert!(hello.version.is_compatible_with(&ProtocolVersion::CURRENT));
875 }
876
877 #[test]
878 fn test_transport_hello_validation() {
879 let config = test_config();
880 let transport = Transport::new(config);
881
882 let hello = transport.create_hello(Some(100), None);
883 let result = transport.validate_hello(&hello);
884 assert!(result.is_ok());
885 }
886
887 #[test]
888 fn test_transport_hello_wrong_chain() {
889 let config = test_config();
890 let transport = Transport::new(config);
891
892 let mut hello = transport.create_hello(Some(100), None);
893 hello.chain_id = "wrong-chain".into();
894
895 let result = transport.validate_hello(&hello);
896 assert!(matches!(result, Err(TransportError::ChainMismatch { .. })));
897 }
898
899 #[test]
900 fn test_transport_hello_ack() {
901 let config = test_config();
902 let transport = Transport::new(config);
903
904 let hello = transport.create_hello(Some(100), None);
905 let ack = transport.create_hello_ack(hello.id, Some(50), None);
906
907 assert_eq!(ack.request_id, hello.id);
908 assert!(transport.validate_hello_ack(&ack).is_ok());
909 }
910
911 #[tokio::test]
912 async fn test_transport_message_roundtrip() {
913 let config = test_config();
914 let transport = Transport::new(config);
915
916 let message = Message::Status(StatusMessage {
917 height: Some(100),
918 head_hash: None,
919 peer_count: 5,
920 syncing: false,
921 timestamp: Utc::now(),
922 });
923
924 let encoded = transport.encode_message(&message).unwrap();
925 let decoded = transport.decode_message(&encoded).unwrap();
926
927 match decoded {
928 Message::Status(s) => {
929 assert_eq!(s.height, Some(100));
930 assert_eq!(s.peer_count, 5);
931 }
932 _ => panic!("wrong message type"),
933 }
934 }
935
936 #[tokio::test]
937 async fn test_transport_async_message_io() {
938 use tokio::io::duplex;
939
940 let config = test_config();
941 let transport = Transport::new(config);
942
943 let (mut client, mut server) = duplex(1024);
944
945 let message = Message::Ping(PingMessage {
946 id: 42,
947 timestamp: Utc::now(),
948 });
949
950 transport
952 .write_message(&mut client, &message)
953 .await
954 .unwrap();
955
956 let received = transport.read_message(&mut server).await.unwrap();
958
959 match received {
960 Message::Ping(p) => assert_eq!(p.id, 42),
961 _ => panic!("wrong message type"),
962 }
963 }
964
965 #[tokio::test]
966 async fn test_get_idle_connections() {
967 let mut config = test_config();
968 config.idle_timeout = Duration::from_millis(100);
969 let pool = ConnectionPool::new(config);
970
971 let key = SecretKey::generate();
972 let peer_id = PeerId::new(key.public_key());
973 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
974
975 pool.add_connection(peer_id.clone(), addr, true)
976 .await
977 .unwrap();
978
979 let idle = pool.get_idle_connections(Duration::from_millis(100)).await;
981 assert!(idle.is_empty());
982
983 tokio::time::sleep(Duration::from_millis(150)).await;
985 let idle = pool.get_idle_connections(Duration::from_millis(100)).await;
986 assert_eq!(idle.len(), 1);
987 assert_eq!(idle[0], peer_id);
988 }
989}