1use std::{
2 io::{self, IoSliceMut},
3 mem,
4 net::{IpAddr, Ipv4Addr},
5 os::windows::io::AsRawSocket,
6 ptr,
7 sync::Mutex,
8 time::Instant,
9};
10
11use libc::{c_int, c_uint};
12use once_cell::sync::Lazy;
13use windows_sys::Win32::Networking::WinSock;
14
15use crate::{
16 EcnCodepoint, IO_ERROR_LOG_INTERVAL, RecvMeta, Transmit, UdpSockRef,
17 cmsg::{self, CMsgHdr},
18 log::debug,
19 log_sendmsg_error,
20};
21
22#[derive(Debug)]
26pub struct UdpSocketState {
27 last_send_error: Mutex<Instant>,
28}
29
30impl UdpSocketState {
31 pub fn new(socket: UdpSockRef<'_>) -> io::Result<Self> {
32 assert!(
33 CMSG_LEN
34 >= WinSock::CMSGHDR::cmsg_space(mem::size_of::<WinSock::IN6_PKTINFO>())
35 + WinSock::CMSGHDR::cmsg_space(mem::size_of::<c_int>())
36 + WinSock::CMSGHDR::cmsg_space(mem::size_of::<u32>())
37 );
38 assert!(
39 mem::align_of::<WinSock::CMSGHDR>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
40 "control message buffers will be misaligned"
41 );
42
43 socket.0.set_nonblocking(true)?;
44 let addr = socket.0.local_addr()?;
45 let is_ipv6 = addr.as_socket_ipv6().is_some();
46 let v6only = unsafe {
47 let mut result: u32 = 0;
48 let mut len = mem::size_of_val(&result) as i32;
49 let rc = WinSock::getsockopt(
50 socket.0.as_raw_socket() as _,
51 WinSock::IPPROTO_IPV6,
52 WinSock::IPV6_V6ONLY as _,
53 &mut result as *mut _ as _,
54 &mut len,
55 );
56 if rc == -1 {
57 return Err(io::Error::last_os_error());
58 }
59 result != 0
60 };
61 let is_ipv4 = addr.as_socket_ipv4().is_some() || !v6only;
62
63 if WSARECVMSG_PTR.is_none() {
65 return Err(io::Error::new(
66 io::ErrorKind::Unsupported,
67 "network stack does not support WSARecvMsg function",
68 ));
69 }
70
71 if is_ipv4 {
72 set_socket_option(
73 &*socket.0,
74 WinSock::IPPROTO_IP,
75 WinSock::IP_DONTFRAGMENT,
76 OPTION_ON,
77 )?;
78
79 set_socket_option(
80 &*socket.0,
81 WinSock::IPPROTO_IP,
82 WinSock::IP_PKTINFO,
83 OPTION_ON,
84 )?;
85 set_socket_option(
86 &*socket.0,
87 WinSock::IPPROTO_IP,
88 WinSock::IP_RECVECN,
89 OPTION_ON,
90 )?;
91 }
92
93 if is_ipv6 {
94 set_socket_option(
95 &*socket.0,
96 WinSock::IPPROTO_IPV6,
97 WinSock::IPV6_DONTFRAG,
98 OPTION_ON,
99 )?;
100
101 set_socket_option(
102 &*socket.0,
103 WinSock::IPPROTO_IPV6,
104 WinSock::IPV6_PKTINFO,
105 OPTION_ON,
106 )?;
107
108 set_socket_option(
109 &*socket.0,
110 WinSock::IPPROTO_IPV6,
111 WinSock::IPV6_RECVECN,
112 OPTION_ON,
113 )?;
114 }
115
116 let now = Instant::now();
117 Ok(Self {
118 last_send_error: Mutex::new(now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now)),
119 })
120 }
121
122 pub fn set_gro(&self, socket: UdpSockRef<'_>, enable: bool) -> io::Result<()> {
130 set_socket_option(
131 &*socket.0,
132 WinSock::IPPROTO_UDP,
133 WinSock::UDP_RECV_MAX_COALESCED_SIZE,
134 match enable {
135 true => u16::MAX as u32,
139 false => 0,
140 },
141 )
142 }
143
144 pub fn send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
156 match send(socket, transmit) {
157 Ok(()) => Ok(()),
158 Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
159 Err(e) => {
160 log_sendmsg_error(&self.last_send_error, e, transmit);
161
162 Ok(())
163 }
164 }
165 }
166
167 pub fn try_send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
169 send(socket, transmit)
170 }
171
172 pub fn recv(
173 &self,
174 socket: UdpSockRef<'_>,
175 bufs: &mut [IoSliceMut<'_>],
176 meta: &mut [RecvMeta],
177 ) -> io::Result<usize> {
178 let wsa_recvmsg_ptr = WSARECVMSG_PTR.expect("valid function pointer for WSARecvMsg");
179
180 let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
182 let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() };
183 let mut data = WinSock::WSABUF {
184 buf: bufs[0].as_mut_ptr(),
185 len: bufs[0].len() as _,
186 };
187
188 let ctrl = WinSock::WSABUF {
189 buf: ctrl_buf.0.as_mut_ptr(),
190 len: ctrl_buf.0.len() as _,
191 };
192
193 let mut wsa_msg = WinSock::WSAMSG {
194 name: &mut source as *mut _ as *mut _,
195 namelen: mem::size_of_val(&source) as _,
196 lpBuffers: &mut data,
197 Control: ctrl,
198 dwBufferCount: 1,
199 dwFlags: 0,
200 };
201
202 let mut len = 0;
203 unsafe {
204 let rc = (wsa_recvmsg_ptr)(
205 socket.0.as_raw_socket() as usize,
206 &mut wsa_msg,
207 &mut len,
208 ptr::null_mut(),
209 None,
210 );
211 if rc == -1 {
212 return Err(io::Error::last_os_error());
213 }
214 }
215
216 let addr = unsafe {
217 let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| {
218 *len = mem::size_of_val(&source) as _;
219 ptr::copy_nonoverlapping(&source, addr_storage as _, 1);
220 Ok(())
221 })?;
222 addr.as_socket()
223 };
224
225 let mut ecn_bits = 0;
227 let mut dst_ip = None;
228 let mut stride = len;
229
230 let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) };
231 for cmsg in cmsg_iter {
232 const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32;
233 match (cmsg.cmsg_level, cmsg.cmsg_type) {
235 (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => {
236 let pktinfo =
237 unsafe { cmsg::decode::<WinSock::IN_PKTINFO, WinSock::CMSGHDR>(cmsg) };
238 let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr }));
240 dst_ip = Some(ip4.into());
241 }
242 (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => {
243 let pktinfo =
244 unsafe { cmsg::decode::<WinSock::IN6_PKTINFO, WinSock::CMSGHDR>(cmsg) };
245 dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte }));
247 }
248 (WinSock::IPPROTO_IP, WinSock::IP_ECN) => {
249 ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
251 }
252 (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => {
253 ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
255 }
256 (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => {
257 stride = unsafe { cmsg::decode::<u32, WinSock::CMSGHDR>(cmsg) };
260 }
261 _ => {}
262 }
263 }
264
265 meta[0] = RecvMeta {
266 len: len as usize,
267 stride: stride as usize,
268 addr: addr.unwrap(),
269 ecn: EcnCodepoint::from_bits(ecn_bits as u8),
270 dst_ip,
271 };
272 Ok(1)
273 }
274
275 #[inline]
281 pub fn max_gso_segments(&self) -> usize {
282 *MAX_GSO_SEGMENTS
283 }
284
285 #[inline]
290 pub fn gro_segments(&self) -> usize {
291 64
293 }
294
295 #[inline]
297 pub fn set_send_buffer_size(&self, socket: UdpSockRef<'_>, bytes: usize) -> io::Result<()> {
298 socket.0.set_send_buffer_size(bytes)
299 }
300
301 #[inline]
303 pub fn set_recv_buffer_size(&self, socket: UdpSockRef<'_>, bytes: usize) -> io::Result<()> {
304 socket.0.set_recv_buffer_size(bytes)
305 }
306
307 #[inline]
309 pub fn send_buffer_size(&self, socket: UdpSockRef<'_>) -> io::Result<usize> {
310 socket.0.send_buffer_size()
311 }
312
313 #[inline]
315 pub fn recv_buffer_size(&self, socket: UdpSockRef<'_>) -> io::Result<usize> {
316 socket.0.recv_buffer_size()
317 }
318
319 #[inline]
320 pub fn may_fragment(&self) -> bool {
321 false
322 }
323}
324
325fn send(socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
326 let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
329 let daddr = socket2::SockAddr::from(transmit.destination);
330
331 let mut data = WinSock::WSABUF {
332 buf: transmit.contents.as_ptr() as *mut _,
333 len: transmit.contents.len() as _,
334 };
335
336 let ctrl = WinSock::WSABUF {
337 buf: ctrl_buf.0.as_mut_ptr(),
338 len: ctrl_buf.0.len() as _,
339 };
340
341 let mut wsa_msg = WinSock::WSAMSG {
342 name: daddr.as_ptr() as *mut _,
343 namelen: daddr.len(),
344 lpBuffers: &mut data,
345 Control: ctrl,
346 dwBufferCount: 1,
347 dwFlags: 0,
348 };
349
350 let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) };
352
353 if let Some(ip) = transmit.src_ip {
354 let ip = std::net::SocketAddr::new(ip, 0);
355 let ip = socket2::SockAddr::from(ip);
356 match ip.family() {
357 WinSock::AF_INET => {
358 let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN) };
359 let pktinfo = WinSock::IN_PKTINFO {
360 ipi_addr: src_ip.sin_addr,
361 ipi_ifindex: 0,
362 };
363 encoder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo);
364 }
365 WinSock::AF_INET6 => {
366 let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN6) };
367 let pktinfo = WinSock::IN6_PKTINFO {
368 ipi6_addr: src_ip.sin6_addr,
369 ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id },
370 };
371 encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo);
372 }
373 _ => {
374 return Err(io::Error::from(io::ErrorKind::InvalidInput));
375 }
376 }
377 }
378
379 let ecn = transmit.ecn.map_or(0, |x| x as c_int);
381 let is_ipv4 = transmit.destination.is_ipv4()
383 || matches!(transmit.destination.ip(), IpAddr::V6(addr) if addr.to_ipv4_mapped().is_some());
384 if is_ipv4 {
385 encoder.push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn);
386 } else {
387 encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn);
388 }
389
390 if let Some(segment_size) = transmit.segment_size {
392 encoder.push(
393 WinSock::IPPROTO_UDP,
394 WinSock::UDP_SEND_MSG_SIZE,
395 segment_size as u32,
396 );
397 }
398
399 encoder.finish();
400
401 let mut len = 0;
402 let rc = unsafe {
403 WinSock::WSASendMsg(
404 socket.0.as_raw_socket() as usize,
405 &wsa_msg,
406 0,
407 &mut len,
408 ptr::null_mut(),
409 None,
410 )
411 };
412
413 match rc {
414 0 => Ok(()),
415 _ => Err(io::Error::last_os_error()),
416 }
417}
418
419fn set_socket_option(
420 socket: &impl AsRawSocket,
421 level: i32,
422 name: i32,
423 value: u32,
424) -> io::Result<()> {
425 let rc = unsafe {
426 WinSock::setsockopt(
427 socket.as_raw_socket() as usize,
428 level,
429 name,
430 &value as *const _ as _,
431 mem::size_of_val(&value) as _,
432 )
433 };
434
435 match rc == 0 {
436 true => Ok(()),
437 false => Err(io::Error::last_os_error()),
438 }
439}
440
441pub(crate) const BATCH_SIZE: usize = 1;
442const CMSG_LEN: usize = 128;
444const OPTION_ON: u32 = 1;
445
446static WSARECVMSG_PTR: Lazy<WinSock::LPFN_WSARECVMSG> = Lazy::new(|| {
448 let s = unsafe { WinSock::socket(WinSock::AF_INET as _, WinSock::SOCK_DGRAM as _, 0) };
449 if s == WinSock::INVALID_SOCKET {
450 debug!(
451 "ignoring WSARecvMsg function pointer due to socket creation error: {}",
452 io::Error::last_os_error()
453 );
454 return None;
455 }
456
457 let guid = WinSock::WSAID_WSARECVMSG;
460 let mut wsa_recvmsg_ptr = None;
461 let mut len = 0;
462
463 let rc = unsafe {
465 WinSock::WSAIoctl(
466 s as _,
467 WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER,
468 &guid as *const _ as *const _,
469 mem::size_of_val(&guid) as u32,
470 &mut wsa_recvmsg_ptr as *mut _ as *mut _,
471 mem::size_of_val(&wsa_recvmsg_ptr) as u32,
472 &mut len,
473 ptr::null_mut(),
474 None,
475 )
476 };
477
478 if rc == -1 {
479 debug!(
480 "ignoring WSARecvMsg function pointer due to ioctl error: {}",
481 io::Error::last_os_error()
482 );
483 } else if len as usize != mem::size_of::<WinSock::LPFN_WSARECVMSG>() {
484 debug!("ignoring WSARecvMsg function pointer due to pointer size mismatch");
485 wsa_recvmsg_ptr = None;
486 }
487
488 unsafe {
489 WinSock::closesocket(s);
490 }
491
492 wsa_recvmsg_ptr
493});
494
495static MAX_GSO_SEGMENTS: Lazy<usize> = Lazy::new(|| {
496 let socket = match std::net::UdpSocket::bind("[::]:0")
497 .or_else(|_| std::net::UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)))
498 {
499 Ok(socket) => socket,
500 Err(_) => return 1,
501 };
502 const GSO_SIZE: c_uint = 1500;
503 match set_socket_option(
504 &socket,
505 WinSock::IPPROTO_UDP,
506 WinSock::UDP_SEND_MSG_SIZE,
507 GSO_SIZE,
508 ) {
509 Ok(()) => 512,
511 Err(_) => 1,
512 }
513});