dndx_forked_unix_udp_sock/
unix.rs

1use std::{
2    io,
3    io::IoSliceMut,
4    mem::{self, MaybeUninit},
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6    os::unix::io::AsRawFd,
7    ptr,
8    sync::atomic::AtomicUsize,
9    task::{Context, Poll},
10};
11
12use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
13use futures_core::ready;
14use socket2::SockRef;
15use tokio::{
16    io::{Interest, ReadBuf},
17    net::ToSocketAddrs,
18};
19
20use super::{cmsg, RecvMeta, UdpState};
21
22#[cfg(target_os = "freebsd")]
23type IpTosTy = libc::c_uchar;
24#[cfg(not(target_os = "freebsd"))]
25type IpTosTy = libc::c_int;
26
27/// Tokio-compatible UDP socket with some useful specializations.
28///
29/// Unlike a standard tokio UDP socket, this allows ECN bits to be read and written on some
30/// platforms.
31#[derive(Debug)]
32pub struct UdpSocket {
33    io: tokio::net::UdpSocket,
34}
35
36impl AsRawFd for UdpSocket {
37    fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
38        self.io.as_raw_fd()
39    }
40}
41
42impl UdpSocket {
43    /// Creates a new UDP socket from a previously created `std::net::UdpSocket`
44    pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpSocket> {
45        socket.set_nonblocking(true)?;
46
47        init(SockRef::from(&socket))?;
48        Ok(UdpSocket {
49            io: tokio::net::UdpSocket::from_std(socket)?,
50        })
51    }
52
53    pub fn into_std(self) -> io::Result<std::net::UdpSocket> {
54        self.io.into_std()
55    }
56
57    /// create a new UDP socket and attempt to bind to `addr`
58    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
59        let io = tokio::net::UdpSocket::bind(addr).await?;
60        init(SockRef::from(&io))?;
61        Ok(UdpSocket { io })
62    }
63
64    /// sets the value of SO_BROADCAST for this socket
65    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
66        self.io.set_broadcast(broadcast)
67    }
68
69    pub async fn connect<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
70        self.io.connect(addrs).await
71    }
72    pub async fn join_multicast_v4(
73        &self,
74        multiaddr: Ipv4Addr,
75        interface: Ipv4Addr,
76    ) -> io::Result<()> {
77        self.io.join_multicast_v4(multiaddr, interface)
78    }
79    pub async fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
80        self.io.join_multicast_v6(multiaddr, interface)
81    }
82    pub async fn leave_multicast_v4(
83        &self,
84        multiaddr: Ipv4Addr,
85        interface: Ipv4Addr,
86    ) -> io::Result<()> {
87        self.io.leave_multicast_v4(multiaddr, interface)
88    }
89    pub async fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
90        self.io.leave_multicast_v6(multiaddr, interface)
91    }
92    pub async fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
93        self.io.set_multicast_loop_v4(on)
94    }
95    pub async fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
96        self.io.set_multicast_loop_v6(on)
97    }
98    /// Sends data on the socket to the given address. On success, returns the
99    /// number of bytes written.
100    ///
101    /// calls underlying tokio [`send_to`]
102    ///
103    /// [`send_to`]: method@tokio::net::UdpSocket::send_to
104    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
105        self.io.send_to(buf, target).await
106    }
107    /// Sends data on the socket to the given address. On success, returns the
108    /// number of bytes written.
109    ///
110    /// calls underlying tokio [`poll_send_to`]
111    ///
112    /// [`poll_send_to`]: method@tokio::net::UdpSocket::poll_send_to
113    pub fn poll_send_to(
114        &self,
115        cx: &mut Context<'_>,
116        buf: &[u8],
117        target: SocketAddr,
118    ) -> Poll<io::Result<usize>> {
119        self.io.poll_send_to(cx, buf, target)
120    }
121    /// Sends data on the socket to the remote address that the socket is
122    /// connected to.
123    ///
124    /// See tokio [`send`]
125    ///
126    /// [`send`]: method@tokio::net::UdpSocket::send
127    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
128        self.io.send(buf).await
129    }
130    /// Sends data on the socket to the remote address that the socket is
131    /// connected to.
132    ///
133    /// See tokio [`poll_send`]
134    ///
135    /// [`poll_send`]: method@tokio::net::UdpSocket::poll_send
136    pub async fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
137        self.io.poll_send(cx, buf)
138    }
139    /// Receives a single datagram message on the socket. On success, returns
140    /// the number of bytes read and the origin.
141    ///
142    /// See tokio [`recv_from`]
143    ///
144    /// [`recv_from`]: method@tokio::net::UdpSocket::recv_from
145    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
146        self.io.recv_from(buf).await
147    }
148    /// Receives a single datagram message on the socket. On success, returns
149    /// the number of bytes read and the origin.
150    ///
151    /// See tokio [`poll_recv_from`]
152    ///
153    /// [`poll_recv_from`]: method@tokio::net::UdpSocket::poll_recv_from
154    pub fn poll_recv_from(
155        &self,
156        cx: &mut Context<'_>,
157        buf: &mut ReadBuf<'_>,
158    ) -> Poll<io::Result<SocketAddr>> {
159        self.io.poll_recv_from(cx, buf)
160    }
161    /// Receives a single datagram message on the socket from the remote address
162    /// to which it is connected. On success, returns the number of bytes read.
163    ///
164    /// See tokio [`recv`]
165    ///
166    /// [`recv`]: method@tokio::net::UdpSocket::recv
167    pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
168        self.io.recv(buf).await
169    }
170    /// Receives a single datagram message on the socket from the remote address
171    /// to which it is connected. On success, returns the number of bytes read.
172    ///
173    /// See tokio [`poll_recv`]
174    ///
175    /// [`poll_recv`]: method@tokio::net::UdpSocket::poll_recv
176    pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
177        self.io.poll_recv(cx, buf)
178    }
179
180    /// Calls syscall [`sendmmsg`]. With a given `state` configured GSO and
181    /// `transmits` with information on the data and metadata about outgoing packets.
182    ///
183    /// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
184    pub async fn send_mmsg<B: AsPtr<u8>>(
185        &self,
186        state: &UdpState,
187        transmits: &[Transmit<B>],
188    ) -> Result<usize, io::Error> {
189        let n = loop {
190            self.io.writable().await?;
191            let io = &self.io;
192            match io.try_io(Interest::WRITABLE, || {
193                send(state, SockRef::from(io), transmits)
194            }) {
195                Ok(res) => break res,
196                Err(_would_block) => continue,
197            }
198        };
199        // if n == transmits.len() {}
200        Ok(n)
201    }
202
203    /// Calls syscall [`sendmsg`]. With a given `state` configured GSO and
204    /// `transmit` with information on the data and metadata about outgoing packet.
205    ///
206    /// [`sendmsg`]: https://linux.die.net/man/2/sendmsg
207    pub async fn send_msg<B: AsPtr<u8>>(
208        &self,
209        state: &UdpState,
210        transmits: Transmit<B>,
211    ) -> io::Result<usize> {
212        let n = loop {
213            self.io.writable().await?;
214            let io = &self.io;
215            match io.try_io(Interest::WRITABLE, || {
216                send_msg(state, SockRef::from(io), &transmits)
217            }) {
218                Ok(res) => break res,
219                Err(_would_block) => continue,
220            }
221        };
222        Ok(n)
223    }
224
225    /// async version of `recvmmsg`
226    pub async fn recv_mmsg(
227        &self,
228        bufs: &mut [IoSliceMut<'_>],
229        meta: &mut [RecvMeta],
230    ) -> io::Result<usize> {
231        debug_assert!(!bufs.is_empty());
232        loop {
233            self.io.readable().await?;
234            let io = &self.io;
235            match io.try_io(Interest::READABLE, || recv(SockRef::from(io), bufs, meta)) {
236                Ok(res) => return Ok(res),
237                Err(_would_block) => continue,
238            }
239        }
240    }
241
242    /// `recv_msg` is similar to `recv_from` but returns extra information
243    /// about the packet in [`RecvMeta`].
244    ///
245    /// [`RecvMeta`]: crate::RecvMeta
246    pub async fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
247        let mut iov = IoSliceMut::new(buf);
248        debug_assert!(!iov.is_empty());
249        loop {
250            self.io.readable().await?;
251            let io = &self.io;
252            match io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), &mut iov)) {
253                Ok(res) => return Ok(res),
254                Err(_would_block) => continue,
255            }
256        }
257    }
258
259    /// calls `sendmmsg`
260    pub fn poll_send_mmsg<B: AsPtr<u8>>(
261        &self,
262        state: &UdpState,
263        cx: &mut Context,
264        transmits: &[Transmit<B>],
265    ) -> Poll<io::Result<usize>> {
266        loop {
267            ready!(self.io.poll_send_ready(cx))?;
268            let io = &self.io;
269            if let Ok(res) = io.try_io(Interest::WRITABLE, || {
270                send(state, SockRef::from(io), transmits)
271            }) {
272                return Poll::Ready(Ok(res));
273            }
274        }
275    }
276    /// calls `sendmsg`
277    pub fn poll_send_msg<B: AsPtr<u8>>(
278        &self,
279        state: &UdpState,
280        cx: &mut Context,
281        transmits: Transmit<B>,
282    ) -> Poll<io::Result<usize>> {
283        loop {
284            ready!(self.io.poll_send_ready(cx))?;
285            let io = &self.io;
286            if let Ok(res) = io.try_io(Interest::WRITABLE, || {
287                send_msg(state, SockRef::from(io), &transmits)
288            }) {
289                return Poll::Ready(Ok(res));
290            }
291        }
292    }
293
294    /// calls `recvmsg`
295    pub fn poll_recv_msg(
296        &self,
297        cx: &mut Context,
298        buf: &mut IoSliceMut<'_>,
299    ) -> Poll<io::Result<RecvMeta>> {
300        loop {
301            ready!(self.io.poll_recv_ready(cx))?;
302            let io = &self.io;
303            if let Ok(res) = io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), buf)) {
304                return Poll::Ready(Ok(res));
305            }
306        }
307    }
308
309    /// calls `recvmmsg`
310    pub fn poll_recv_mmsg(
311        &self,
312        cx: &mut Context,
313        bufs: &mut [IoSliceMut<'_>],
314        meta: &mut [RecvMeta],
315    ) -> Poll<io::Result<usize>> {
316        debug_assert!(!bufs.is_empty());
317        loop {
318            ready!(self.io.poll_recv_ready(cx))?;
319            let io = &self.io;
320            if let Ok(res) = io.try_io(Interest::READABLE, || recv(SockRef::from(io), bufs, meta)) {
321                return Poll::Ready(Ok(res));
322            }
323        }
324    }
325
326    /// Returns local address this socket is bound to.
327    pub fn local_addr(&self) -> io::Result<SocketAddr> {
328        self.io.local_addr()
329    }
330}
331
332pub mod sync {
333
334    use std::os::unix::prelude::IntoRawFd;
335
336    use super::*;
337
338    #[derive(Debug)]
339    pub struct UdpSocket {
340        io: std::net::UdpSocket,
341    }
342
343    impl AsRawFd for UdpSocket {
344        fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
345            self.io.as_raw_fd()
346        }
347    }
348    impl IntoRawFd for UdpSocket {
349        fn into_raw_fd(self) -> std::os::unix::prelude::RawFd {
350            self.io.into_raw_fd()
351        }
352    }
353
354    impl UdpSocket {
355        /// Creates a new UDP socket from a previously created `std::net::UdpSocket`
356        pub fn from_std(socket: std::net::UdpSocket) -> io::Result<Self> {
357            init(SockRef::from(&socket))?;
358            Ok(Self { io: socket })
359        }
360        /// create a new UDP socket and attempt to bind to `addr`
361        pub fn bind<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
362            let io = std::net::UdpSocket::bind(addr)?;
363            init(SockRef::from(&io))?;
364            Ok(Self { io })
365        }
366        /// sets nonblocking mode
367        pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
368            self.io.set_nonblocking(nonblocking)
369        }
370        /// sets the value of SO_BROADCAST for this socket
371        pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
372            self.io.set_broadcast(broadcast)
373        }
374        pub fn connect<A: std::net::ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
375            self.io.connect(addrs)
376        }
377        pub fn join_multicast_v4(
378            &self,
379            multiaddr: Ipv4Addr,
380            interface: Ipv4Addr,
381        ) -> io::Result<()> {
382            self.io.join_multicast_v4(&multiaddr, &interface)
383        }
384        pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
385            self.io.join_multicast_v6(multiaddr, interface)
386        }
387        pub fn leave_multicast_v4(
388            &self,
389            multiaddr: Ipv4Addr,
390            interface: Ipv4Addr,
391        ) -> io::Result<()> {
392            self.io.leave_multicast_v4(&multiaddr, &interface)
393        }
394        pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
395            self.io.leave_multicast_v6(multiaddr, interface)
396        }
397        pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
398            self.io.set_multicast_loop_v4(on)
399        }
400        pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
401            self.io.set_multicast_loop_v6(on)
402        }
403        /// Sends data on the socket to the given address. On success, returns the
404        /// number of bytes written.
405        ///
406        /// calls underlying tokio [`send_to`]
407        ///
408        /// [`send_to`]: method@tokio::net::UdpSocket::send_to
409        pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
410            self.io.send_to(buf, target)
411        }
412        /// Sends data on the socket to the remote address that the socket is
413        /// connected to.
414        ///
415        /// See tokio [`send`]
416        ///
417        /// [`send`]: method@tokio::net::UdpSocket::send
418        pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
419            self.io.send(buf)
420        }
421        /// Receives a single datagram message on the socket. On success, returns
422        /// the number of bytes read and the origin.
423        ///
424        /// See tokio [`recv_from`]
425        ///
426        /// [`recv_from`]: method@tokio::net::UdpSocket::recv_from
427        pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
428            self.io.recv_from(buf)
429        }
430        /// Receives a single datagram message on the socket from the remote address
431        /// to which it is connected. On success, returns the number of bytes read.
432        ///
433        /// See tokio [`recv`]
434        ///
435        /// [`recv`]: method@tokio::net::UdpSocket::recv
436        pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
437            self.io.recv(buf)
438        }
439        /// Calls syscall [`sendmmsg`]. With a given `state` configured GSO and
440        /// `transmits` with information on the data and metadata about outgoing packets.
441        ///
442        /// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
443        pub fn send_mmsg<B: AsPtr<u8>>(
444            &self,
445            state: &UdpState,
446            transmits: &[Transmit<B>],
447        ) -> Result<usize, io::Error> {
448            send(state, SockRef::from(&self.io), transmits)
449        }
450        /// Calls syscall [`sendmsg`]. With a given `state` configured GSO and
451        /// `transmit` with information on the data and metadata about outgoing packet.
452        ///
453        /// [`sendmsg`]: https://linux.die.net/man/2/sendmsg
454        pub fn send_msg<B: AsPtr<u8>>(
455            &self,
456            state: &UdpState,
457            transmits: Transmit<B>,
458        ) -> io::Result<usize> {
459            send_msg(state, SockRef::from(&self.io), &transmits)
460        }
461
462        /// async version of `recvmmsg`
463        pub fn recv_mmsg(
464            &self,
465            bufs: &mut [IoSliceMut<'_>],
466            meta: &mut [RecvMeta],
467        ) -> io::Result<usize> {
468            debug_assert!(!bufs.is_empty());
469            recv(SockRef::from(&self.io), bufs, meta)
470        }
471
472        /// `recv_msg` is similar to `recv_from` but returns extra information
473        /// about the packet in [`RecvMeta`].
474        ///
475        /// [`RecvMeta`]: crate::RecvMeta
476        pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
477            let mut iov = IoSliceMut::new(buf);
478            debug_assert!(!iov.is_empty());
479
480            recv_msg(SockRef::from(&self.io), &mut iov)
481        }
482        /// Returns local address this socket is bound to.
483        pub fn local_addr(&self) -> io::Result<SocketAddr> {
484            self.io.local_addr()
485        }
486    }
487}
488
489fn init(io: SockRef<'_>) -> io::Result<()> {
490    let mut cmsg_platform_space = 0;
491    if cfg!(target_os = "linux") {
492        cmsg_platform_space +=
493            unsafe { libc::CMSG_SPACE(mem::size_of::<libc::in6_pktinfo>() as _) as usize };
494    }
495
496    assert!(
497        CMSG_LEN
498            >= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
499                + cmsg_platform_space
500    );
501    assert!(
502        mem::align_of::<libc::cmsghdr>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
503        "control message buffers will be misaligned"
504    );
505
506    io.set_nonblocking(true)?;
507
508    let addr = io.local_addr()?;
509    let is_ipv4 = addr.family() == libc::AF_INET as libc::sa_family_t;
510
511    // macos and ios do not support IP_RECVTOS on dual-stack sockets :(
512    if is_ipv4 || ((!cfg!(any(target_os = "macos", target_os = "ios"))) && !io.only_v6()?) {
513        let on: libc::c_int = 1;
514        let rc = unsafe {
515            libc::setsockopt(
516                io.as_raw_fd(),
517                libc::IPPROTO_IP,
518                libc::IP_RECVTOS,
519                &on as *const _ as _,
520                mem::size_of_val(&on) as _,
521            )
522        };
523        if rc == -1 {
524            return Err(io::Error::last_os_error());
525        }
526    }
527    #[cfg(target_os = "linux")]
528    {
529        // opportunistically try to enable GRO. See gro::gro_segments().
530        let on: libc::c_int = 1;
531        unsafe {
532            libc::setsockopt(
533                io.as_raw_fd(),
534                libc::SOL_UDP,
535                libc::UDP_GRO,
536                &on as *const _ as _,
537                mem::size_of_val(&on) as _,
538            )
539        };
540
541        // Forbid IPv4 fragmentation. Set even for IPv6 to account for IPv6 mapped IPv4 addresses.
542        let rc = unsafe {
543            libc::setsockopt(
544                io.as_raw_fd(),
545                libc::IPPROTO_IP,
546                libc::IP_MTU_DISCOVER,
547                &libc::IP_PMTUDISC_PROBE as *const _ as _,
548                mem::size_of_val(&libc::IP_PMTUDISC_PROBE) as _,
549            )
550        };
551        if rc == -1 {
552            return Err(io::Error::last_os_error());
553        }
554
555        if is_ipv4 {
556            let on: libc::c_int = 1;
557            let rc = unsafe {
558                libc::setsockopt(
559                    io.as_raw_fd(),
560                    libc::IPPROTO_IP,
561                    libc::IP_PKTINFO,
562                    &on as *const _ as _,
563                    mem::size_of_val(&on) as _,
564                )
565            };
566            if rc == -1 {
567                return Err(io::Error::last_os_error());
568            }
569        } else {
570            let rc = unsafe {
571                libc::setsockopt(
572                    io.as_raw_fd(),
573                    libc::IPPROTO_IPV6,
574                    libc::IPV6_MTU_DISCOVER,
575                    &libc::IP_PMTUDISC_PROBE as *const _ as _,
576                    mem::size_of_val(&libc::IP_PMTUDISC_PROBE) as _,
577                )
578            };
579            if rc == -1 {
580                return Err(io::Error::last_os_error());
581            }
582
583            let on: libc::c_int = 1;
584            let rc = unsafe {
585                libc::setsockopt(
586                    io.as_raw_fd(),
587                    libc::IPPROTO_IPV6,
588                    libc::IPV6_RECVPKTINFO,
589                    &on as *const _ as _,
590                    mem::size_of_val(&on) as _,
591                )
592            };
593            if rc == -1 {
594                return Err(io::Error::last_os_error());
595            }
596        }
597    }
598    if !is_ipv4 {
599        let on: libc::c_int = 1;
600        let rc = unsafe {
601            libc::setsockopt(
602                io.as_raw_fd(),
603                libc::IPPROTO_IPV6,
604                libc::IPV6_RECVTCLASS,
605                &on as *const _ as _,
606                mem::size_of_val(&on) as _,
607            )
608        };
609        if rc == -1 {
610            return Err(io::Error::last_os_error());
611        }
612    }
613    Ok(())
614}
615
616#[cfg(not(any(target_os = "macos", target_os = "ios")))]
617fn send_msg<B: AsPtr<u8>>(
618    state: &UdpState,
619    io: SockRef<'_>,
620    transmit: &Transmit<B>,
621) -> io::Result<usize> {
622    let mut msg_hdr: libc::msghdr = unsafe { mem::zeroed() };
623    let mut iovec: libc::iovec = unsafe { mem::zeroed() };
624    let mut cmsg = cmsg::Aligned([0u8; CMSG_LEN]);
625
626    let addr = socket2::SockAddr::from(transmit.dst);
627    let dst_addr = &addr;
628    prepare_msg(transmit, dst_addr, &mut msg_hdr, &mut iovec, &mut cmsg);
629
630    loop {
631        let n = unsafe { libc::sendmsg(io.as_raw_fd(), &msg_hdr, 0) };
632        if n == -1 {
633            let e = io::Error::last_os_error();
634            match e.kind() {
635                io::ErrorKind::Interrupted => {
636                    // Retry the transmission
637                    continue;
638                }
639                io::ErrorKind::WouldBlock => return Err(e),
640                _ => {
641                    // Some network adapters do not support GSO. Unfortunately, Linux offers no easy way
642                    // for us to detect this short of an I/O error when we try to actually send
643                    // datagrams using it.
644                    #[cfg(target_os = "linux")]
645                    if e.raw_os_error() == Some(libc::EIO) {
646                        // Prevent new transmits from being scheduled using GSO. Existing GSO transmits
647                        // may already be in the pipeline, so we need to tolerate additional failures.
648                        if state.max_gso_segments() > 1 {
649                            tracing::error!("got EIO, halting segmentation offload");
650                            state
651                                .max_gso_segments
652                                .store(1, std::sync::atomic::Ordering::Relaxed);
653                        }
654                    }
655
656                    // The ERRORS section in https://man7.org/linux/man-pages/man2/sendmmsg.2.html
657                    // describes that errors will only be returned if no message could be transmitted
658                    // at all. Therefore drop the first (problematic) message,
659                    // and retry the remaining ones.
660                    return Ok(n as usize);
661                }
662            }
663        }
664        return Ok(n as usize);
665    }
666}
667
668#[cfg(not(any(target_os = "macos", target_os = "ios")))]
669fn send<B: AsPtr<u8>>(
670    state: &UdpState,
671    io: SockRef<'_>,
672    transmits: &[Transmit<B>],
673) -> io::Result<usize> {
674    let mut msgs: [libc::mmsghdr; BATCH_SIZE] = unsafe { mem::zeroed() };
675    let mut iovecs: [libc::iovec; BATCH_SIZE] = unsafe { mem::zeroed() };
676    let mut cmsgs = [cmsg::Aligned([0u8; CMSG_LEN]); BATCH_SIZE];
677    // This assume_init looks a bit weird because one might think it
678    // assumes the SockAddr data to be initialized, but that call
679    // refers to the whole array, which itself is made up of MaybeUninit
680    // containers. Their presence protects the SockAddr inside from
681    // being assumed as initialized by the assume_init call.
682    // TODO: Replace this with uninit_array once it becomes MSRV-stable
683    let mut addrs: [MaybeUninit<socket2::SockAddr>; BATCH_SIZE] =
684        unsafe { MaybeUninit::uninit().assume_init() };
685    for (i, transmit) in transmits.iter().enumerate().take(BATCH_SIZE) {
686        let dst_addr = unsafe {
687            std::ptr::write(addrs[i].as_mut_ptr(), socket2::SockAddr::from(transmit.dst));
688            &*addrs[i].as_ptr()
689        };
690        prepare_msg(
691            transmit,
692            dst_addr,
693            &mut msgs[i].msg_hdr,
694            &mut iovecs[i],
695            &mut cmsgs[i],
696        );
697    }
698    let num_transmits = transmits.len().min(BATCH_SIZE);
699
700    loop {
701        let n =
702            unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as u32, 0) };
703        if n == -1 {
704            let e = io::Error::last_os_error();
705            match e.kind() {
706                io::ErrorKind::Interrupted => {
707                    // Retry the transmission
708                    continue;
709                }
710                io::ErrorKind::WouldBlock => return Err(e),
711                _ => {
712                    // Some network adapters do not support GSO. Unfortunately, Linux offers no easy way
713                    // for us to detect this short of an I/O error when we try to actually send
714                    // datagrams using it.
715                    #[cfg(target_os = "linux")]
716                    if e.raw_os_error() == Some(libc::EIO) {
717                        // Prevent new transmits from being scheduled using GSO. Existing GSO transmits
718                        // may already be in the pipeline, so we need to tolerate additional failures.
719                        if state.max_gso_segments() > 1 {
720                            tracing::error!("got EIO, halting segmentation offload");
721                            state
722                                .max_gso_segments
723                                .store(1, std::sync::atomic::Ordering::Relaxed);
724                        }
725                    }
726
727                    // The ERRORS section in https://man7.org/linux/man-pages/man2/sendmmsg.2.html
728                    // describes that errors will only be returned if no message could be transmitted
729                    // at all. Therefore drop the first (problematic) message,
730                    // and retry the remaining ones.
731                    return Ok(num_transmits.min(1));
732                }
733            }
734        }
735        return Ok(n as usize);
736    }
737}
738
739#[cfg(any(target_os = "macos", target_os = "ios"))]
740fn send(
741    _state: &UdpState,
742    io: SockRef<'_>,
743    last_send_error: &mut Instant,
744    transmits: &[Transmit],
745) -> io::Result<usize> {
746    let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
747    let mut iov: libc::iovec = unsafe { mem::zeroed() };
748    let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
749    let mut sent = 0;
750    while sent < transmits.len() {
751        let addr = socket2::SockAddr::from(transmits[sent].destination);
752        prepare_msg(&transmits[sent], &addr, &mut hdr, &mut iov, &mut ctrl);
753        let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
754        if n == -1 {
755            let e = io::Error::last_os_error();
756            match e.kind() {
757                io::ErrorKind::Interrupted => {
758                    // Retry the transmission
759                }
760                io::ErrorKind::WouldBlock if sent != 0 => return Ok(sent),
761                io::ErrorKind::WouldBlock => return Err(e),
762                _ => {
763                    sent += 1;
764                }
765            }
766        } else {
767            sent += 1;
768        }
769    }
770    Ok(sent)
771}
772
773#[cfg(not(any(target_os = "macos", target_os = "ios")))]
774fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result<usize> {
775    let mut names = [MaybeUninit::<libc::sockaddr_storage>::uninit(); BATCH_SIZE];
776    let mut ctrls = [cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); BATCH_SIZE];
777    let mut hdrs = unsafe { mem::zeroed::<[libc::mmsghdr; BATCH_SIZE]>() };
778    let max_msg_count = bufs.len().min(BATCH_SIZE);
779    for i in 0..max_msg_count {
780        prepare_recv(
781            &mut bufs[i],
782            &mut names[i],
783            &mut ctrls[i],
784            &mut hdrs[i].msg_hdr,
785        );
786    }
787    let msg_count = loop {
788        let n = unsafe {
789            libc::recvmmsg(
790                io.as_raw_fd(),
791                hdrs.as_mut_ptr(),
792                bufs.len().min(BATCH_SIZE) as libc::c_uint,
793                0,
794                ptr::null_mut(),
795            )
796        };
797        if n == -1 {
798            let e = io::Error::last_os_error();
799            if e.kind() == io::ErrorKind::Interrupted {
800                continue;
801            }
802            return Err(e);
803        }
804        break n;
805    };
806    for i in 0..(msg_count as usize) {
807        meta[i] = decode_recv(&names[i], &hdrs[i].msg_hdr, hdrs[i].msg_len as usize);
808    }
809    Ok(msg_count as usize)
810}
811
812#[cfg(not(any(target_os = "macos", target_os = "ios")))]
813fn recv_msg(io: SockRef<'_>, bufs: &mut IoSliceMut<'_>) -> io::Result<RecvMeta> {
814    let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
815    let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
816    let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
817
818    prepare_recv(bufs, &mut name, &mut ctrl, &mut hdr);
819
820    let n = loop {
821        let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
822        if n == -1 {
823            let e = io::Error::last_os_error();
824            if e.kind() == io::ErrorKind::Interrupted {
825                continue;
826            }
827            return Err(e);
828        }
829        if hdr.msg_flags & libc::MSG_TRUNC != 0 {
830            continue;
831        }
832        break n;
833    };
834    Ok(decode_recv(&name, &hdr, n as usize))
835}
836
837#[cfg(any(target_os = "macos", target_os = "ios"))]
838fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result<usize> {
839    let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
840    let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
841    let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
842    prepare_recv(&mut bufs[0], &mut name, &mut ctrl, &mut hdr);
843    let n = loop {
844        let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
845        if n == -1 {
846            let e = io::Error::last_os_error();
847            if e.kind() == io::ErrorKind::Interrupted {
848                continue;
849            }
850            return Err(e);
851        }
852        if hdr.msg_flags & libc::MSG_TRUNC != 0 {
853            continue;
854        }
855        break n;
856    };
857    meta[0] = decode_recv(&name, &hdr, n as usize);
858    Ok(1)
859}
860
861/// Returns the platforms UDP socket capabilities
862pub fn udp_state() -> UdpState {
863    UdpState {
864        max_gso_segments: AtomicUsize::new(gso::max_gso_segments()),
865        gro_segments: gro::gro_segments(),
866    }
867}
868
869const CMSG_LEN: usize = 88;
870
871fn prepare_msg<B: AsPtr<u8>>(
872    transmit: &Transmit<B>,
873    dst_addr: &socket2::SockAddr,
874    hdr: &mut libc::msghdr,
875    iov: &mut libc::iovec,
876    ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>,
877) {
878    iov.iov_base = transmit.contents.as_ptr() as *const _ as *mut _;
879    iov.iov_len = transmit.contents.len();
880
881    // SAFETY: Casting the pointer to a mutable one is legal,
882    // as sendmsg is guaranteed to not alter the mutable pointer
883    // as per the POSIX spec. See the section on the sys/socket.h
884    // header for details. The type is only mutable in the first
885    // place because it is reused by recvmsg as well.
886    let name = dst_addr.as_ptr() as *mut libc::c_void;
887    let namelen = dst_addr.len();
888    hdr.msg_name = name as *mut _;
889    hdr.msg_namelen = namelen;
890    hdr.msg_iov = iov;
891    hdr.msg_iovlen = 1;
892
893    hdr.msg_control = ctrl.0.as_mut_ptr() as _;
894    hdr.msg_controllen = CMSG_LEN as _;
895    let mut encoder = unsafe { cmsg::Encoder::new(hdr) };
896    let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int);
897    if transmit.dst.is_ipv4() {
898        encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy);
899    } else {
900        encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn);
901    }
902
903    if let Some(segment_size) = transmit.segment_size {
904        gso::set_segment_size(&mut encoder, segment_size as u16);
905    }
906
907    if let Some(ip) = &transmit.src {
908        if cfg!(target_os = "linux") {
909            match ip {
910                Source::Ip(IpAddr::V4(v4)) => {
911                    let pktinfo = libc::in_pktinfo {
912                        ipi_ifindex: 0,
913                        ipi_spec_dst: libc::in_addr {
914                            s_addr: u32::from_ne_bytes(v4.octets()),
915                        },
916                        ipi_addr: libc::in_addr { s_addr: 0 },
917                    };
918                    encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
919                }
920                Source::Ip(IpAddr::V6(v6)) => {
921                    let pktinfo = libc::in6_pktinfo {
922                        ipi6_ifindex: 0,
923                        ipi6_addr: libc::in6_addr {
924                            s6_addr: v6.octets(),
925                        },
926                    };
927                    encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
928                }
929                Source::Interface(i) => {
930                    let pktinfo = libc::in_pktinfo {
931                        ipi_ifindex: *i as i32,
932                        ipi_spec_dst: libc::in_addr { s_addr: 0 },
933                        ipi_addr: libc::in_addr { s_addr: 0 },
934                    };
935                    encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
936                }
937                Source::InterfaceV6(i, ip) => {
938                    let pktinfo = libc::in6_pktinfo {
939                        ipi6_ifindex: *i,
940                        ipi6_addr: libc::in6_addr {
941                            s6_addr: ip.octets(),
942                        },
943                    };
944                    encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
945                }
946            }
947        }
948    }
949
950    encoder.finish();
951}
952
953fn prepare_recv(
954    buf: &mut IoSliceMut,
955    name: &mut MaybeUninit<libc::sockaddr_storage>,
956    ctrl: &mut cmsg::Aligned<MaybeUninit<[u8; CMSG_LEN]>>,
957    hdr: &mut libc::msghdr,
958) {
959    hdr.msg_name = name.as_mut_ptr() as _;
960    hdr.msg_namelen = mem::size_of::<libc::sockaddr_storage>() as _;
961    hdr.msg_iov = buf as *mut IoSliceMut as *mut libc::iovec;
962    hdr.msg_iovlen = 1;
963    hdr.msg_control = ctrl.0.as_mut_ptr() as _;
964    hdr.msg_controllen = CMSG_LEN as _;
965    hdr.msg_flags = 0;
966}
967
968fn decode_recv(
969    name: &MaybeUninit<libc::sockaddr_storage>,
970    hdr: &libc::msghdr,
971    len: usize,
972) -> RecvMeta {
973    let name = unsafe { name.assume_init() };
974    let mut ecn_bits = 0;
975    let mut dst_ip = None;
976    let mut dst_local_ip = None;
977    let mut ifindex = 0;
978    #[allow(unused_mut)] // only mutable on Linux
979    let mut stride = len;
980
981    let cmsg_iter = unsafe { cmsg::Iter::new(hdr) };
982    for cmsg in cmsg_iter {
983        match (cmsg.cmsg_level, cmsg.cmsg_type) {
984            // FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in.
985            (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
986                ecn_bits = cmsg::decode::<u8>(cmsg);
987            },
988            (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
989                // Temporary hack around broken macos ABI. Remove once upstream fixes it.
990                // https://bugreport.apple.com/web/?problemID=48761855
991                if cfg!(target_os = "macos")
992                    && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
993                {
994                    ecn_bits = cmsg::decode::<u8>(cmsg);
995                } else {
996                    ecn_bits = cmsg::decode::<libc::c_int>(cmsg) as u8;
997                }
998            },
999            (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
1000                let pktinfo = unsafe { cmsg::decode::<libc::in_pktinfo>(cmsg) };
1001                dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
1002                    pktinfo.ipi_addr.s_addr.to_ne_bytes(),
1003                )));
1004                dst_local_ip = Some(IpAddr::V4(Ipv4Addr::from(
1005                    pktinfo.ipi_spec_dst.s_addr.to_ne_bytes(),
1006                )));
1007                ifindex = pktinfo.ipi_ifindex as _;
1008            }
1009            (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
1010                let pktinfo = unsafe { cmsg::decode::<libc::in6_pktinfo>(cmsg) };
1011                dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
1012                ifindex = pktinfo.ipi6_ifindex;
1013            }
1014            #[cfg(target_os = "linux")]
1015            (libc::SOL_UDP, libc::UDP_GRO) => unsafe {
1016                stride = cmsg::decode::<libc::c_int>(cmsg) as usize;
1017            },
1018            _ => {}
1019        }
1020    }
1021
1022    let addr = match libc::c_int::from(name.ss_family) {
1023        libc::AF_INET => {
1024            // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
1025            let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in) };
1026            let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
1027            let port = u16::from_be(addr.sin_port);
1028            SocketAddr::V4(SocketAddrV4::new(ip, port))
1029        }
1030        libc::AF_INET6 => {
1031            // Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6.
1032            let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) };
1033            let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
1034            let port = u16::from_be(addr.sin6_port);
1035            SocketAddr::V6(SocketAddrV6::new(
1036                ip,
1037                port,
1038                addr.sin6_flowinfo,
1039                addr.sin6_scope_id,
1040            ))
1041        }
1042        _ => unreachable!(),
1043    };
1044
1045    RecvMeta {
1046        len,
1047        stride,
1048        addr,
1049        ecn: EcnCodepoint::from_bits(ecn_bits),
1050        dst_ip,
1051        dst_local_ip,
1052        ifindex,
1053    }
1054}
1055
1056#[cfg(not(any(target_os = "macos", target_os = "ios")))]
1057// Chosen somewhat arbitrarily; might benefit from additional tuning.
1058pub const BATCH_SIZE: usize = 32;
1059
1060#[cfg(any(target_os = "macos", target_os = "ios"))]
1061pub const BATCH_SIZE: usize = 1;
1062
1063#[cfg(target_os = "linux")]
1064mod gso {
1065    use super::*;
1066
1067    /// Checks whether GSO support is available by setting the UDP_SEGMENT
1068    /// option on a socket
1069    pub fn max_gso_segments() -> usize {
1070        const GSO_SIZE: libc::c_int = 1500;
1071
1072        let socket = match std::net::UdpSocket::bind("[::]:0") {
1073            Ok(socket) => socket,
1074            Err(_) => return 1,
1075        };
1076
1077        let rc = unsafe {
1078            libc::setsockopt(
1079                socket.as_raw_fd(),
1080                libc::SOL_UDP,
1081                libc::UDP_SEGMENT,
1082                &GSO_SIZE as *const _ as _,
1083                mem::size_of_val(&GSO_SIZE) as _,
1084            )
1085        };
1086
1087        if rc != -1 {
1088            // As defined in linux/udp.h
1089            // #define UDP_MAX_SEGMENTS        (1 << 6UL)
1090            64
1091        } else {
1092            1
1093        }
1094    }
1095
1096    pub fn set_segment_size(encoder: &mut cmsg::Encoder, segment_size: u16) {
1097        encoder.push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size);
1098    }
1099}
1100
1101#[cfg(not(target_os = "linux"))]
1102mod gso {
1103    use super::*;
1104
1105    pub fn max_gso_segments() -> usize {
1106        1
1107    }
1108
1109    pub fn set_segment_size(_encoder: &mut cmsg::Encoder, _segment_size: u16) {
1110        panic!("Setting a segment size is not supported on current platform");
1111    }
1112}
1113
1114#[cfg(target_os = "linux")]
1115mod gro {
1116    use super::*;
1117
1118    pub fn gro_segments() -> usize {
1119        let socket = match std::net::UdpSocket::bind("[::]:0") {
1120            Ok(socket) => socket,
1121            Err(_) => return 1,
1122        };
1123
1124        let on: libc::c_int = 1;
1125        let rc = unsafe {
1126            libc::setsockopt(
1127                socket.as_raw_fd(),
1128                libc::SOL_UDP,
1129                libc::UDP_GRO,
1130                &on as *const _ as _,
1131                mem::size_of_val(&on) as _,
1132            )
1133        };
1134
1135        if rc != -1 {
1136            // As defined in net/ipv4/udp_offload.c
1137            // #define UDP_GRO_CNT_MAX 64
1138            //
1139            // NOTE: this MUST be set to UDP_GRO_CNT_MAX to ensure that the receive buffer size
1140            // (get_max_udp_payload_size() * gro_segments()) is large enough to hold the largest GRO
1141            // list the kernel might potentially produce. See
1142            // https://github.com/quinn-rs/quinn/pull/1354.
1143            64
1144        } else {
1145            1
1146        }
1147    }
1148}
1149
1150#[cfg(not(target_os = "linux"))]
1151mod gro {
1152    pub fn gro_segments() -> usize {
1153        1
1154    }
1155}