1#[cfg(not(any(target_os = "macos", target_os = "ios")))]
2use std::ptr;
3use std::{
4 io,
5 io::IoSliceMut,
6 mem::{self, MaybeUninit},
7 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
8 os::unix::io::AsRawFd,
9 sync::atomic::{AtomicU64, AtomicUsize},
10 time::Instant,
11};
12
13use socket2::SockRef;
14
15use super::{
16 cmsg, log_sendmsg_error, Capabilities, EcnCodepoint, RecvMeta, Transmit, UdpSockRef,
17 IO_ERROR_LOG_INTERVAL,
18};
19
20#[cfg(target_os = "freebsd")]
21type IpTosTy = libc::c_uchar;
22#[cfg(not(target_os = "freebsd"))]
23type IpTosTy = libc::c_int;
24
25#[derive(Debug)]
30pub struct UdpSocketState {
31 epoch: Instant,
32 last_send_error: AtomicU64,
33}
34
35impl UdpSocketState {
36 pub fn new() -> Self {
37 let now = Instant::now();
38 Self {
39 epoch: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
40 last_send_error: AtomicU64::new(0),
41 }
42 }
43
44 pub fn configure(sock: UdpSockRef<'_>) -> io::Result<()> {
45 init(sock.0)
46 }
47
48 pub fn send(
49 &self,
50 socket: UdpSockRef<'_>,
51 capabilities: &Capabilities,
52 transmits: &[Transmit],
53 ) -> Result<usize, io::Error> {
54 send(
55 capabilities,
56 socket.0,
57 &self.epoch,
58 &self.last_send_error,
59 transmits,
60 )
61 }
62
63 pub fn recv(
64 &self,
65 socket: UdpSockRef<'_>,
66 bufs: &mut [IoSliceMut<'_>],
67 meta: &mut [RecvMeta],
68 ) -> io::Result<usize> {
69 recv(socket.0, bufs, meta)
70 }
71}
72
73impl Default for UdpSocketState {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79fn init(io: SockRef<'_>) -> io::Result<()> {
80 let mut cmsg_platform_space = 0;
81 if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") || cfg!(target_os = "macos") {
82 cmsg_platform_space +=
83 unsafe { libc::CMSG_SPACE(mem::size_of::<libc::in6_pktinfo>() as _) as usize };
84 }
85
86 assert!(
87 CMSG_LEN
88 >= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
89 + cmsg_platform_space
90 );
91 assert!(
92 mem::align_of::<libc::cmsghdr>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
93 "control message buffers will be misaligned"
94 );
95
96 io.set_nonblocking(true)?;
97
98 let addr = io.local_addr()?;
99 let is_ipv4 = addr.family() == libc::AF_INET as libc::sa_family_t;
100
101 if is_ipv4 || !io.only_v6()? {
104 if let Err(err) = set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVTOS, OPTION_ON) {
105 tracing::debug!("Ignoring error setting IP_RECVTOS on socket: {err:?}",);
106 }
107 }
108
109 #[cfg(target_os = "linux")]
110 {
111 let _ = set_socket_option(&*io, libc::SOL_UDP, libc::UDP_GRO, OPTION_ON);
113
114 set_socket_option(
116 &*io,
117 libc::IPPROTO_IP,
118 libc::IP_MTU_DISCOVER,
119 libc::IP_PMTUDISC_PROBE,
120 )?;
121
122 if is_ipv4 {
123 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_PKTINFO, OPTION_ON)?;
124 } else {
125 set_socket_option(
126 &*io,
127 libc::IPPROTO_IPV6,
128 libc::IPV6_MTU_DISCOVER,
129 libc::IP_PMTUDISC_PROBE,
130 )?;
131 }
132 }
133 #[cfg(any(target_os = "freebsd", target_os = "macos"))]
134 {
138 if is_ipv4 {
139 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVDSTADDR, OPTION_ON)?;
140 }
141 }
142
143 if !is_ipv4 {
145 set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, OPTION_ON)?;
146 set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, OPTION_ON)?;
147 }
148
149 if !is_ipv4 {
150 set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, OPTION_ON)?;
151 }
152
153 Ok(())
154}
155
156#[cfg(not(any(target_os = "macos", target_os = "ios")))]
157fn send(
158 #[allow(unused_variables)] capabilities: &Capabilities,
160 io: SockRef<'_>,
161 epoch: &Instant,
162 last_send_error: &AtomicU64,
163 transmits: &[Transmit],
164) -> io::Result<usize> {
165 #[allow(unused_mut)] let mut encode_src_ip = true;
167 #[cfg(target_os = "freebsd")]
168 {
169 let addr = io.local_addr()?;
170 let is_ipv4 = addr.family() == libc::AF_INET as libc::sa_family_t;
171 if is_ipv4 {
172 if let Some(socket) = addr.as_socket_ipv4() {
173 encode_src_ip = socket.ip() == &Ipv4Addr::UNSPECIFIED;
174 }
175 }
176 }
177 let mut msgs: [libc::mmsghdr; BATCH_SIZE] = unsafe { mem::zeroed() };
178 let mut iovecs: [libc::iovec; BATCH_SIZE] = unsafe { mem::zeroed() };
179 let mut cmsgs = [cmsg::Aligned([0u8; CMSG_LEN]); BATCH_SIZE];
180 let mut addrs: [MaybeUninit<socket2::SockAddr>; BATCH_SIZE] =
187 unsafe { MaybeUninit::uninit().assume_init() };
188 for (i, transmit) in transmits.iter().enumerate().take(BATCH_SIZE) {
189 let dst_addr = unsafe {
190 ptr::write(
191 addrs[i].as_mut_ptr(),
192 socket2::SockAddr::from(transmit.destination),
193 );
194 &*addrs[i].as_ptr()
195 };
196 prepare_msg(
197 transmit,
198 dst_addr,
199 &mut msgs[i].msg_hdr,
200 &mut iovecs[i],
201 &mut cmsgs[i],
202 encode_src_ip,
203 );
204 }
205 let num_transmits = transmits.len().min(BATCH_SIZE);
206
207 loop {
208 let n = unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as _, 0) };
209 if n == -1 {
210 let e = io::Error::last_os_error();
211 match e.kind() {
212 io::ErrorKind::Interrupted => {
213 continue;
215 }
216 io::ErrorKind::WouldBlock => return Err(e),
217 _ => {
218 #[cfg(target_os = "linux")]
222 if e.raw_os_error() == Some(libc::EIO) {
223 if capabilities.max_gso_segments() > 1 {
226 tracing::error!("got EIO, halting segmentation offload");
227 capabilities
228 .max_gso_segments
229 .store(1, std::sync::atomic::Ordering::Relaxed);
230 }
231 }
232
233 log_sendmsg_error(epoch, last_send_error, e, &transmits[0]);
240
241 return Ok(num_transmits.min(1));
246 }
247 }
248 }
249 return Ok(n as usize);
250 }
251}
252
253#[cfg(any(target_os = "macos", target_os = "ios"))]
254fn send(
255 _capabilities: &Capabilities,
256 io: SockRef<'_>,
257 epoch: &Instant,
258 last_send_error: &AtomicU64,
259 transmits: &[Transmit],
260) -> io::Result<usize> {
261 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
262 let mut iov: libc::iovec = unsafe { mem::zeroed() };
263 let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
264 let mut sent = 0;
265
266 while sent < transmits.len() {
267 let addr = socket2::SockAddr::from(transmits[sent].destination);
268 prepare_msg(
269 &transmits[sent],
270 &addr,
271 &mut hdr,
272 &mut iov,
273 &mut ctrl,
274 cfg!(target_os = "macos"),
276 );
277 let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
278 if n == -1 {
279 let e = io::Error::last_os_error();
280 match e.kind() {
281 io::ErrorKind::Interrupted => {
282 }
284 io::ErrorKind::WouldBlock if sent != 0 => return Ok(sent),
285 io::ErrorKind::WouldBlock => return Err(e),
286 _ => {
287 log_sendmsg_error(epoch, last_send_error, e, &transmits[sent]);
294 sent += 1;
295 }
296 }
297 } else {
298 sent += 1;
299 }
300 }
301 Ok(sent)
302}
303
304#[cfg(not(any(target_os = "macos", target_os = "ios")))]
305fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result<usize> {
306 let mut names = [MaybeUninit::<libc::sockaddr_storage>::uninit(); BATCH_SIZE];
307 let mut ctrls = [cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); BATCH_SIZE];
308 let mut hdrs = unsafe { mem::zeroed::<[libc::mmsghdr; BATCH_SIZE]>() };
309 let max_msg_count = bufs.len().min(BATCH_SIZE);
310 for i in 0..max_msg_count {
311 prepare_recv(
312 &mut bufs[i],
313 &mut names[i],
314 &mut ctrls[i],
315 &mut hdrs[i].msg_hdr,
316 );
317 }
318 let msg_count = loop {
319 let n = unsafe {
320 libc::recvmmsg(
321 io.as_raw_fd(),
322 hdrs.as_mut_ptr(),
323 bufs.len().min(BATCH_SIZE) as _,
324 0,
325 ptr::null_mut(),
326 )
327 };
328 if n == -1 {
329 let e = io::Error::last_os_error();
330 if e.kind() == io::ErrorKind::Interrupted {
331 continue;
332 }
333 return Err(e);
334 }
335 break n;
336 };
337 for i in 0..(msg_count as usize) {
338 meta[i] = decode_recv(&names[i], &hdrs[i].msg_hdr, hdrs[i].msg_len as usize);
339 }
340 Ok(msg_count as usize)
341}
342
343#[cfg(any(target_os = "macos", target_os = "ios"))]
344fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result<usize> {
345 let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
346 let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
347 let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
348 prepare_recv(&mut bufs[0], &mut name, &mut ctrl, &mut hdr);
349 let n = loop {
350 let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
351 if n == -1 {
352 let e = io::Error::last_os_error();
353 if e.kind() == io::ErrorKind::Interrupted {
354 continue;
355 }
356 return Err(e);
357 }
358 if hdr.msg_flags & libc::MSG_TRUNC != 0 {
359 continue;
360 }
361 break n;
362 };
363 meta[0] = decode_recv(&name, &hdr, n as usize);
364 Ok(1)
365}
366
367pub fn capabilities() -> Capabilities {
369 Capabilities {
370 max_gso_segments: AtomicUsize::new(gso::max_gso_segments()),
371 gro_segments: gro::gro_segments(),
372 }
373}
374
375const CMSG_LEN: usize = 88;
376
377fn prepare_msg(
378 transmit: &Transmit,
379 dst_addr: &socket2::SockAddr,
380 hdr: &mut libc::msghdr,
381 iov: &mut libc::iovec,
382 ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>,
383 #[allow(unused_variables)] encode_src_ip: bool,
385) {
386 iov.iov_base = transmit.contents.as_ptr() as *const _ as *mut _;
387 iov.iov_len = transmit.contents.len();
388
389 let name = dst_addr.as_ptr() as *mut libc::c_void;
395 let namelen = dst_addr.len();
396 hdr.msg_name = name as *mut _;
397 hdr.msg_namelen = namelen;
398 hdr.msg_iov = iov;
399 hdr.msg_iovlen = 1;
400
401 hdr.msg_control = ctrl.0.as_mut_ptr() as _;
402 hdr.msg_controllen = CMSG_LEN as _;
403 let mut encoder = unsafe { cmsg::Encoder::new(hdr) };
404 let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int);
405 if transmit.destination.is_ipv4() {
406 encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy);
407 } else {
408 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn);
409 }
410
411 if let Some(segment_size) = transmit.segment_size {
412 gso::set_segment_size(&mut encoder, segment_size as u16);
413 }
414
415 if let Some(ip) = &transmit.src_ip {
416 match ip {
417 IpAddr::V4(v4) => {
418 #[cfg(target_os = "linux")]
419 {
420 let pktinfo = libc::in_pktinfo {
421 ipi_ifindex: 0,
422 ipi_spec_dst: libc::in_addr {
423 s_addr: u32::from_ne_bytes(v4.octets()),
424 },
425 ipi_addr: libc::in_addr { s_addr: 0 },
426 };
427 encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
428 }
429 #[cfg(any(target_os = "freebsd", target_os = "macos"))]
430 {
431 if encode_src_ip {
432 let addr = libc::in_addr {
433 s_addr: u32::from_ne_bytes(v4.octets()),
434 };
435 encoder.push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr);
436 }
437 }
438 }
439 IpAddr::V6(v6) => {
440 let pktinfo = libc::in6_pktinfo {
441 ipi6_ifindex: 0,
442 ipi6_addr: libc::in6_addr {
443 s6_addr: v6.octets(),
444 },
445 };
446 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
447 }
448 }
449 }
450
451 encoder.finish();
452}
453
454fn prepare_recv(
455 buf: &mut IoSliceMut<'_>,
456 name: &mut MaybeUninit<libc::sockaddr_storage>,
457 ctrl: &mut cmsg::Aligned<MaybeUninit<[u8; CMSG_LEN]>>,
458 hdr: &mut libc::msghdr,
459) {
460 hdr.msg_name = name.as_mut_ptr() as _;
461 hdr.msg_namelen = mem::size_of::<libc::sockaddr_storage>() as _;
462 hdr.msg_iov = buf as *mut IoSliceMut<'_> as *mut libc::iovec;
463 hdr.msg_iovlen = 1;
464 hdr.msg_control = ctrl.0.as_mut_ptr() as _;
465 hdr.msg_controllen = CMSG_LEN as _;
466 hdr.msg_flags = 0;
467}
468
469fn decode_recv(
470 name: &MaybeUninit<libc::sockaddr_storage>,
471 hdr: &libc::msghdr,
472 len: usize,
473) -> RecvMeta {
474 let name = unsafe { name.assume_init() };
475 let mut ecn_bits = 0;
476 let mut dst_ip = None;
477 #[allow(unused_mut)] let mut stride = len;
479
480 let cmsg_iter = unsafe { cmsg::Iter::new(hdr) };
481 for cmsg in cmsg_iter {
482 match (cmsg.cmsg_level, cmsg.cmsg_type) {
483 (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
485 ecn_bits = cmsg::decode::<u8>(cmsg);
486 },
487 (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
488 #[allow(clippy::unnecessary_cast)] if cfg!(target_os = "macos")
492 && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
493 {
494 ecn_bits = cmsg::decode::<u8>(cmsg);
495 } else {
496 ecn_bits = cmsg::decode::<libc::c_int>(cmsg) as u8;
497 }
498 },
499 #[cfg(target_os = "linux")]
500 (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
501 let pktinfo = unsafe { cmsg::decode::<libc::in_pktinfo>(cmsg) };
502 dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
503 pktinfo.ipi_addr.s_addr.to_ne_bytes(),
504 )));
505 }
506 #[cfg(any(target_os = "freebsd", target_os = "macos"))]
507 (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => {
508 let in_addr = unsafe { cmsg::decode::<libc::in_addr>(cmsg) };
509 dst_ip = Some(IpAddr::V4(Ipv4Addr::from(in_addr.s_addr.to_ne_bytes())));
510 }
511 (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
512 let pktinfo = unsafe { cmsg::decode::<libc::in6_pktinfo>(cmsg) };
513 dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
514 }
515 #[cfg(target_os = "linux")]
516 (libc::SOL_UDP, libc::UDP_GRO) => unsafe {
517 stride = cmsg::decode::<libc::c_int>(cmsg) as usize;
518 },
519 _ => {}
520 }
521 }
522
523 let addr = match libc::c_int::from(name.ss_family) {
524 libc::AF_INET => {
525 let addr: &libc::sockaddr_in =
527 unsafe { &*(&name as *const _ as *const libc::sockaddr_in) };
528 SocketAddr::V4(SocketAddrV4::new(
529 Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()),
530 u16::from_be(addr.sin_port),
531 ))
532 }
533 libc::AF_INET6 => {
534 let addr: &libc::sockaddr_in6 =
536 unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) };
537 SocketAddr::V6(SocketAddrV6::new(
538 Ipv6Addr::from(addr.sin6_addr.s6_addr),
539 u16::from_be(addr.sin6_port),
540 addr.sin6_flowinfo,
541 addr.sin6_scope_id,
542 ))
543 }
544 _ => unreachable!(),
545 };
546
547 RecvMeta {
548 len,
549 stride,
550 addr,
551 ecn: EcnCodepoint::from_bits(ecn_bits),
552 dst_ip,
553 }
554}
555
556#[cfg(not(any(target_os = "macos", target_os = "ios")))]
557pub const BATCH_SIZE: usize = 32;
559
560#[cfg(any(target_os = "macos", target_os = "ios"))]
561pub const BATCH_SIZE: usize = 1;
562
563#[cfg(target_os = "linux")]
564mod gso {
565 use super::*;
566
567 pub fn max_gso_segments() -> usize {
570 const GSO_SIZE: libc::c_int = 1500;
571
572 let socket = match std::net::UdpSocket::bind("[::]:0")
573 .or_else(|_| std::net::UdpSocket::bind("127.0.0.1:0"))
574 {
575 Ok(socket) => socket,
576 Err(_) => return 1,
577 };
578
579 match set_socket_option(&socket, libc::SOL_UDP, libc::UDP_SEGMENT, GSO_SIZE) {
582 Ok(()) => 64,
583 Err(_) => 1,
584 }
585 }
586
587 pub fn set_segment_size(encoder: &mut cmsg::Encoder<'_>, segment_size: u16) {
588 encoder.push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size);
589 }
590}
591
592#[cfg(not(target_os = "linux"))]
593mod gso {
594 use super::*;
595
596 pub fn max_gso_segments() -> usize {
597 1
598 }
599
600 pub fn set_segment_size(_encoder: &mut cmsg::Encoder<'_>, _segment_size: u16) {
601 panic!("Setting a segment size is not supported on current platform");
602 }
603}
604
605#[cfg(target_os = "linux")]
606mod gro {
607 use super::*;
608
609 pub fn gro_segments() -> usize {
610 let socket = match std::net::UdpSocket::bind("[::]:0")
611 .or_else(|_| std::net::UdpSocket::bind("127.0.0.1:0"))
612 {
613 Ok(socket) => socket,
614 Err(_) => return 1,
615 };
616
617 match set_socket_option(&socket, libc::SOL_UDP, libc::UDP_GRO, OPTION_ON) {
625 Ok(()) => 64,
626 Err(_) => 1,
627 }
628 }
629}
630
631fn set_socket_option(
632 socket: &impl AsRawFd,
633 level: libc::c_int,
634 name: libc::c_int,
635 value: libc::c_int,
636) -> Result<(), io::Error> {
637 let rc = unsafe {
638 libc::setsockopt(
639 socket.as_raw_fd(),
640 level,
641 name,
642 &value as *const _ as _,
643 mem::size_of_val(&value) as _,
644 )
645 };
646
647 match rc == 0 {
648 true => Ok(()),
649 false => Err(io::Error::last_os_error()),
650 }
651}
652
653const OPTION_ON: libc::c_int = 1;
654
655#[cfg(not(target_os = "linux"))]
656mod gro {
657 pub fn gro_segments() -> usize {
658 1
659 }
660}