Skip to main content

ping_tokio/
net.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::mem::MaybeUninit;
4use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::os::fd::AsRawFd;
6
7use socket2::{Domain, Protocol, Socket, Type};
8use socket2::{MaybeUninitSlice, SockAddr};
9use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard};
10use tokio::io::Interest;
11
12use crate::addr::ToIpAddr;
13
14/// Configuration of a `recvmsg(2)` system call.
15///
16/// This wraps `msghdr` on Unix and `WSAMSG` on Windows. Also see [`MsgHdr`] for
17/// the variant used by `sendmsg(2)`.
18#[repr(transparent)]
19pub(crate) struct MsgHdrMut<'addr, 'bufs, 'control> {
20    inner: libc::msghdr,
21    #[allow(clippy::type_complexity)]
22    _lifetimes: PhantomData<(
23        &'addr mut SockAddr,
24        &'bufs mut MaybeUninitSlice<'bufs>,
25        &'control mut [u8],
26    )>,
27}
28
29#[cfg(not(any(target_os = "redox", target_os = "wasi")))]
30impl<'addr, 'bufs, 'control> MsgHdrMut<'addr, 'bufs, 'control> {
31    /// Create a new `MsgHdrMut` with all empty/zero fields.
32    #[allow(clippy::new_without_default)]
33    pub fn new() -> MsgHdrMut<'addr, 'bufs, 'control> {
34        // SAFETY: all zero is valid for `msghdr` and `WSAMSG`.
35        MsgHdrMut {
36            inner: unsafe { std::mem::zeroed() },
37            _lifetimes: PhantomData,
38        }
39    }
40
41    /// Set the mutable address (name) of the message.
42    ///
43    /// Corresponds to setting `msg_name` and `msg_namelen` on Unix and `name`
44    /// and `namelen` on Windows.
45    #[allow(clippy::needless_pass_by_ref_mut)]
46    pub fn with_addr(mut self, addr: &'addr mut SockAddr) -> Self {
47        Self::set_msghdr_name(&mut self.inner, addr);
48        self
49    }
50
51    /// Set the mutable buffer(s) of the message.
52    ///
53    /// Corresponds to setting `msg_iov` and `msg_iovlen` on Unix and `lpBuffers`
54    /// and `dwBufferCount` on Windows.
55    pub fn with_buffers(mut self, bufs: &'bufs mut [MaybeUninitSlice<'_>]) -> Self {
56        Self::set_msghdr_iov(&mut self.inner, bufs.as_mut_ptr().cast(), bufs.len());
57        self
58    }
59
60    /// Set the mutable control buffer of the message.
61    ///
62    /// Corresponds to setting `msg_control` and `msg_controllen` on Unix and
63    /// `Control` on Windows.
64    pub fn with_control(mut self, buf: &'control mut [MaybeUninit<u8>]) -> Self {
65        Self::set_msghdr_control(&mut self.inner, buf.as_mut_ptr().cast(), buf.len());
66        self
67    }
68
69    /// Gets the message flags written by `recvmsg(2)` (e.g. `MSG_CTRUNC`,
70    /// `MSG_TRUNC`).
71    pub fn flags(&self) -> libc::c_int {
72        self.inner.msg_flags
73    }
74
75    /// Returns a reference to the underlying `msghdr`.
76    ///
77    /// Provided so callers can use kernel macros like `CMSG_FIRSTHDR` /
78    /// `CMSG_NXTHDR`, which take a `*const msghdr` and use `msg_control` /
79    /// `msg_controllen` from it.
80    pub fn as_msghdr(&self) -> &libc::msghdr {
81        &self.inner
82    }
83
84    fn set_msghdr_name(msg: &mut libc::msghdr, name: &SockAddr) {
85        msg.msg_name = name.as_ptr() as *mut _;
86        msg.msg_namelen = name.len();
87    }
88
89    #[cfg(any(
90        target_os = "macos",
91        target_os = "ios",
92        target_os = "tvos",
93        target_os = "watchos",
94        target_os = "visionos",
95        target_os = "freebsd",
96        target_os = "dragonfly",
97        target_os = "openbsd",
98        target_os = "netbsd"
99    ))]
100    fn set_msghdr_iov(msg: &mut libc::msghdr, ptr: *mut libc::iovec, len: usize) {
101        msg.msg_iov = ptr;
102        msg.msg_iovlen = std::cmp::min(len, libc::c_int::MAX as usize) as libc::c_int;
103    }
104
105    #[cfg(any(
106        target_os = "linux",
107        target_os = "l4re",
108        target_os = "android",
109        target_os = "emscripten"
110    ))]
111    fn set_msghdr_iov(msg: &mut libc::msghdr, ptr: *mut libc::iovec, len: usize) {
112        msg.msg_iov = ptr;
113        msg.msg_iovlen = len;
114    }
115
116    fn set_msghdr_control(msg: &mut libc::msghdr, ptr: *mut libc::c_void, len: usize) {
117        msg.msg_control = ptr;
118        msg.msg_controllen = len as _;
119    }
120}
121
122unsafe impl Send for MsgHdrMut<'_, '_, '_> {}
123
124impl fmt::Debug for MsgHdrMut<'_, '_, '_> {
125    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
126        "MsgHdrMut".fmt(fmt)
127    }
128}
129
130/// Asynchronous, non-blocking ICMP raw socket.
131///
132/// Wraps a [`socket2::Socket`] in [`tokio::io::unix::AsyncFd`] so that send
133/// and receive operations integrate with the Tokio runtime. Supports both
134/// ICMPv4 and ICMPv6; the protocol is selected by the address family of the
135/// bind address.
136///
137/// Creating an `IcmpSocket` requires permission to open raw sockets
138/// (e.g. `CAP_NET_RAW` on Linux, or running as root).
139pub struct IcmpSocket {
140    io: AsyncFd<Socket>,
141}
142
143impl IcmpSocket {
144    /// Create a new ICMP raw socket bound to `addr`.
145    ///
146    /// The address family of `addr` (after resolution) determines whether an
147    /// ICMPv4 or ICMPv6 socket is created. The socket is placed in
148    /// non-blocking mode and registered with the current Tokio runtime.
149    pub async fn bind<A: ToIpAddr>(addr: A) -> std::io::Result<IcmpSocket> {
150        let ip_addr = addr.to_ip_addr().await?;
151        let (sock_addr, domain, protocol) = match ip_addr {
152            std::net::IpAddr::V4(ipv4_addr) => (
153                SocketAddr::V4(SocketAddrV4::new(ipv4_addr, 0u16)),
154                Domain::IPV4,
155                Protocol::ICMPV4,
156            ),
157            std::net::IpAddr::V6(ipv6_addr) => (
158                SocketAddr::V6(SocketAddrV6::new(ipv6_addr, 0u16, 0, 0)),
159                Domain::IPV6,
160                Protocol::ICMPV6,
161            ),
162        };
163        let socket = Socket::new(domain, Type::RAW, Some(protocol))?;
164        socket.set_nonblocking(true)?;
165        if domain == Domain::IPV6 {
166            socket.set_recv_hoplimit_v6(true)?;
167        }
168        // options not exposed by socket2
169        set_dont_fragment(&socket, domain, true)?;
170
171        socket.bind(&sock_addr.into())?;
172        let io = AsyncFd::new(socket)?;
173        Ok(Self { io })
174    }
175
176    /// Connect this socket to `addr` so that subsequent `send`/`recv` calls
177    /// communicate with that peer only.
178    pub async fn connect<A: ToIpAddr>(&self, addr: A) -> std::io::Result<()> {
179        let ip_addr = addr.to_ip_addr().await?;
180        let socket_addr = match ip_addr {
181            std::net::IpAddr::V4(ipv4_addr) => SocketAddr::V4(SocketAddrV4::new(ipv4_addr, 0u16)),
182            std::net::IpAddr::V6(ipv6_addr) => {
183                SocketAddr::V6(SocketAddrV6::new(ipv6_addr, 0u16, 0, 0))
184            }
185        };
186        self.io.get_ref().connect(&socket_addr.into())
187    }
188
189    /// Wait for the socket to become ready for the given [`Interest`].
190    pub async fn ready(
191        &self,
192        interest: Interest,
193    ) -> std::io::Result<AsyncFdReadyGuard<'_, Socket>> {
194        self.io.ready(interest).await
195    }
196
197    /// Wait for the socket to become writable.
198    pub async fn writable(&self) -> std::io::Result<()> {
199        let _ = self.ready(Interest::WRITABLE).await?;
200        Ok(())
201    }
202
203    /// Send `buf` on the socket. Requires that the socket has been connected.
204    pub async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
205        self.io.async_io(Interest::WRITABLE, |s| s.send(buf)).await
206    }
207
208    /// Wait for the socket to become readable.
209    pub async fn readable(&self) -> std::io::Result<()> {
210        let _ = self.ready(Interest::READABLE).await?;
211        Ok(())
212    }
213
214    /// Receive a datagram into `buf`, returning the number of bytes received.
215    pub async fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<usize> {
216        self.io.async_io(Interest::READABLE, |s| s.recv(buf)).await
217    }
218
219    pub(crate) async fn recvmsg(&self, msg: &mut MsgHdrMut<'_, '_, '_>) -> std::io::Result<usize> {
220        self.io
221            .async_io(Interest::READABLE, |s| recvmsg(s, msg, 0))
222            .await
223    }
224}
225
226fn recvmsg(
227    socket: &Socket,
228    msg: &mut MsgHdrMut<'_, '_, '_>,
229    flags: libc::c_int,
230) -> std::io::Result<usize> {
231    let fd = socket.as_raw_fd();
232    let res = unsafe { libc::recvmsg(fd, &raw mut msg.inner, flags) };
233    if res == -1 {
234        Err(std::io::Error::last_os_error())
235    } else {
236        Ok(res as usize)
237    }
238}
239
240#[cfg(any(
241    target_os = "linux",
242    target_os = "l4re",
243    target_os = "android",
244    target_os = "emscripten"
245))]
246fn set_dont_fragment(socket: &Socket, domain: Domain, dont_fragment: bool) -> std::io::Result<()> {
247    match domain {
248        Domain::IPV4 => {
249            let payload = if dont_fragment {
250                libc::IP_PMTUDISC_DO
251            } else {
252                libc::IP_PMTUDISC_DONT
253            };
254
255            unsafe { setsockopt(socket, libc::IPPROTO_IP, libc::IP_MTU_DISCOVER, payload) }
256        }
257        Domain::IPV6 => {
258            let payload = if dont_fragment {
259                libc::IPV6_PMTUDISC_DO
260            } else {
261                libc::IPV6_PMTUDISC_DONT
262            };
263            unsafe { setsockopt(socket, libc::IPPROTO_IPV6, libc::IPV6_MTU_DISCOVER, payload) }
264        }
265        _ => Ok(()),
266    }
267}
268
269#[cfg(any(
270    target_os = "macos",
271    target_os = "ios",
272    target_os = "tvos",
273    target_os = "watchos",
274    target_os = "visionos",
275    target_os = "freebsd",
276    target_os = "dragonfly",
277    target_os = "openbsd",
278    target_os = "netbsd"
279))]
280fn set_dont_fragment(socket: &Socket, domain: Domain, dont_fragment: bool) -> std::io::Result<()> {
281    match domain {
282        Domain::IPV4 => unsafe {
283            setsockopt(
284                socket,
285                libc::IPPROTO_IP,
286                libc::IP_DONTFRAG,
287                dont_fragment as libc::c_int,
288            )
289        },
290        Domain::IPV6 => unsafe {
291            setsockopt(
292                socket,
293                libc::IPPROTO_IPV6,
294                libc::IPV6_DONTFRAG,
295                dont_fragment as libc::c_int,
296            )
297        },
298        _ => Ok(()),
299    }
300}
301
302// `payload` is taken by value so we can take its address with `&raw const`
303// for `setsockopt`; the caller's value would otherwise need to outlive the
304// call. The borrow lint doesn't model this.
305#[allow(clippy::needless_pass_by_value)]
306unsafe fn setsockopt<T>(
307    socket: &Socket,
308    opt: libc::c_int,
309    val: libc::c_int,
310    payload: T,
311) -> std::io::Result<()> {
312    let payload = (&raw const payload).cast();
313    let res = unsafe {
314        libc::setsockopt(
315            socket.as_raw_fd(),
316            opt,
317            val,
318            payload,
319            std::mem::size_of::<T>() as libc::socklen_t,
320        )
321    };
322    if res != 0 {
323        return Err(std::io::Error::last_os_error());
324    }
325    Ok(())
326}
327
328#[cfg(test)]
329mod tests {
330    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
331
332    use super::IcmpSocket;
333
334    #[tokio::test]
335    async fn bind_accepts_str_literal() {
336        IcmpSocket::bind("127.0.0.1").await.unwrap();
337    }
338
339    #[tokio::test]
340    async fn bind_accepts_owned_string() {
341        IcmpSocket::bind(String::from("127.0.0.1")).await.unwrap();
342    }
343
344    #[tokio::test]
345    async fn bind_accepts_ipv4addr() {
346        IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
347    }
348
349    #[tokio::test]
350    async fn bind_accepts_ipv6addr() {
351        IcmpSocket::bind(Ipv6Addr::LOCALHOST).await.unwrap();
352    }
353
354    #[tokio::test]
355    async fn bind_accepts_ip_addr() {
356        IcmpSocket::bind(IpAddr::V4(Ipv4Addr::LOCALHOST))
357            .await
358            .unwrap();
359    }
360
361    #[tokio::test]
362    async fn connect_accepts_str_literal() {
363        let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
364        sock.connect("127.0.0.1").await.unwrap();
365    }
366
367    #[tokio::test]
368    async fn connect_accepts_owned_string() {
369        let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
370        sock.connect(String::from("127.0.0.1")).await.unwrap();
371    }
372
373    #[tokio::test]
374    async fn connect_accepts_ipv4addr() {
375        let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
376        sock.connect(Ipv4Addr::LOCALHOST).await.unwrap();
377    }
378
379    #[tokio::test]
380    async fn connect_accepts_ipv6addr() {
381        let sock = IcmpSocket::bind(Ipv6Addr::LOCALHOST).await.unwrap();
382        sock.connect(Ipv6Addr::LOCALHOST).await.unwrap();
383    }
384
385    #[tokio::test]
386    async fn connect_accepts_ip_addr() {
387        let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
388        sock.connect(IpAddr::V4(Ipv4Addr::LOCALHOST)).await.unwrap();
389    }
390}