use std::{
io::{Error, Result},
mem,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
};
use crate::SocketAddr;
#[derive(Clone, Debug)]
pub struct Socket(RawFd);
impl AsRawFd for Socket {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Socket(fd)
}
}
impl Drop for Socket {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
}
}
impl Socket {
pub fn new(protocol: isize) -> Result<Self> {
let res = unsafe {
libc::socket(
libc::PF_NETLINK,
libc::SOCK_DGRAM | libc::SOCK_CLOEXEC,
protocol as libc::c_int,
)
};
if res < 0 {
return Err(Error::last_os_error());
}
Ok(Socket(res))
}
pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn bind_auto(&mut self) -> Result<SocketAddr> {
let mut addr = SocketAddr::new(0, 0);
self.bind(&addr)?;
self.get_address(&mut addr)?;
Ok(addr)
}
pub fn get_address(&self, addr: &mut SocketAddr) -> Result<()> {
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let addr_len_copy = addr_len;
let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
if res < 0 {
return Err(Error::last_os_error());
}
assert_eq!(addr_len, addr_len_copy);
Ok(())
}
#[allow(dead_code)]
pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
let mut non_blocking = non_blocking as libc::c_int;
let res = unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn connect(&self, remote_addr: &SocketAddr) -> Result<()> {
let (addr, addr_len) = remote_addr.as_raw();
let res = unsafe { libc::connect(self.0, addr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn recv_from<B>(&self, buf: &mut B, flags: libc::c_int) -> Result<(usize, SocketAddr)>
where
B: bytes::BufMut,
{
let mut addr = unsafe { mem::zeroed::<libc::sockaddr_nl>() };
let addr_ptr = &mut addr as *mut libc::sockaddr_nl as *mut libc::sockaddr;
let mut addrlen = mem::size_of_val(&addr);
let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t;
let chunk = buf.chunk_mut();
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;
let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) };
if res < 0 {
return Err(Error::last_os_error());
} else {
let written = std::cmp::min(buf_len, res as usize);
unsafe {
buf.advance_mut(written);
}
}
Ok((res as usize, SocketAddr(addr)))
}
pub fn recv<B>(&self, buf: &mut B, flags: libc::c_int) -> Result<usize>
where
B: bytes::BufMut,
{
let chunk = buf.chunk_mut();
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;
let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
} else {
let written = std::cmp::min(buf_len, res as usize);
unsafe {
buf.advance_mut(written);
}
}
Ok(res as usize)
}
pub fn recv_from_full(&self) -> Result<(Vec<u8>, SocketAddr)> {
let mut buf: Vec<u8> = Vec::new();
let (peek_len, _) = self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?;
buf.clear();
buf.reserve(peek_len);
let (rlen, addr) = self.recv_from(&mut buf, 0)?;
assert_eq!(rlen, peek_len);
Ok((buf, addr))
}
pub fn send_to(&self, buf: &[u8], addr: &SocketAddr, flags: libc::c_int) -> Result<usize> {
let (addr_ptr, addr_len) = addr.as_raw();
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe { libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn send(&self, buf: &[u8], flags: libc::c_int) -> Result<usize> {
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn set_pktinfo(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, value)
}
pub fn get_pktinfo(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO)?;
Ok(res == 1)
}
pub fn add_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_ADD_MEMBERSHIP,
group,
)
}
pub fn drop_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_DROP_MEMBERSHIP,
group,
)
}
pub fn set_broadcast_error(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
value,
)
}
pub fn get_broadcast_error(&self) -> Result<bool> {
let res =
getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR)?;
Ok(res == 1)
}
pub fn set_no_enobufs(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, value)
}
pub fn get_no_enobufs(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS)?;
Ok(res == 1)
}
pub fn set_listen_all_namespaces(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
value,
)
}
pub fn get_listen_all_namespaces(&self) -> Result<bool> {
let res =
getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID)?;
Ok(res == 1)
}
pub fn set_cap_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, value)
}
pub fn get_cap_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK)?;
Ok(res == 1)
}
}
pub(crate) fn getsockopt<T: Copy>(fd: RawFd, level: libc::c_int, option: libc::c_int) -> Result<T> {
let mut slot: T = unsafe { mem::zeroed() };
let slot_ptr = &mut slot as *mut T as *mut libc::c_void;
let mut slot_len = mem::size_of::<T>() as libc::socklen_t;
let slot_len_ptr = &mut slot_len as *mut libc::socklen_t;
let res = unsafe { libc::getsockopt(fd, level, option, slot_ptr, slot_len_ptr) };
if res < 0 {
return Err(Error::last_os_error());
}
assert_eq!(slot_len as usize, mem::size_of::<T>());
Ok(slot)
}
fn setsockopt<T>(fd: RawFd, level: libc::c_int, option: libc::c_int, payload: T) -> Result<()> {
let payload = &payload as *const T as *const libc::c_void;
let payload_len = mem::size_of::<T>() as libc::socklen_t;
let res = unsafe { libc::setsockopt(fd, level, option, payload, payload_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use crate::protocols::NETLINK_ROUTE;
#[test]
fn new() {
Socket::new(NETLINK_ROUTE).unwrap();
}
#[test]
fn connect() {
let sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.connect(&SocketAddr::new(0, 0)).unwrap();
}
#[test]
fn bind() {
let mut sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.bind(&SocketAddr::new(4321, 0)).unwrap();
}
#[test]
fn bind_auto() {
let mut sock = Socket::new(NETLINK_ROUTE).unwrap();
let addr = sock.bind_auto().unwrap();
assert!(addr.port_number() != 0);
}
#[test]
fn set_non_blocking() {
let sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.set_non_blocking(true).unwrap();
sock.set_non_blocking(false).unwrap();
}
#[test]
fn options() {
let mut sock = Socket::new(NETLINK_ROUTE).unwrap();
sock.set_cap_ack(true).unwrap();
assert!(sock.get_cap_ack().unwrap());
sock.set_cap_ack(false).unwrap();
assert!(!sock.get_cap_ack().unwrap());
sock.set_no_enobufs(true).unwrap();
assert!(sock.get_no_enobufs().unwrap());
sock.set_no_enobufs(false).unwrap();
assert!(!sock.get_no_enobufs().unwrap());
sock.set_broadcast_error(true).unwrap();
assert!(sock.get_broadcast_error().unwrap());
sock.set_broadcast_error(false).unwrap();
assert!(!sock.get_broadcast_error().unwrap());
}
}