use crate::transport::TransportError;
use std::os::unix::io::{AsRawFd, RawFd};
pub struct PacketSocket {
fd: RawFd,
if_index: i32,
ethertype: u16,
}
impl PacketSocket {
pub fn open(interface: &str, ethertype: u16) -> Result<Self, TransportError> {
let fd = unsafe {
libc::socket(
libc::AF_PACKET,
libc::SOCK_DGRAM,
(ethertype).to_be() as i32,
)
};
if fd < 0 {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EPERM) {
return Err(TransportError::StartFailed(
"AF_PACKET requires CAP_NET_RAW capability \
(run as root or use: setcap cap_net_raw=ep <binary>)"
.into(),
));
}
return Err(TransportError::StartFailed(format!(
"socket(AF_PACKET) failed: {}",
err
)));
}
let if_index = get_if_index(fd, interface)?;
let mut sll: libc::sockaddr_ll = unsafe { std::mem::zeroed() };
sll.sll_family = libc::AF_PACKET as u16;
sll.sll_protocol = ethertype.to_be();
sll.sll_ifindex = if_index;
let ret = unsafe {
libc::bind(
fd,
&sll as *const libc::sockaddr_ll as *const libc::sockaddr,
std::mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
)
};
if ret < 0 {
let err = std::io::Error::last_os_error();
unsafe { libc::close(fd) };
return Err(TransportError::StartFailed(format!(
"bind(AF_PACKET, {}) failed: {}",
interface, err
)));
}
let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
if flags < 0 {
let err = std::io::Error::last_os_error();
unsafe { libc::close(fd) };
return Err(TransportError::StartFailed(format!(
"fcntl(F_GETFL) failed: {}",
err
)));
}
let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
if ret < 0 {
let err = std::io::Error::last_os_error();
unsafe { libc::close(fd) };
return Err(TransportError::StartFailed(format!(
"fcntl(F_SETFL, O_NONBLOCK) failed: {}",
err
)));
}
Ok(Self {
fd,
if_index,
ethertype,
})
}
pub fn if_index(&self) -> i32 {
self.if_index
}
pub fn local_mac(&self) -> Result<[u8; 6], TransportError> {
get_mac_addr(self.fd, self.if_index)
}
pub fn interface_mtu(&self) -> Result<u16, TransportError> {
get_if_mtu(self.fd, self.if_index)
}
pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), TransportError> {
let size = size as libc::c_int;
let ret = unsafe {
libc::setsockopt(
self.fd,
libc::SOL_SOCKET,
libc::SO_RCVBUF,
&size as *const libc::c_int as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if ret < 0 {
return Err(TransportError::StartFailed(format!(
"setsockopt(SO_RCVBUF) failed: {}",
std::io::Error::last_os_error()
)));
}
Ok(())
}
pub fn set_send_buffer_size(&self, size: usize) -> Result<(), TransportError> {
let size = size as libc::c_int;
let ret = unsafe {
libc::setsockopt(
self.fd,
libc::SOL_SOCKET,
libc::SO_SNDBUF,
&size as *const libc::c_int as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if ret < 0 {
return Err(TransportError::StartFailed(format!(
"setsockopt(SO_SNDBUF) failed: {}",
std::io::Error::last_os_error()
)));
}
Ok(())
}
pub fn send_to(&self, data: &[u8], dest_mac: &[u8; 6]) -> std::io::Result<usize> {
let mut sll: libc::sockaddr_ll = unsafe { std::mem::zeroed() };
sll.sll_family = libc::AF_PACKET as u16;
sll.sll_protocol = self.ethertype.to_be();
sll.sll_ifindex = self.if_index;
sll.sll_halen = 6;
sll.sll_addr[..6].copy_from_slice(dest_mac);
let ret = unsafe {
libc::sendto(
self.fd,
data.as_ptr() as *const libc::c_void,
data.len(),
0,
&sll as *const libc::sockaddr_ll as *const libc::sockaddr,
std::mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
)
};
if ret < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(ret as usize)
}
}
pub fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, [u8; 6])> {
let mut sll: libc::sockaddr_ll = unsafe { std::mem::zeroed() };
let mut sll_len = std::mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t;
let ret = unsafe {
libc::recvfrom(
self.fd,
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
0,
&mut sll as *mut libc::sockaddr_ll as *mut libc::sockaddr,
&mut sll_len,
)
};
if ret < 0 {
return Err(std::io::Error::last_os_error());
}
let mut src_mac = [0u8; 6];
src_mac.copy_from_slice(&sll.sll_addr[..6]);
Ok((ret as usize, src_mac))
}
}
impl AsRawFd for PacketSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for PacketSocket {
fn drop(&mut self) {
unsafe {
libc::close(self.fd);
}
}
}
fn get_if_index(_fd: RawFd, interface: &str) -> Result<i32, TransportError> {
let c_name = std::ffi::CString::new(interface).map_err(|_| {
TransportError::StartFailed(format!("invalid interface name: {}", interface))
})?;
let idx = unsafe { libc::if_nametoindex(c_name.as_ptr()) };
if idx == 0 {
return Err(TransportError::StartFailed(format!(
"interface not found: {} ({})",
interface,
std::io::Error::last_os_error()
)));
}
Ok(idx as i32)
}
fn get_mac_addr(fd: RawFd, if_index: i32) -> Result<[u8; 6], TransportError> {
let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() };
let mut name_buf = [0u8; libc::IFNAMSIZ];
let ret = unsafe {
libc::if_indextoname(
if_index as libc::c_uint,
name_buf.as_mut_ptr() as *mut libc::c_char,
)
};
if ret.is_null() {
return Err(TransportError::StartFailed(format!(
"if_indextoname({}) failed: {}",
if_index,
std::io::Error::last_os_error()
)));
}
let name_len = name_buf
.iter()
.position(|&b| b == 0)
.unwrap_or(name_buf.len());
let copy_len = name_len.min(libc::IFNAMSIZ - 1);
unsafe {
std::ptr::copy_nonoverlapping(
name_buf.as_ptr(),
ifr.ifr_name.as_mut_ptr() as *mut u8,
copy_len,
);
}
#[cfg(target_env = "musl")]
let ioctl_req = libc::SIOCGIFHWADDR as libc::c_int;
#[cfg(not(target_env = "musl"))]
let ioctl_req = libc::SIOCGIFHWADDR as libc::c_ulong;
let ret = unsafe { libc::ioctl(fd, ioctl_req, &ifr) };
if ret < 0 {
return Err(TransportError::StartFailed(format!(
"ioctl(SIOCGIFHWADDR) failed: {}",
std::io::Error::last_os_error()
)));
}
let mut mac = [0u8; 6];
unsafe {
let sa_data = ifr.ifr_ifru.ifru_hwaddr.sa_data;
for (i, byte) in mac.iter_mut().enumerate() {
*byte = sa_data[i] as u8;
}
}
Ok(mac)
}
fn get_if_mtu(fd: RawFd, if_index: i32) -> Result<u16, TransportError> {
let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() };
let mut name_buf = [0u8; libc::IFNAMSIZ];
let ret = unsafe {
libc::if_indextoname(
if_index as libc::c_uint,
name_buf.as_mut_ptr() as *mut libc::c_char,
)
};
if ret.is_null() {
return Err(TransportError::StartFailed(format!(
"if_indextoname({}) failed: {}",
if_index,
std::io::Error::last_os_error()
)));
}
let name_len = name_buf
.iter()
.position(|&b| b == 0)
.unwrap_or(name_buf.len());
let copy_len = name_len.min(libc::IFNAMSIZ - 1);
unsafe {
std::ptr::copy_nonoverlapping(
name_buf.as_ptr(),
ifr.ifr_name.as_mut_ptr() as *mut u8,
copy_len,
);
}
#[cfg(target_env = "musl")]
let ioctl_req = libc::SIOCGIFMTU as libc::c_int;
#[cfg(not(target_env = "musl"))]
let ioctl_req = libc::SIOCGIFMTU as libc::c_ulong;
let ret = unsafe { libc::ioctl(fd, ioctl_req, &ifr) };
if ret < 0 {
return Err(TransportError::StartFailed(format!(
"ioctl(SIOCGIFMTU) failed: {}",
std::io::Error::last_os_error()
)));
}
let mtu = unsafe { ifr.ifr_ifru.ifru_mtu } as u16;
Ok(mtu)
}