use std::{
io::{Error, Result},
mem,
os::{
fd::{AsFd, BorrowedFd, FromRawFd},
unix::io::{AsRawFd, RawFd},
},
};
use crate::SocketAddr;
#[derive(Clone, Debug)]
pub struct Socket(RawFd);
impl AsRawFd for Socket {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
impl AsFd for Socket {
fn as_fd(&self) -> BorrowedFd<'_> {
unsafe { BorrowedFd::borrow_raw(self.0) }
}
}
impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Socket(fd)
}
}
impl Drop for Socket {
fn drop(&mut self) {
unsafe { libc::close(self.as_raw_fd()) };
}
}
impl Socket {
pub fn new(protocol: isize) -> Result<Self> {
let res = unsafe {
libc::socket(
libc::PF_NETLINK,
libc::SOCK_DGRAM | libc::SOCK_CLOEXEC,
protocol as libc::c_int,
)
};
if res < 0 {
return Err(Error::last_os_error());
}
Ok(Socket(res))
}
pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let res = unsafe { libc::bind(self.as_raw_fd(), addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn bind_auto(&mut self) -> Result<SocketAddr> {
let mut addr = SocketAddr::new(0, 0);
self.bind(&addr)?;
self.get_address(&mut addr)?;
Ok(addr)
}
pub fn get_address(&self, addr: &mut SocketAddr) -> Result<()> {
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let addr_len_copy = addr_len;
let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
let res = unsafe {
libc::getsockname(self.as_raw_fd(), addr_ptr, addr_len_ptr)
};
if res < 0 {
return Err(Error::last_os_error());
}
assert_eq!(addr_len, addr_len_copy);
Ok(())
}
#[allow(dead_code)]
pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
let mut non_blocking = non_blocking as libc::c_int;
let res = unsafe {
libc::ioctl(self.as_raw_fd(), libc::FIONBIO, &mut non_blocking)
};
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn connect(&self, remote_addr: &SocketAddr) -> Result<()> {
let (addr, addr_len) = remote_addr.as_raw();
let res = unsafe { libc::connect(self.as_raw_fd(), addr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn recv_from<B>(
&self,
buf: &mut B,
flags: libc::c_int,
) -> Result<(usize, SocketAddr)>
where
B: bytes::BufMut,
{
let mut addr = unsafe { mem::zeroed::<libc::sockaddr_nl>() };
let addr_ptr =
&mut addr as *mut libc::sockaddr_nl as *mut libc::sockaddr;
let mut addrlen = mem::size_of_val(&addr);
let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t;
let chunk = buf.chunk_mut();
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;
let res = unsafe {
libc::recvfrom(
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
addr_ptr,
addrlen_ptr,
)
};
if res < 0 {
return Err(Error::last_os_error());
} else {
let written = std::cmp::min(buf_len, res as usize);
unsafe {
buf.advance_mut(written);
}
}
Ok((res as usize, SocketAddr(addr)))
}
pub fn recv<B>(&self, buf: &mut B, flags: libc::c_int) -> Result<usize>
where
B: bytes::BufMut,
{
let chunk = buf.chunk_mut();
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;
let res =
unsafe { libc::recv(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
} else {
let written = std::cmp::min(buf_len, res as usize);
unsafe {
buf.advance_mut(written);
}
}
Ok(res as usize)
}
pub fn recv_from_full(&self) -> Result<(Vec<u8>, SocketAddr)> {
let mut buf: Vec<u8> = Vec::new();
let (peek_len, _) =
self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?;
buf.clear();
buf.reserve(peek_len);
let (rlen, addr) = self.recv_from(&mut buf, 0)?;
assert_eq!(rlen, peek_len);
Ok((buf, addr))
}
pub fn send_to(
&self,
buf: &[u8],
addr: &SocketAddr,
flags: libc::c_int,
) -> Result<usize> {
let (addr_ptr, addr_len) = addr.as_raw();
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe {
libc::sendto(
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
addr_ptr,
addr_len,
)
};
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn send(&self, buf: &[u8], flags: libc::c_int) -> Result<usize> {
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res =
unsafe { libc::send(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn set_pktinfo(&self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
value,
)
}
pub fn get_pktinfo(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
)?;
Ok(res == 1)
}
pub fn add_membership(&self, group: u32) -> Result<()> {
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_ADD_MEMBERSHIP,
group,
)
}
pub fn drop_membership(&self, group: u32) -> Result<()> {
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_DROP_MEMBERSHIP,
group,
)
}
pub fn set_broadcast_error(&self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
value,
)
}
pub fn get_broadcast_error(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
)?;
Ok(res == 1)
}
pub fn set_no_enobufs(&self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
value,
)
}
pub fn get_no_enobufs(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
)?;
Ok(res == 1)
}
pub fn set_listen_all_namespaces(&self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
value,
)
}
pub fn get_listen_all_namespaces(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
)?;
Ok(res == 1)
}
pub fn set_cap_ack(&self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
value,
)
}
pub fn get_cap_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
)?;
Ok(res == 1)
}
pub fn set_ext_ack(&self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
value,
)
}
pub fn get_ext_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
)?;
Ok(res == 1)
}
pub fn set_rx_buf_sz<T>(&self, size: T) -> Result<()> {
setsockopt(self.as_raw_fd(), libc::SOL_SOCKET, libc::SO_RCVBUF, size)
}
pub fn get_rx_buf_sz(&self) -> Result<usize> {
let res = getsockopt::<libc::c_int>(
self.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_RCVBUF,
)?;
Ok(res as usize)
}
pub fn set_netlink_get_strict_chk(&self, value: bool) -> Result<()> {
let value: u32 = value.into();
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_GET_STRICT_CHK,
value,
)
}
}
pub(crate) fn getsockopt<T: Copy>(
fd: RawFd,
level: libc::c_int,
option: libc::c_int,
) -> Result<T> {
let mut slot: T = unsafe { mem::zeroed() };
let slot_ptr = &mut slot as *mut T as *mut libc::c_void;
let mut slot_len = mem::size_of::<T>() as libc::socklen_t;
let slot_len_ptr = &mut slot_len as *mut libc::socklen_t;
let res =
unsafe { libc::getsockopt(fd, level, option, slot_ptr, slot_len_ptr) };
if res < 0 {
return Err(Error::last_os_error());
}
assert_eq!(slot_len as usize, mem::size_of::<T>());
Ok(slot)
}
fn setsockopt<T>(
fd: RawFd,
level: libc::c_int,
option: libc::c_int,
payload: T,
) -> Result<()> {
let payload = &payload as *const T as *const libc::c_void;
let payload_len = mem::size_of::<T>() as libc::socklen_t;
let res =
unsafe { libc::setsockopt(fd, level, option, payload, payload_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use crate::protocols::NETLINK_ROUTE;
#[test]
fn new() {
Socket::new(NETLINK_ROUTE).unwrap();
}
#[test]
fn connect() {
let sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.connect(&SocketAddr::new(0, 0)).unwrap();
}
#[test]
fn bind() {
let mut sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.bind(&SocketAddr::new(4321, 0)).unwrap();
}
#[test]
fn bind_auto() {
let mut sock = Socket::new(NETLINK_ROUTE).unwrap();
let addr = sock.bind_auto().unwrap();
assert!(addr.port_number() != 0);
}
#[test]
fn set_non_blocking() {
let sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.set_non_blocking(true).unwrap();
sock.set_non_blocking(false).unwrap();
}
#[test]
fn options() {
let sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.set_cap_ack(true).unwrap();
assert!(sock.get_cap_ack().unwrap());
sock.set_cap_ack(false).unwrap();
assert!(!sock.get_cap_ack().unwrap());
sock.set_no_enobufs(true).unwrap();
assert!(sock.get_no_enobufs().unwrap());
sock.set_no_enobufs(false).unwrap();
assert!(!sock.get_no_enobufs().unwrap());
sock.set_broadcast_error(true).unwrap();
assert!(sock.get_broadcast_error().unwrap());
sock.set_broadcast_error(false).unwrap();
assert!(!sock.get_broadcast_error().unwrap());
}
}