1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
2
3use crate::{maybe_fut_constructor_result, maybe_fut_method, maybe_fut_method_sync};
4
5#[derive(Debug, Unwrap)]
11#[unwrap_types(
12 std(std::net::UdpSocket),
13 tokio(tokio::net::UdpSocket),
14 tokio_gated("tokio-net")
15)]
16pub struct UdpSocket(UdpSocketInner);
17
18#[derive(Debug)]
19enum UdpSocketInner {
20 Std(std::net::UdpSocket),
21 #[cfg(feature = "tokio-net")]
22 #[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
23 Tokio(tokio::net::UdpSocket),
24}
25
26impl From<std::net::UdpSocket> for UdpSocket {
27 fn from(socket: std::net::UdpSocket) -> Self {
28 UdpSocket(UdpSocketInner::Std(socket))
29 }
30}
31
32#[cfg(feature = "tokio-net")]
33#[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
34impl From<tokio::net::UdpSocket> for UdpSocket {
35 fn from(socket: tokio::net::UdpSocket) -> Self {
36 UdpSocket(UdpSocketInner::Tokio(socket))
37 }
38}
39
40#[cfg(unix)]
41impl std::os::fd::AsFd for UdpSocket {
42 fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
43 match &self.0 {
44 UdpSocketInner::Std(file) => file.as_fd(),
45 #[cfg(tokio_net)]
46 UdpSocketInner::Tokio(file) => file.as_fd(),
47 }
48 }
49}
50
51#[cfg(unix)]
52impl std::os::fd::AsRawFd for UdpSocket {
53 fn as_raw_fd(&self) -> std::os::fd::RawFd {
54 match &self.0 {
55 UdpSocketInner::Std(file) => file.as_raw_fd(),
56 #[cfg(tokio_net)]
57 UdpSocketInner::Tokio(file) => file.as_raw_fd(),
58 }
59 }
60}
61
62#[cfg(windows)]
63impl std::os::windows::io::AsSocket for UdpSocket {
64 fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> {
65 match &self.0 {
66 UdpSocketInner::Std(file) => file.as_socket(),
67 #[cfg(tokio_net)]
68 UdpSocketInner::Tokio(file) => file.as_socket(),
69 }
70 }
71}
72
73#[cfg(windows)]
74impl std::os::windows::io::AsRawSocket for UdpSocket {
75 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
76 match &self.0 {
77 UdpSocketInner::Std(file) => file.as_raw_socket(),
78 #[cfg(tokio_net)]
79 UdpSocketInner::Tokio(file) => file.as_raw_socket(),
80 }
81 }
82}
83
84impl UdpSocket {
85 maybe_fut_constructor_result!(
86 bind(addr: std::net::SocketAddr) -> std::io::Result<UdpSocket>,
88 std::net::UdpSocket::bind,
89 tokio::net::UdpSocket::bind,
90 tokio_net
91 );
92
93 maybe_fut_method!(
94 recv_from(buf: &mut [u8]) -> std::io::Result<(usize, std::net::SocketAddr)>,
98 UdpSocketInner::Std,
99 UdpSocketInner::Tokio,
100 tokio_net
101 );
102
103 maybe_fut_method!(
104 peek_from(buf: &mut [u8]) -> std::io::Result<(usize, std::net::SocketAddr)>,
108 UdpSocketInner::Std,
109 UdpSocketInner::Tokio,
110 tokio_net
111 );
112
113 maybe_fut_method!(
114 send_to(buf: &[u8], target: std::net::SocketAddr) -> std::io::Result<usize>,
119 UdpSocketInner::Std,
120 UdpSocketInner::Tokio,
121 tokio_net
122 );
123
124 maybe_fut_method_sync!(
125 peer_addr() -> std::io::Result<std::net::SocketAddr>,
127 UdpSocketInner::Std,
128 UdpSocketInner::Tokio,
129 tokio_net
130 );
131
132 maybe_fut_method_sync!(
133 local_addr() -> std::io::Result<std::net::SocketAddr>,
135 UdpSocketInner::Std,
136 UdpSocketInner::Tokio,
137 tokio_net
138 );
139
140 pub fn try_clone(&self) -> std::io::Result<Self> {
144 match &self.0 {
145 UdpSocketInner::Std(socket) => socket.try_clone().map(UdpSocket::from),
146 #[cfg(feature = "tokio-net")]
147 UdpSocketInner::Tokio(_) => Err(std::io::Error::other(
148 "Tokio UdpSocket does not support try_clone",
149 )),
150 }
151 }
152
153 pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> std::io::Result<()> {
157 match &self.0 {
158 UdpSocketInner::Std(socket) => socket.set_read_timeout(timeout),
159 #[cfg(feature = "tokio-net")]
160 UdpSocketInner::Tokio(_) => Err(std::io::Error::other(
161 "Tokio UdpSocket does not support set_read_timeout",
162 )),
163 }
164 }
165
166 pub fn set_write_timeout(&self, timeout: Option<std::time::Duration>) -> std::io::Result<()> {
170 match &self.0 {
171 UdpSocketInner::Std(socket) => socket.set_write_timeout(timeout),
172 #[cfg(feature = "tokio-net")]
173 UdpSocketInner::Tokio(_) => Err(std::io::Error::other(
174 "Tokio UdpSocket does not support set_read_timeout",
175 )),
176 }
177 }
178
179 pub fn read_timeout(&self) -> std::io::Result<Option<std::time::Duration>> {
183 match &self.0 {
184 UdpSocketInner::Std(socket) => socket.read_timeout(),
185 #[cfg(feature = "tokio-net")]
186 UdpSocketInner::Tokio(_) => Err(std::io::Error::other(
187 "Tokio UdpSocket does not support read_timeout",
188 )),
189 }
190 }
191
192 pub fn write_timeout(&self) -> std::io::Result<Option<std::time::Duration>> {
196 match &self.0 {
197 UdpSocketInner::Std(socket) => socket.write_timeout(),
198 #[cfg(feature = "tokio-net")]
199 UdpSocketInner::Tokio(_) => Err(std::io::Error::other(
200 "Tokio UdpSocket does not support write_timeout",
201 )),
202 }
203 }
204
205 maybe_fut_method_sync!(
206 set_broadcast(broadcast: bool) -> std::io::Result<()>,
208 UdpSocketInner::Std,
209 UdpSocketInner::Tokio,
210 tokio_net
211 );
212
213 maybe_fut_method_sync!(
214 broadcast() -> std::io::Result<bool>,
216 UdpSocketInner::Std,
217 UdpSocketInner::Tokio,
218 tokio_net
219 );
220
221 maybe_fut_method_sync!(
222 set_multicast_loop_v4(loop_v4: bool) -> std::io::Result<()>,
224 UdpSocketInner::Std,
225 UdpSocketInner::Tokio,
226 tokio_net
227 );
228
229 maybe_fut_method_sync!(
230 multicast_loop_v4() -> std::io::Result<bool>,
232 UdpSocketInner::Std,
233 UdpSocketInner::Tokio,
234 tokio_net
235 );
236
237 maybe_fut_method_sync!(
238 set_multicast_ttl_v4(ttl: u32) -> std::io::Result<()>,
240 UdpSocketInner::Std,
241 UdpSocketInner::Tokio,
242 tokio_net
243 );
244
245 maybe_fut_method_sync!(
246 multicast_ttl_v4() -> std::io::Result<u32>,
248 UdpSocketInner::Std,
249 UdpSocketInner::Tokio,
250 tokio_net
251 );
252
253 maybe_fut_method_sync!(
254 set_multicast_loop_v6(loop_v6: bool) -> std::io::Result<()>,
256 UdpSocketInner::Std,
257 UdpSocketInner::Tokio,
258 tokio_net
259 );
260
261 maybe_fut_method_sync!(
262 multicast_loop_v6() -> std::io::Result<bool>,
264 UdpSocketInner::Std,
265 UdpSocketInner::Tokio,
266 tokio_net
267 );
268
269 maybe_fut_method_sync!(
270 set_ttl(ttl: u32) -> std::io::Result<()>,
272 UdpSocketInner::Std,
273 UdpSocketInner::Tokio,
274 tokio_net
275 );
276
277 maybe_fut_method_sync!(
278 ttl() -> std::io::Result<u32>,
280 UdpSocketInner::Std,
281 UdpSocketInner::Tokio,
282 tokio_net
283 );
284
285 pub fn join_multicast_v4(
287 &self,
288 multiaddr: &Ipv4Addr,
289 interface: &Ipv4Addr,
290 ) -> std::io::Result<()> {
291 match &self.0 {
292 UdpSocketInner::Std(socket) => socket.join_multicast_v4(multiaddr, interface),
293 #[cfg(feature = "tokio-net")]
294 UdpSocketInner::Tokio(socket) => socket.join_multicast_v4(*multiaddr, *interface),
295 }
296 }
297
298 pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> std::io::Result<()> {
300 match &self.0 {
301 UdpSocketInner::Std(socket) => socket.join_multicast_v6(multiaddr, interface),
302 #[cfg(feature = "tokio-net")]
303 UdpSocketInner::Tokio(socket) => socket.join_multicast_v6(multiaddr, interface),
304 }
305 }
306
307 pub fn leave_multicast_v4(
308 &self,
309 multiaddr: &Ipv4Addr,
310 interface: &Ipv4Addr,
311 ) -> std::io::Result<()> {
312 match &self.0 {
313 UdpSocketInner::Std(socket) => socket.leave_multicast_v4(multiaddr, interface),
314 #[cfg(feature = "tokio-net")]
315 UdpSocketInner::Tokio(socket) => socket.leave_multicast_v4(*multiaddr, *interface),
316 }
317 }
318
319 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> std::io::Result<()> {
320 match &self.0 {
321 UdpSocketInner::Std(socket) => socket.leave_multicast_v6(multiaddr, interface),
322 #[cfg(feature = "tokio-net")]
323 UdpSocketInner::Tokio(socket) => socket.leave_multicast_v6(multiaddr, interface),
324 }
325 }
326
327 maybe_fut_method_sync!(
328 take_error() -> std::io::Result<Option<std::io::Error>>,
330 UdpSocketInner::Std,
331 UdpSocketInner::Tokio,
332 tokio_net
333 );
334
335 pub async fn connect(&self, addr: SocketAddr) -> std::io::Result<()> {
339 match &self.0 {
340 UdpSocketInner::Std(socket) => socket.connect(addr),
341 #[cfg(feature = "tokio-net")]
342 UdpSocketInner::Tokio(socket) => socket.connect(addr).await,
343 }
344 }
345
346 maybe_fut_method!(
347 send(buf: &[u8]) -> std::io::Result<usize>,
351 UdpSocketInner::Std,
352 UdpSocketInner::Tokio,
353 tokio_net
354 );
355
356 maybe_fut_method!(
357 recv(buf: &mut [u8]) -> std::io::Result<usize>,
361 UdpSocketInner::Std,
362 UdpSocketInner::Tokio,
363 tokio_net
364 );
365
366 maybe_fut_method!(
367 peek(buf: &mut [u8]) -> std::io::Result<usize>,
371 UdpSocketInner::Std,
372 UdpSocketInner::Tokio,
373 tokio_net
374 );
375
376 pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
380 match &self.0 {
381 UdpSocketInner::Std(socket) => socket.set_nonblocking(nonblocking),
382 #[cfg(feature = "tokio-net")]
383 UdpSocketInner::Tokio(_) => Err(std::io::Error::other(
384 "Tokio UdpSocket does not support set_nonblocking",
385 )),
386 }
387 }
388}
389
390#[cfg(test)]
391mod test {
392
393 use std::sync::Arc;
394 use std::sync::atomic::AtomicBool;
395 use std::thread::JoinHandle;
396
397 use super::*;
398 use crate::{Unwrap, block_on};
399
400 #[test]
401 #[serial_test::serial]
402 fn test_should_bind_udp_std() {
403 let socket = block_on(UdpSocket::bind(
404 "127.0.0.1:0"
405 .parse::<SocketAddr>()
406 .expect("failed to parse"),
407 ))
408 .expect("failed to bind UDP socket");
409
410 assert!(socket.get_std().is_some());
411 }
412
413 #[cfg(feature = "tokio-net")]
414 #[tokio::test]
415 #[serial_test::serial]
416 async fn test_should_bind_udp_tokio() {
417 let socket = UdpSocket::bind(
418 "127.0.0.1:0"
419 .parse::<SocketAddr>()
420 .expect("failed to parse"),
421 )
422 .await
423 .expect("failed to bind UDP socket");
424
425 assert!(socket.get_tokio().is_some());
426 }
427
428 #[test]
429 #[serial_test::serial]
430 fn test_should_send_and_recv_from_udp_std() {
431 let (_server_handle, server_addr, exit) = echo_server();
432 let socket = bind_std();
433
434 let msg = b"Hello, UDP!";
435 let mut buf = [0; 1024];
436
437 let sent_bytes = block_on(socket.send_to(msg, server_addr)).expect("failed to send");
439 assert_eq!(sent_bytes, msg.len());
440
441 let (received_bytes, src) =
443 block_on(socket.recv_from(&mut buf)).expect("failed to receive");
444 assert_eq!(received_bytes, msg.len());
445 assert_eq!(src, server_addr);
446 assert_eq!(&buf[..received_bytes], msg);
447
448 exit.store(true, std::sync::atomic::Ordering::Relaxed);
449 }
451
452 #[cfg(feature = "tokio-net")]
453 #[tokio::test]
454 #[serial_test::serial]
455 async fn test_should_send_and_recv_from_udp_tokio() {
456 let (_server_handle, server_addr, exit) = echo_server();
457 let socket = bind_tokio().await;
458
459 let msg = b"Hello, UDP!";
460 let mut buf = [0; 1024];
461
462 let sent_bytes = socket
464 .send_to(msg, server_addr)
465 .await
466 .expect("failed to send");
467 assert_eq!(sent_bytes, msg.len());
468
469 let (received_bytes, src) = socket.recv_from(&mut buf).await.expect("failed to receive");
471 assert_eq!(received_bytes, msg.len());
472 assert_eq!(src, server_addr);
473 assert_eq!(&buf[..received_bytes], msg);
474
475 exit.store(true, std::sync::atomic::Ordering::Relaxed);
476 }
478
479 #[test]
480 fn test_should_get_options_std() {
481 let socket = bind_std();
482
483 socket.set_broadcast(true).expect("failed to set broadcast");
485 let broadcast = socket.broadcast().expect("failed to get broadcast");
486 assert!(broadcast);
487 socket
488 .set_broadcast(false)
489 .expect("failed to set broadcast");
490 let broadcast = socket.broadcast().expect("failed to get broadcast");
491 assert!(!broadcast);
492
493 socket
495 .set_multicast_loop_v4(true)
496 .expect("failed to set multicast loop");
497 let loop_v4 = socket
498 .multicast_loop_v4()
499 .expect("failed to get multicast loop");
500 assert!(loop_v4);
501 socket
502 .set_multicast_loop_v4(false)
503 .expect("failed to set multicast loop");
504 let loop_v4 = socket
505 .multicast_loop_v4()
506 .expect("failed to get multicast loop");
507 assert!(!loop_v4);
508
509 socket
511 .set_multicast_ttl_v4(1)
512 .expect("failed to set multicast TTL");
513 let ttl = socket
514 .multicast_ttl_v4()
515 .expect("failed to get multicast TTL");
516 assert_eq!(ttl, 1);
517 socket
518 .set_multicast_ttl_v4(64)
519 .expect("failed to set multicast TTL");
520 let ttl = socket
521 .multicast_ttl_v4()
522 .expect("failed to get multicast TTL");
523 assert_eq!(ttl, 64);
524
525 socket.set_ttl(64).expect("failed to set TTL");
543 let ttl = socket.ttl().expect("failed to get TTL");
544 assert_eq!(ttl, 64);
545 socket.set_ttl(128).expect("failed to set TTL");
546 let ttl = socket.ttl().expect("failed to get TTL");
547 assert_eq!(ttl, 128);
548
549 let multiaddr_v4 = Ipv4Addr::new(224, 0, 0, 1);
551 let interface_v4 = Ipv4Addr::new(127, 0, 0, 1);
552 socket
553 .join_multicast_v4(&multiaddr_v4, &interface_v4)
554 .expect("failed to join multicast v4");
555
556 socket
557 .leave_multicast_v4(&multiaddr_v4, &interface_v4)
558 .expect("failed to leave multicast v4");
559
560 let error = socket.take_error().expect("failed to get SO_ERROR");
571 assert!(error.is_none(), "Expected no error, got: {:?}", error);
572 }
573
574 #[cfg(feature = "tokio-net")]
575 #[tokio::test]
576 #[serial_test::serial]
577 async fn test_should_get_options_tokio() {
578 let socket = bind_tokio().await;
579
580 socket.set_broadcast(true).expect("failed to set broadcast");
582 let broadcast = socket.broadcast().expect("failed to get broadcast");
583 assert!(broadcast);
584 socket
585 .set_broadcast(false)
586 .expect("failed to set broadcast");
587 let broadcast = socket.broadcast().expect("failed to get broadcast");
588 assert!(!broadcast);
589
590 socket
592 .set_multicast_loop_v4(true)
593 .expect("failed to set multicast loop");
594 let loop_v4 = socket
595 .multicast_loop_v4()
596 .expect("failed to get multicast loop");
597 assert!(loop_v4);
598 socket
599 .set_multicast_loop_v4(false)
600 .expect("failed to set multicast loop");
601 let loop_v4 = socket
602 .multicast_loop_v4()
603 .expect("failed to get multicast loop");
604 assert!(!loop_v4);
605
606 socket
608 .set_multicast_ttl_v4(1)
609 .expect("failed to set multicast TTL");
610 let ttl = socket
611 .multicast_ttl_v4()
612 .expect("failed to get multicast TTL");
613 assert_eq!(ttl, 1);
614 socket
615 .set_multicast_ttl_v4(64)
616 .expect("failed to set multicast TTL");
617 let ttl = socket
618 .multicast_ttl_v4()
619 .expect("failed to get multicast TTL");
620 assert_eq!(ttl, 64);
621
622 socket.set_ttl(64).expect("failed to set TTL");
639 let ttl = socket.ttl().expect("failed to get TTL");
640 assert_eq!(ttl, 64);
641 socket.set_ttl(128).expect("failed to set TTL");
642 let ttl = socket.ttl().expect("failed to get TTL");
643 assert_eq!(ttl, 128);
644
645 let multiaddr_v4 = Ipv4Addr::new(224, 0, 0, 1);
647 let interface_v4 = Ipv4Addr::new(127, 0, 0, 1);
648 socket
649 .join_multicast_v4(&multiaddr_v4, &interface_v4)
650 .expect("failed to join multicast v4");
651
652 socket
653 .leave_multicast_v4(&multiaddr_v4, &interface_v4)
654 .expect("failed to leave multicast v4");
655
656 let error = socket.take_error().expect("failed to get SO_ERROR");
666 assert!(error.is_none(), "Expected no error, got: {:?}", error);
667 }
668
669 fn bind_std() -> UdpSocket {
670 block_on(UdpSocket::bind(
671 "127.0.0.1:0"
672 .parse::<SocketAddr>()
673 .expect("failed to parse"),
674 ))
675 .expect("failed to bind UDP socket")
676 }
677
678 #[cfg(feature = "tokio-net")]
679 async fn bind_tokio() -> UdpSocket {
680 UdpSocket::bind(
681 "127.0.0.1:0"
682 .parse::<SocketAddr>()
683 .expect("failed to parse"),
684 )
685 .await
686 .expect("failed to bind UDP socket")
687 }
688
689 fn echo_server() -> (JoinHandle<()>, SocketAddr, Arc<AtomicBool>) {
690 std::thread::sleep(std::time::Duration::from_millis(
691 rand::random::<u64>() % 1000,
692 ));
693
694 let exit = Arc::new(AtomicBool::new(false));
695 let exit_clone = exit.clone();
696
697 let server = std::net::UdpSocket::bind("127.0.0.1:0").expect("failed to bind UDP server");
698 server
699 .set_nonblocking(true)
700 .expect("failed to set non-blocking mode");
701 let addr = server.local_addr().expect("failed to get local address");
702 let handle = std::thread::spawn(move || {
703 let mut buf = [0; 1024];
704 while !exit_clone.load(std::sync::atomic::Ordering::Relaxed) {
705 match server.recv_from(&mut buf) {
706 Ok((size, src)) => {
707 println!("Received {} bytes from {}", size, src);
708 if let Err(e) = server.send_to(&buf[..size], src) {
709 eprintln!("Failed to send response: {}", e);
710 }
711 }
712 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
713 std::thread::sleep(std::time::Duration::from_millis(100));
715 continue;
716 }
717 Err(e) => eprintln!("Failed to receive data: {}", e),
718 }
719 }
720 });
721 (handle, addr, exit)
722 }
723}