1use super::{
6 DiscoveredPeer, PacketTx, ReceivedPacket, Transport, TransportAddr, TransportError,
7 TransportId, TransportState, TransportType,
8};
9#[cfg(any(target_os = "linux", target_os = "macos"))]
10pub(crate) mod connected_peer;
11#[cfg(target_os = "macos")]
12pub(crate) mod darwin_sockopts;
13#[cfg(any(target_os = "linux", target_os = "macos"))]
14pub(crate) mod peer_drain;
15pub(crate) mod socket;
16mod stats;
17use super::resolve_socket_addr;
18use crate::config::UdpConfig;
19use crate::discovery::is_punch_packet;
20use socket::{AsyncUdpSocket, UdpRawSocket};
21use stats::UdpStats;
22use std::collections::HashMap;
23use std::net::SocketAddr;
24use std::sync::{Arc, Mutex as StdMutex};
25use std::time::{Duration, Instant};
26use tokio::task::JoinHandle;
27use tracing::{debug, info, trace, warn};
28
29const DNS_CACHE_TTL: Duration = Duration::from_secs(60);
31
32fn socket_addr_families_compatible(local: SocketAddr, remote: SocketAddr) -> bool {
33 matches!(
34 (local, remote),
35 (SocketAddr::V4(_), SocketAddr::V4(_)) | (SocketAddr::V6(_), SocketAddr::V6(_))
36 )
37}
38
39pub struct UdpTransport {
69 transport_id: TransportId,
71 name: Option<String>,
73 config: UdpConfig,
75 state: TransportState,
77 socket: Option<AsyncUdpSocket>,
79 packet_tx: PacketTx,
81 recv_task: Option<JoinHandle<()>>,
83 local_addr: Option<SocketAddr>,
85 stats: Arc<UdpStats>,
87 dns_cache: StdMutex<HashMap<TransportAddr, (SocketAddr, Instant)>>,
89}
90
91impl UdpTransport {
92 pub fn new(
94 transport_id: TransportId,
95 name: Option<String>,
96 config: UdpConfig,
97 packet_tx: PacketTx,
98 ) -> Self {
99 Self {
100 transport_id,
101 name,
102 config,
103 state: TransportState::Configured,
104 socket: None,
105 packet_tx,
106 recv_task: None,
107 local_addr: None,
108 stats: Arc::new(UdpStats::new()),
109 dns_cache: StdMutex::new(HashMap::new()),
110 }
111 }
112
113 pub fn name(&self) -> Option<&str> {
115 self.name.as_deref()
116 }
117
118 pub fn local_addr(&self) -> Option<SocketAddr> {
120 self.local_addr
121 }
122
123 pub fn recv_buf_size(&self) -> usize {
127 self.config.recv_buf_size()
128 }
129
130 pub fn send_buf_size(&self) -> usize {
132 self.config.send_buf_size()
133 }
134
135 pub fn clone_packet_tx(&self) -> PacketTx {
139 self.packet_tx.clone()
140 }
141
142 pub fn stats(&self) -> &Arc<UdpStats> {
144 &self.stats
145 }
146
147 pub async fn resolve_for_off_task(
155 &self,
156 addr: &TransportAddr,
157 ) -> Result<SocketAddr, TransportError> {
158 self.resolve_cached(addr).await
159 }
160
161 pub fn async_socket(&self) -> Option<AsyncUdpSocket> {
172 self.socket.clone()
173 }
174
175 async fn resolve_cached(&self, addr: &TransportAddr) -> Result<SocketAddr, TransportError> {
181 if let Some(s) = addr.as_str()
183 && let Ok(sock_addr) = s.parse::<SocketAddr>()
184 {
185 return Ok(sock_addr);
186 }
187
188 {
190 let cache = self.dns_cache.lock().unwrap();
191 if let Some((resolved, cached_at)) = cache.get(addr)
192 && cached_at.elapsed() < DNS_CACHE_TTL
193 {
194 return Ok(*resolved);
195 }
196 }
197
198 let resolved = resolve_socket_addr(addr).await?;
200
201 {
203 let mut cache = self.dns_cache.lock().unwrap();
204 cache.insert(addr.clone(), (resolved, Instant::now()));
205 }
206
207 Ok(resolved)
208 }
209
210 pub fn congestion(&self) -> super::TransportCongestion {
212 super::TransportCongestion {
213 recv_drops: Some(
214 self.stats
215 .kernel_drops
216 .load(std::sync::atomic::Ordering::Relaxed),
217 ),
218 }
219 }
220
221 pub async fn start_async(&mut self) -> Result<(), TransportError> {
225 if !self.state.can_start() {
226 return Err(TransportError::AlreadyStarted);
227 }
228
229 self.state = TransportState::Starting;
230
231 if self.config.outbound_only() && self.config.bind_addr.is_some() {
232 warn!(
233 configured_bind_addr = ?self.config.bind_addr,
234 "udp.outbound_only = true; configured bind_addr is ignored, binding to 0.0.0.0:0"
235 );
236 }
237
238 let bind_addr: SocketAddr = self
240 .config
241 .bind_addr()
242 .parse()
243 .map_err(|e| TransportError::StartFailed(format!("invalid bind address: {}", e)))?;
244
245 let raw_socket = UdpRawSocket::open(
247 bind_addr,
248 self.config.recv_buf_size(),
249 self.config.send_buf_size(),
250 )?;
251
252 let actual_recv = raw_socket.recv_buffer_size()?;
253 let actual_send = raw_socket.send_buffer_size()?;
254 self.local_addr = Some(raw_socket.local_addr());
255
256 let async_socket = raw_socket.into_async()?;
258 self.socket = Some(async_socket.clone());
259
260 let transport_id = self.transport_id;
262 let packet_tx = self.packet_tx.clone();
263 let mtu = self.config.mtu();
264 let stats = self.stats.clone();
265
266 let recv_task = tokio::spawn(async move {
267 udp_receive_loop(async_socket, transport_id, packet_tx, mtu, stats).await;
268 });
269
270 self.recv_task = Some(recv_task);
271 self.state = TransportState::Up;
272
273 if let Some(ref name) = self.name {
274 info!(
275 name = %name,
276 local_addr = %self.local_addr.unwrap(),
277 recv_buf = actual_recv,
278 send_buf = actual_send,
279 "UDP transport started"
280 );
281 } else {
282 info!(
283 local_addr = %self.local_addr.unwrap(),
284 recv_buf = actual_recv,
285 send_buf = actual_send,
286 "UDP transport started"
287 );
288 }
289
290 Ok(())
291 }
292
293 pub async fn adopt_socket_async(
298 &mut self,
299 socket: std::net::UdpSocket,
300 ) -> Result<(), TransportError> {
301 if !self.state.can_start() {
302 return Err(TransportError::AlreadyStarted);
303 }
304
305 self.state = TransportState::Starting;
306
307 let raw_socket = UdpRawSocket::adopt(
308 socket,
309 self.config.recv_buf_size(),
310 self.config.send_buf_size(),
311 )?;
312
313 let actual_recv = raw_socket.recv_buffer_size()?;
314 let actual_send = raw_socket.send_buffer_size()?;
315 self.local_addr = Some(raw_socket.local_addr());
316
317 let async_socket = raw_socket.into_async()?;
318 self.socket = Some(async_socket.clone());
319
320 let transport_id = self.transport_id;
321 let packet_tx = self.packet_tx.clone();
322 let mtu = self.config.mtu();
323 let stats = self.stats.clone();
324
325 let recv_task = tokio::spawn(async move {
326 udp_receive_loop(async_socket, transport_id, packet_tx, mtu, stats).await;
327 });
328
329 self.recv_task = Some(recv_task);
330 self.state = TransportState::Up;
331
332 if let Some(ref name) = self.name {
333 info!(
334 name = %name,
335 local_addr = %self.local_addr.unwrap(),
336 recv_buf = actual_recv,
337 send_buf = actual_send,
338 "UDP transport adopted existing socket"
339 );
340 } else {
341 info!(
342 local_addr = %self.local_addr.unwrap(),
343 recv_buf = actual_recv,
344 send_buf = actual_send,
345 "UDP transport adopted existing socket"
346 );
347 }
348
349 Ok(())
350 }
351
352 pub async fn stop_async(&mut self) -> Result<(), TransportError> {
354 if !self.state.is_operational() {
355 return Err(TransportError::NotStarted);
356 }
357
358 if let Some(task) = self.recv_task.take() {
360 task.abort();
361 let _ = task.await; }
363
364 self.socket.take();
366 self.local_addr = None;
367
368 self.state = TransportState::Down;
369
370 info!(
371 transport_id = %self.transport_id,
372 "UDP transport stopped"
373 );
374
375 Ok(())
376 }
377
378 pub async fn send_async(
385 &self,
386 addr: &TransportAddr,
387 data: &[u8],
388 ) -> Result<usize, TransportError> {
389 if !self.state.is_operational() {
390 return Err(TransportError::NotStarted);
391 }
392
393 if data.len() > self.config.mtu() as usize {
394 self.stats.record_mtu_exceeded();
395 return Err(TransportError::MtuExceeded {
396 packet_size: data.len(),
397 mtu: self.config.mtu(),
398 });
399 }
400
401 let socket_addr = self.resolve_cached(addr).await?;
402 let socket = self.socket.as_ref().ok_or(TransportError::NotStarted)?;
403 let local_addr = self.local_addr.ok_or(TransportError::NotStarted)?;
404 if !socket_addr_families_compatible(local_addr, socket_addr) {
405 return Err(TransportError::InvalidAddress(format!(
406 "remote address family {socket_addr} is incompatible with local UDP socket {local_addr}"
407 )));
408 }
409 match socket.send_to(data, &socket_addr).await {
410 Ok(bytes_sent) => {
411 self.stats.record_send(bytes_sent);
412 trace!(
413 transport_id = %self.transport_id,
414 remote_addr = %socket_addr,
415 bytes = bytes_sent,
416 "UDP packet sent"
417 );
418 Ok(bytes_sent)
419 }
420 Err(e) => {
421 self.stats.record_send_error();
422 Err(e)
423 }
424 }
425 }
426
427 pub async fn flush_pending_send(&self) {}
433}
434
435impl Transport for UdpTransport {
436 fn transport_id(&self) -> TransportId {
437 self.transport_id
438 }
439
440 fn transport_type(&self) -> &TransportType {
441 &TransportType::UDP
442 }
443
444 fn state(&self) -> TransportState {
445 self.state
446 }
447
448 fn mtu(&self) -> u16 {
449 self.config.mtu()
450 }
451
452 fn start(&mut self) -> Result<(), TransportError> {
453 Err(TransportError::NotSupported(
455 "use start_async() for UDP transport".into(),
456 ))
457 }
458
459 fn stop(&mut self) -> Result<(), TransportError> {
460 Err(TransportError::NotSupported(
462 "use stop_async() for UDP transport".into(),
463 ))
464 }
465
466 fn send(&self, _addr: &TransportAddr, _data: &[u8]) -> Result<(), TransportError> {
467 Err(TransportError::NotSupported(
469 "use send_async() for UDP transport".into(),
470 ))
471 }
472
473 fn discover(&self) -> Result<Vec<DiscoveredPeer>, TransportError> {
474 Ok(Vec::new())
477 }
478
479 fn accept_connections(&self) -> bool {
486 if self.config.outbound_only() {
487 false
488 } else {
489 self.config.accept_connections()
490 }
491 }
492}
493
494impl Drop for UdpTransport {
495 fn drop(&mut self) {
496 let had_task = self.recv_task.is_some();
497 let had_socket = self.socket.is_some();
498 if had_task || had_socket {
499 debug!(
500 transport_id = %self.transport_id,
501 state = ?self.state,
502 had_recv_task = had_task,
503 had_socket = had_socket,
504 "UdpTransport dropped without stop_async(); cleaning up",
505 );
506 }
507 if let Some(task) = self.recv_task.take() {
508 task.abort();
509 }
510 self.socket.take();
511 self.local_addr = None;
512 }
513}
514
515async fn udp_receive_loop(
522 socket: AsyncUdpSocket,
523 transport_id: TransportId,
524 packet_tx: PacketTx,
525 mtu: u16,
526 stats: Arc<UdpStats>,
527) {
528 debug!(transport_id = %transport_id, "UDP receive loop starting");
529
530 #[cfg(target_os = "linux")]
531 {
532 const BATCH: usize = 32;
533 let buf_size = mtu as usize + 100;
534 let mut backing: Vec<Vec<u8>> = (0..BATCH).map(|_| vec![0u8; buf_size]).collect();
547 let mut addrs: [Option<std::net::SocketAddr>; BATCH] = std::array::from_fn(|_| None);
548 let mut lens: [usize; BATCH] = [0; BATCH];
549
550 loop {
551 let mut bufs: [&mut [u8]; BATCH] = {
555 let mut iter = backing.iter_mut();
556 std::array::from_fn(|_| iter.next().unwrap().as_mut_slice())
557 };
558
559 let recv_result = {
560 let _t = crate::perf_profile::Timer::start(crate::perf_profile::Stage::UdpRecv);
561 socket.recv_batch(&mut bufs, &mut addrs, &mut lens).await
562 };
563 match recv_result {
564 Ok((count, kernel_drops)) => {
565 stats.set_kernel_drops(kernel_drops as u64);
566 for i in 0..count {
567 let len = lens[i];
568 let Some(remote_addr) = addrs[i] else {
569 continue;
570 };
571 stats.record_recv(len);
572
573 if is_punch_packet(&backing[i][..len]) {
576 trace!(
577 transport_id = %transport_id,
578 remote_addr = %remote_addr,
579 bytes = len,
580 "Dropping stray punch probe/ack on UDP transport"
581 );
582 continue;
583 }
584
585 let mut data = std::mem::replace(&mut backing[i], vec![0u8; buf_size]);
590 data.truncate(len);
591 let addr = TransportAddr::from_socket_addr(remote_addr);
592 let packet = ReceivedPacket::new(transport_id, addr, data);
593
594 trace!(
595 transport_id = %transport_id,
596 remote_addr = %remote_addr,
597 bytes = len,
598 "UDP packet received"
599 );
600
601 if packet_tx.send(packet).is_err() {
602 debug!(
603 transport_id = %transport_id,
604 "Packet channel closed, stopping receive loop"
605 );
606 return;
607 }
608 }
609 }
610 Err(e) => {
611 stats.record_recv_error();
612 warn!(
613 transport_id = %transport_id,
614 error = %e,
615 "UDP receive error"
616 );
617 }
618 }
619 }
620 }
621
622 #[cfg(not(target_os = "linux"))]
623 {
624 let mut buf = vec![0u8; mtu as usize + 100];
625
626 loop {
627 match socket.recv_from(&mut buf).await {
628 Ok((len, remote_addr, kernel_drops)) => {
629 stats.record_recv(len);
630 stats.set_kernel_drops(kernel_drops as u64);
631
632 if is_punch_packet(&buf[..len]) {
633 trace!(
634 transport_id = %transport_id,
635 remote_addr = %remote_addr,
636 bytes = len,
637 "Dropping stray punch probe/ack on UDP transport"
638 );
639 continue;
640 }
641
642 let data = buf[..len].to_vec();
643 let addr = TransportAddr::from_socket_addr(remote_addr);
644 let packet = ReceivedPacket::new(transport_id, addr, data);
645
646 trace!(
647 transport_id = %transport_id,
648 remote_addr = %remote_addr,
649 bytes = len,
650 "UDP packet received"
651 );
652
653 if packet_tx.send(packet).is_err() {
654 debug!(
655 transport_id = %transport_id,
656 "Packet channel closed, stopping receive loop"
657 );
658 break;
659 }
660 }
661 Err(e) => {
662 stats.record_recv_error();
663 warn!(
664 transport_id = %transport_id,
665 error = %e,
666 "UDP receive error"
667 );
668 }
669 }
670 }
671 }
672}
673
674#[cfg(test)]
679mod tests {
680 use super::*;
681 use crate::transport::packet_channel;
682 use tokio::time::{Duration, timeout};
683
684 fn make_config(port: u16) -> UdpConfig {
685 UdpConfig {
686 bind_addr: Some(format!("127.0.0.1:{}", port)),
687 mtu: Some(1280),
688 ..Default::default()
689 }
690 }
691
692 #[tokio::test]
693 async fn test_start_stop() {
694 let (tx, _rx) = packet_channel(100);
695 let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
696
697 assert_eq!(transport.state(), TransportState::Configured);
698
699 transport.start_async().await.unwrap();
700 assert_eq!(transport.state(), TransportState::Up);
701 assert!(transport.local_addr().is_some());
702
703 transport.stop_async().await.unwrap();
704 assert_eq!(transport.state(), TransportState::Down);
705 }
706
707 #[tokio::test]
708 async fn test_double_start_fails() {
709 let (tx, _rx) = packet_channel(100);
710 let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
711
712 transport.start_async().await.unwrap();
713
714 let result = transport.start_async().await;
715 assert!(matches!(result, Err(TransportError::AlreadyStarted)));
716
717 transport.stop_async().await.unwrap();
718 }
719
720 #[tokio::test]
721 async fn test_stop_not_started_fails() {
722 let (tx, _rx) = packet_channel(100);
723 let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
724
725 let result = transport.stop_async().await;
726 assert!(matches!(result, Err(TransportError::NotStarted)));
727 }
728
729 #[tokio::test]
730 async fn test_send_recv() {
731 let (tx1, _rx1) = packet_channel(100);
732 let (tx2, mut rx2) = packet_channel(100);
733
734 let mut t1 = UdpTransport::new(TransportId::new(1), None, make_config(0), tx1);
735 let mut t2 = UdpTransport::new(TransportId::new(2), None, make_config(0), tx2);
736
737 t1.start_async().await.unwrap();
738 t2.start_async().await.unwrap();
739
740 let addr1 = t1.local_addr().unwrap();
741 let addr2 = t2.local_addr().unwrap();
742
743 let data = b"hello world";
745 let bytes_sent = t1
746 .send_async(&TransportAddr::from_string(&addr2.to_string()), data)
747 .await
748 .unwrap();
749 assert_eq!(bytes_sent, data.len());
750
751 let packet = timeout(Duration::from_secs(1), rx2.recv())
753 .await
754 .expect("timeout")
755 .expect("channel closed");
756
757 assert_eq!(packet.data, data);
758 assert_eq!(
759 packet.remote_addr.as_str(),
760 Some(addr1.to_string().as_str())
761 );
762
763 t1.stop_async().await.unwrap();
764 t2.stop_async().await.unwrap();
765 }
766
767 #[tokio::test]
768 async fn test_bidirectional() {
769 let (tx1, mut rx1) = packet_channel(100);
770 let (tx2, mut rx2) = packet_channel(100);
771
772 let mut t1 = UdpTransport::new(TransportId::new(1), None, make_config(0), tx1);
773 let mut t2 = UdpTransport::new(TransportId::new(2), None, make_config(0), tx2);
774
775 t1.start_async().await.unwrap();
776 t2.start_async().await.unwrap();
777
778 let addr1 = TransportAddr::from_string(&t1.local_addr().unwrap().to_string());
779 let addr2 = TransportAddr::from_string(&t2.local_addr().unwrap().to_string());
780
781 t1.send_async(&addr2, b"ping").await.unwrap();
783
784 let packet = timeout(Duration::from_secs(1), rx2.recv())
786 .await
787 .expect("timeout")
788 .expect("channel closed");
789 assert_eq!(packet.data, b"ping");
790
791 t2.send_async(&addr1, b"pong").await.unwrap();
793
794 let packet = timeout(Duration::from_secs(1), rx1.recv())
796 .await
797 .expect("timeout")
798 .expect("channel closed");
799 assert_eq!(packet.data, b"pong");
800
801 t1.stop_async().await.unwrap();
802 t2.stop_async().await.unwrap();
803 }
804
805 #[tokio::test]
806 async fn test_mtu_exceeded() {
807 let (tx, _rx) = packet_channel(100);
808 let mut transport = UdpTransport::new(
809 TransportId::new(1),
810 None,
811 UdpConfig {
812 mtu: Some(100),
813 ..make_config(0)
814 },
815 tx,
816 );
817
818 transport.start_async().await.unwrap();
819
820 let oversized = vec![0u8; 200];
821 let result = transport
822 .send_async(&TransportAddr::from_string("127.0.0.1:9999"), &oversized)
823 .await;
824
825 assert!(matches!(result, Err(TransportError::MtuExceeded { .. })));
826
827 transport.stop_async().await.unwrap();
828 }
829
830 #[tokio::test]
831 async fn test_send_not_started() {
832 let (tx, _rx) = packet_channel(100);
833 let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
834
835 let result = transport
836 .send_async(&TransportAddr::from_string("127.0.0.1:9999"), b"test")
837 .await;
838
839 assert!(matches!(result, Err(TransportError::NotStarted)));
840 }
841
842 #[tokio::test]
843 async fn test_discover_returns_empty() {
844 let (tx, _rx) = packet_channel(100);
845 let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
846
847 let peers = transport.discover().unwrap();
849 assert!(peers.is_empty());
850 }
851
852 #[test]
853 fn test_transport_type() {
854 let (tx, _rx) = packet_channel(100);
855 let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
856
857 assert_eq!(transport.transport_type().name, "udp");
858 assert!(!transport.transport_type().connection_oriented);
859 assert!(!transport.transport_type().reliable);
860 }
861
862 #[test]
863 fn test_sync_methods_return_not_supported() {
864 let (tx, _rx) = packet_channel(100);
865 let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
866
867 assert!(matches!(
868 transport.start(),
869 Err(TransportError::NotSupported(_))
870 ));
871 assert!(matches!(
872 transport.stop(),
873 Err(TransportError::NotSupported(_))
874 ));
875 assert!(matches!(
876 transport.send(&TransportAddr::from_string("test"), b"data"),
877 Err(TransportError::NotSupported(_))
878 ));
879 }
880
881 #[tokio::test]
882 async fn test_resolve_socket_addr_ip() {
883 let addr = TransportAddr::from_string("192.168.1.1:2121");
884 let result = resolve_socket_addr(&addr).await.unwrap();
885 assert_eq!(result.to_string(), "192.168.1.1:2121");
886 }
887
888 #[tokio::test]
889 async fn test_resolve_socket_addr_invalid() {
890 let invalid = TransportAddr::from_string("nonexistent.invalid:2121");
891 assert!(resolve_socket_addr(&invalid).await.is_err());
892
893 let binary = TransportAddr::new(vec![0xff, 0x80]);
894 assert!(resolve_socket_addr(&binary).await.is_err());
895 }
896
897 #[tokio::test]
898 async fn test_resolve_socket_addr_hostname() {
899 let addr = TransportAddr::from_string("localhost:2121");
900 let result = resolve_socket_addr(&addr).await.unwrap();
901 assert!(result.ip().is_loopback());
903 assert_eq!(result.port(), 2121);
904 }
905
906 #[tokio::test]
907 async fn test_congestion_reports_kernel_drops() {
908 let (tx, _rx) = packet_channel(100);
909 let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
910
911 let cong = transport.congestion();
913 assert_eq!(cong.recv_drops, Some(0));
914 }
915
916 #[test]
917 fn test_accept_connections_default_true() {
918 let (tx, _rx) = packet_channel(100);
919 let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
920 assert!(transport.accept_connections());
922 }
923
924 #[test]
925 fn test_accept_connections_false_when_configured() {
926 let (tx, _rx) = packet_channel(100);
927 let transport = UdpTransport::new(
928 TransportId::new(1),
929 None,
930 UdpConfig {
931 bind_addr: Some("127.0.0.1:0".to_string()),
932 accept_connections: Some(false),
933 ..Default::default()
934 },
935 tx,
936 );
937 assert!(!transport.accept_connections());
938 }
939
940 #[test]
941 fn test_accept_connections_forced_false_in_outbound_only() {
942 let (tx, _rx) = packet_channel(100);
943 let transport = UdpTransport::new(
944 TransportId::new(1),
945 None,
946 UdpConfig {
947 outbound_only: Some(true),
948 accept_connections: Some(true), ..Default::default()
950 },
951 tx,
952 );
953 assert!(!transport.accept_connections());
954 }
955
956 #[tokio::test]
957 async fn test_outbound_only_binds_ephemeral() {
958 let (tx, _rx) = packet_channel(100);
964 let mut transport = UdpTransport::new(
965 TransportId::new(1),
966 None,
967 UdpConfig {
968 bind_addr: Some("127.0.0.1:65535".to_string()),
969 outbound_only: Some(true),
970 ..Default::default()
971 },
972 tx,
973 );
974
975 transport.start_async().await.unwrap();
976 let local = transport.local_addr().unwrap();
977 assert_ne!(local.port(), 65535);
980 assert!(local.port() > 0);
981 assert!(local.ip().is_unspecified());
984 transport.stop_async().await.unwrap();
985 }
986
987 #[tokio::test]
988 async fn test_punch_probe_dropped() {
989 let (tx_recv, mut rx_recv) = packet_channel(100);
990 let (tx_send, _rx_send) = packet_channel(100);
991
992 let mut t_recv = UdpTransport::new(TransportId::new(1), None, make_config(0), tx_recv);
993 let mut t_send = UdpTransport::new(TransportId::new(2), None, make_config(0), tx_send);
994
995 t_recv.start_async().await.unwrap();
996 t_send.start_async().await.unwrap();
997
998 let recv_addr = t_recv.local_addr().unwrap();
999 let recv_addr_str = TransportAddr::from_string(&recv_addr.to_string());
1000
1001 let mut probe = vec![0u8; 16];
1003 probe[..4].copy_from_slice(&0x4E505443u32.to_be_bytes());
1004 t_send.send_async(&recv_addr_str, &probe).await.unwrap();
1005
1006 let mut ack = vec![0u8; 16];
1008 ack[..4].copy_from_slice(&0x4E505441u32.to_be_bytes());
1009 t_send.send_async(&recv_addr_str, &ack).await.unwrap();
1010
1011 let real = b"valid-fmp-frame";
1013 t_send.send_async(&recv_addr_str, real).await.unwrap();
1014
1015 let packet = timeout(Duration::from_secs(1), rx_recv.recv())
1018 .await
1019 .expect("timeout waiting for real packet")
1020 .expect("channel closed");
1021 assert_eq!(packet.data, real);
1022
1023 let no_more = timeout(Duration::from_millis(200), rx_recv.recv()).await;
1025 assert!(no_more.is_err(), "punch probe/ack leaked through filter");
1026
1027 t_recv.stop_async().await.unwrap();
1028 t_send.stop_async().await.unwrap();
1029 }
1030
1031 #[test]
1032 fn test_is_punch_packet_helper() {
1033 use crate::discovery::is_punch_packet;
1034 assert!(is_punch_packet(&[0x4E, 0x50, 0x54, 0x43, 0xAA, 0xBB]));
1036 assert!(is_punch_packet(&[0x4E, 0x50, 0x54, 0x41]));
1038 assert!(!is_punch_packet(&[0x01, 0x02, 0x03, 0x04]));
1040 assert!(!is_punch_packet(&[0x4E, 0x50, 0x54]));
1042 assert!(!is_punch_packet(&[]));
1043 }
1044
1045 #[tokio::test]
1046 async fn test_send_recv_ip_string() {
1047 let (tx1, _rx1) = packet_channel(100);
1048 let (tx2, mut rx2) = packet_channel(100);
1049
1050 let mut t1 = UdpTransport::new(TransportId::new(1), None, make_config(0), tx1);
1051 let mut t2 = UdpTransport::new(TransportId::new(2), None, make_config(0), tx2);
1052
1053 t1.start_async().await.unwrap();
1054 t2.start_async().await.unwrap();
1055
1056 let port2 = t2.local_addr().unwrap().port();
1057
1058 let data = b"hello via ip string";
1060 let bytes_sent = t1
1061 .send_async(
1062 &TransportAddr::from_string(&format!("127.0.0.1:{}", port2)),
1063 data,
1064 )
1065 .await
1066 .unwrap();
1067 assert_eq!(bytes_sent, data.len());
1068
1069 let packet = timeout(Duration::from_secs(1), rx2.recv())
1071 .await
1072 .expect("timeout")
1073 .expect("channel closed");
1074
1075 assert_eq!(packet.data, data);
1076
1077 t1.stop_async().await.unwrap();
1078 t2.stop_async().await.unwrap();
1079 }
1080}