use crate::addr::NetAddr;
use std::os::unix::io::{OwnedFd, RawFd};
pub fn std_to_libc_in_addr(addr: std::net::Ipv4Addr) -> libc::in_addr {
libc::in_addr {
s_addr: addr
.octets()
.iter()
.fold(0, |acc, x| (acc << 8) | (*x as u32)),
}
}
pub const fn std_to_libc_in6_addr(addr: std::net::Ipv6Addr) -> libc::in6_addr {
libc::in6_addr {
s6_addr: addr.octets(),
}
}
pub type MsgFlags = nix::sys::socket::MsgFlags;
pub use std::io::{IoSlice, IoSliceMut};
use nix::libc;
#[derive(Debug)]
pub struct ControlMessage {
pub send_from: Option<std::net::IpAddr>,
pktinfo4: libc::in_pktinfo,
pktinfo6: libc::in6_pktinfo,
}
impl ControlMessage {
pub fn new() -> Self {
Self {
send_from: None,
pktinfo4: libc::in_pktinfo {
ipi_ifindex: 0,
ipi_addr: std_to_libc_in_addr(std::net::Ipv4Addr::UNSPECIFIED),
ipi_spec_dst: std_to_libc_in_addr(std::net::Ipv4Addr::UNSPECIFIED),
},
pktinfo6: libc::in6_pktinfo {
ipi6_ifindex: 0,
ipi6_addr: std_to_libc_in6_addr(std::net::Ipv6Addr::UNSPECIFIED),
},
}
}
#[must_use]
pub const fn set_send_from(mut self, send_from: Option<std::net::IpAddr>) -> Self {
self.send_from = send_from;
self
}
#[must_use]
pub const fn set_src4_intf(mut self, intf: u32) -> Self {
self.pktinfo4.ipi_ifindex = intf as i32;
self
}
#[must_use]
pub const fn set_src6_intf(mut self, intf: u32) -> Self {
self.pktinfo6.ipi6_ifindex = intf;
self
}
pub fn convert_to_cmsg(&'_ mut self) -> Vec<nix::sys::socket::ControlMessage<'_>> {
let mut cmsgs: Vec<nix::sys::socket::ControlMessage> = vec![];
if let Some(addr) = self.send_from {
match addr {
std::net::IpAddr::V4(ip) => {
self.pktinfo4.ipi_spec_dst = std_to_libc_in_addr(ip);
cmsgs.push(nix::sys::socket::ControlMessage::Ipv4PacketInfo(
&self.pktinfo4,
))
}
std::net::IpAddr::V6(ip) => {
self.pktinfo6.ipi6_addr = std_to_libc_in6_addr(ip);
cmsgs.push(nix::sys::socket::ControlMessage::Ipv6PacketInfo(
&self.pktinfo6,
))
}
}
}
cmsgs
}
}
impl Default for ControlMessage {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct RecvMsg {
pub buffer: Vec<u8>,
pub address: Option<NetAddr>,
timestamp: Option<nix::sys::time::TimeVal>,
ipv4pktinfo: Option<libc::in_pktinfo>,
ipv6pktinfo: Option<libc::in6_pktinfo>,
}
impl RecvMsg {
fn new(m: nix::sys::socket::RecvMsg<NetAddr>, buffer: Vec<u8>) -> RecvMsg {
let mut r = RecvMsg {
buffer,
address: m.address,
timestamp: None,
ipv4pktinfo: None,
ipv6pktinfo: None,
};
match m.cmsgs() {
Ok(cmsgs) => {
for cmsg in cmsgs {
use nix::sys::socket::ControlMessageOwned;
match cmsg {
ControlMessageOwned::ScmTimestamp(rtime) => {
r.timestamp = Some(rtime);
}
ControlMessageOwned::Ipv4PacketInfo(pi) => {
r.ipv4pktinfo = Some(pi);
}
ControlMessageOwned::Ipv6PacketInfo(pi) => {
r.ipv6pktinfo = Some(pi);
}
x => log::warn!("Unknown control message {:?}", x),
}
}
}
Err(err) => {
log::warn!("Failed to decode control message: {:?}", err);
}
}
r
}
pub const fn local_ip(&self) -> Option<std::net::IpAddr> {
if let Some(pi) = self.ipv6pktinfo {
Some(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
(pi.ipi6_addr.s6_addr[0] as u16) << 8 | (pi.ipi6_addr.s6_addr[1] as u16),
(pi.ipi6_addr.s6_addr[2] as u16) << 8 | (pi.ipi6_addr.s6_addr[3] as u16),
(pi.ipi6_addr.s6_addr[4] as u16) << 8 | (pi.ipi6_addr.s6_addr[5] as u16),
(pi.ipi6_addr.s6_addr[6] as u16) << 8 | (pi.ipi6_addr.s6_addr[7] as u16),
(pi.ipi6_addr.s6_addr[8] as u16) << 8 | (pi.ipi6_addr.s6_addr[9] as u16),
(pi.ipi6_addr.s6_addr[10] as u16) << 8 | (pi.ipi6_addr.s6_addr[11] as u16),
(pi.ipi6_addr.s6_addr[12] as u16) << 8 | (pi.ipi6_addr.s6_addr[13] as u16),
(pi.ipi6_addr.s6_addr[14] as u16) << 8 | (pi.ipi6_addr.s6_addr[15] as u16),
)))
} else if let Some(pi) = self.ipv4pktinfo {
let ip = pi.ipi_addr.s_addr.to_ne_bytes(); Some(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
ip[0], ip[1], ip[2], ip[3],
)))
} else {
None
}
}
pub fn local_intf(&self) -> Option<i32> {
if let Some(pi) = self.ipv6pktinfo {
Some(pi.ipi6_ifindex as i32)
} else {
self.ipv4pktinfo.map(|pi| pi.ipi_ifindex)
}
}
}
#[derive(Debug)]
pub struct SocketFd {
fd: OwnedFd,
}
impl std::os::unix::io::AsRawFd for SocketFd {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
pub fn new_socket(
domain: libc::c_int,
ty: libc::c_int,
protocol: libc::c_int,
) -> Result<SocketFd, std::io::Error> {
unsafe {
use std::os::unix::io::FromRawFd as _;
let fd = libc::socket(
domain,
ty | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
protocol,
);
if fd == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(SocketFd {
fd: OwnedFd::from_raw_fd(fd as RawFd),
})
}
}
pub async fn recv_msg<F: std::os::unix::io::AsRawFd>(
sock: &tokio::io::unix::AsyncFd<F>,
bufsize: usize,
flags: MsgFlags,
) -> Result<RecvMsg, std::io::Error> {
let mut ev = sock.readable().await?;
let mut buf = Vec::new();
buf.resize_with(bufsize, Default::default);
let iov = &mut [IoSliceMut::new(buf.as_mut_slice())];
let mut cmsg = Vec::new();
cmsg.resize_with(65536, Default::default);
let mut flags = flags;
flags.set(MsgFlags::MSG_DONTWAIT, true);
match nix::sys::socket::recvmsg(sock.get_ref().as_raw_fd(), iov, Some(&mut cmsg), flags) {
Ok(rm) => {
if let Some(buf) = rm.iovs().next() {
ev.retain_ready();
Ok(RecvMsg::new(rm, buf.into()))
} else {
ev.retain_ready();
Ok(RecvMsg::new(rm, vec![]))
}
}
Err(e) if e == nix::errno::Errno::EAGAIN => {
ev.clear_ready();
Err(e.into())
}
Err(e) => {
ev.retain_ready();
Err(e.into())
}
}
}
pub fn recv_msg_stream<F>(
sock: &tokio::io::unix::AsyncFd<F>,
bufsize: usize,
flags: MsgFlags,
) -> impl futures::stream::Stream<Item = Result<RecvMsg, std::io::Error>> + '_
where
F: std::os::unix::io::AsRawFd,
{
futures::stream::unfold((), move |()| async move {
Some((recv_msg::<F>(sock, bufsize, flags).await, ()))
})
}
pub async fn send_msg<F: std::os::unix::io::AsRawFd>(
sock: &tokio::io::unix::AsyncFd<F>,
buffer: &[u8],
cmsg: &ControlMessage,
flags: MsgFlags,
from: Option<&NetAddr>,
) -> std::io::Result<()> {
let mut ev = sock.writable().await?;
let iov = &[IoSlice::new(buffer)];
let mut cmsgs: Vec<nix::sys::socket::ControlMessage> = vec![];
let mut in_pktinfo = cmsg.pktinfo4;
let mut in6_pktinfo = cmsg.pktinfo6;
if let Some(addr) = cmsg.send_from {
match addr {
std::net::IpAddr::V4(ip) => {
in_pktinfo.ipi_spec_dst = std_to_libc_in_addr(ip);
cmsgs.push(nix::sys::socket::ControlMessage::Ipv4PacketInfo(
&in_pktinfo,
))
}
std::net::IpAddr::V6(ip) => {
in6_pktinfo.ipi6_addr = std_to_libc_in6_addr(ip);
cmsgs.push(nix::sys::socket::ControlMessage::Ipv6PacketInfo(
&in6_pktinfo,
))
}
}
} else if in6_pktinfo.ipi6_ifindex != 0 {
cmsgs.push(nix::sys::socket::ControlMessage::Ipv6PacketInfo(
&in6_pktinfo,
));
} else if in_pktinfo.ipi_ifindex != 0 {
cmsgs.push(nix::sys::socket::ControlMessage::Ipv4PacketInfo(
&in_pktinfo,
));
}
match nix::sys::socket::sendmsg(sock.get_ref().as_raw_fd(), iov, &cmsgs, flags, from) {
Ok(_) => {
ev.retain_ready();
Ok(())
}
Err(nix::errno::Errno::EAGAIN) => {
ev.clear_ready();
Err(nix::errno::Errno::EAGAIN.into())
}
Err(e) => {
ev.retain_ready();
Err(e.into())
}
}
}
pub fn set_ipv6_unicast_hoplimit(fd: RawFd, val: i32) -> Result<(), nix::Error> {
unsafe {
let res = libc::setsockopt(
fd,
libc::IPPROTO_IPV6,
libc::IPV6_UNICAST_HOPS,
&val as *const i32 as *const libc::c_void,
std::mem::size_of::<i32>() as libc::socklen_t,
);
nix::errno::Errno::result(res).map(drop)
}
}
pub fn set_ipv6_multicast_hoplimit(fd: RawFd, val: i32) -> Result<(), nix::Error> {
unsafe {
let res = libc::setsockopt(
fd,
libc::IPPROTO_IPV6,
libc::IPV6_MULTICAST_HOPS,
&val as *const i32 as *const libc::c_void,
std::mem::size_of::<i32>() as libc::socklen_t,
);
nix::errno::Errno::result(res).map(drop)
}
}