quinn_udp/
windows.rs

1use std::{
2    io::{self, IoSliceMut},
3    mem,
4    net::{IpAddr, Ipv4Addr},
5    os::windows::io::AsRawSocket,
6    ptr,
7    sync::Mutex,
8    time::Instant,
9};
10
11use libc::{c_int, c_uint};
12use once_cell::sync::Lazy;
13use windows_sys::Win32::Networking::WinSock;
14
15use crate::{
16    EcnCodepoint, IO_ERROR_LOG_INTERVAL, RecvMeta, Transmit, UdpSockRef,
17    cmsg::{self, CMsgHdr},
18    log::debug,
19    log_sendmsg_error,
20};
21
22/// QUIC-friendly UDP socket for Windows
23///
24/// Unlike a standard Windows UDP socket, this allows ECN bits to be read and written.
25#[derive(Debug)]
26pub struct UdpSocketState {
27    last_send_error: Mutex<Instant>,
28}
29
30impl UdpSocketState {
31    pub fn new(socket: UdpSockRef<'_>) -> io::Result<Self> {
32        assert!(
33            CMSG_LEN
34                >= WinSock::CMSGHDR::cmsg_space(mem::size_of::<WinSock::IN6_PKTINFO>())
35                    + WinSock::CMSGHDR::cmsg_space(mem::size_of::<c_int>())
36                    + WinSock::CMSGHDR::cmsg_space(mem::size_of::<u32>())
37        );
38        assert!(
39            mem::align_of::<WinSock::CMSGHDR>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
40            "control message buffers will be misaligned"
41        );
42
43        socket.0.set_nonblocking(true)?;
44        let addr = socket.0.local_addr()?;
45        let is_ipv6 = addr.as_socket_ipv6().is_some();
46        let v6only = unsafe {
47            let mut result: u32 = 0;
48            let mut len = mem::size_of_val(&result) as i32;
49            let rc = WinSock::getsockopt(
50                socket.0.as_raw_socket() as _,
51                WinSock::IPPROTO_IPV6,
52                WinSock::IPV6_V6ONLY as _,
53                &mut result as *mut _ as _,
54                &mut len,
55            );
56            if rc == -1 {
57                return Err(io::Error::last_os_error());
58            }
59            result != 0
60        };
61        let is_ipv4 = addr.as_socket_ipv4().is_some() || !v6only;
62
63        // We don't support old versions of Windows that do not enable access to `WSARecvMsg()`
64        if WSARECVMSG_PTR.is_none() {
65            return Err(io::Error::new(
66                io::ErrorKind::Unsupported,
67                "network stack does not support WSARecvMsg function",
68            ));
69        }
70
71        if is_ipv4 {
72            set_socket_option(
73                &*socket.0,
74                WinSock::IPPROTO_IP,
75                WinSock::IP_DONTFRAGMENT,
76                OPTION_ON,
77            )?;
78
79            set_socket_option(
80                &*socket.0,
81                WinSock::IPPROTO_IP,
82                WinSock::IP_PKTINFO,
83                OPTION_ON,
84            )?;
85            set_socket_option(
86                &*socket.0,
87                WinSock::IPPROTO_IP,
88                WinSock::IP_RECVECN,
89                OPTION_ON,
90            )?;
91        }
92
93        if is_ipv6 {
94            set_socket_option(
95                &*socket.0,
96                WinSock::IPPROTO_IPV6,
97                WinSock::IPV6_DONTFRAG,
98                OPTION_ON,
99            )?;
100
101            set_socket_option(
102                &*socket.0,
103                WinSock::IPPROTO_IPV6,
104                WinSock::IPV6_PKTINFO,
105                OPTION_ON,
106            )?;
107
108            set_socket_option(
109                &*socket.0,
110                WinSock::IPPROTO_IPV6,
111                WinSock::IPV6_RECVECN,
112                OPTION_ON,
113            )?;
114        }
115
116        let now = Instant::now();
117        Ok(Self {
118            last_send_error: Mutex::new(now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now)),
119        })
120    }
121
122    /// Enable or disable receive offloading.
123    ///
124    /// Also referred to as UDP Receive Segment Coalescing Offload (URO) on Windows.
125    ///
126    /// <https://learn.microsoft.com/en-us/windows-hardware/drivers/network/udp-rsc-offload>
127    ///
128    /// Disabled by default on Windows due to <https://github.com/quinn-rs/quinn/issues/2041>.
129    pub fn set_gro(&self, socket: UdpSockRef<'_>, enable: bool) -> io::Result<()> {
130        set_socket_option(
131            &*socket.0,
132            WinSock::IPPROTO_UDP,
133            WinSock::UDP_RECV_MAX_COALESCED_SIZE,
134            match enable {
135                // u32 per
136                // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-udp-socket-options.
137                // Choice of 2^16 - 1 inspired by msquic.
138                true => u16::MAX as u32,
139                false => 0,
140            },
141        )
142    }
143
144    /// Sends a [`Transmit`] on the given socket.
145    ///
146    /// This function will only ever return errors of kind [`io::ErrorKind::WouldBlock`].
147    /// All other errors will be logged and converted to `Ok`.
148    ///
149    /// UDP transmission errors are considered non-fatal because higher-level protocols must
150    /// employ retransmits and timeouts anyway in order to deal with UDP's unreliable nature.
151    /// Thus, logging is most likely the only thing you can do with these errors.
152    ///
153    /// If you would like to handle these errors yourself, use [`UdpSocketState::try_send`]
154    /// instead.
155    pub fn send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
156        match send(socket, transmit) {
157            Ok(()) => Ok(()),
158            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
159            Err(e) => {
160                log_sendmsg_error(&self.last_send_error, e, transmit);
161
162                Ok(())
163            }
164        }
165    }
166
167    /// Sends a [`Transmit`] on the given socket without any additional error handling.
168    pub fn try_send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
169        send(socket, transmit)
170    }
171
172    pub fn recv(
173        &self,
174        socket: UdpSockRef<'_>,
175        bufs: &mut [IoSliceMut<'_>],
176        meta: &mut [RecvMeta],
177    ) -> io::Result<usize> {
178        let wsa_recvmsg_ptr = WSARECVMSG_PTR.expect("valid function pointer for WSARecvMsg");
179
180        // we cannot use [`socket2::MsgHdrMut`] as we do not have access to inner field which holds the WSAMSG
181        let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
182        let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() };
183        let mut data = WinSock::WSABUF {
184            buf: bufs[0].as_mut_ptr(),
185            len: bufs[0].len() as _,
186        };
187
188        let ctrl = WinSock::WSABUF {
189            buf: ctrl_buf.0.as_mut_ptr(),
190            len: ctrl_buf.0.len() as _,
191        };
192
193        let mut wsa_msg = WinSock::WSAMSG {
194            name: &mut source as *mut _ as *mut _,
195            namelen: mem::size_of_val(&source) as _,
196            lpBuffers: &mut data,
197            Control: ctrl,
198            dwBufferCount: 1,
199            dwFlags: 0,
200        };
201
202        let mut len = 0;
203        unsafe {
204            let rc = (wsa_recvmsg_ptr)(
205                socket.0.as_raw_socket() as usize,
206                &mut wsa_msg,
207                &mut len,
208                ptr::null_mut(),
209                None,
210            );
211            if rc == -1 {
212                return Err(io::Error::last_os_error());
213            }
214        }
215
216        let addr = unsafe {
217            let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| {
218                *len = mem::size_of_val(&source) as _;
219                ptr::copy_nonoverlapping(&source, addr_storage as _, 1);
220                Ok(())
221            })?;
222            addr.as_socket()
223        };
224
225        // Decode control messages (PKTINFO and ECN)
226        let mut ecn_bits = 0;
227        let mut dst_ip = None;
228        let mut stride = len;
229
230        let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) };
231        for cmsg in cmsg_iter {
232            const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32;
233            // [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding]
234            match (cmsg.cmsg_level, cmsg.cmsg_type) {
235                (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => {
236                    let pktinfo =
237                        unsafe { cmsg::decode::<WinSock::IN_PKTINFO, WinSock::CMSGHDR>(cmsg) };
238                    // Addr is stored in big endian format
239                    let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr }));
240                    dst_ip = Some(ip4.into());
241                }
242                (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => {
243                    let pktinfo =
244                        unsafe { cmsg::decode::<WinSock::IN6_PKTINFO, WinSock::CMSGHDR>(cmsg) };
245                    // Addr is stored in big endian format
246                    dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte }));
247                }
248                (WinSock::IPPROTO_IP, WinSock::IP_ECN) => {
249                    // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
250                    ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
251                }
252                (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => {
253                    // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
254                    ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
255                }
256                (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => {
257                    // Has type u32 (aka DWORD) per
258                    // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-udp-socket-options
259                    stride = unsafe { cmsg::decode::<u32, WinSock::CMSGHDR>(cmsg) };
260                }
261                _ => {}
262            }
263        }
264
265        meta[0] = RecvMeta {
266            len: len as usize,
267            stride: stride as usize,
268            addr: addr.unwrap(),
269            ecn: EcnCodepoint::from_bits(ecn_bits as u8),
270            dst_ip,
271        };
272        Ok(1)
273    }
274
275    /// The maximum amount of segments which can be transmitted if a platform
276    /// supports Generic Send Offload (GSO).
277    ///
278    /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected
279    /// while using GSO.
280    #[inline]
281    pub fn max_gso_segments(&self) -> usize {
282        *MAX_GSO_SEGMENTS
283    }
284
285    /// The number of segments to read when GRO is enabled. Used as a factor to
286    /// compute the receive buffer size.
287    ///
288    /// Returns 1 if the platform doesn't support GRO.
289    #[inline]
290    pub fn gro_segments(&self) -> usize {
291        // Arbitrary reasonable value inspired by Linux and msquic
292        64
293    }
294
295    /// Resize the send buffer of `socket` to `bytes`
296    #[inline]
297    pub fn set_send_buffer_size(&self, socket: UdpSockRef<'_>, bytes: usize) -> io::Result<()> {
298        socket.0.set_send_buffer_size(bytes)
299    }
300
301    /// Resize the receive buffer of `socket` to `bytes`
302    #[inline]
303    pub fn set_recv_buffer_size(&self, socket: UdpSockRef<'_>, bytes: usize) -> io::Result<()> {
304        socket.0.set_recv_buffer_size(bytes)
305    }
306
307    /// Get the size of the `socket` send buffer
308    #[inline]
309    pub fn send_buffer_size(&self, socket: UdpSockRef<'_>) -> io::Result<usize> {
310        socket.0.send_buffer_size()
311    }
312
313    /// Get the size of the `socket` receive buffer
314    #[inline]
315    pub fn recv_buffer_size(&self, socket: UdpSockRef<'_>) -> io::Result<usize> {
316        socket.0.recv_buffer_size()
317    }
318
319    #[inline]
320    pub fn may_fragment(&self) -> bool {
321        false
322    }
323}
324
325fn send(socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
326    // we cannot use [`socket2::sendmsg()`] and [`socket2::MsgHdr`] as we do not have access
327    // to the inner field which holds the WSAMSG
328    let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
329    let daddr = socket2::SockAddr::from(transmit.destination);
330
331    let mut data = WinSock::WSABUF {
332        buf: transmit.contents.as_ptr() as *mut _,
333        len: transmit.contents.len() as _,
334    };
335
336    let ctrl = WinSock::WSABUF {
337        buf: ctrl_buf.0.as_mut_ptr(),
338        len: ctrl_buf.0.len() as _,
339    };
340
341    let mut wsa_msg = WinSock::WSAMSG {
342        name: daddr.as_ptr() as *mut _,
343        namelen: daddr.len(),
344        lpBuffers: &mut data,
345        Control: ctrl,
346        dwBufferCount: 1,
347        dwFlags: 0,
348    };
349
350    // Add control messages (ECN and PKTINFO)
351    let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) };
352
353    if let Some(ip) = transmit.src_ip {
354        let ip = std::net::SocketAddr::new(ip, 0);
355        let ip = socket2::SockAddr::from(ip);
356        match ip.family() {
357            WinSock::AF_INET => {
358                let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN) };
359                let pktinfo = WinSock::IN_PKTINFO {
360                    ipi_addr: src_ip.sin_addr,
361                    ipi_ifindex: 0,
362                };
363                encoder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo);
364            }
365            WinSock::AF_INET6 => {
366                let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN6) };
367                let pktinfo = WinSock::IN6_PKTINFO {
368                    ipi6_addr: src_ip.sin6_addr,
369                    ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id },
370                };
371                encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo);
372            }
373            _ => {
374                return Err(io::Error::from(io::ErrorKind::InvalidInput));
375            }
376        }
377    }
378
379    // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
380    let ecn = transmit.ecn.map_or(0, |x| x as c_int);
381    // True for IPv4 or IPv4-Mapped IPv6
382    let is_ipv4 = transmit.destination.is_ipv4()
383        || matches!(transmit.destination.ip(), IpAddr::V6(addr) if addr.to_ipv4_mapped().is_some());
384    if is_ipv4 {
385        encoder.push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn);
386    } else {
387        encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn);
388    }
389
390    // Segment size is a u32 https://learn.microsoft.com/en-us/windows/win32/api/ws2tcpip/nf-ws2tcpip-wsasetudpsendmessagesize
391    if let Some(segment_size) = transmit.segment_size {
392        encoder.push(
393            WinSock::IPPROTO_UDP,
394            WinSock::UDP_SEND_MSG_SIZE,
395            segment_size as u32,
396        );
397    }
398
399    encoder.finish();
400
401    let mut len = 0;
402    let rc = unsafe {
403        WinSock::WSASendMsg(
404            socket.0.as_raw_socket() as usize,
405            &wsa_msg,
406            0,
407            &mut len,
408            ptr::null_mut(),
409            None,
410        )
411    };
412
413    match rc {
414        0 => Ok(()),
415        _ => Err(io::Error::last_os_error()),
416    }
417}
418
419fn set_socket_option(
420    socket: &impl AsRawSocket,
421    level: i32,
422    name: i32,
423    value: u32,
424) -> io::Result<()> {
425    let rc = unsafe {
426        WinSock::setsockopt(
427            socket.as_raw_socket() as usize,
428            level,
429            name,
430            &value as *const _ as _,
431            mem::size_of_val(&value) as _,
432        )
433    };
434
435    match rc == 0 {
436        true => Ok(()),
437        false => Err(io::Error::last_os_error()),
438    }
439}
440
441pub(crate) const BATCH_SIZE: usize = 1;
442// Enough to store max(IP_PKTINFO + IP_ECN, IPV6_PKTINFO + IPV6_ECN) + max(UDP_SEND_MSG_SIZE, UDP_COALESCED_INFO) bytes (header + data) and some extra margin
443const CMSG_LEN: usize = 128;
444const OPTION_ON: u32 = 1;
445
446// FIXME this could use [`std::sync::OnceLock`] once the MSRV is bumped to 1.70 and upper
447static WSARECVMSG_PTR: Lazy<WinSock::LPFN_WSARECVMSG> = Lazy::new(|| {
448    let s = unsafe { WinSock::socket(WinSock::AF_INET as _, WinSock::SOCK_DGRAM as _, 0) };
449    if s == WinSock::INVALID_SOCKET {
450        debug!(
451            "ignoring WSARecvMsg function pointer due to socket creation error: {}",
452            io::Error::last_os_error()
453        );
454        return None;
455    }
456
457    // Detect if OS expose WSARecvMsg API based on
458    // https://github.com/Azure/mio-uds-windows/blob/a3c97df82018086add96d8821edb4aa85ec1b42b/src/stdnet/ext.rs#L601
459    let guid = WinSock::WSAID_WSARECVMSG;
460    let mut wsa_recvmsg_ptr = None;
461    let mut len = 0;
462
463    // Safety: Option handles the NULL pointer with a None value
464    let rc = unsafe {
465        WinSock::WSAIoctl(
466            s as _,
467            WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER,
468            &guid as *const _ as *const _,
469            mem::size_of_val(&guid) as u32,
470            &mut wsa_recvmsg_ptr as *mut _ as *mut _,
471            mem::size_of_val(&wsa_recvmsg_ptr) as u32,
472            &mut len,
473            ptr::null_mut(),
474            None,
475        )
476    };
477
478    if rc == -1 {
479        debug!(
480            "ignoring WSARecvMsg function pointer due to ioctl error: {}",
481            io::Error::last_os_error()
482        );
483    } else if len as usize != mem::size_of::<WinSock::LPFN_WSARECVMSG>() {
484        debug!("ignoring WSARecvMsg function pointer due to pointer size mismatch");
485        wsa_recvmsg_ptr = None;
486    }
487
488    unsafe {
489        WinSock::closesocket(s);
490    }
491
492    wsa_recvmsg_ptr
493});
494
495static MAX_GSO_SEGMENTS: Lazy<usize> = Lazy::new(|| {
496    let socket = match std::net::UdpSocket::bind("[::]:0")
497        .or_else(|_| std::net::UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)))
498    {
499        Ok(socket) => socket,
500        Err(_) => return 1,
501    };
502    const GSO_SIZE: c_uint = 1500;
503    match set_socket_option(
504        &socket,
505        WinSock::IPPROTO_UDP,
506        WinSock::UDP_SEND_MSG_SIZE,
507        GSO_SIZE,
508    ) {
509        // Empirically found on Windows 11 x64
510        Ok(()) => 512,
511        Err(_) => 1,
512    }
513});