rperf3/
batch_socket.rs

1//! Batch socket operations for improved performance.
2//!
3//! This module provides high-performance batch socket operations using
4//! platform-specific APIs like `sendmmsg` and `recvmmsg` on Linux.
5//! on platforms without native batch support, it falls back to standard
6//! socket operations.
7//!
8//! # Performance
9//!
10//! Batch operations can improve UDP throughput by 30-50% at high packet
11//! rates by reducing system call overhead.
12
13use std::io;
14use std::net::SocketAddr;
15
16/// Maximum number of messages to batch in a single operation.
17///
18/// This value balances throughput gains against latency. Too high
19/// and we introduce unnecessary latency, too low and we don't get
20/// the full benefit of batching.
21pub const MAX_BATCH_SIZE: usize = 64;
22
23/// A batch of UDP packets ready to send.
24///
25/// This structure holds multiple packets that can be sent in a single
26/// `sendmmsg` system call on Linux, or sent individually on other platforms.
27#[derive(Debug)]
28pub struct UdpSendBatch {
29    /// The packets to send
30    packets: Vec<Vec<u8>>,
31    /// Target addresses for each packet
32    addresses: Vec<SocketAddr>,
33}
34
35impl UdpSendBatch {
36    /// Creates a new empty batch.
37    pub fn new() -> Self {
38        Self {
39            packets: Vec::with_capacity(MAX_BATCH_SIZE),
40            addresses: Vec::with_capacity(MAX_BATCH_SIZE),
41        }
42    }
43
44    /// Creates a new batch with the specified capacity.
45    pub fn with_capacity(capacity: usize) -> Self {
46        Self {
47            packets: Vec::with_capacity(capacity),
48            addresses: Vec::with_capacity(capacity),
49        }
50    }
51
52    /// Adds a packet to the batch.
53    ///
54    /// Returns `true` if the packet was added, `false` if the batch is full.
55    pub fn add(&mut self, packet: Vec<u8>, addr: SocketAddr) -> bool {
56        if self.packets.len() >= MAX_BATCH_SIZE {
57            return false;
58        }
59        self.packets.push(packet);
60        self.addresses.push(addr);
61        true
62    }
63
64    /// Returns the number of packets in the batch.
65    pub fn len(&self) -> usize {
66        self.packets.len()
67    }
68
69    /// Returns `true` if the batch is empty.
70    pub fn is_empty(&self) -> bool {
71        self.packets.is_empty()
72    }
73
74    /// Returns `true` if the batch is full.
75    pub fn is_full(&self) -> bool {
76        self.packets.len() >= MAX_BATCH_SIZE
77    }
78
79    /// Clears the batch, removing all packets.
80    pub fn clear(&mut self) {
81        self.packets.clear();
82        self.addresses.clear();
83    }
84
85    /// Sends all packets in the batch using the most efficient method available.
86    ///
87    /// On Linux, uses `sendmmsg` for batched sending. On other platforms,
88    /// falls back to individual `send_to` calls.
89    ///
90    /// Returns the number of bytes sent and the number of packets successfully sent.
91    #[cfg(target_os = "linux")]
92    pub async fn send(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<(usize, usize)> {
93        if self.is_empty() {
94            return Ok((0, 0));
95        }
96
97        // Use sendmmsg for batch sending on Linux
98        self.send_mmsg(socket).await
99    }
100
101    /// Sends all packets in the batch using individual send_to calls.
102    #[cfg(not(target_os = "linux"))]
103    pub async fn send(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<(usize, usize)> {
104        if self.is_empty() {
105            return Ok((0, 0));
106        }
107
108        let mut total_bytes = 0;
109        let mut packets_sent = 0;
110
111        for (packet, addr) in self.packets.iter().zip(self.addresses.iter()) {
112            match socket.send_to(packet, addr).await {
113                Ok(n) => {
114                    total_bytes += n;
115                    packets_sent += 1;
116                }
117                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
118                    // Socket buffer full, stop sending
119                    break;
120                }
121                Err(e) => return Err(e),
122            }
123        }
124
125        // Remove sent packets from the batch
126        self.packets.drain(..packets_sent);
127        self.addresses.drain(..packets_sent);
128
129        Ok((total_bytes, packets_sent))
130    }
131
132    /// Linux-specific implementation using sendmmsg.
133    #[cfg(target_os = "linux")]
134    async fn send_mmsg(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<(usize, usize)> {
135        use std::os::unix::io::AsRawFd;
136
137        if self.is_empty() {
138            return Ok((0, 0));
139        }
140
141        let fd = socket.as_raw_fd();
142        let packets = &self.packets;
143        let addresses = &self.addresses;
144
145        // Call the synchronous helper that does all the unsafe work
146        let result = send_mmsg_sync(fd, packets, addresses)?;
147
148        // Remove sent packets from the batch
149        if result.1 > 0 {
150            self.packets.drain(..result.1);
151            self.addresses.drain(..result.1);
152        }
153
154        Ok(result)
155    }
156}
157
158/// Synchronous helper for sendmmsg (Linux only)
159#[cfg(target_os = "linux")]
160fn send_mmsg_sync(
161    fd: std::os::unix::io::RawFd,
162    packets: &[Vec<u8>],
163    addresses: &[SocketAddr],
164) -> io::Result<(usize, usize)> {
165    use libc::{
166        iovec, mmsghdr, sendmmsg, sockaddr_in, sockaddr_in6, sockaddr_storage, AF_INET, AF_INET6,
167        MSG_DONTWAIT,
168    };
169    use std::mem;
170
171    let count = packets.len();
172
173    // Prepare mmsghdr structures
174    let mut msgvec: Vec<mmsghdr> = Vec::with_capacity(count);
175    let mut iovecs: Vec<iovec> = Vec::with_capacity(count);
176    let mut addrs: Vec<sockaddr_storage> = Vec::with_capacity(count);
177
178    for (packet, addr) in packets.iter().zip(addresses.iter()) {
179        // Prepare iovec for this packet
180        let iov = iovec {
181            iov_base: packet.as_ptr() as *mut _,
182            iov_len: packet.len(),
183        };
184        iovecs.push(iov);
185
186        // Convert SocketAddr to sockaddr_storage
187        let mut storage: sockaddr_storage = unsafe { mem::zeroed() };
188        let addr_len = match addr {
189            SocketAddr::V4(v4) => {
190                let sin = sockaddr_in {
191                    sin_family: AF_INET as u16,
192                    sin_port: v4.port().to_be(),
193                    sin_addr: libc::in_addr {
194                        s_addr: u32::from_ne_bytes(v4.ip().octets()),
195                    },
196                    sin_zero: [0; 8],
197                };
198                unsafe {
199                    std::ptr::copy_nonoverlapping(
200                        &sin as *const _ as *const u8,
201                        &mut storage as *mut _ as *mut u8,
202                        mem::size_of::<sockaddr_in>(),
203                    );
204                }
205                mem::size_of::<sockaddr_in>() as u32
206            }
207            SocketAddr::V6(v6) => {
208                let sin6 = sockaddr_in6 {
209                    sin6_family: AF_INET6 as u16,
210                    sin6_port: v6.port().to_be(),
211                    sin6_flowinfo: 0,
212                    sin6_addr: libc::in6_addr {
213                        s6_addr: v6.ip().octets(),
214                    },
215                    sin6_scope_id: 0,
216                };
217                unsafe {
218                    std::ptr::copy_nonoverlapping(
219                        &sin6 as *const _ as *const u8,
220                        &mut storage as *mut _ as *mut u8,
221                        mem::size_of::<sockaddr_in6>(),
222                    );
223                }
224                mem::size_of::<sockaddr_in6>() as u32
225            }
226        };
227        addrs.push(storage);
228
229        // Prepare mmsghdr
230        let mut hdr: mmsghdr = unsafe { mem::zeroed() };
231        hdr.msg_hdr.msg_name = addrs.last_mut().unwrap() as *mut _ as *mut _;
232        hdr.msg_hdr.msg_namelen = addr_len;
233        hdr.msg_hdr.msg_iov = iovecs.last_mut().unwrap() as *mut _;
234        hdr.msg_hdr.msg_iovlen = 1;
235        msgvec.push(hdr);
236    }
237
238    // Perform the sendmmsg operation - this is non-blocking (MSG_DONTWAIT)
239    // Note: MSG_DONTWAIT is i32 on gnu libc but u32 on musl libc
240    #[cfg(target_env = "musl")]
241    let ret = unsafe { sendmmsg(fd, msgvec.as_mut_ptr(), count as u32, MSG_DONTWAIT as u32) };
242    #[cfg(not(target_env = "musl"))]
243    let ret = unsafe { sendmmsg(fd, msgvec.as_mut_ptr(), count as u32, MSG_DONTWAIT) };
244
245    if ret < 0 {
246        let err = io::Error::last_os_error();
247        // If the socket would block, return what we can send (0)
248        if err.kind() == io::ErrorKind::WouldBlock {
249            return Ok((0, 0));
250        }
251        return Err(err);
252    }
253
254    // Calculate total bytes sent
255    let packets_sent = ret as usize;
256    let total_bytes = msgvec
257        .iter()
258        .take(packets_sent)
259        .map(|msg| msg.msg_len as usize)
260        .sum();
261
262    Ok((total_bytes, packets_sent))
263}
264
265impl Default for UdpSendBatch {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271/// A batch of UDP packets received.
272///
273/// This structure holds multiple received packets from a single
274/// `recvmmsg` system call on Linux.
275#[derive(Debug)]
276pub struct UdpRecvBatch {
277    /// The received packets
278    packets: Vec<Vec<u8>>,
279    /// Source addresses for each packet
280    addresses: Vec<SocketAddr>,
281    /// Number of valid packets in the batch
282    count: usize,
283}
284
285impl UdpRecvBatch {
286    /// Creates a new empty batch with pre-allocated buffers.
287    pub fn new() -> Self {
288        let mut packets = Vec::with_capacity(MAX_BATCH_SIZE);
289        for _ in 0..MAX_BATCH_SIZE {
290            packets.push(vec![0u8; 65536]); // Max UDP packet size
291        }
292
293        Self {
294            packets,
295            addresses: vec![SocketAddr::from(([0, 0, 0, 0], 0)); MAX_BATCH_SIZE],
296            count: 0,
297        }
298    }
299
300    /// Receives a batch of packets using the most efficient method available.
301    ///
302    /// On Linux, uses `recvmmsg` for batched receiving. On other platforms,
303    /// receives packets individually up to MAX_BATCH_SIZE.
304    ///
305    /// Returns the number of packets received.
306    #[cfg(target_os = "linux")]
307    pub async fn recv(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<usize> {
308        self.recv_mmsg(socket).await
309    }
310
311    /// Receives packets using individual recv_from calls.
312    #[cfg(not(target_os = "linux"))]
313    pub async fn recv(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<usize> {
314        self.count = 0;
315
316        // Try to receive multiple packets without blocking
317        for i in 0..MAX_BATCH_SIZE {
318            match socket.try_recv_from(&mut self.packets[i]) {
319                Ok((n, addr)) => {
320                    self.packets[i].truncate(n);
321                    self.addresses[i] = addr;
322                    self.count += 1;
323                }
324                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
325                    // No more packets available
326                    break;
327                }
328                Err(e) => return Err(e),
329            }
330        }
331
332        // If we got no packets with try_recv, do a blocking receive for at least one
333        if self.count == 0 {
334            match socket.recv_from(&mut self.packets[0]).await {
335                Ok((n, addr)) => {
336                    self.packets[0].truncate(n);
337                    self.addresses[0] = addr;
338                    self.count = 1;
339                }
340                Err(e) => return Err(e),
341            }
342        }
343
344        Ok(self.count)
345    }
346
347    /// Linux-specific implementation using recvmmsg.
348    #[cfg(target_os = "linux")]
349    async fn recv_mmsg(&mut self, socket: &tokio::net::UdpSocket) -> io::Result<usize> {
350        use std::os::unix::io::AsRawFd;
351
352        let fd = socket.as_raw_fd();
353
354        // Prepare buffers for receiving
355        for packet in self.packets.iter_mut() {
356            packet.resize(65536, 0);
357        }
358
359        // Try non-blocking receive first
360        let count = match recv_mmsg_sync(fd, &mut self.packets, &mut self.addresses, false) {
361            Ok(count) if count > 0 => count,
362            Ok(0) => {
363                // No packets available, wait for socket to be readable
364                socket.readable().await?;
365                // Try again after socket is readable
366                recv_mmsg_sync(fd, &mut self.packets, &mut self.addresses, false)?
367            }
368            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
369                // Wait for socket to be readable
370                socket.readable().await?;
371                // Try again after socket is readable
372                recv_mmsg_sync(fd, &mut self.packets, &mut self.addresses, false)?
373            }
374            Ok(count) => count,
375            Err(e) => return Err(e),
376        };
377
378        self.count = count;
379        Ok(count)
380    }
381
382    /// Returns the number of packets in the batch.
383    pub fn len(&self) -> usize {
384        self.count
385    }
386
387    /// Returns `true` if the batch is empty.
388    pub fn is_empty(&self) -> bool {
389        self.count == 0
390    }
391
392    /// Gets a reference to a packet and its source address by index.
393    pub fn get(&self, index: usize) -> Option<(&[u8], SocketAddr)> {
394        if index < self.count {
395            Some((&self.packets[index], self.addresses[index]))
396        } else {
397            None
398        }
399    }
400
401    /// Returns an iterator over the packets and their source addresses.
402    pub fn iter(&self) -> impl Iterator<Item = (&[u8], SocketAddr)> {
403        self.packets[..self.count]
404            .iter()
405            .zip(self.addresses[..self.count].iter())
406            .map(|(p, a)| (p.as_slice(), *a))
407    }
408}
409
410impl Default for UdpRecvBatch {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416/// Converts a sockaddr_storage to a SocketAddr (Linux-specific).
417#[cfg(target_os = "linux")]
418fn sockaddr_to_socketaddr(storage: &libc::sockaddr_storage, _len: u32) -> io::Result<SocketAddr> {
419    use libc::{AF_INET, AF_INET6};
420    use std::net::{Ipv4Addr, Ipv6Addr};
421
422    unsafe {
423        match storage.ss_family as i32 {
424            AF_INET => {
425                let sin: *const libc::sockaddr_in = storage as *const _ as *const _;
426                let addr = Ipv4Addr::from(u32::from_be((*sin).sin_addr.s_addr).to_ne_bytes());
427                let port = u16::from_be((*sin).sin_port);
428                Ok(SocketAddr::from((addr, port)))
429            }
430            AF_INET6 => {
431                let sin6: *const libc::sockaddr_in6 = storage as *const _ as *const _;
432                let addr = Ipv6Addr::from((*sin6).sin6_addr.s6_addr);
433                let port = u16::from_be((*sin6).sin6_port);
434                Ok(SocketAddr::from((addr, port)))
435            }
436            _ => Err(io::Error::new(
437                io::ErrorKind::InvalidInput,
438                "Unsupported address family",
439            )),
440        }
441    }
442}
443
444/// Synchronous helper for recvmmsg (Linux only)
445#[cfg(target_os = "linux")]
446fn recv_mmsg_sync(
447    fd: std::os::unix::io::RawFd,
448    packets: &mut [Vec<u8>],
449    addresses: &mut [SocketAddr],
450    _blocking: bool,
451) -> io::Result<usize> {
452    use libc::{iovec, mmsghdr, recvmmsg, sockaddr_storage, MSG_DONTWAIT};
453    use std::mem;
454
455    let count = packets.len().min(MAX_BATCH_SIZE);
456
457    // Prepare mmsghdr structures
458    let mut msgvec: Vec<mmsghdr> = Vec::with_capacity(count);
459    let mut iovecs: Vec<iovec> = Vec::with_capacity(count);
460    let mut addrs: Vec<sockaddr_storage> = Vec::with_capacity(count);
461
462    for packet in packets.iter_mut().take(count) {
463        // Prepare iovec for this packet
464        let iov = iovec {
465            iov_base: packet.as_mut_ptr() as *mut _,
466            iov_len: packet.len(),
467        };
468        iovecs.push(iov);
469
470        // Prepare address storage
471        let storage: sockaddr_storage = unsafe { mem::zeroed() };
472        addrs.push(storage);
473
474        // Prepare mmsghdr
475        let mut hdr: mmsghdr = unsafe { mem::zeroed() };
476        hdr.msg_hdr.msg_name = addrs.last_mut().unwrap() as *mut _ as *mut _;
477        hdr.msg_hdr.msg_namelen = mem::size_of::<sockaddr_storage>() as u32;
478        hdr.msg_hdr.msg_iov = iovecs.last_mut().unwrap() as *mut _;
479        hdr.msg_hdr.msg_iovlen = 1;
480        msgvec.push(hdr);
481    }
482
483    // Perform the recvmmsg operation
484    // Note: MSG_DONTWAIT is i32 on gnu libc but u32 on musl libc
485    #[cfg(target_env = "musl")]
486    let ret = unsafe {
487        recvmmsg(
488            fd,
489            msgvec.as_mut_ptr(),
490            count as u32,
491            MSG_DONTWAIT as u32,
492            std::ptr::null_mut(),
493        )
494    };
495    #[cfg(not(target_env = "musl"))]
496    let ret = unsafe {
497        recvmmsg(
498            fd,
499            msgvec.as_mut_ptr(),
500            count as u32,
501            MSG_DONTWAIT,
502            std::ptr::null_mut(),
503        )
504    };
505
506    if ret < 0 {
507        return Err(io::Error::last_os_error());
508    }
509
510    let received_count = ret as usize;
511
512    // Truncate buffers and extract addresses
513    for (i, msg) in msgvec.iter().enumerate().take(received_count) {
514        let bytes_received = msg.msg_len as usize;
515        packets[i].truncate(bytes_received);
516
517        // Convert sockaddr_storage to SocketAddr
518        addresses[i] = sockaddr_to_socketaddr(&addrs[i], msg.msg_hdr.msg_namelen)?;
519    }
520
521    Ok(received_count)
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_batch_capacity() {
530        let mut batch = UdpSendBatch::new();
531        assert_eq!(batch.len(), 0);
532        assert!(batch.is_empty());
533        assert!(!batch.is_full());
534
535        for i in 0..MAX_BATCH_SIZE {
536            let packet = vec![i as u8; 100];
537            let addr = SocketAddr::from(([127, 0, 0, 1], 5000));
538            assert!(batch.add(packet, addr));
539        }
540
541        assert_eq!(batch.len(), MAX_BATCH_SIZE);
542        assert!(!batch.is_empty());
543        assert!(batch.is_full());
544
545        // Should not be able to add more
546        let packet = vec![0u8; 100];
547        let addr = SocketAddr::from(([127, 0, 0, 1], 5000));
548        assert!(!batch.add(packet, addr));
549    }
550
551    #[test]
552    fn test_batch_clear() {
553        let mut batch = UdpSendBatch::new();
554
555        for i in 0..10 {
556            let packet = vec![i as u8; 100];
557            let addr = SocketAddr::from(([127, 0, 0, 1], 5000));
558            batch.add(packet, addr);
559        }
560
561        assert_eq!(batch.len(), 10);
562        batch.clear();
563        assert_eq!(batch.len(), 0);
564        assert!(batch.is_empty());
565    }
566}