1use std::io;
14use std::net::SocketAddr;
15
16pub const MAX_BATCH_SIZE: usize = 64;
22
23#[derive(Debug)]
28pub struct UdpSendBatch {
29 packets: Vec<Vec<u8>>,
31 addresses: Vec<SocketAddr>,
33}
34
35impl UdpSendBatch {
36 pub fn new() -> Self {
38 Self {
39 packets: Vec::with_capacity(MAX_BATCH_SIZE),
40 addresses: Vec::with_capacity(MAX_BATCH_SIZE),
41 }
42 }
43
44 pub fn with_capacity(capacity: usize) -> Self {
46 Self {
47 packets: Vec::with_capacity(capacity),
48 addresses: Vec::with_capacity(capacity),
49 }
50 }
51
52 pub fn add(&mut self, packet: Vec<u8>, addr: SocketAddr) -> bool {
56 if self.packets.len() >= MAX_BATCH_SIZE {
57 return false;
58 }
59 self.packets.push(packet);
60 self.addresses.push(addr);
61 true
62 }
63
64 pub fn len(&self) -> usize {
66 self.packets.len()
67 }
68
69 pub fn is_empty(&self) -> bool {
71 self.packets.is_empty()
72 }
73
74 pub fn is_full(&self) -> bool {
76 self.packets.len() >= MAX_BATCH_SIZE
77 }
78
79 pub fn clear(&mut self) {
81 self.packets.clear();
82 self.addresses.clear();
83 }
84
85 #[cfg(target_os = "linux")]
92 pub async fn send(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<(usize, usize)> {
93 if self.is_empty() {
94 return Ok((0, 0));
95 }
96
97 self.send_mmsg(socket).await
99 }
100
101 #[cfg(not(target_os = "linux"))]
103 pub async fn send(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<(usize, usize)> {
104 if self.is_empty() {
105 return Ok((0, 0));
106 }
107
108 let mut total_bytes = 0;
109 let mut packets_sent = 0;
110
111 for (packet, addr) in self.packets.iter().zip(self.addresses.iter()) {
112 match socket.send_to(packet, addr).await {
113 Ok(n) => {
114 total_bytes += n;
115 packets_sent += 1;
116 }
117 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
118 break;
120 }
121 Err(e) => return Err(e),
122 }
123 }
124
125 self.packets.drain(..packets_sent);
127 self.addresses.drain(..packets_sent);
128
129 Ok((total_bytes, packets_sent))
130 }
131
132 #[cfg(target_os = "linux")]
134 async fn send_mmsg(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<(usize, usize)> {
135 use std::os::unix::io::AsRawFd;
136
137 if self.is_empty() {
138 return Ok((0, 0));
139 }
140
141 let fd = socket.as_raw_fd();
142 let packets = &self.packets;
143 let addresses = &self.addresses;
144
145 let result = send_mmsg_sync(fd, packets, addresses)?;
147
148 if result.1 > 0 {
150 self.packets.drain(..result.1);
151 self.addresses.drain(..result.1);
152 }
153
154 Ok(result)
155 }
156}
157
158#[cfg(target_os = "linux")]
160fn send_mmsg_sync(
161 fd: std::os::unix::io::RawFd,
162 packets: &[Vec<u8>],
163 addresses: &[SocketAddr],
164) -> io::Result<(usize, usize)> {
165 use libc::{
166 iovec, mmsghdr, sendmmsg, sockaddr_in, sockaddr_in6, sockaddr_storage, AF_INET, AF_INET6,
167 MSG_DONTWAIT,
168 };
169 use std::mem;
170
171 let count = packets.len();
172
173 let mut msgvec: Vec<mmsghdr> = Vec::with_capacity(count);
175 let mut iovecs: Vec<iovec> = Vec::with_capacity(count);
176 let mut addrs: Vec<sockaddr_storage> = Vec::with_capacity(count);
177
178 for (packet, addr) in packets.iter().zip(addresses.iter()) {
179 let iov = iovec {
181 iov_base: packet.as_ptr() as *mut _,
182 iov_len: packet.len(),
183 };
184 iovecs.push(iov);
185
186 let mut storage: sockaddr_storage = unsafe { mem::zeroed() };
188 let addr_len = match addr {
189 SocketAddr::V4(v4) => {
190 let sin = sockaddr_in {
191 sin_family: AF_INET as u16,
192 sin_port: v4.port().to_be(),
193 sin_addr: libc::in_addr {
194 s_addr: u32::from_ne_bytes(v4.ip().octets()),
195 },
196 sin_zero: [0; 8],
197 };
198 unsafe {
199 std::ptr::copy_nonoverlapping(
200 &sin as *const _ as *const u8,
201 &mut storage as *mut _ as *mut u8,
202 mem::size_of::<sockaddr_in>(),
203 );
204 }
205 mem::size_of::<sockaddr_in>() as u32
206 }
207 SocketAddr::V6(v6) => {
208 let sin6 = sockaddr_in6 {
209 sin6_family: AF_INET6 as u16,
210 sin6_port: v6.port().to_be(),
211 sin6_flowinfo: 0,
212 sin6_addr: libc::in6_addr {
213 s6_addr: v6.ip().octets(),
214 },
215 sin6_scope_id: 0,
216 };
217 unsafe {
218 std::ptr::copy_nonoverlapping(
219 &sin6 as *const _ as *const u8,
220 &mut storage as *mut _ as *mut u8,
221 mem::size_of::<sockaddr_in6>(),
222 );
223 }
224 mem::size_of::<sockaddr_in6>() as u32
225 }
226 };
227 addrs.push(storage);
228
229 let mut hdr: mmsghdr = unsafe { mem::zeroed() };
231 hdr.msg_hdr.msg_name = addrs.last_mut().unwrap() as *mut _ as *mut _;
232 hdr.msg_hdr.msg_namelen = addr_len;
233 hdr.msg_hdr.msg_iov = iovecs.last_mut().unwrap() as *mut _;
234 hdr.msg_hdr.msg_iovlen = 1;
235 msgvec.push(hdr);
236 }
237
238 #[cfg(target_env = "musl")]
241 let ret = unsafe { sendmmsg(fd, msgvec.as_mut_ptr(), count as u32, MSG_DONTWAIT as u32) };
242 #[cfg(not(target_env = "musl"))]
243 let ret = unsafe { sendmmsg(fd, msgvec.as_mut_ptr(), count as u32, MSG_DONTWAIT) };
244
245 if ret < 0 {
246 let err = io::Error::last_os_error();
247 if err.kind() == io::ErrorKind::WouldBlock {
249 return Ok((0, 0));
250 }
251 return Err(err);
252 }
253
254 let packets_sent = ret as usize;
256 let total_bytes = msgvec
257 .iter()
258 .take(packets_sent)
259 .map(|msg| msg.msg_len as usize)
260 .sum();
261
262 Ok((total_bytes, packets_sent))
263}
264
265impl Default for UdpSendBatch {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271#[derive(Debug)]
276pub struct UdpRecvBatch {
277 packets: Vec<Vec<u8>>,
279 addresses: Vec<SocketAddr>,
281 count: usize,
283}
284
285impl UdpRecvBatch {
286 pub fn new() -> Self {
288 let mut packets = Vec::with_capacity(MAX_BATCH_SIZE);
289 for _ in 0..MAX_BATCH_SIZE {
290 packets.push(vec![0u8; 65536]); }
292
293 Self {
294 packets,
295 addresses: vec![SocketAddr::from(([0, 0, 0, 0], 0)); MAX_BATCH_SIZE],
296 count: 0,
297 }
298 }
299
300 #[cfg(target_os = "linux")]
307 pub async fn recv(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<usize> {
308 self.recv_mmsg(socket).await
309 }
310
311 #[cfg(not(target_os = "linux"))]
313 pub async fn recv(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<usize> {
314 self.count = 0;
315
316 for i in 0..MAX_BATCH_SIZE {
318 match socket.try_recv_from(&mut self.packets[i]) {
319 Ok((n, addr)) => {
320 self.packets[i].truncate(n);
321 self.addresses[i] = addr;
322 self.count += 1;
323 }
324 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
325 break;
327 }
328 Err(e) => return Err(e),
329 }
330 }
331
332 if self.count == 0 {
334 match socket.recv_from(&mut self.packets[0]).await {
335 Ok((n, addr)) => {
336 self.packets[0].truncate(n);
337 self.addresses[0] = addr;
338 self.count = 1;
339 }
340 Err(e) => return Err(e),
341 }
342 }
343
344 Ok(self.count)
345 }
346
347 #[cfg(target_os = "linux")]
349 async fn recv_mmsg(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<usize> {
350 use std::os::unix::io::AsRawFd;
351
352 let fd = socket.as_raw_fd();
353
354 for packet in self.packets.iter_mut() {
356 packet.resize(65536, 0);
357 }
358
359 let count = match recv_mmsg_sync(fd, &mut self.packets, &mut self.addresses, false) {
361 Ok(count) if count > 0 => count,
362 Ok(0) => {
363 socket.readable().await?;
365 recv_mmsg_sync(fd, &mut self.packets, &mut self.addresses, false)?
367 }
368 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
369 socket.readable().await?;
371 recv_mmsg_sync(fd, &mut self.packets, &mut self.addresses, false)?
373 }
374 Ok(count) => count,
375 Err(e) => return Err(e),
376 };
377
378 self.count = count;
379 Ok(count)
380 }
381
382 pub fn len(&self) -> usize {
384 self.count
385 }
386
387 pub fn is_empty(&self) -> bool {
389 self.count == 0
390 }
391
392 pub fn get(&self, index: usize) -> Option<(&[u8], SocketAddr)> {
394 if index < self.count {
395 Some((&self.packets[index], self.addresses[index]))
396 } else {
397 None
398 }
399 }
400
401 pub fn iter(&self) -> impl Iterator<Item = (&[u8], SocketAddr)> {
403 self.packets[..self.count]
404 .iter()
405 .zip(self.addresses[..self.count].iter())
406 .map(|(p, a)| (p.as_slice(), *a))
407 }
408}
409
410impl Default for UdpRecvBatch {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[cfg(target_os = "linux")]
418fn sockaddr_to_socketaddr(storage: &libc::sockaddr_storage, _len: u32) -> io::Result<SocketAddr> {
419 use libc::{AF_INET, AF_INET6};
420 use std::net::{Ipv4Addr, Ipv6Addr};
421
422 unsafe {
423 match storage.ss_family as i32 {
424 AF_INET => {
425 let sin: *const libc::sockaddr_in = storage as *const _ as *const _;
426 let addr = Ipv4Addr::from(u32::from_be((*sin).sin_addr.s_addr).to_ne_bytes());
427 let port = u16::from_be((*sin).sin_port);
428 Ok(SocketAddr::from((addr, port)))
429 }
430 AF_INET6 => {
431 let sin6: *const libc::sockaddr_in6 = storage as *const _ as *const _;
432 let addr = Ipv6Addr::from((*sin6).sin6_addr.s6_addr);
433 let port = u16::from_be((*sin6).sin6_port);
434 Ok(SocketAddr::from((addr, port)))
435 }
436 _ => Err(io::Error::new(
437 io::ErrorKind::InvalidInput,
438 "Unsupported address family",
439 )),
440 }
441 }
442}
443
444#[cfg(target_os = "linux")]
446fn recv_mmsg_sync(
447 fd: std::os::unix::io::RawFd,
448 packets: &mut [Vec<u8>],
449 addresses: &mut [SocketAddr],
450 _blocking: bool,
451) -> io::Result<usize> {
452 use libc::{iovec, mmsghdr, recvmmsg, sockaddr_storage, MSG_DONTWAIT};
453 use std::mem;
454
455 let count = packets.len().min(MAX_BATCH_SIZE);
456
457 let mut msgvec: Vec<mmsghdr> = Vec::with_capacity(count);
459 let mut iovecs: Vec<iovec> = Vec::with_capacity(count);
460 let mut addrs: Vec<sockaddr_storage> = Vec::with_capacity(count);
461
462 for packet in packets.iter_mut().take(count) {
463 let iov = iovec {
465 iov_base: packet.as_mut_ptr() as *mut _,
466 iov_len: packet.len(),
467 };
468 iovecs.push(iov);
469
470 let storage: sockaddr_storage = unsafe { mem::zeroed() };
472 addrs.push(storage);
473
474 let mut hdr: mmsghdr = unsafe { mem::zeroed() };
476 hdr.msg_hdr.msg_name = addrs.last_mut().unwrap() as *mut _ as *mut _;
477 hdr.msg_hdr.msg_namelen = mem::size_of::<sockaddr_storage>() as u32;
478 hdr.msg_hdr.msg_iov = iovecs.last_mut().unwrap() as *mut _;
479 hdr.msg_hdr.msg_iovlen = 1;
480 msgvec.push(hdr);
481 }
482
483 #[cfg(target_env = "musl")]
486 let ret = unsafe {
487 recvmmsg(
488 fd,
489 msgvec.as_mut_ptr(),
490 count as u32,
491 MSG_DONTWAIT as u32,
492 std::ptr::null_mut(),
493 )
494 };
495 #[cfg(not(target_env = "musl"))]
496 let ret = unsafe {
497 recvmmsg(
498 fd,
499 msgvec.as_mut_ptr(),
500 count as u32,
501 MSG_DONTWAIT,
502 std::ptr::null_mut(),
503 )
504 };
505
506 if ret < 0 {
507 return Err(io::Error::last_os_error());
508 }
509
510 let received_count = ret as usize;
511
512 for (i, msg) in msgvec.iter().enumerate().take(received_count) {
514 let bytes_received = msg.msg_len as usize;
515 packets[i].truncate(bytes_received);
516
517 addresses[i] = sockaddr_to_socketaddr(&addrs[i], msg.msg_hdr.msg_namelen)?;
519 }
520
521 Ok(received_count)
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_batch_capacity() {
530 let mut batch = UdpSendBatch::new();
531 assert_eq!(batch.len(), 0);
532 assert!(batch.is_empty());
533 assert!(!batch.is_full());
534
535 for i in 0..MAX_BATCH_SIZE {
536 let packet = vec![i as u8; 100];
537 let addr = SocketAddr::from(([127, 0, 0, 1], 5000));
538 assert!(batch.add(packet, addr));
539 }
540
541 assert_eq!(batch.len(), MAX_BATCH_SIZE);
542 assert!(!batch.is_empty());
543 assert!(batch.is_full());
544
545 let packet = vec![0u8; 100];
547 let addr = SocketAddr::from(([127, 0, 0, 1], 5000));
548 assert!(!batch.add(packet, addr));
549 }
550
551 #[test]
552 fn test_batch_clear() {
553 let mut batch = UdpSendBatch::new();
554
555 for i in 0..10 {
556 let packet = vec![i as u8; 100];
557 let addr = SocketAddr::from(([127, 0, 0, 1], 5000));
558 batch.add(packet, addr);
559 }
560
561 assert_eq!(batch.len(), 10);
562 batch.clear();
563 assert_eq!(batch.len(), 0);
564 assert!(batch.is_empty());
565 }
566}