solana_streamer/
recvmmsg.rs

1//! The `recvmmsg` module provides recvmmsg() API implementation
2
3pub use solana_perf::packet::PACKETS_PER_BATCH;
4#[cfg(target_os = "linux")]
5use {
6    crate::msghdr::create_msghdr,
7    itertools::izip,
8    libc::{iovec, mmsghdr, sockaddr_storage, socklen_t, AF_INET, AF_INET6, MSG_WAITFORONE},
9    std::{
10        mem::{self, MaybeUninit},
11        net::{SocketAddr, SocketAddrV4, SocketAddrV6},
12        os::unix::io::AsRawFd,
13    },
14};
15use {
16    crate::packet::{Meta, Packet},
17    std::{cmp, io, net::UdpSocket},
18};
19
20#[cfg(not(target_os = "linux"))]
21pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num packets:*/ usize> {
22    debug_assert!(packets.iter().all(|pkt| pkt.meta() == &Meta::default()));
23    let mut i = 0;
24    let count = cmp::min(PACKETS_PER_BATCH, packets.len());
25    for p in packets.iter_mut().take(count) {
26        p.meta_mut().size = 0;
27        match socket.recv_from(p.buffer_mut()) {
28            Err(_) if i > 0 => {
29                break;
30            }
31            Err(e) => {
32                return Err(e);
33            }
34            Ok((nrecv, from)) => {
35                p.meta_mut().size = nrecv;
36                p.meta_mut().set_socket_addr(&from);
37                if i == 0 {
38                    socket.set_nonblocking(true)?;
39                }
40            }
41        }
42        i += 1;
43    }
44    Ok(i)
45}
46
47#[cfg(target_os = "linux")]
48fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option<SocketAddr> {
49    use libc::{sa_family_t, sockaddr_in, sockaddr_in6};
50    const SOCKADDR_IN_SIZE: usize = std::mem::size_of::<sockaddr_in>();
51    const SOCKADDR_IN6_SIZE: usize = std::mem::size_of::<sockaddr_in6>();
52    if addr.ss_family == AF_INET as sa_family_t
53        && hdr.msg_hdr.msg_namelen == SOCKADDR_IN_SIZE as socklen_t
54    {
55        // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L167-L172
56        let addr = unsafe { &*(addr as *const _ as *const sockaddr_in) };
57        return Some(SocketAddr::V4(SocketAddrV4::new(
58            std::net::Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()),
59            u16::from_be(addr.sin_port),
60        )));
61    }
62    if addr.ss_family == AF_INET6 as sa_family_t
63        && hdr.msg_hdr.msg_namelen == SOCKADDR_IN6_SIZE as socklen_t
64    {
65        // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L174-L189
66        let addr = unsafe { &*(addr as *const _ as *const sockaddr_in6) };
67        return Some(SocketAddr::V6(SocketAddrV6::new(
68            std::net::Ipv6Addr::from(addr.sin6_addr.s6_addr),
69            u16::from_be(addr.sin6_port),
70            addr.sin6_flowinfo,
71            addr.sin6_scope_id,
72        )));
73    }
74    error!(
75        "recvmmsg unexpected ss_family:{} msg_namelen:{}",
76        addr.ss_family, hdr.msg_hdr.msg_namelen
77    );
78    None
79}
80
81/** Receive multiple messages from `sock` into buffer provided in `packets`.
82This is a wrapper around recvmmsg(7) call.
83
84The buffer provided in packets should have all `meta()` fields cleared before calling
85this function
86
87
88 This function is *supposed to* timeout in 1 second and *may* block forever
89 due to a bug in the linux kernel.
90 You may want to call `sock.set_read_timeout(Some(Duration::from_secs(1)));` or similar
91 prior to calling this function if you require this to actually time out after 1 second.
92*/
93#[cfg(target_os = "linux")]
94pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num packets:*/ usize> {
95    // Should never hit this, but bail if the caller didn't provide any Packets
96    // to receive into
97    if packets.is_empty() {
98        return Ok(0);
99    }
100    // Assert that there are no leftovers in packets.
101    debug_assert!(packets.iter().all(|pkt| pkt.meta() == &Meta::default()));
102    const SOCKADDR_STORAGE_SIZE: socklen_t = mem::size_of::<sockaddr_storage>() as socklen_t;
103
104    let mut iovs = [MaybeUninit::uninit(); PACKETS_PER_BATCH];
105    let mut addrs = [MaybeUninit::zeroed(); PACKETS_PER_BATCH];
106    let mut hdrs = [MaybeUninit::uninit(); PACKETS_PER_BATCH];
107
108    let sock_fd = sock.as_raw_fd();
109    let count = cmp::min(iovs.len(), packets.len());
110
111    for (packet, hdr, iov, addr) in
112        izip!(packets.iter_mut(), &mut hdrs, &mut iovs, &mut addrs).take(count)
113    {
114        let buffer = packet.buffer_mut();
115        iov.write(iovec {
116            iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
117            iov_len: buffer.len(),
118        });
119
120        let msg_hdr = create_msghdr(addr, SOCKADDR_STORAGE_SIZE, iov);
121
122        hdr.write(mmsghdr {
123            msg_len: 0,
124            msg_hdr,
125        });
126    }
127
128    let mut ts = libc::timespec {
129        tv_sec: 1,
130        tv_nsec: 0,
131    };
132    // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl
133    #[allow(clippy::useless_conversion)]
134    let nrecv = unsafe {
135        libc::recvmmsg(
136            sock_fd,
137            hdrs[0].assume_init_mut(),
138            count as u32,
139            MSG_WAITFORONE.try_into().unwrap(),
140            &mut ts,
141        )
142    };
143    let nrecv = if nrecv < 0 {
144        return Err(io::Error::last_os_error());
145    } else {
146        usize::try_from(nrecv).unwrap()
147    };
148    for (addr, hdr, pkt) in izip!(addrs, hdrs, packets.iter_mut()).take(nrecv) {
149        // SAFETY: We initialized `count` elements of `hdrs` above. `count` is
150        // passed to recvmmsg() as the limit of messages that can be read. So,
151        // `nrevc <= count` which means we initialized this `hdr` and
152        // recvmmsg() will have updated it appropriately
153        let hdr_ref = unsafe { hdr.assume_init_ref() };
154        // SAFETY: Similar to above, we initialized this `addr` and recvmmsg()
155        // will have populated it
156        let addr_ref = unsafe { addr.assume_init_ref() };
157        pkt.meta_mut().size = hdr_ref.msg_len as usize;
158        if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) {
159            pkt.meta_mut().set_socket_addr(&addr);
160        }
161    }
162
163    for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) {
164        // SAFETY: We initialized `count` elements of each array above
165        //
166        // It may be that `packets.len() != PACKETS_PER_BATCH`; thus, some elements
167        // in `iovs` / `addrs` / `hdrs` may not get initialized. So, we must
168        // manually drop `count` elements from each array instead of being able
169        // to convert [MaybeUninit<T>] to [T] and letting `Drop` do the work
170        // for us when these items go out of scope at the end of the function
171        unsafe {
172            iov.assume_init_drop();
173            addr.assume_init_drop();
174            hdr.assume_init_drop();
175        }
176    }
177
178    Ok(nrecv)
179}
180
181#[cfg(test)]
182mod tests {
183    use {
184        crate::{packet::PACKET_DATA_SIZE, recvmmsg::*},
185        solana_net_utils::sockets::{
186            bind_in_range_with_config, localhost_port_range_for_tests, unique_port_range_for_tests,
187            SocketConfiguration as SocketConfig,
188        },
189        std::{
190            net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
191            time::{Duration, Instant},
192        },
193    };
194
195    type TestConfig = (UdpSocket, SocketAddr, UdpSocket, SocketAddr);
196
197    fn test_setup_reader_sender(ip: IpAddr) -> io::Result<TestConfig> {
198        let port_range = unique_port_range_for_tests(2);
199        let reader = bind_in_range_with_config(
200            ip,
201            (port_range.start, port_range.end),
202            SocketConfig::default(),
203        )?
204        .1;
205        let reader_addr = reader.local_addr()?;
206        let sender = bind_in_range_with_config(
207            ip,
208            (port_range.start, port_range.end),
209            SocketConfig::default(),
210        )?
211        .1;
212        let sender_addr = sender.local_addr()?;
213        Ok((reader, reader_addr, sender, sender_addr))
214    }
215
216    const TEST_NUM_MSGS: usize = 32;
217    #[test]
218    pub fn test_recv_mmsg_one_iter() {
219        let test_one_iter = |(reader, addr, sender, saddr): TestConfig| {
220            let sent = TEST_NUM_MSGS - 1;
221            for _ in 0..sent {
222                let data = [0; PACKET_DATA_SIZE];
223                sender.send_to(&data[..], addr).unwrap();
224            }
225
226            let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
227            let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
228            assert_eq!(sent, recv);
229            for packet in packets.iter().take(recv) {
230                assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
231                assert_eq!(packet.meta().socket_addr(), saddr);
232            }
233        };
234
235        test_one_iter(test_setup_reader_sender(IpAddr::V4(Ipv4Addr::LOCALHOST)).unwrap());
236
237        match test_setup_reader_sender(IpAddr::V6(Ipv6Addr::LOCALHOST)) {
238            Ok(config) => test_one_iter(config),
239            Err(e) => warn!("Failed to configure IPv6: {e:?}"),
240        }
241    }
242
243    #[test]
244    pub fn test_recv_mmsg_multi_iter() {
245        let test_multi_iter = |(reader, addr, sender, saddr): TestConfig| {
246            let sent = TEST_NUM_MSGS + 10;
247            for _ in 0..sent {
248                let data = [0; PACKET_DATA_SIZE];
249                sender.send_to(&data[..], addr).unwrap();
250            }
251
252            let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
253            let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
254            assert_eq!(TEST_NUM_MSGS, recv);
255            for packet in packets.iter().take(recv) {
256                assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
257                assert_eq!(packet.meta().socket_addr(), saddr);
258            }
259
260            packets
261                .iter_mut()
262                .for_each(|pkt| *pkt.meta_mut() = Meta::default());
263            let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
264            assert_eq!(sent - TEST_NUM_MSGS, recv);
265            for packet in packets.iter().take(recv) {
266                assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
267                assert_eq!(packet.meta().socket_addr(), saddr);
268            }
269        };
270
271        test_multi_iter(test_setup_reader_sender(IpAddr::V4(Ipv4Addr::LOCALHOST)).unwrap());
272
273        match test_setup_reader_sender(IpAddr::V6(Ipv6Addr::LOCALHOST)) {
274            Ok(config) => test_multi_iter(config),
275            Err(e) => warn!("Failed to configure IPv6: {e:?}"),
276        }
277    }
278
279    #[test]
280    pub fn test_recv_mmsg_multi_iter_timeout() {
281        let (reader, reader_addr, sender, sender_addr) =
282            test_setup_reader_sender(IpAddr::V4(Ipv4Addr::LOCALHOST)).unwrap();
283        reader.set_read_timeout(Some(Duration::new(5, 0))).unwrap();
284        reader.set_nonblocking(false).unwrap();
285        let sent = TEST_NUM_MSGS;
286        for _ in 0..sent {
287            let data = [0; PACKET_DATA_SIZE];
288            sender.send_to(&data[..], reader_addr).unwrap();
289        }
290
291        let start = Instant::now();
292        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
293        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
294        assert_eq!(TEST_NUM_MSGS, recv);
295        for packet in packets.iter().take(recv) {
296            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
297            assert_eq!(packet.meta().socket_addr(), sender_addr);
298        }
299        reader.set_nonblocking(true).unwrap();
300
301        packets
302            .iter_mut()
303            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
304        let _recv = recv_mmsg(&reader, &mut packets[..]);
305        assert!(start.elapsed().as_secs() < 5);
306    }
307
308    #[test]
309    pub fn test_recv_mmsg_multi_addrs() {
310        let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
311        let port_range = localhost_port_range_for_tests();
312        let reader = bind_in_range_with_config(ip, port_range, SocketConfig::default())
313            .unwrap()
314            .1;
315        let reader_addr = reader.local_addr().unwrap();
316        let sender1 = bind_in_range_with_config(ip, port_range, SocketConfig::default())
317            .unwrap()
318            .1;
319        let sender1_addr = sender1.local_addr().unwrap();
320        let sent1 = TEST_NUM_MSGS - 1;
321
322        let sender2 = bind_in_range_with_config(ip, port_range, SocketConfig::default())
323            .unwrap()
324            .1;
325        let sender_addr = sender2.local_addr().unwrap();
326        let sent2 = TEST_NUM_MSGS + 1;
327
328        for _ in 0..sent1 {
329            let data = [0; PACKET_DATA_SIZE];
330            sender1.send_to(&data[..], reader_addr).unwrap();
331        }
332
333        for _ in 0..sent2 {
334            let data = [0; PACKET_DATA_SIZE];
335            sender2.send_to(&data[..], reader_addr).unwrap();
336        }
337
338        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
339
340        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
341        assert_eq!(TEST_NUM_MSGS, recv);
342        for packet in packets.iter().take(sent1) {
343            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
344            assert_eq!(packet.meta().socket_addr(), sender1_addr);
345        }
346        for packet in packets.iter().skip(sent1).take(recv - sent1) {
347            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
348            assert_eq!(packet.meta().socket_addr(), sender_addr);
349        }
350
351        packets
352            .iter_mut()
353            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
354        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
355        assert_eq!(sent1 + sent2 - TEST_NUM_MSGS, recv);
356        for packet in packets.iter().take(recv) {
357            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
358            assert_eq!(packet.meta().socket_addr(), sender_addr);
359        }
360    }
361}