1pub mod stats;
26pub mod stream;
27
28use super::resolve_socket_addrs;
29use super::{
30 ConnectionState, DiscoveredPeer, PacketTx, ReceivedPacket, Transport, TransportAddr,
31 TransportError, TransportId, TransportState, TransportType,
32};
33use crate::config::TcpConfig;
34use stats::TcpStats;
35use stream::read_fmp_packet;
36
37use futures::FutureExt;
38use socket2::TcpKeepalive;
39use std::collections::HashMap;
40use std::net::SocketAddr;
41use std::sync::Arc;
42use std::time::Duration;
43use tokio::io::AsyncWriteExt;
44use tokio::net::tcp::OwnedWriteHalf;
45use tokio::net::{TcpListener, TcpStream};
46use tokio::sync::Mutex;
47use tokio::task::JoinHandle;
48use tokio::time::Instant;
49use tracing::{debug, info, trace, warn};
50
51struct TcpConnection {
57 writer: Arc<Mutex<OwnedWriteHalf>>,
59 recv_task: JoinHandle<()>,
61 #[allow(dead_code)]
63 mtu: u16,
64 #[allow(dead_code)]
66 established_at: Instant,
67}
68
69type ConnectionPool = Arc<Mutex<HashMap<TransportAddr, TcpConnection>>>;
71
72struct ConnectingEntry {
77 task: JoinHandle<Result<(TcpStream, u16), TransportError>>,
79}
80
81type ConnectingPool = Arc<Mutex<HashMap<TransportAddr, ConnectingEntry>>>;
83
84pub struct TcpTransport {
94 transport_id: TransportId,
96 name: Option<String>,
98 config: TcpConfig,
100 state: TransportState,
102 pool: ConnectionPool,
104 connecting: ConnectingPool,
106 packet_tx: PacketTx,
108 accept_task: Option<JoinHandle<()>>,
110 local_addr: Option<SocketAddr>,
112 stats: Arc<TcpStats>,
114}
115
116impl TcpTransport {
117 pub fn new(
119 transport_id: TransportId,
120 name: Option<String>,
121 config: TcpConfig,
122 packet_tx: PacketTx,
123 ) -> Self {
124 Self {
125 transport_id,
126 name,
127 config,
128 state: TransportState::Configured,
129 pool: Arc::new(Mutex::new(HashMap::new())),
130 connecting: Arc::new(Mutex::new(HashMap::new())),
131 packet_tx,
132 accept_task: None,
133 local_addr: None,
134 stats: Arc::new(TcpStats::new()),
135 }
136 }
137
138 pub fn name(&self) -> Option<&str> {
140 self.name.as_deref()
141 }
142
143 pub fn local_addr(&self) -> Option<SocketAddr> {
145 self.local_addr
146 }
147
148 pub fn stats(&self) -> &Arc<TcpStats> {
150 &self.stats
151 }
152
153 pub async fn start_async(&mut self) -> Result<(), TransportError> {
158 if !self.state.can_start() {
159 return Err(TransportError::AlreadyStarted);
160 }
161
162 self.state = TransportState::Starting;
163
164 if let Some(ref bind_addr) = self.config.bind_addr {
166 let addr: SocketAddr = bind_addr
167 .parse()
168 .map_err(|e| TransportError::StartFailed(format!("invalid bind address: {}", e)))?;
169
170 let listener = TcpListener::bind(addr)
171 .await
172 .map_err(|e| TransportError::StartFailed(format!("bind failed: {}", e)))?;
173
174 self.local_addr = Some(
175 listener
176 .local_addr()
177 .map_err(|e| TransportError::StartFailed(format!("get local addr: {}", e)))?,
178 );
179
180 let transport_id = self.transport_id;
182 let packet_tx = self.packet_tx.clone();
183 let pool = self.pool.clone();
184 let stats = self.stats.clone();
185 let cfg = AcceptConfig {
186 mtu: self.config.mtu(),
187 max_inbound: self.config.max_inbound_connections(),
188 nodelay: self.config.nodelay(),
189 keepalive_secs: self.config.keepalive_secs(),
190 recv_buf: self.config.recv_buf_size(),
191 send_buf: self.config.send_buf_size(),
192 };
193
194 let accept_task = tokio::spawn(async move {
195 accept_loop(listener, transport_id, packet_tx, pool, cfg, stats).await;
196 });
197 self.accept_task = Some(accept_task);
198 }
199
200 self.state = TransportState::Up;
201
202 if let Some(ref name) = self.name {
203 info!(
204 name = %name,
205 local_addr = ?self.local_addr,
206 mtu = self.config.mtu(),
207 "TCP transport started"
208 );
209 } else {
210 info!(
211 local_addr = ?self.local_addr,
212 mtu = self.config.mtu(),
213 "TCP transport started"
214 );
215 }
216
217 Ok(())
218 }
219
220 pub async fn stop_async(&mut self) -> Result<(), TransportError> {
222 if !self.state.is_operational() {
223 return Err(TransportError::NotStarted);
224 }
225
226 if let Some(task) = self.accept_task.take() {
228 task.abort();
229 let _ = task.await;
230 }
231
232 let mut connecting = self.connecting.lock().await;
234 for (addr, entry) in connecting.drain() {
235 entry.task.abort();
236 debug!(
237 transport_id = %self.transport_id,
238 remote_addr = %addr,
239 "TCP connect aborted (transport stopping)"
240 );
241 }
242 drop(connecting);
243
244 let mut pool = self.pool.lock().await;
246 for (addr, conn) in pool.drain() {
247 conn.recv_task.abort();
248 let _ = conn.recv_task.await;
249 debug!(
250 transport_id = %self.transport_id,
251 remote_addr = %addr,
252 "TCP connection closed (transport stopping)"
253 );
254 }
255 drop(pool);
256
257 self.local_addr = None;
258 self.state = TransportState::Down;
259
260 info!(
261 transport_id = %self.transport_id,
262 "TCP transport stopped"
263 );
264
265 Ok(())
266 }
267
268 pub async fn send_async(
274 &self,
275 addr: &TransportAddr,
276 data: &[u8],
277 ) -> Result<usize, TransportError> {
278 if !self.state.is_operational() {
279 return Err(TransportError::NotStarted);
280 }
281
282 let mtu = self.config.mtu() as usize;
287 if data.len() > mtu {
288 self.stats.record_mtu_exceeded();
289 return Err(TransportError::MtuExceeded {
290 packet_size: data.len(),
291 mtu: self.config.mtu(),
292 });
293 }
294
295 let writer = {
297 let pool = self.pool.lock().await;
298 pool.get(addr).map(|c| c.writer.clone())
299 };
300
301 let writer = match writer {
302 Some(w) => w,
303 None => {
304 self.connect(addr).await?
306 }
307 };
308
309 let mut w = writer.lock().await;
311 match w.write_all(data).await {
312 Ok(()) => {
313 self.stats.record_send(data.len());
314 trace!(
315 transport_id = %self.transport_id,
316 remote_addr = %addr,
317 bytes = data.len(),
318 "TCP packet sent"
319 );
320 Ok(data.len())
321 }
322 Err(e) => {
323 self.stats.record_send_error();
324 drop(w);
325 let mut pool = self.pool.lock().await;
327 if let Some(conn) = pool.remove(addr) {
328 conn.recv_task.abort();
329 }
330 Err(TransportError::SendFailed(format!("{}", e)))
331 }
332 }
333 }
334
335 async fn connect(
340 &self,
341 addr: &TransportAddr,
342 ) -> Result<Arc<Mutex<OwnedWriteHalf>>, TransportError> {
343 let socket_addrs = resolve_socket_addrs(addr).await?;
344 let timeout_ms = self.config.connect_timeout_ms();
345
346 let stream = match connect_to_any_addr(&socket_addrs, timeout_ms).await {
347 Ok(stream) => stream,
348 Err(error @ TransportError::Timeout) => {
349 self.stats.record_connect_timeout();
350 return Err(error);
351 }
352 Err(error @ TransportError::ConnectionRefused) => {
353 self.stats.record_connect_refused();
354 return Err(error);
355 }
356 Err(error) => return Err(error),
357 };
358
359 let std_stream = stream
361 .into_std()
362 .map_err(|e| TransportError::StartFailed(format!("into_std: {}", e)))?;
363 configure_socket(&std_stream, &self.config)?;
364
365 let mss_mtu = read_mss_mtu(&std_stream, self.config.mtu());
367
368 let stream = TcpStream::from_std(std_stream)
370 .map_err(|e| TransportError::StartFailed(format!("from_std: {}", e)))?;
371
372 let (read_half, write_half) = stream.into_split();
374 let writer = Arc::new(Mutex::new(write_half));
375
376 let transport_id = self.transport_id;
377 let packet_tx = self.packet_tx.clone();
378 let pool = self.pool.clone();
379 let recv_stats = self.stats.clone();
380 let remote_addr = addr.clone();
381 let mtu = mss_mtu;
382
383 let recv_task = tokio::spawn(async move {
384 tcp_receive_loop(
385 read_half,
386 transport_id,
387 remote_addr.clone(),
388 packet_tx,
389 pool,
390 mtu,
391 recv_stats,
392 )
393 .await;
394 });
395
396 let conn = TcpConnection {
397 writer: writer.clone(),
398 recv_task,
399 mtu: mss_mtu,
400 established_at: Instant::now(),
401 };
402
403 let mut pool = self.pool.lock().await;
404 pool.insert(addr.clone(), conn);
405
406 self.stats.record_connection_established();
407
408 debug!(
409 transport_id = %self.transport_id,
410 remote_addr = %addr,
411 mtu = mss_mtu,
412 "TCP connection established (connect-on-send)"
413 );
414
415 Ok(writer)
416 }
417
418 pub async fn close_connection_async(&self, addr: &TransportAddr) {
423 let mut pool = self.pool.lock().await;
424 if let Some(conn) = pool.remove(addr) {
425 conn.recv_task.abort();
426 debug!(
427 transport_id = %self.transport_id,
428 remote_addr = %addr,
429 "TCP connection closed (close_connection)"
430 );
431 }
432 }
433
434 pub async fn connect_async(&self, addr: &TransportAddr) -> Result<(), TransportError> {
442 if !self.state.is_operational() {
443 return Err(TransportError::NotStarted);
444 }
445
446 {
448 let pool = self.pool.lock().await;
449 if pool.contains_key(addr) {
450 return Ok(());
451 }
452 }
453
454 {
456 let connecting = self.connecting.lock().await;
457 if connecting.contains_key(addr) {
458 return Ok(());
459 }
460 }
461
462 let socket_addrs = resolve_socket_addrs(addr).await?;
463 let timeout_ms = self.config.connect_timeout_ms();
464 let config = self.config.clone();
465 let transport_id = self.transport_id;
466 let remote_addr = addr.clone();
467
468 debug!(
469 transport_id = %transport_id,
470 remote_addr = %remote_addr,
471 timeout_ms,
472 "Initiating background TCP connect"
473 );
474
475 let task = tokio::spawn(async move {
476 let stream = match connect_to_any_addr(&socket_addrs, timeout_ms).await {
477 Ok(stream) => stream,
478 Err(error @ TransportError::ConnectionRefused) => {
479 debug!(
480 transport_id = %transport_id,
481 remote_addr = %remote_addr,
482 error = %error,
483 "Background TCP connect refused"
484 );
485 return Err(error);
486 }
487 Err(error @ TransportError::Timeout) => {
488 debug!(
489 transport_id = %transport_id,
490 remote_addr = %remote_addr,
491 "Background TCP connect timed out"
492 );
493 return Err(error);
494 }
495 Err(error) => return Err(error),
496 };
497
498 let std_stream = stream
500 .into_std()
501 .map_err(|e| TransportError::StartFailed(format!("into_std: {}", e)))?;
502 configure_socket(&std_stream, &config)?;
503
504 let mss_mtu = read_mss_mtu(&std_stream, config.mtu());
506
507 let stream = TcpStream::from_std(std_stream)
509 .map_err(|e| TransportError::StartFailed(format!("from_std: {}", e)))?;
510
511 Ok((stream, mss_mtu))
512 });
513
514 let mut connecting = self.connecting.lock().await;
515 connecting.insert(addr.clone(), ConnectingEntry { task });
516
517 Ok(())
518 }
519
520 pub fn connection_state_sync(&self, addr: &TransportAddr) -> ConnectionState {
529 if let Ok(pool) = self.pool.try_lock() {
531 if pool.contains_key(addr) {
532 return ConnectionState::Connected;
533 }
534 } else {
535 return ConnectionState::Connecting; }
537
538 let mut connecting = match self.connecting.try_lock() {
540 Ok(c) => c,
541 Err(_) => return ConnectionState::Connecting,
542 };
543
544 let entry = match connecting.get_mut(addr) {
545 Some(e) => e,
546 None => return ConnectionState::None,
547 };
548
549 if !entry.task.is_finished() {
551 return ConnectionState::Connecting;
552 }
553
554 let addr_clone = addr.clone();
558 let task = connecting.remove(&addr_clone).unwrap().task;
559
560 match task.now_or_never() {
563 Some(Ok(Ok((stream, mss_mtu)))) => {
564 self.promote_connection(addr, stream, mss_mtu);
566 ConnectionState::Connected
567 }
568 Some(Ok(Err(e))) => ConnectionState::Failed(format!("{}", e)),
569 Some(Err(e)) => {
570 ConnectionState::Failed(format!("task failed: {}", e))
572 }
573 None => {
574 ConnectionState::Connecting
576 }
577 }
578 }
579
580 fn promote_connection(&self, addr: &TransportAddr, stream: TcpStream, mss_mtu: u16) {
585 let (read_half, write_half) = stream.into_split();
586 let writer = Arc::new(Mutex::new(write_half));
587
588 let transport_id = self.transport_id;
589 let packet_tx = self.packet_tx.clone();
590 let pool = self.pool.clone();
591 let recv_stats = self.stats.clone();
592 let remote_addr = addr.clone();
593
594 let recv_task = tokio::spawn(async move {
595 tcp_receive_loop(
596 read_half,
597 transport_id,
598 remote_addr.clone(),
599 packet_tx,
600 pool,
601 mss_mtu,
602 recv_stats,
603 )
604 .await;
605 });
606
607 let conn = TcpConnection {
608 writer,
609 recv_task,
610 mtu: mss_mtu,
611 established_at: Instant::now(),
612 };
613
614 if let Ok(mut pool) = self.pool.try_lock() {
617 pool.insert(addr.clone(), conn);
618 self.stats.record_connection_established();
619 debug!(
620 transport_id = %self.transport_id,
621 remote_addr = %addr,
622 mtu = mss_mtu,
623 "TCP connection established (background connect)"
624 );
625 } else {
626 conn.recv_task.abort();
628 warn!(
629 transport_id = %self.transport_id,
630 remote_addr = %addr,
631 "Failed to promote connection (pool locked)"
632 );
633 }
634 }
635}
636
637impl Transport for TcpTransport {
638 fn transport_id(&self) -> TransportId {
639 self.transport_id
640 }
641
642 fn transport_type(&self) -> &TransportType {
643 &TransportType::TCP
644 }
645
646 fn state(&self) -> TransportState {
647 self.state
648 }
649
650 fn mtu(&self) -> u16 {
651 self.config.mtu()
652 }
653
654 fn link_mtu(&self, _addr: &TransportAddr) -> u16 {
655 self.config.mtu()
659 }
660
661 fn start(&mut self) -> Result<(), TransportError> {
662 Err(TransportError::NotSupported(
663 "use start_async() for TCP transport".into(),
664 ))
665 }
666
667 fn stop(&mut self) -> Result<(), TransportError> {
668 Err(TransportError::NotSupported(
669 "use stop_async() for TCP transport".into(),
670 ))
671 }
672
673 fn send(&self, _addr: &TransportAddr, _data: &[u8]) -> Result<(), TransportError> {
674 Err(TransportError::NotSupported(
675 "use send_async() for TCP transport".into(),
676 ))
677 }
678
679 fn discover(&self) -> Result<Vec<DiscoveredPeer>, TransportError> {
680 Ok(Vec::new())
682 }
683
684 fn accept_connections(&self) -> bool {
685 self.config.bind_addr.is_some()
687 }
688}
689
690struct AcceptConfig {
696 mtu: u16,
697 max_inbound: usize,
698 nodelay: bool,
699 keepalive_secs: u64,
700 recv_buf: usize,
701 send_buf: usize,
702}
703
704#[allow(clippy::too_many_arguments)]
706async fn accept_loop(
707 listener: TcpListener,
708 transport_id: TransportId,
709 packet_tx: PacketTx,
710 pool: ConnectionPool,
711 cfg: AcceptConfig,
712 stats: Arc<TcpStats>,
713) {
714 let AcceptConfig {
715 mtu,
716 max_inbound,
717 nodelay,
718 keepalive_secs,
719 recv_buf,
720 send_buf,
721 } = cfg;
722 debug!(transport_id = %transport_id, "TCP accept loop starting");
723
724 loop {
725 match listener.accept().await {
726 Ok((stream, peer_addr)) => {
727 {
729 let pool_guard = pool.lock().await;
730 if pool_guard.len() >= max_inbound {
731 stats.record_connection_rejected();
732 warn!(
733 transport_id = %transport_id,
734 peer_addr = %peer_addr,
735 max = max_inbound,
736 "Rejecting inbound TCP connection (max_inbound_connections reached)"
737 );
738 continue;
739 }
740 }
741
742 let std_stream = match stream.into_std() {
744 Ok(s) => s,
745 Err(e) => {
746 warn!(
747 transport_id = %transport_id,
748 error = %e,
749 "Failed to convert accepted stream to std"
750 );
751 continue;
752 }
753 };
754
755 if let Err(e) = configure_accepted_socket(
756 &std_stream,
757 nodelay,
758 keepalive_secs,
759 recv_buf,
760 send_buf,
761 ) {
762 warn!(
763 transport_id = %transport_id,
764 peer_addr = %peer_addr,
765 error = %e,
766 "Failed to configure accepted socket"
767 );
768 continue;
769 }
770
771 let conn_mtu = read_mss_mtu(&std_stream, mtu);
773
774 let stream = match TcpStream::from_std(std_stream) {
775 Ok(s) => s,
776 Err(e) => {
777 warn!(
778 transport_id = %transport_id,
779 error = %e,
780 "Failed to convert accepted stream back to tokio"
781 );
782 continue;
783 }
784 };
785
786 let remote_addr = TransportAddr::from_string(&peer_addr.to_string());
787
788 let (read_half, write_half) = stream.into_split();
790 let writer = Arc::new(Mutex::new(write_half));
791
792 let recv_pool = pool.clone();
793 let recv_packet_tx = packet_tx.clone();
794 let recv_stats = stats.clone();
795 let recv_addr = remote_addr.clone();
796
797 let recv_task = tokio::spawn(async move {
798 tcp_receive_loop(
799 read_half,
800 transport_id,
801 recv_addr,
802 recv_packet_tx,
803 recv_pool,
804 conn_mtu,
805 recv_stats,
806 )
807 .await;
808 });
809
810 let conn = TcpConnection {
811 writer,
812 recv_task,
813 mtu: conn_mtu,
814 established_at: Instant::now(),
815 };
816
817 let mut pool_guard = pool.lock().await;
818 pool_guard.insert(remote_addr.clone(), conn);
819
820 stats.record_connection_accepted();
821
822 debug!(
823 transport_id = %transport_id,
824 remote_addr = %remote_addr,
825 mtu = conn_mtu,
826 "Accepted inbound TCP connection"
827 );
828 }
829 Err(e) => {
830 warn!(
831 transport_id = %transport_id,
832 error = %e,
833 "TCP accept error"
834 );
835 }
836 }
837 }
838}
839
840async fn tcp_receive_loop(
850 mut reader: tokio::net::tcp::OwnedReadHalf,
851 transport_id: TransportId,
852 remote_addr: TransportAddr,
853 packet_tx: PacketTx,
854 pool: ConnectionPool,
855 mtu: u16,
856 stats: Arc<TcpStats>,
857) {
858 debug!(
859 transport_id = %transport_id,
860 remote_addr = %remote_addr,
861 "TCP receive loop starting"
862 );
863
864 loop {
865 match read_fmp_packet(&mut reader, mtu).await {
866 Ok(data) => {
867 stats.record_recv(data.len());
868
869 trace!(
870 transport_id = %transport_id,
871 remote_addr = %remote_addr,
872 bytes = data.len(),
873 "TCP packet received"
874 );
875
876 let packet = ReceivedPacket::new(transport_id, remote_addr.clone(), data);
877
878 if packet_tx.send(packet).is_err() {
879 debug!(
880 transport_id = %transport_id,
881 "Packet channel closed, stopping TCP receive loop"
882 );
883 break;
884 }
885 }
886 Err(e) => {
887 stats.record_recv_error();
888 debug!(
890 transport_id = %transport_id,
891 remote_addr = %remote_addr,
892 error = %e,
893 "TCP receive error, removing connection"
894 );
895 break;
896 }
897 }
898 }
899
900 let mut pool_guard = pool.lock().await;
902 pool_guard.remove(&remote_addr);
903
904 debug!(
905 transport_id = %transport_id,
906 remote_addr = %remote_addr,
907 "TCP receive loop stopped"
908 );
909}
910
911async fn connect_to_any_addr(
916 socket_addrs: &[SocketAddr],
917 timeout_ms: u64,
918) -> Result<TcpStream, TransportError> {
919 let mut last_error = None;
920 for socket_addr in socket_addrs {
921 match tokio::time::timeout(
922 Duration::from_millis(timeout_ms),
923 TcpStream::connect(socket_addr),
924 )
925 .await
926 {
927 Ok(Ok(stream)) => return Ok(stream),
928 Ok(Err(error)) => {
929 trace!(
930 remote_addr = %socket_addr,
931 error = %error,
932 "TCP connect candidate failed"
933 );
934 last_error = Some(TransportError::ConnectionRefused);
935 }
936 Err(_) => {
937 trace!(
938 remote_addr = %socket_addr,
939 timeout_ms,
940 "TCP connect candidate timed out"
941 );
942 last_error = Some(TransportError::Timeout);
943 }
944 }
945 }
946 Err(last_error
947 .unwrap_or_else(|| TransportError::InvalidAddress("no TCP addresses to dial".to_string())))
948}
949
950fn configure_socket(
952 stream: &std::net::TcpStream,
953 config: &TcpConfig,
954) -> Result<(), TransportError> {
955 let socket = socket2::SockRef::from(stream)
956 .try_clone()
957 .map_err(|e| TransportError::StartFailed(format!("clone socket: {}", e)))?;
958
959 socket
961 .set_tcp_nodelay(config.nodelay())
962 .map_err(|e| TransportError::StartFailed(format!("set nodelay: {}", e)))?;
963
964 let keepalive_secs = config.keepalive_secs();
966 if keepalive_secs > 0 {
967 let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(keepalive_secs));
968 socket
969 .set_tcp_keepalive(&keepalive)
970 .map_err(|e| TransportError::StartFailed(format!("set keepalive: {}", e)))?;
971 }
972
973 socket
975 .set_recv_buffer_size(config.recv_buf_size())
976 .map_err(|e| TransportError::StartFailed(format!("set recv buffer: {}", e)))?;
977 socket
978 .set_send_buffer_size(config.send_buf_size())
979 .map_err(|e| TransportError::StartFailed(format!("set send buffer: {}", e)))?;
980
981 Ok(())
982}
983
984fn configure_accepted_socket(
986 stream: &std::net::TcpStream,
987 nodelay: bool,
988 keepalive_secs: u64,
989 recv_buf: usize,
990 send_buf: usize,
991) -> Result<(), TransportError> {
992 let socket = socket2::SockRef::from(stream)
993 .try_clone()
994 .map_err(|e| TransportError::StartFailed(format!("clone socket: {}", e)))?;
995
996 socket
997 .set_tcp_nodelay(nodelay)
998 .map_err(|e| TransportError::StartFailed(format!("set nodelay: {}", e)))?;
999
1000 if keepalive_secs > 0 {
1001 let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(keepalive_secs));
1002 socket
1003 .set_tcp_keepalive(&keepalive)
1004 .map_err(|e| TransportError::StartFailed(format!("set keepalive: {}", e)))?;
1005 }
1006
1007 socket
1008 .set_recv_buffer_size(recv_buf)
1009 .map_err(|e| TransportError::StartFailed(format!("set recv buffer: {}", e)))?;
1010 socket
1011 .set_send_buffer_size(send_buf)
1012 .map_err(|e| TransportError::StartFailed(format!("set send buffer: {}", e)))?;
1013
1014 Ok(())
1015}
1016
1017fn read_mss_mtu(stream: &std::net::TcpStream, default_mtu: u16) -> u16 {
1019 #[cfg(target_os = "linux")]
1021 {
1022 use std::os::unix::io::AsRawFd;
1023 unsafe {
1024 let mut mss: libc::c_int = 0;
1025 let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
1026 let fd = stream.as_raw_fd();
1027 let ret = libc::getsockopt(
1028 fd,
1029 libc::IPPROTO_TCP,
1030 libc::TCP_MAXSEG,
1031 &mut mss as *mut libc::c_int as *mut libc::c_void,
1032 &mut len,
1033 );
1034 if ret == 0 && mss > 0 {
1035 let mss_mtu = (mss as u32).min(u16::MAX as u32) as u16;
1036 return mss_mtu.min(default_mtu);
1038 }
1039 }
1040 }
1041
1042 #[cfg(not(target_os = "linux"))]
1043 let _ = stream;
1044
1045 default_mtu
1047}
1048
1049#[cfg(test)]
1054mod tests {
1055 use super::*;
1056 use crate::transport::packet_channel;
1057 use tokio::time::{Duration, timeout};
1058
1059 fn make_config() -> TcpConfig {
1060 TcpConfig {
1061 bind_addr: Some("127.0.0.1:0".to_string()),
1062 mtu: Some(1400),
1063 ..Default::default()
1064 }
1065 }
1066
1067 fn make_outbound_config() -> TcpConfig {
1068 TcpConfig {
1069 bind_addr: None,
1070 mtu: Some(1400),
1071 ..Default::default()
1072 }
1073 }
1074
1075 async fn unused_loopback_addr(except_port: u16) -> SocketAddr {
1076 for port in 49152..65535 {
1077 if port == except_port {
1078 continue;
1079 }
1080 let addr = SocketAddr::from(([127, 0, 0, 1], port));
1081 if TcpStream::connect(addr).await.is_err() {
1082 return addr;
1083 }
1084 }
1085 panic!("no unused loopback port found");
1086 }
1087
1088 #[tokio::test]
1089 async fn test_start_stop() {
1090 let (tx, _rx) = packet_channel(100);
1091 let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1092
1093 assert_eq!(transport.state(), TransportState::Configured);
1094
1095 transport.start_async().await.unwrap();
1096 assert_eq!(transport.state(), TransportState::Up);
1097 assert!(transport.local_addr().is_some());
1098
1099 transport.stop_async().await.unwrap();
1100 assert_eq!(transport.state(), TransportState::Down);
1101 }
1102
1103 #[tokio::test]
1104 async fn connect_to_any_addr_tries_later_candidates() {
1105 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1106 let good_addr = listener.local_addr().unwrap();
1107 let bad_addr = unused_loopback_addr(good_addr.port()).await;
1108 let accept = tokio::spawn(async move { listener.accept().await });
1109
1110 let stream = connect_to_any_addr(&[bad_addr, good_addr], 1_000)
1111 .await
1112 .expect("second TCP candidate should connect");
1113 drop(stream);
1114
1115 timeout(Duration::from_secs(1), accept)
1116 .await
1117 .expect("listener should accept")
1118 .expect("accept task should not panic")
1119 .expect("accept should succeed");
1120 }
1121
1122 #[tokio::test]
1123 async fn test_start_outbound_only() {
1124 let (tx, _rx) = packet_channel(100);
1125 let mut transport =
1126 TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx);
1127
1128 transport.start_async().await.unwrap();
1129 assert_eq!(transport.state(), TransportState::Up);
1130 assert!(transport.local_addr().is_none());
1132
1133 transport.stop_async().await.unwrap();
1134 }
1135
1136 #[tokio::test]
1137 async fn test_double_start_fails() {
1138 let (tx, _rx) = packet_channel(100);
1139 let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1140
1141 transport.start_async().await.unwrap();
1142
1143 let result = transport.start_async().await;
1144 assert!(matches!(result, Err(TransportError::AlreadyStarted)));
1145
1146 transport.stop_async().await.unwrap();
1147 }
1148
1149 #[tokio::test]
1150 async fn test_stop_not_started_fails() {
1151 let (tx, _rx) = packet_channel(100);
1152 let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1153
1154 let result = transport.stop_async().await;
1155 assert!(matches!(result, Err(TransportError::NotStarted)));
1156 }
1157
1158 #[tokio::test]
1159 async fn test_send_not_started() {
1160 let (tx, _rx) = packet_channel(100);
1161 let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1162
1163 let result = transport
1164 .send_async(&TransportAddr::from_string("127.0.0.1:9999"), b"test")
1165 .await;
1166
1167 assert!(matches!(result, Err(TransportError::NotStarted)));
1168 }
1169
1170 #[tokio::test]
1171 async fn test_send_recv() {
1172 let (tx1, _rx1) = packet_channel(100);
1173 let (tx2, mut rx2) = packet_channel(100);
1174
1175 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
1176 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1177
1178 t1.start_async().await.unwrap();
1179 t2.start_async().await.unwrap();
1180
1181 let addr2 = t2.local_addr().unwrap();
1182
1183 let payload_len = 4u16;
1186 let total = 4 + 12 + payload_len as usize + 16;
1187 let mut frame = vec![0u8; total];
1188 frame[0] = 0x00; frame[1] = 0x00; frame[2..4].copy_from_slice(&payload_len.to_le_bytes());
1191 for (i, byte) in frame[4..total].iter_mut().enumerate() {
1193 *byte = ((4 + i) & 0xFF) as u8;
1194 }
1195
1196 let bytes_sent = t1
1197 .send_async(&TransportAddr::from_string(&addr2.to_string()), &frame)
1198 .await
1199 .unwrap();
1200 assert_eq!(bytes_sent, frame.len());
1201
1202 let packet = timeout(Duration::from_secs(2), rx2.recv())
1204 .await
1205 .expect("timeout")
1206 .expect("channel closed");
1207
1208 assert_eq!(packet.data, frame);
1209
1210 t1.stop_async().await.unwrap();
1211 t2.stop_async().await.unwrap();
1212 }
1213
1214 #[tokio::test]
1215 async fn test_bidirectional() {
1216 let (tx1, mut rx1) = packet_channel(100);
1217 let (tx2, mut rx2) = packet_channel(100);
1218
1219 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_config(), tx1);
1220 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1221
1222 t1.start_async().await.unwrap();
1223 t2.start_async().await.unwrap();
1224
1225 let addr1 = t1.local_addr().unwrap();
1226 let addr2 = t2.local_addr().unwrap();
1227
1228 let mut msg1_frame = vec![0xAA; 114];
1230 msg1_frame[0] = 0x01; msg1_frame[1] = 0x00;
1232 msg1_frame[2..4].copy_from_slice(&110u16.to_le_bytes()); t1.send_async(&TransportAddr::from_string(&addr2.to_string()), &msg1_frame)
1236 .await
1237 .unwrap();
1238
1239 let packet = timeout(Duration::from_secs(2), rx2.recv())
1240 .await
1241 .expect("timeout")
1242 .expect("channel closed");
1243 assert_eq!(packet.data, msg1_frame);
1244
1245 let mut msg2_frame = vec![0xBB; 69];
1247 msg2_frame[0] = 0x02; msg2_frame[1] = 0x00;
1249 msg2_frame[2..4].copy_from_slice(&65u16.to_le_bytes()); t2.send_async(&TransportAddr::from_string(&addr1.to_string()), &msg2_frame)
1253 .await
1254 .unwrap();
1255
1256 let packet = timeout(Duration::from_secs(2), rx1.recv())
1257 .await
1258 .expect("timeout")
1259 .expect("channel closed");
1260 assert_eq!(packet.data, msg2_frame);
1261
1262 t1.stop_async().await.unwrap();
1263 t2.stop_async().await.unwrap();
1264 }
1265
1266 #[tokio::test]
1267 async fn test_connect_timeout() {
1268 let (tx, _rx) = packet_channel(100);
1269 let config = TcpConfig {
1270 bind_addr: None,
1271 connect_timeout_ms: Some(100), ..Default::default()
1273 };
1274 let mut transport = TcpTransport::new(TransportId::new(1), None, config, tx);
1275 transport.start_async().await.unwrap();
1276
1277 let result = transport
1279 .send_async(
1280 &TransportAddr::from_string("192.0.2.1:2121"),
1281 b"\x00\x00\x04\x00test1234567890123456789012345678",
1282 )
1283 .await;
1284
1285 assert!(result.is_err());
1286
1287 transport.stop_async().await.unwrap();
1288 }
1289
1290 #[tokio::test]
1291 async fn test_close_connection() {
1292 let (tx1, _rx1) = packet_channel(100);
1293 let (tx2, _rx2) = packet_channel(100);
1294
1295 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
1296 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1297
1298 t1.start_async().await.unwrap();
1299 t2.start_async().await.unwrap();
1300
1301 let addr2 = t2.local_addr().unwrap();
1302 let remote = TransportAddr::from_string(&addr2.to_string());
1303
1304 let mut msg1 = vec![0xAA; 114];
1306 msg1[0] = 0x01;
1307 msg1[1] = 0x00;
1308 msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
1309
1310 t1.send_async(&remote, &msg1).await.unwrap();
1311
1312 {
1314 let pool = t1.pool.lock().await;
1315 assert!(pool.contains_key(&remote));
1316 }
1317
1318 t1.close_connection_async(&remote).await;
1320
1321 {
1323 let pool = t1.pool.lock().await;
1324 assert!(!pool.contains_key(&remote));
1325 }
1326
1327 t1.stop_async().await.unwrap();
1328 t2.stop_async().await.unwrap();
1329 }
1330
1331 #[tokio::test]
1332 async fn test_discover_returns_empty() {
1333 let (tx, _rx) = packet_channel(100);
1334 let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1335
1336 let peers = transport.discover().unwrap();
1337 assert!(peers.is_empty());
1338 }
1339
1340 #[test]
1341 fn test_transport_type() {
1342 let (tx, _rx) = packet_channel(100);
1343 let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1344
1345 assert_eq!(transport.transport_type().name, "tcp");
1346 assert!(transport.transport_type().connection_oriented);
1347 assert!(transport.transport_type().reliable);
1348 }
1349
1350 #[test]
1351 fn test_sync_methods_return_not_supported() {
1352 let (tx, _rx) = packet_channel(100);
1353 let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1354
1355 assert!(matches!(
1356 transport.start(),
1357 Err(TransportError::NotSupported(_))
1358 ));
1359 assert!(matches!(
1360 transport.stop(),
1361 Err(TransportError::NotSupported(_))
1362 ));
1363 assert!(matches!(
1364 transport.send(&TransportAddr::from_string("test"), b"data"),
1365 Err(TransportError::NotSupported(_))
1366 ));
1367 }
1368
1369 #[test]
1370 fn test_accept_connections_with_bind() {
1371 let (tx, _rx) = packet_channel(100);
1372 let config = TcpConfig {
1373 bind_addr: Some("0.0.0.0:0".to_string()),
1374 ..Default::default()
1375 };
1376 let transport = TcpTransport::new(TransportId::new(1), None, config, tx);
1377 assert!(transport.accept_connections());
1378 }
1379
1380 #[test]
1381 fn test_accept_connections_without_bind() {
1382 let (tx, _rx) = packet_channel(100);
1383 let config = TcpConfig {
1384 bind_addr: None,
1385 ..Default::default()
1386 };
1387 let transport = TcpTransport::new(TransportId::new(1), None, config, tx);
1388 assert!(!transport.accept_connections());
1389 }
1390
1391 #[tokio::test]
1392 async fn test_connection_drop_and_reconnect() {
1393 let (tx1, _rx1) = packet_channel(100);
1394 let (tx2, mut rx2) = packet_channel(100);
1395
1396 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
1397 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1398
1399 t1.start_async().await.unwrap();
1400 t2.start_async().await.unwrap();
1401
1402 let addr2 = t2.local_addr().unwrap();
1403 let remote = TransportAddr::from_string(&addr2.to_string());
1404
1405 let mut msg1 = vec![0xAA; 114];
1407 msg1[0] = 0x01;
1408 msg1[1] = 0x00;
1409 msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
1410
1411 t1.send_async(&remote, &msg1).await.unwrap();
1413 let _ = timeout(Duration::from_secs(1), rx2.recv()).await;
1414
1415 t1.close_connection_async(&remote).await;
1417
1418 t1.send_async(&remote, &msg1).await.unwrap();
1420
1421 let packet = timeout(Duration::from_secs(2), rx2.recv())
1422 .await
1423 .expect("timeout")
1424 .expect("channel closed");
1425 assert_eq!(packet.data, msg1);
1426
1427 t1.stop_async().await.unwrap();
1428 t2.stop_async().await.unwrap();
1429 }
1430
1431 #[tokio::test]
1432 async fn test_connect_async_success() {
1433 let (tx1, mut rx1) = packet_channel(100);
1434 let (tx2, _rx2) = packet_channel(100);
1435
1436 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
1437 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1438
1439 t1.start_async().await.unwrap();
1440 t2.start_async().await.unwrap();
1441
1442 let addr2 = t2.local_addr().unwrap();
1443 let remote = TransportAddr::from_string(&addr2.to_string());
1444
1445 assert_eq!(t1.connection_state_sync(&remote), ConnectionState::None);
1447
1448 t1.connect_async(&remote).await.unwrap();
1450
1451 tokio::time::sleep(Duration::from_millis(200)).await;
1453
1454 let state = t1.connection_state_sync(&remote);
1456 assert_eq!(state, ConnectionState::Connected);
1457
1458 let mut msg1 = vec![0xAA; 114];
1460 msg1[0] = 0x01;
1461 msg1[1] = 0x00;
1462 msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
1463
1464 t1.send_async(&remote, &msg1).await.unwrap();
1465
1466 let packet = timeout(Duration::from_secs(2), rx1.recv()).await;
1467 drop(packet);
1470
1471 t1.stop_async().await.unwrap();
1472 t2.stop_async().await.unwrap();
1473 }
1474
1475 #[tokio::test]
1476 async fn test_connect_async_timeout() {
1477 let (tx, _rx) = packet_channel(100);
1478 let config = TcpConfig {
1479 bind_addr: None,
1480 connect_timeout_ms: Some(100), ..Default::default()
1482 };
1483 let mut transport = TcpTransport::new(TransportId::new(1), None, config, tx);
1484 transport.start_async().await.unwrap();
1485
1486 let remote = TransportAddr::from_string("192.0.2.1:2121");
1487 transport.connect_async(&remote).await.unwrap();
1488
1489 tokio::time::sleep(Duration::from_millis(500)).await;
1491
1492 let state = transport.connection_state_sync(&remote);
1493 assert!(matches!(state, ConnectionState::Failed(_)));
1494
1495 transport.stop_async().await.unwrap();
1496 }
1497
1498 #[tokio::test]
1499 async fn test_connect_async_not_started() {
1500 let (tx, _rx) = packet_channel(100);
1501 let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1502
1503 let result = transport
1504 .connect_async(&TransportAddr::from_string("127.0.0.1:9999"))
1505 .await;
1506
1507 assert!(matches!(result, Err(TransportError::NotStarted)));
1508 }
1509
1510 #[tokio::test]
1511 async fn test_connect_async_already_connected() {
1512 let (tx1, _rx1) = packet_channel(100);
1513 let (tx2, _rx2) = packet_channel(100);
1514
1515 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
1516 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1517
1518 t1.start_async().await.unwrap();
1519 t2.start_async().await.unwrap();
1520
1521 let addr2 = t2.local_addr().unwrap();
1522 let remote = TransportAddr::from_string(&addr2.to_string());
1523
1524 t1.connect_async(&remote).await.unwrap();
1526 tokio::time::sleep(Duration::from_millis(200)).await;
1527 assert_eq!(
1528 t1.connection_state_sync(&remote),
1529 ConnectionState::Connected
1530 );
1531
1532 t1.connect_async(&remote).await.unwrap();
1534
1535 t1.stop_async().await.unwrap();
1536 t2.stop_async().await.unwrap();
1537 }
1538
1539 #[tokio::test]
1540 async fn test_connect_async_then_send_recv() {
1541 let (tx1, _rx1) = packet_channel(100);
1542 let (tx2, mut rx2) = packet_channel(100);
1543
1544 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
1545 let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
1546
1547 t1.start_async().await.unwrap();
1548 t2.start_async().await.unwrap();
1549
1550 let addr2 = t2.local_addr().unwrap();
1551 let remote = TransportAddr::from_string(&addr2.to_string());
1552
1553 t1.connect_async(&remote).await.unwrap();
1555 tokio::time::sleep(Duration::from_millis(200)).await;
1556 assert_eq!(
1557 t1.connection_state_sync(&remote),
1558 ConnectionState::Connected
1559 );
1560
1561 let mut msg1 = vec![0xAA; 114];
1563 msg1[0] = 0x01;
1564 msg1[1] = 0x00;
1565 msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
1566
1567 t1.send_async(&remote, &msg1).await.unwrap();
1569
1570 let packet = timeout(Duration::from_secs(2), rx2.recv())
1571 .await
1572 .expect("timeout")
1573 .expect("channel closed");
1574 assert_eq!(packet.data, msg1);
1575
1576 t1.stop_async().await.unwrap();
1577 t2.stop_async().await.unwrap();
1578 }
1579
1580 #[test]
1581 fn test_connection_state_none_for_unknown() {
1582 let (tx, _rx) = packet_channel(100);
1583 let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
1584
1585 let state = transport.connection_state_sync(&TransportAddr::from_string("unknown:1234"));
1586 assert_eq!(state, ConnectionState::None);
1587 }
1588
1589 #[tokio::test]
1590 async fn test_connect_ip_string() {
1591 let (tx1, _rx1) = packet_channel(100);
1592 let (tx2, mut rx2) = packet_channel(100);
1593
1594 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_config(), tx1);
1595 let mut t2 = TcpTransport::new(
1596 TransportId::new(2),
1597 None,
1598 TcpConfig {
1599 bind_addr: Some("127.0.0.1:0".to_string()),
1600 ..Default::default()
1601 },
1602 tx2,
1603 );
1604
1605 t1.start_async().await.unwrap();
1606 t2.start_async().await.unwrap();
1607
1608 let port2 = t2.local_addr().unwrap().port();
1609
1610 let addr = TransportAddr::from_string(&format!("127.0.0.1:{}", port2));
1612 let mut frame = vec![0xAA; 114];
1613 frame[0] = 0x01; frame[1] = 0x00; frame[2..4].copy_from_slice(&110u16.to_le_bytes()); t1.send_async(&addr, &frame).await.unwrap();
1617
1618 let packet = tokio::time::timeout(Duration::from_secs(5), rx2.recv())
1620 .await
1621 .expect("timeout")
1622 .expect("channel closed");
1623
1624 assert_eq!(packet.data, frame);
1625
1626 t1.stop_async().await.unwrap();
1627 t2.stop_async().await.unwrap();
1628 }
1629
1630 #[tokio::test]
1631 async fn test_connect_async_ip_string() {
1632 let (tx1, _rx1) = packet_channel(100);
1633 let (tx2, _rx2) = packet_channel(100);
1634
1635 let mut t1 = TcpTransport::new(TransportId::new(1), None, make_config(), tx1);
1636 let mut t2 = TcpTransport::new(
1637 TransportId::new(2),
1638 None,
1639 TcpConfig {
1640 bind_addr: Some("127.0.0.1:0".to_string()),
1641 ..Default::default()
1642 },
1643 tx2,
1644 );
1645
1646 t1.start_async().await.unwrap();
1647 t2.start_async().await.unwrap();
1648
1649 let port2 = t2.local_addr().unwrap().port();
1650 let addr = TransportAddr::from_string(&format!("127.0.0.1:{}", port2));
1651
1652 t1.connect_async(&addr).await.unwrap();
1654
1655 for _ in 0..50 {
1657 let state = t1.connection_state_sync(&addr);
1658 if state == ConnectionState::Connected {
1659 break;
1660 }
1661 tokio::time::sleep(Duration::from_millis(100)).await;
1662 }
1663
1664 assert_eq!(t1.connection_state_sync(&addr), ConnectionState::Connected,);
1665
1666 t1.stop_async().await.unwrap();
1667 t2.stop_async().await.unwrap();
1668 }
1669}