1use std::{
2 io,
3 io::IoSliceMut,
4 mem::{self, MaybeUninit},
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6 os::unix::io::AsRawFd,
7 ptr,
8 sync::atomic::AtomicUsize,
9 task::{Context, Poll},
10};
11
12use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
13use futures_core::ready;
14use socket2::SockRef;
15use tokio::{
16 io::{Interest, ReadBuf},
17 net::ToSocketAddrs,
18};
19
20use super::{cmsg, RecvMeta, UdpState};
21
22#[cfg(target_os = "freebsd")]
23type IpTosTy = libc::c_uchar;
24#[cfg(not(target_os = "freebsd"))]
25type IpTosTy = libc::c_int;
26
27#[derive(Debug)]
32pub struct UdpSocket {
33 io: tokio::net::UdpSocket,
34}
35
36impl AsRawFd for UdpSocket {
37 fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
38 self.io.as_raw_fd()
39 }
40}
41
42impl UdpSocket {
43 pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpSocket> {
45 socket.set_nonblocking(true)?;
46
47 init(SockRef::from(&socket))?;
48 Ok(UdpSocket {
49 io: tokio::net::UdpSocket::from_std(socket)?,
50 })
51 }
52
53 pub fn into_std(self) -> io::Result<std::net::UdpSocket> {
54 self.io.into_std()
55 }
56
57 pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
59 let io = tokio::net::UdpSocket::bind(addr).await?;
60 init(SockRef::from(&io))?;
61 Ok(UdpSocket { io })
62 }
63
64 pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
66 self.io.set_broadcast(broadcast)
67 }
68
69 pub async fn connect<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
70 self.io.connect(addrs).await
71 }
72 pub async fn join_multicast_v4(
73 &self,
74 multiaddr: Ipv4Addr,
75 interface: Ipv4Addr,
76 ) -> io::Result<()> {
77 self.io.join_multicast_v4(multiaddr, interface)
78 }
79 pub async fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
80 self.io.join_multicast_v6(multiaddr, interface)
81 }
82 pub async fn leave_multicast_v4(
83 &self,
84 multiaddr: Ipv4Addr,
85 interface: Ipv4Addr,
86 ) -> io::Result<()> {
87 self.io.leave_multicast_v4(multiaddr, interface)
88 }
89 pub async fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
90 self.io.leave_multicast_v6(multiaddr, interface)
91 }
92 pub async fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
93 self.io.set_multicast_loop_v4(on)
94 }
95 pub async fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
96 self.io.set_multicast_loop_v6(on)
97 }
98 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
105 self.io.send_to(buf, target).await
106 }
107 pub fn poll_send_to(
114 &self,
115 cx: &mut Context<'_>,
116 buf: &[u8],
117 target: SocketAddr,
118 ) -> Poll<io::Result<usize>> {
119 self.io.poll_send_to(cx, buf, target)
120 }
121 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
128 self.io.send(buf).await
129 }
130 pub async fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
137 self.io.poll_send(cx, buf)
138 }
139 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
146 self.io.recv_from(buf).await
147 }
148 pub fn poll_recv_from(
155 &self,
156 cx: &mut Context<'_>,
157 buf: &mut ReadBuf<'_>,
158 ) -> Poll<io::Result<SocketAddr>> {
159 self.io.poll_recv_from(cx, buf)
160 }
161 pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
168 self.io.recv(buf).await
169 }
170 pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
177 self.io.poll_recv(cx, buf)
178 }
179
180 pub async fn send_mmsg<B: AsPtr<u8>>(
185 &self,
186 state: &UdpState,
187 transmits: &[Transmit<B>],
188 ) -> Result<usize, io::Error> {
189 let n = loop {
190 self.io.writable().await?;
191 let io = &self.io;
192 match io.try_io(Interest::WRITABLE, || {
193 send(state, SockRef::from(io), transmits)
194 }) {
195 Ok(res) => break res,
196 Err(_would_block) => continue,
197 }
198 };
199 Ok(n)
201 }
202
203 pub async fn send_msg<B: AsPtr<u8>>(
208 &self,
209 state: &UdpState,
210 transmits: Transmit<B>,
211 ) -> io::Result<usize> {
212 let n = loop {
213 self.io.writable().await?;
214 let io = &self.io;
215 match io.try_io(Interest::WRITABLE, || {
216 send_msg(state, SockRef::from(io), &transmits)
217 }) {
218 Ok(res) => break res,
219 Err(_would_block) => continue,
220 }
221 };
222 Ok(n)
223 }
224
225 pub async fn recv_mmsg(
227 &self,
228 bufs: &mut [IoSliceMut<'_>],
229 meta: &mut [RecvMeta],
230 ) -> io::Result<usize> {
231 debug_assert!(!bufs.is_empty());
232 loop {
233 self.io.readable().await?;
234 let io = &self.io;
235 match io.try_io(Interest::READABLE, || recv(SockRef::from(io), bufs, meta)) {
236 Ok(res) => return Ok(res),
237 Err(_would_block) => continue,
238 }
239 }
240 }
241
242 pub async fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
247 let mut iov = IoSliceMut::new(buf);
248 debug_assert!(!iov.is_empty());
249 loop {
250 self.io.readable().await?;
251 let io = &self.io;
252 match io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), &mut iov)) {
253 Ok(res) => return Ok(res),
254 Err(_would_block) => continue,
255 }
256 }
257 }
258
259 pub fn poll_send_mmsg<B: AsPtr<u8>>(
261 &self,
262 state: &UdpState,
263 cx: &mut Context,
264 transmits: &[Transmit<B>],
265 ) -> Poll<io::Result<usize>> {
266 loop {
267 ready!(self.io.poll_send_ready(cx))?;
268 let io = &self.io;
269 if let Ok(res) = io.try_io(Interest::WRITABLE, || {
270 send(state, SockRef::from(io), transmits)
271 }) {
272 return Poll::Ready(Ok(res));
273 }
274 }
275 }
276 pub fn poll_send_msg<B: AsPtr<u8>>(
278 &self,
279 state: &UdpState,
280 cx: &mut Context,
281 transmits: Transmit<B>,
282 ) -> Poll<io::Result<usize>> {
283 loop {
284 ready!(self.io.poll_send_ready(cx))?;
285 let io = &self.io;
286 if let Ok(res) = io.try_io(Interest::WRITABLE, || {
287 send_msg(state, SockRef::from(io), &transmits)
288 }) {
289 return Poll::Ready(Ok(res));
290 }
291 }
292 }
293
294 pub fn poll_recv_msg(
296 &self,
297 cx: &mut Context,
298 buf: &mut IoSliceMut<'_>,
299 ) -> Poll<io::Result<RecvMeta>> {
300 loop {
301 ready!(self.io.poll_recv_ready(cx))?;
302 let io = &self.io;
303 if let Ok(res) = io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), buf)) {
304 return Poll::Ready(Ok(res));
305 }
306 }
307 }
308
309 pub fn poll_recv_mmsg(
311 &self,
312 cx: &mut Context,
313 bufs: &mut [IoSliceMut<'_>],
314 meta: &mut [RecvMeta],
315 ) -> Poll<io::Result<usize>> {
316 debug_assert!(!bufs.is_empty());
317 loop {
318 ready!(self.io.poll_recv_ready(cx))?;
319 let io = &self.io;
320 if let Ok(res) = io.try_io(Interest::READABLE, || recv(SockRef::from(io), bufs, meta)) {
321 return Poll::Ready(Ok(res));
322 }
323 }
324 }
325
326 pub fn local_addr(&self) -> io::Result<SocketAddr> {
328 self.io.local_addr()
329 }
330}
331
332pub mod sync {
333
334 use std::os::unix::prelude::IntoRawFd;
335
336 use super::*;
337
338 #[derive(Debug)]
339 pub struct UdpSocket {
340 io: std::net::UdpSocket,
341 }
342
343 impl AsRawFd for UdpSocket {
344 fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
345 self.io.as_raw_fd()
346 }
347 }
348 impl IntoRawFd for UdpSocket {
349 fn into_raw_fd(self) -> std::os::unix::prelude::RawFd {
350 self.io.into_raw_fd()
351 }
352 }
353
354 impl UdpSocket {
355 pub fn from_std(socket: std::net::UdpSocket) -> io::Result<Self> {
357 init(SockRef::from(&socket))?;
358 Ok(Self { io: socket })
359 }
360 pub fn bind<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
362 let io = std::net::UdpSocket::bind(addr)?;
363 init(SockRef::from(&io))?;
364 Ok(Self { io })
365 }
366 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
368 self.io.set_nonblocking(nonblocking)
369 }
370 pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
372 self.io.set_broadcast(broadcast)
373 }
374 pub fn connect<A: std::net::ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
375 self.io.connect(addrs)
376 }
377 pub fn join_multicast_v4(
378 &self,
379 multiaddr: Ipv4Addr,
380 interface: Ipv4Addr,
381 ) -> io::Result<()> {
382 self.io.join_multicast_v4(&multiaddr, &interface)
383 }
384 pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
385 self.io.join_multicast_v6(multiaddr, interface)
386 }
387 pub fn leave_multicast_v4(
388 &self,
389 multiaddr: Ipv4Addr,
390 interface: Ipv4Addr,
391 ) -> io::Result<()> {
392 self.io.leave_multicast_v4(&multiaddr, &interface)
393 }
394 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
395 self.io.leave_multicast_v6(multiaddr, interface)
396 }
397 pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
398 self.io.set_multicast_loop_v4(on)
399 }
400 pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
401 self.io.set_multicast_loop_v6(on)
402 }
403 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
410 self.io.send_to(buf, target)
411 }
412 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
419 self.io.send(buf)
420 }
421 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
428 self.io.recv_from(buf)
429 }
430 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
437 self.io.recv(buf)
438 }
439 pub fn send_mmsg<B: AsPtr<u8>>(
444 &self,
445 state: &UdpState,
446 transmits: &[Transmit<B>],
447 ) -> Result<usize, io::Error> {
448 send(state, SockRef::from(&self.io), transmits)
449 }
450 pub fn send_msg<B: AsPtr<u8>>(
455 &self,
456 state: &UdpState,
457 transmits: Transmit<B>,
458 ) -> io::Result<usize> {
459 send_msg(state, SockRef::from(&self.io), &transmits)
460 }
461
462 pub fn recv_mmsg(
464 &self,
465 bufs: &mut [IoSliceMut<'_>],
466 meta: &mut [RecvMeta],
467 ) -> io::Result<usize> {
468 debug_assert!(!bufs.is_empty());
469 recv(SockRef::from(&self.io), bufs, meta)
470 }
471
472 pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
477 let mut iov = IoSliceMut::new(buf);
478 debug_assert!(!iov.is_empty());
479
480 recv_msg(SockRef::from(&self.io), &mut iov)
481 }
482 pub fn local_addr(&self) -> io::Result<SocketAddr> {
484 self.io.local_addr()
485 }
486 }
487}
488
489fn init(io: SockRef<'_>) -> io::Result<()> {
490 let mut cmsg_platform_space = 0;
491 if cfg!(target_os = "linux") {
492 cmsg_platform_space +=
493 unsafe { libc::CMSG_SPACE(mem::size_of::<libc::in6_pktinfo>() as _) as usize };
494 }
495
496 assert!(
497 CMSG_LEN
498 >= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
499 + cmsg_platform_space
500 );
501 assert!(
502 mem::align_of::<libc::cmsghdr>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
503 "control message buffers will be misaligned"
504 );
505
506 io.set_nonblocking(true)?;
507
508 let addr = io.local_addr()?;
509 let is_ipv4 = addr.family() == libc::AF_INET as libc::sa_family_t;
510
511 if is_ipv4 || ((!cfg!(any(target_os = "macos", target_os = "ios"))) && !io.only_v6()?) {
513 let on: libc::c_int = 1;
514 let rc = unsafe {
515 libc::setsockopt(
516 io.as_raw_fd(),
517 libc::IPPROTO_IP,
518 libc::IP_RECVTOS,
519 &on as *const _ as _,
520 mem::size_of_val(&on) as _,
521 )
522 };
523 if rc == -1 {
524 return Err(io::Error::last_os_error());
525 }
526 }
527 #[cfg(target_os = "linux")]
528 {
529 let on: libc::c_int = 1;
531 unsafe {
532 libc::setsockopt(
533 io.as_raw_fd(),
534 libc::SOL_UDP,
535 libc::UDP_GRO,
536 &on as *const _ as _,
537 mem::size_of_val(&on) as _,
538 )
539 };
540
541 let rc = unsafe {
543 libc::setsockopt(
544 io.as_raw_fd(),
545 libc::IPPROTO_IP,
546 libc::IP_MTU_DISCOVER,
547 &libc::IP_PMTUDISC_PROBE as *const _ as _,
548 mem::size_of_val(&libc::IP_PMTUDISC_PROBE) as _,
549 )
550 };
551 if rc == -1 {
552 return Err(io::Error::last_os_error());
553 }
554
555 if is_ipv4 {
556 let on: libc::c_int = 1;
557 let rc = unsafe {
558 libc::setsockopt(
559 io.as_raw_fd(),
560 libc::IPPROTO_IP,
561 libc::IP_PKTINFO,
562 &on as *const _ as _,
563 mem::size_of_val(&on) as _,
564 )
565 };
566 if rc == -1 {
567 return Err(io::Error::last_os_error());
568 }
569 } else {
570 let rc = unsafe {
571 libc::setsockopt(
572 io.as_raw_fd(),
573 libc::IPPROTO_IPV6,
574 libc::IPV6_MTU_DISCOVER,
575 &libc::IP_PMTUDISC_PROBE as *const _ as _,
576 mem::size_of_val(&libc::IP_PMTUDISC_PROBE) as _,
577 )
578 };
579 if rc == -1 {
580 return Err(io::Error::last_os_error());
581 }
582
583 let on: libc::c_int = 1;
584 let rc = unsafe {
585 libc::setsockopt(
586 io.as_raw_fd(),
587 libc::IPPROTO_IPV6,
588 libc::IPV6_RECVPKTINFO,
589 &on as *const _ as _,
590 mem::size_of_val(&on) as _,
591 )
592 };
593 if rc == -1 {
594 return Err(io::Error::last_os_error());
595 }
596 }
597 }
598 if !is_ipv4 {
599 let on: libc::c_int = 1;
600 let rc = unsafe {
601 libc::setsockopt(
602 io.as_raw_fd(),
603 libc::IPPROTO_IPV6,
604 libc::IPV6_RECVTCLASS,
605 &on as *const _ as _,
606 mem::size_of_val(&on) as _,
607 )
608 };
609 if rc == -1 {
610 return Err(io::Error::last_os_error());
611 }
612 }
613 Ok(())
614}
615
616#[cfg(not(any(target_os = "macos", target_os = "ios")))]
617fn send_msg<B: AsPtr<u8>>(
618 state: &UdpState,
619 io: SockRef<'_>,
620 transmit: &Transmit<B>,
621) -> io::Result<usize> {
622 let mut msg_hdr: libc::msghdr = unsafe { mem::zeroed() };
623 let mut iovec: libc::iovec = unsafe { mem::zeroed() };
624 let mut cmsg = cmsg::Aligned([0u8; CMSG_LEN]);
625
626 let addr = socket2::SockAddr::from(transmit.dst);
627 let dst_addr = &addr;
628 prepare_msg(transmit, dst_addr, &mut msg_hdr, &mut iovec, &mut cmsg);
629
630 loop {
631 let n = unsafe { libc::sendmsg(io.as_raw_fd(), &msg_hdr, 0) };
632 if n == -1 {
633 let e = io::Error::last_os_error();
634 match e.kind() {
635 io::ErrorKind::Interrupted => {
636 continue;
638 }
639 io::ErrorKind::WouldBlock => return Err(e),
640 _ => {
641 #[cfg(target_os = "linux")]
645 if e.raw_os_error() == Some(libc::EIO) {
646 if state.max_gso_segments() > 1 {
649 tracing::error!("got EIO, halting segmentation offload");
650 state
651 .max_gso_segments
652 .store(1, std::sync::atomic::Ordering::Relaxed);
653 }
654 }
655
656 return Ok(n as usize);
661 }
662 }
663 }
664 return Ok(n as usize);
665 }
666}
667
668#[cfg(not(any(target_os = "macos", target_os = "ios")))]
669fn send<B: AsPtr<u8>>(
670 state: &UdpState,
671 io: SockRef<'_>,
672 transmits: &[Transmit<B>],
673) -> io::Result<usize> {
674 let mut msgs: [libc::mmsghdr; BATCH_SIZE] = unsafe { mem::zeroed() };
675 let mut iovecs: [libc::iovec; BATCH_SIZE] = unsafe { mem::zeroed() };
676 let mut cmsgs = [cmsg::Aligned([0u8; CMSG_LEN]); BATCH_SIZE];
677 let mut addrs: [MaybeUninit<socket2::SockAddr>; BATCH_SIZE] =
684 unsafe { MaybeUninit::uninit().assume_init() };
685 for (i, transmit) in transmits.iter().enumerate().take(BATCH_SIZE) {
686 let dst_addr = unsafe {
687 std::ptr::write(addrs[i].as_mut_ptr(), socket2::SockAddr::from(transmit.dst));
688 &*addrs[i].as_ptr()
689 };
690 prepare_msg(
691 transmit,
692 dst_addr,
693 &mut msgs[i].msg_hdr,
694 &mut iovecs[i],
695 &mut cmsgs[i],
696 );
697 }
698 let num_transmits = transmits.len().min(BATCH_SIZE);
699
700 loop {
701 let n =
702 unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as u32, 0) };
703 if n == -1 {
704 let e = io::Error::last_os_error();
705 match e.kind() {
706 io::ErrorKind::Interrupted => {
707 continue;
709 }
710 io::ErrorKind::WouldBlock => return Err(e),
711 _ => {
712 #[cfg(target_os = "linux")]
716 if e.raw_os_error() == Some(libc::EIO) {
717 if state.max_gso_segments() > 1 {
720 tracing::error!("got EIO, halting segmentation offload");
721 state
722 .max_gso_segments
723 .store(1, std::sync::atomic::Ordering::Relaxed);
724 }
725 }
726
727 return Ok(num_transmits.min(1));
732 }
733 }
734 }
735 return Ok(n as usize);
736 }
737}
738
739#[cfg(any(target_os = "macos", target_os = "ios"))]
740fn send(
741 _state: &UdpState,
742 io: SockRef<'_>,
743 last_send_error: &mut Instant,
744 transmits: &[Transmit],
745) -> io::Result<usize> {
746 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
747 let mut iov: libc::iovec = unsafe { mem::zeroed() };
748 let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
749 let mut sent = 0;
750 while sent < transmits.len() {
751 let addr = socket2::SockAddr::from(transmits[sent].destination);
752 prepare_msg(&transmits[sent], &addr, &mut hdr, &mut iov, &mut ctrl);
753 let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
754 if n == -1 {
755 let e = io::Error::last_os_error();
756 match e.kind() {
757 io::ErrorKind::Interrupted => {
758 }
760 io::ErrorKind::WouldBlock if sent != 0 => return Ok(sent),
761 io::ErrorKind::WouldBlock => return Err(e),
762 _ => {
763 sent += 1;
764 }
765 }
766 } else {
767 sent += 1;
768 }
769 }
770 Ok(sent)
771}
772
773#[cfg(not(any(target_os = "macos", target_os = "ios")))]
774fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result<usize> {
775 let mut names = [MaybeUninit::<libc::sockaddr_storage>::uninit(); BATCH_SIZE];
776 let mut ctrls = [cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); BATCH_SIZE];
777 let mut hdrs = unsafe { mem::zeroed::<[libc::mmsghdr; BATCH_SIZE]>() };
778 let max_msg_count = bufs.len().min(BATCH_SIZE);
779 for i in 0..max_msg_count {
780 prepare_recv(
781 &mut bufs[i],
782 &mut names[i],
783 &mut ctrls[i],
784 &mut hdrs[i].msg_hdr,
785 );
786 }
787 let msg_count = loop {
788 let n = unsafe {
789 libc::recvmmsg(
790 io.as_raw_fd(),
791 hdrs.as_mut_ptr(),
792 bufs.len().min(BATCH_SIZE) as libc::c_uint,
793 0,
794 ptr::null_mut(),
795 )
796 };
797 if n == -1 {
798 let e = io::Error::last_os_error();
799 if e.kind() == io::ErrorKind::Interrupted {
800 continue;
801 }
802 return Err(e);
803 }
804 break n;
805 };
806 for i in 0..(msg_count as usize) {
807 meta[i] = decode_recv(&names[i], &hdrs[i].msg_hdr, hdrs[i].msg_len as usize);
808 }
809 Ok(msg_count as usize)
810}
811
812#[cfg(not(any(target_os = "macos", target_os = "ios")))]
813fn recv_msg(io: SockRef<'_>, bufs: &mut IoSliceMut<'_>) -> io::Result<RecvMeta> {
814 let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
815 let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
816 let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
817
818 prepare_recv(bufs, &mut name, &mut ctrl, &mut hdr);
819
820 let n = loop {
821 let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
822 if n == -1 {
823 let e = io::Error::last_os_error();
824 if e.kind() == io::ErrorKind::Interrupted {
825 continue;
826 }
827 return Err(e);
828 }
829 if hdr.msg_flags & libc::MSG_TRUNC != 0 {
830 continue;
831 }
832 break n;
833 };
834 Ok(decode_recv(&name, &hdr, n as usize))
835}
836
837#[cfg(any(target_os = "macos", target_os = "ios"))]
838fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result<usize> {
839 let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
840 let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
841 let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
842 prepare_recv(&mut bufs[0], &mut name, &mut ctrl, &mut hdr);
843 let n = loop {
844 let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
845 if n == -1 {
846 let e = io::Error::last_os_error();
847 if e.kind() == io::ErrorKind::Interrupted {
848 continue;
849 }
850 return Err(e);
851 }
852 if hdr.msg_flags & libc::MSG_TRUNC != 0 {
853 continue;
854 }
855 break n;
856 };
857 meta[0] = decode_recv(&name, &hdr, n as usize);
858 Ok(1)
859}
860
861pub fn udp_state() -> UdpState {
863 UdpState {
864 max_gso_segments: AtomicUsize::new(gso::max_gso_segments()),
865 gro_segments: gro::gro_segments(),
866 }
867}
868
869const CMSG_LEN: usize = 88;
870
871fn prepare_msg<B: AsPtr<u8>>(
872 transmit: &Transmit<B>,
873 dst_addr: &socket2::SockAddr,
874 hdr: &mut libc::msghdr,
875 iov: &mut libc::iovec,
876 ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>,
877) {
878 iov.iov_base = transmit.contents.as_ptr() as *const _ as *mut _;
879 iov.iov_len = transmit.contents.len();
880
881 let name = dst_addr.as_ptr() as *mut libc::c_void;
887 let namelen = dst_addr.len();
888 hdr.msg_name = name as *mut _;
889 hdr.msg_namelen = namelen;
890 hdr.msg_iov = iov;
891 hdr.msg_iovlen = 1;
892
893 hdr.msg_control = ctrl.0.as_mut_ptr() as _;
894 hdr.msg_controllen = CMSG_LEN as _;
895 let mut encoder = unsafe { cmsg::Encoder::new(hdr) };
896 let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int);
897 if transmit.dst.is_ipv4() {
898 encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy);
899 } else {
900 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn);
901 }
902
903 if let Some(segment_size) = transmit.segment_size {
904 gso::set_segment_size(&mut encoder, segment_size as u16);
905 }
906
907 if let Some(ip) = &transmit.src {
908 if cfg!(target_os = "linux") {
909 match ip {
910 Source::Ip(IpAddr::V4(v4)) => {
911 let pktinfo = libc::in_pktinfo {
912 ipi_ifindex: 0,
913 ipi_spec_dst: libc::in_addr {
914 s_addr: u32::from_ne_bytes(v4.octets()),
915 },
916 ipi_addr: libc::in_addr { s_addr: 0 },
917 };
918 encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
919 }
920 Source::Ip(IpAddr::V6(v6)) => {
921 let pktinfo = libc::in6_pktinfo {
922 ipi6_ifindex: 0,
923 ipi6_addr: libc::in6_addr {
924 s6_addr: v6.octets(),
925 },
926 };
927 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
928 }
929 Source::Interface(i) => {
930 let pktinfo = libc::in_pktinfo {
931 ipi_ifindex: *i as i32,
932 ipi_spec_dst: libc::in_addr { s_addr: 0 },
933 ipi_addr: libc::in_addr { s_addr: 0 },
934 };
935 encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
936 }
937 Source::InterfaceV6(i, ip) => {
938 let pktinfo = libc::in6_pktinfo {
939 ipi6_ifindex: *i,
940 ipi6_addr: libc::in6_addr {
941 s6_addr: ip.octets(),
942 },
943 };
944 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
945 }
946 }
947 }
948 }
949
950 encoder.finish();
951}
952
953fn prepare_recv(
954 buf: &mut IoSliceMut,
955 name: &mut MaybeUninit<libc::sockaddr_storage>,
956 ctrl: &mut cmsg::Aligned<MaybeUninit<[u8; CMSG_LEN]>>,
957 hdr: &mut libc::msghdr,
958) {
959 hdr.msg_name = name.as_mut_ptr() as _;
960 hdr.msg_namelen = mem::size_of::<libc::sockaddr_storage>() as _;
961 hdr.msg_iov = buf as *mut IoSliceMut as *mut libc::iovec;
962 hdr.msg_iovlen = 1;
963 hdr.msg_control = ctrl.0.as_mut_ptr() as _;
964 hdr.msg_controllen = CMSG_LEN as _;
965 hdr.msg_flags = 0;
966}
967
968fn decode_recv(
969 name: &MaybeUninit<libc::sockaddr_storage>,
970 hdr: &libc::msghdr,
971 len: usize,
972) -> RecvMeta {
973 let name = unsafe { name.assume_init() };
974 let mut ecn_bits = 0;
975 let mut dst_ip = None;
976 let mut dst_local_ip = None;
977 let mut ifindex = 0;
978 #[allow(unused_mut)] let mut stride = len;
980
981 let cmsg_iter = unsafe { cmsg::Iter::new(hdr) };
982 for cmsg in cmsg_iter {
983 match (cmsg.cmsg_level, cmsg.cmsg_type) {
984 (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
986 ecn_bits = cmsg::decode::<u8>(cmsg);
987 },
988 (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
989 if cfg!(target_os = "macos")
992 && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
993 {
994 ecn_bits = cmsg::decode::<u8>(cmsg);
995 } else {
996 ecn_bits = cmsg::decode::<libc::c_int>(cmsg) as u8;
997 }
998 },
999 (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
1000 let pktinfo = unsafe { cmsg::decode::<libc::in_pktinfo>(cmsg) };
1001 dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
1002 pktinfo.ipi_addr.s_addr.to_ne_bytes(),
1003 )));
1004 dst_local_ip = Some(IpAddr::V4(Ipv4Addr::from(
1005 pktinfo.ipi_spec_dst.s_addr.to_ne_bytes(),
1006 )));
1007 ifindex = pktinfo.ipi_ifindex as _;
1008 }
1009 (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
1010 let pktinfo = unsafe { cmsg::decode::<libc::in6_pktinfo>(cmsg) };
1011 dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
1012 ifindex = pktinfo.ipi6_ifindex;
1013 }
1014 #[cfg(target_os = "linux")]
1015 (libc::SOL_UDP, libc::UDP_GRO) => unsafe {
1016 stride = cmsg::decode::<libc::c_int>(cmsg) as usize;
1017 },
1018 _ => {}
1019 }
1020 }
1021
1022 let addr = match libc::c_int::from(name.ss_family) {
1023 libc::AF_INET => {
1024 let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in) };
1026 let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
1027 let port = u16::from_be(addr.sin_port);
1028 SocketAddr::V4(SocketAddrV4::new(ip, port))
1029 }
1030 libc::AF_INET6 => {
1031 let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) };
1033 let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
1034 let port = u16::from_be(addr.sin6_port);
1035 SocketAddr::V6(SocketAddrV6::new(
1036 ip,
1037 port,
1038 addr.sin6_flowinfo,
1039 addr.sin6_scope_id,
1040 ))
1041 }
1042 _ => unreachable!(),
1043 };
1044
1045 RecvMeta {
1046 len,
1047 stride,
1048 addr,
1049 ecn: EcnCodepoint::from_bits(ecn_bits),
1050 dst_ip,
1051 dst_local_ip,
1052 ifindex,
1053 }
1054}
1055
1056#[cfg(not(any(target_os = "macos", target_os = "ios")))]
1057pub const BATCH_SIZE: usize = 32;
1059
1060#[cfg(any(target_os = "macos", target_os = "ios"))]
1061pub const BATCH_SIZE: usize = 1;
1062
1063#[cfg(target_os = "linux")]
1064mod gso {
1065 use super::*;
1066
1067 pub fn max_gso_segments() -> usize {
1070 const GSO_SIZE: libc::c_int = 1500;
1071
1072 let socket = match std::net::UdpSocket::bind("[::]:0") {
1073 Ok(socket) => socket,
1074 Err(_) => return 1,
1075 };
1076
1077 let rc = unsafe {
1078 libc::setsockopt(
1079 socket.as_raw_fd(),
1080 libc::SOL_UDP,
1081 libc::UDP_SEGMENT,
1082 &GSO_SIZE as *const _ as _,
1083 mem::size_of_val(&GSO_SIZE) as _,
1084 )
1085 };
1086
1087 if rc != -1 {
1088 64
1091 } else {
1092 1
1093 }
1094 }
1095
1096 pub fn set_segment_size(encoder: &mut cmsg::Encoder, segment_size: u16) {
1097 encoder.push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size);
1098 }
1099}
1100
1101#[cfg(not(target_os = "linux"))]
1102mod gso {
1103 use super::*;
1104
1105 pub fn max_gso_segments() -> usize {
1106 1
1107 }
1108
1109 pub fn set_segment_size(_encoder: &mut cmsg::Encoder, _segment_size: u16) {
1110 panic!("Setting a segment size is not supported on current platform");
1111 }
1112}
1113
1114#[cfg(target_os = "linux")]
1115mod gro {
1116 use super::*;
1117
1118 pub fn gro_segments() -> usize {
1119 let socket = match std::net::UdpSocket::bind("[::]:0") {
1120 Ok(socket) => socket,
1121 Err(_) => return 1,
1122 };
1123
1124 let on: libc::c_int = 1;
1125 let rc = unsafe {
1126 libc::setsockopt(
1127 socket.as_raw_fd(),
1128 libc::SOL_UDP,
1129 libc::UDP_GRO,
1130 &on as *const _ as _,
1131 mem::size_of_val(&on) as _,
1132 )
1133 };
1134
1135 if rc != -1 {
1136 64
1144 } else {
1145 1
1146 }
1147 }
1148}
1149
1150#[cfg(not(target_os = "linux"))]
1151mod gro {
1152 pub fn gro_segments() -> usize {
1153 1
1154 }
1155}