1use std::fmt;
2use std::marker::PhantomData;
3use std::mem::MaybeUninit;
4use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::os::fd::AsRawFd;
6
7use socket2::{Domain, Protocol, Socket, Type};
8use socket2::{MaybeUninitSlice, SockAddr};
9use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard};
10use tokio::io::Interest;
11
12use crate::addr::ToIpAddr;
13
14#[repr(transparent)]
19pub(crate) struct MsgHdrMut<'addr, 'bufs, 'control> {
20 inner: libc::msghdr,
21 #[allow(clippy::type_complexity)]
22 _lifetimes: PhantomData<(
23 &'addr mut SockAddr,
24 &'bufs mut MaybeUninitSlice<'bufs>,
25 &'control mut [u8],
26 )>,
27}
28
29#[cfg(not(any(target_os = "redox", target_os = "wasi")))]
30impl<'addr, 'bufs, 'control> MsgHdrMut<'addr, 'bufs, 'control> {
31 #[allow(clippy::new_without_default)]
33 pub fn new() -> MsgHdrMut<'addr, 'bufs, 'control> {
34 MsgHdrMut {
36 inner: unsafe { std::mem::zeroed() },
37 _lifetimes: PhantomData,
38 }
39 }
40
41 #[allow(clippy::needless_pass_by_ref_mut)]
46 pub fn with_addr(mut self, addr: &'addr mut SockAddr) -> Self {
47 Self::set_msghdr_name(&mut self.inner, addr);
48 self
49 }
50
51 pub fn with_buffers(mut self, bufs: &'bufs mut [MaybeUninitSlice<'_>]) -> Self {
56 Self::set_msghdr_iov(&mut self.inner, bufs.as_mut_ptr().cast(), bufs.len());
57 self
58 }
59
60 pub fn with_control(mut self, buf: &'control mut [MaybeUninit<u8>]) -> Self {
65 Self::set_msghdr_control(&mut self.inner, buf.as_mut_ptr().cast(), buf.len());
66 self
67 }
68
69 pub fn flags(&self) -> libc::c_int {
72 self.inner.msg_flags
73 }
74
75 pub fn as_msghdr(&self) -> &libc::msghdr {
81 &self.inner
82 }
83
84 fn set_msghdr_name(msg: &mut libc::msghdr, name: &SockAddr) {
85 msg.msg_name = name.as_ptr() as *mut _;
86 msg.msg_namelen = name.len();
87 }
88
89 #[cfg(any(
90 target_os = "macos",
91 target_os = "ios",
92 target_os = "tvos",
93 target_os = "watchos",
94 target_os = "visionos",
95 target_os = "freebsd",
96 target_os = "dragonfly",
97 target_os = "openbsd",
98 target_os = "netbsd"
99 ))]
100 fn set_msghdr_iov(msg: &mut libc::msghdr, ptr: *mut libc::iovec, len: usize) {
101 msg.msg_iov = ptr;
102 msg.msg_iovlen = std::cmp::min(len, libc::c_int::MAX as usize) as libc::c_int;
103 }
104
105 #[cfg(any(
106 target_os = "linux",
107 target_os = "l4re",
108 target_os = "android",
109 target_os = "emscripten"
110 ))]
111 fn set_msghdr_iov(msg: &mut libc::msghdr, ptr: *mut libc::iovec, len: usize) {
112 msg.msg_iov = ptr;
113 msg.msg_iovlen = len;
114 }
115
116 fn set_msghdr_control(msg: &mut libc::msghdr, ptr: *mut libc::c_void, len: usize) {
117 msg.msg_control = ptr;
118 msg.msg_controllen = len as _;
119 }
120}
121
122unsafe impl Send for MsgHdrMut<'_, '_, '_> {}
123
124impl fmt::Debug for MsgHdrMut<'_, '_, '_> {
125 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
126 "MsgHdrMut".fmt(fmt)
127 }
128}
129
130pub struct IcmpSocket {
140 io: AsyncFd<Socket>,
141}
142
143impl IcmpSocket {
144 pub async fn bind<A: ToIpAddr>(addr: A) -> std::io::Result<IcmpSocket> {
150 let ip_addr = addr.to_ip_addr().await?;
151 let (sock_addr, domain, protocol) = match ip_addr {
152 std::net::IpAddr::V4(ipv4_addr) => (
153 SocketAddr::V4(SocketAddrV4::new(ipv4_addr, 0u16)),
154 Domain::IPV4,
155 Protocol::ICMPV4,
156 ),
157 std::net::IpAddr::V6(ipv6_addr) => (
158 SocketAddr::V6(SocketAddrV6::new(ipv6_addr, 0u16, 0, 0)),
159 Domain::IPV6,
160 Protocol::ICMPV6,
161 ),
162 };
163 let socket = Socket::new(domain, Type::RAW, Some(protocol))?;
164 socket.set_nonblocking(true)?;
165 if domain == Domain::IPV6 {
166 socket.set_recv_hoplimit_v6(true)?;
167 }
168 set_dont_fragment(&socket, domain, true)?;
170
171 socket.bind(&sock_addr.into())?;
172 let io = AsyncFd::new(socket)?;
173 Ok(Self { io })
174 }
175
176 pub async fn connect<A: ToIpAddr>(&self, addr: A) -> std::io::Result<()> {
179 let ip_addr = addr.to_ip_addr().await?;
180 let socket_addr = match ip_addr {
181 std::net::IpAddr::V4(ipv4_addr) => SocketAddr::V4(SocketAddrV4::new(ipv4_addr, 0u16)),
182 std::net::IpAddr::V6(ipv6_addr) => {
183 SocketAddr::V6(SocketAddrV6::new(ipv6_addr, 0u16, 0, 0))
184 }
185 };
186 self.io.get_ref().connect(&socket_addr.into())
187 }
188
189 pub async fn ready(
191 &self,
192 interest: Interest,
193 ) -> std::io::Result<AsyncFdReadyGuard<'_, Socket>> {
194 self.io.ready(interest).await
195 }
196
197 pub async fn writable(&self) -> std::io::Result<()> {
199 let _ = self.ready(Interest::WRITABLE).await?;
200 Ok(())
201 }
202
203 pub async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
205 self.io.async_io(Interest::WRITABLE, |s| s.send(buf)).await
206 }
207
208 pub async fn readable(&self) -> std::io::Result<()> {
210 let _ = self.ready(Interest::READABLE).await?;
211 Ok(())
212 }
213
214 pub async fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<usize> {
216 self.io.async_io(Interest::READABLE, |s| s.recv(buf)).await
217 }
218
219 pub(crate) async fn recvmsg(&self, msg: &mut MsgHdrMut<'_, '_, '_>) -> std::io::Result<usize> {
220 self.io
221 .async_io(Interest::READABLE, |s| recvmsg(s, msg, 0))
222 .await
223 }
224}
225
226fn recvmsg(
227 socket: &Socket,
228 msg: &mut MsgHdrMut<'_, '_, '_>,
229 flags: libc::c_int,
230) -> std::io::Result<usize> {
231 let fd = socket.as_raw_fd();
232 let res = unsafe { libc::recvmsg(fd, &raw mut msg.inner, flags) };
233 if res == -1 {
234 Err(std::io::Error::last_os_error())
235 } else {
236 Ok(res as usize)
237 }
238}
239
240#[cfg(any(
241 target_os = "linux",
242 target_os = "l4re",
243 target_os = "android",
244 target_os = "emscripten"
245))]
246fn set_dont_fragment(socket: &Socket, domain: Domain, dont_fragment: bool) -> std::io::Result<()> {
247 match domain {
248 Domain::IPV4 => {
249 let payload = if dont_fragment {
250 libc::IP_PMTUDISC_DO
251 } else {
252 libc::IP_PMTUDISC_DONT
253 };
254
255 unsafe { setsockopt(socket, libc::IPPROTO_IP, libc::IP_MTU_DISCOVER, payload) }
256 }
257 Domain::IPV6 => {
258 let payload = if dont_fragment {
259 libc::IPV6_PMTUDISC_DO
260 } else {
261 libc::IPV6_PMTUDISC_DONT
262 };
263 unsafe { setsockopt(socket, libc::IPPROTO_IPV6, libc::IPV6_MTU_DISCOVER, payload) }
264 }
265 _ => Ok(()),
266 }
267}
268
269#[cfg(any(
270 target_os = "macos",
271 target_os = "ios",
272 target_os = "tvos",
273 target_os = "watchos",
274 target_os = "visionos",
275 target_os = "freebsd",
276 target_os = "dragonfly",
277 target_os = "openbsd",
278 target_os = "netbsd"
279))]
280fn set_dont_fragment(socket: &Socket, domain: Domain, dont_fragment: bool) -> std::io::Result<()> {
281 match domain {
282 Domain::IPV4 => unsafe {
283 setsockopt(
284 socket,
285 libc::IPPROTO_IP,
286 libc::IP_DONTFRAG,
287 dont_fragment as libc::c_int,
288 )
289 },
290 Domain::IPV6 => unsafe {
291 setsockopt(
292 socket,
293 libc::IPPROTO_IPV6,
294 libc::IPV6_DONTFRAG,
295 dont_fragment as libc::c_int,
296 )
297 },
298 _ => Ok(()),
299 }
300}
301
302#[allow(clippy::needless_pass_by_value)]
306unsafe fn setsockopt<T>(
307 socket: &Socket,
308 opt: libc::c_int,
309 val: libc::c_int,
310 payload: T,
311) -> std::io::Result<()> {
312 let payload = (&raw const payload).cast();
313 let res = unsafe {
314 libc::setsockopt(
315 socket.as_raw_fd(),
316 opt,
317 val,
318 payload,
319 std::mem::size_of::<T>() as libc::socklen_t,
320 )
321 };
322 if res != 0 {
323 return Err(std::io::Error::last_os_error());
324 }
325 Ok(())
326}
327
328#[cfg(test)]
329mod tests {
330 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
331
332 use super::IcmpSocket;
333
334 #[tokio::test]
335 async fn bind_accepts_str_literal() {
336 IcmpSocket::bind("127.0.0.1").await.unwrap();
337 }
338
339 #[tokio::test]
340 async fn bind_accepts_owned_string() {
341 IcmpSocket::bind(String::from("127.0.0.1")).await.unwrap();
342 }
343
344 #[tokio::test]
345 async fn bind_accepts_ipv4addr() {
346 IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
347 }
348
349 #[tokio::test]
350 async fn bind_accepts_ipv6addr() {
351 IcmpSocket::bind(Ipv6Addr::LOCALHOST).await.unwrap();
352 }
353
354 #[tokio::test]
355 async fn bind_accepts_ip_addr() {
356 IcmpSocket::bind(IpAddr::V4(Ipv4Addr::LOCALHOST))
357 .await
358 .unwrap();
359 }
360
361 #[tokio::test]
362 async fn connect_accepts_str_literal() {
363 let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
364 sock.connect("127.0.0.1").await.unwrap();
365 }
366
367 #[tokio::test]
368 async fn connect_accepts_owned_string() {
369 let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
370 sock.connect(String::from("127.0.0.1")).await.unwrap();
371 }
372
373 #[tokio::test]
374 async fn connect_accepts_ipv4addr() {
375 let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
376 sock.connect(Ipv4Addr::LOCALHOST).await.unwrap();
377 }
378
379 #[tokio::test]
380 async fn connect_accepts_ipv6addr() {
381 let sock = IcmpSocket::bind(Ipv6Addr::LOCALHOST).await.unwrap();
382 sock.connect(Ipv6Addr::LOCALHOST).await.unwrap();
383 }
384
385 #[tokio::test]
386 async fn connect_accepts_ip_addr() {
387 let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
388 sock.connect(IpAddr::V4(Ipv4Addr::LOCALHOST)).await.unwrap();
389 }
390}