use std::fmt;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::fd::AsRawFd;
use socket2::{Domain, Protocol, Socket, Type};
use socket2::{MaybeUninitSlice, SockAddr};
use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard};
use tokio::io::Interest;
use crate::addr::ToIpAddr;
#[repr(transparent)]
pub(crate) struct MsgHdrMut<'addr, 'bufs, 'control> {
inner: libc::msghdr,
#[allow(clippy::type_complexity)]
_lifetimes: PhantomData<(
&'addr mut SockAddr,
&'bufs mut MaybeUninitSlice<'bufs>,
&'control mut [u8],
)>,
}
#[cfg(not(any(target_os = "redox", target_os = "wasi")))]
impl<'addr, 'bufs, 'control> MsgHdrMut<'addr, 'bufs, 'control> {
#[allow(clippy::new_without_default)]
pub fn new() -> MsgHdrMut<'addr, 'bufs, 'control> {
MsgHdrMut {
inner: unsafe { std::mem::zeroed() },
_lifetimes: PhantomData,
}
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn with_addr(mut self, addr: &'addr mut SockAddr) -> Self {
Self::set_msghdr_name(&mut self.inner, addr);
self
}
pub fn with_buffers(mut self, bufs: &'bufs mut [MaybeUninitSlice<'_>]) -> Self {
Self::set_msghdr_iov(&mut self.inner, bufs.as_mut_ptr().cast(), bufs.len());
self
}
pub fn with_control(mut self, buf: &'control mut [MaybeUninit<u8>]) -> Self {
Self::set_msghdr_control(&mut self.inner, buf.as_mut_ptr().cast(), buf.len());
self
}
pub fn flags(&self) -> libc::c_int {
self.inner.msg_flags
}
pub fn as_msghdr(&self) -> &libc::msghdr {
&self.inner
}
fn set_msghdr_name(msg: &mut libc::msghdr, name: &SockAddr) {
msg.msg_name = name.as_ptr() as *mut _;
msg.msg_namelen = name.len();
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "visionos",
target_os = "freebsd",
target_os = "dragonfly",
target_os = "openbsd",
target_os = "netbsd"
))]
fn set_msghdr_iov(msg: &mut libc::msghdr, ptr: *mut libc::iovec, len: usize) {
msg.msg_iov = ptr;
msg.msg_iovlen = std::cmp::min(len, libc::c_int::MAX as usize) as libc::c_int;
}
#[cfg(any(
target_os = "linux",
target_os = "l4re",
target_os = "android",
target_os = "emscripten"
))]
fn set_msghdr_iov(msg: &mut libc::msghdr, ptr: *mut libc::iovec, len: usize) {
msg.msg_iov = ptr;
msg.msg_iovlen = len;
}
fn set_msghdr_control(msg: &mut libc::msghdr, ptr: *mut libc::c_void, len: usize) {
msg.msg_control = ptr;
msg.msg_controllen = len as _;
}
}
unsafe impl Send for MsgHdrMut<'_, '_, '_> {}
impl fmt::Debug for MsgHdrMut<'_, '_, '_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
"MsgHdrMut".fmt(fmt)
}
}
pub struct IcmpSocket {
io: AsyncFd<Socket>,
}
impl IcmpSocket {
pub async fn bind<A: ToIpAddr>(addr: A) -> std::io::Result<IcmpSocket> {
let ip_addr = addr.to_ip_addr().await?;
let (sock_addr, domain, protocol) = match ip_addr {
std::net::IpAddr::V4(ipv4_addr) => (
SocketAddr::V4(SocketAddrV4::new(ipv4_addr, 0u16)),
Domain::IPV4,
Protocol::ICMPV4,
),
std::net::IpAddr::V6(ipv6_addr) => (
SocketAddr::V6(SocketAddrV6::new(ipv6_addr, 0u16, 0, 0)),
Domain::IPV6,
Protocol::ICMPV6,
),
};
let socket = Socket::new(domain, Type::RAW, Some(protocol))?;
socket.set_nonblocking(true)?;
if domain == Domain::IPV6 {
socket.set_recv_hoplimit_v6(true)?;
}
set_dont_fragment(&socket, domain, true)?;
socket.bind(&sock_addr.into())?;
let io = AsyncFd::new(socket)?;
Ok(Self { io })
}
pub async fn connect<A: ToIpAddr>(&self, addr: A) -> std::io::Result<()> {
let ip_addr = addr.to_ip_addr().await?;
let socket_addr = match ip_addr {
std::net::IpAddr::V4(ipv4_addr) => SocketAddr::V4(SocketAddrV4::new(ipv4_addr, 0u16)),
std::net::IpAddr::V6(ipv6_addr) => {
SocketAddr::V6(SocketAddrV6::new(ipv6_addr, 0u16, 0, 0))
}
};
self.io.get_ref().connect(&socket_addr.into())
}
pub async fn ready(
&self,
interest: Interest,
) -> std::io::Result<AsyncFdReadyGuard<'_, Socket>> {
self.io.ready(interest).await
}
pub async fn writable(&self) -> std::io::Result<()> {
let _ = self.ready(Interest::WRITABLE).await?;
Ok(())
}
pub async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.io.async_io(Interest::WRITABLE, |s| s.send(buf)).await
}
pub async fn readable(&self) -> std::io::Result<()> {
let _ = self.ready(Interest::READABLE).await?;
Ok(())
}
pub async fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<usize> {
self.io.async_io(Interest::READABLE, |s| s.recv(buf)).await
}
pub(crate) async fn recvmsg(&self, msg: &mut MsgHdrMut<'_, '_, '_>) -> std::io::Result<usize> {
self.io
.async_io(Interest::READABLE, |s| recvmsg(s, msg, 0))
.await
}
}
fn recvmsg(
socket: &Socket,
msg: &mut MsgHdrMut<'_, '_, '_>,
flags: libc::c_int,
) -> std::io::Result<usize> {
let fd = socket.as_raw_fd();
let res = unsafe { libc::recvmsg(fd, &raw mut msg.inner, flags) };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res as usize)
}
}
#[cfg(any(
target_os = "linux",
target_os = "l4re",
target_os = "android",
target_os = "emscripten"
))]
fn set_dont_fragment(socket: &Socket, domain: Domain, dont_fragment: bool) -> std::io::Result<()> {
match domain {
Domain::IPV4 => {
let payload = if dont_fragment {
libc::IP_PMTUDISC_DO
} else {
libc::IP_PMTUDISC_DONT
};
unsafe { setsockopt(socket, libc::IPPROTO_IP, libc::IP_MTU_DISCOVER, payload) }
}
Domain::IPV6 => {
let payload = if dont_fragment {
libc::IPV6_PMTUDISC_DO
} else {
libc::IPV6_PMTUDISC_DONT
};
unsafe { setsockopt(socket, libc::IPPROTO_IPV6, libc::IPV6_MTU_DISCOVER, payload) }
}
_ => Ok(()),
}
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "visionos",
target_os = "freebsd",
target_os = "dragonfly",
target_os = "openbsd",
target_os = "netbsd"
))]
fn set_dont_fragment(socket: &Socket, domain: Domain, dont_fragment: bool) -> std::io::Result<()> {
match domain {
Domain::IPV4 => unsafe {
setsockopt(
socket,
libc::IPPROTO_IP,
libc::IP_DONTFRAG,
dont_fragment as libc::c_int,
)
},
Domain::IPV6 => unsafe {
setsockopt(
socket,
libc::IPPROTO_IPV6,
libc::IPV6_DONTFRAG,
dont_fragment as libc::c_int,
)
},
_ => Ok(()),
}
}
#[allow(clippy::needless_pass_by_value)]
unsafe fn setsockopt<T>(
socket: &Socket,
opt: libc::c_int,
val: libc::c_int,
payload: T,
) -> std::io::Result<()> {
let payload = (&raw const payload).cast();
let res = unsafe {
libc::setsockopt(
socket.as_raw_fd(),
opt,
val,
payload,
std::mem::size_of::<T>() as libc::socklen_t,
)
};
if res != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use super::IcmpSocket;
#[tokio::test]
async fn bind_accepts_str_literal() {
IcmpSocket::bind("127.0.0.1").await.unwrap();
}
#[tokio::test]
async fn bind_accepts_owned_string() {
IcmpSocket::bind(String::from("127.0.0.1")).await.unwrap();
}
#[tokio::test]
async fn bind_accepts_ipv4addr() {
IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
}
#[tokio::test]
async fn bind_accepts_ipv6addr() {
IcmpSocket::bind(Ipv6Addr::LOCALHOST).await.unwrap();
}
#[tokio::test]
async fn bind_accepts_ip_addr() {
IcmpSocket::bind(IpAddr::V4(Ipv4Addr::LOCALHOST))
.await
.unwrap();
}
#[tokio::test]
async fn connect_accepts_str_literal() {
let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
sock.connect("127.0.0.1").await.unwrap();
}
#[tokio::test]
async fn connect_accepts_owned_string() {
let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
sock.connect(String::from("127.0.0.1")).await.unwrap();
}
#[tokio::test]
async fn connect_accepts_ipv4addr() {
let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
sock.connect(Ipv4Addr::LOCALHOST).await.unwrap();
}
#[tokio::test]
async fn connect_accepts_ipv6addr() {
let sock = IcmpSocket::bind(Ipv6Addr::LOCALHOST).await.unwrap();
sock.connect(Ipv6Addr::LOCALHOST).await.unwrap();
}
#[tokio::test]
async fn connect_accepts_ip_addr() {
let sock = IcmpSocket::bind(Ipv4Addr::LOCALHOST).await.unwrap();
sock.connect(IpAddr::V4(Ipv4Addr::LOCALHOST)).await.unwrap();
}
}