use std::fmt;
use std::hash::{Hash, Hasher};
use std::io::{Error, Result};
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use super::Protocol;
#[derive(Clone, Debug)]
pub struct Socket(RawFd);
impl AsRawFd for Socket {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Socket(fd)
}
}
impl Drop for Socket {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
}
}
#[derive(Copy, Clone)]
pub struct SocketAddr(libc::sockaddr_nl);
impl Hash for SocketAddr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.nl_family.hash(state);
self.0.nl_pid.hash(state);
self.0.nl_groups.hash(state);
}
}
impl PartialEq for SocketAddr {
fn eq(&self, other: &SocketAddr) -> bool {
self.0.nl_family == other.0.nl_family
&& self.0.nl_pid == other.0.nl_pid
&& self.0.nl_groups == other.0.nl_groups
}
}
impl fmt::Debug for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SocketAddr(nl_family={}, nl_pid={}, nl_groups={})",
self.0.nl_family, self.0.nl_pid, self.0.nl_groups
)
}
}
impl Eq for SocketAddr {}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"address family: {}, pid: {}, multicast groups: {})",
self.0.nl_family, self.0.nl_pid, self.0.nl_groups
)
}
}
impl SocketAddr {
pub fn new(port_number: u32, multicast_groups: u32) -> Self {
let mut addr: libc::sockaddr_nl = unsafe { mem::zeroed() };
addr.nl_family = libc::PF_NETLINK as libc::sa_family_t;
addr.nl_pid = port_number;
addr.nl_groups = multicast_groups;
SocketAddr(addr)
}
pub fn port_number(&self) -> u32 {
self.0.nl_pid
}
pub fn multicast_groups(&self) -> u32 {
self.0.nl_groups
}
fn as_raw(&self) -> (*const libc::sockaddr, libc::socklen_t) {
let addr_ptr = &self.0 as *const libc::sockaddr_nl as *const libc::sockaddr;
let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
(addr_ptr, addr_len)
}
fn as_raw_mut(&mut self) -> (*mut libc::sockaddr, libc::socklen_t) {
let addr_ptr = &mut self.0 as *mut libc::sockaddr_nl as *mut libc::sockaddr;
let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
(addr_ptr, addr_len)
}
}
impl Socket {
pub fn new(protocol: Protocol) -> Result<Self> {
let res =
unsafe { libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, protocol as libc::c_int) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(Socket(res))
}
pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn bind_auto(&mut self) -> Result<SocketAddr> {
let mut addr = SocketAddr::new(0, 0);
self.bind(&addr)?;
self.get_address(&mut addr)?;
Ok(addr)
}
pub fn get_address(&self, addr: &mut SocketAddr) -> Result<()> {
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let addr_len_copy = addr_len;
let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
if res < 0 {
return Err(Error::last_os_error());
}
assert_eq!(addr_len, addr_len_copy);
Ok(())
}
pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
let mut non_blocking = non_blocking as libc::c_int;
let res = unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn connect(&self, remote_addr: &SocketAddr) -> Result<()> {
let (addr, addr_len) = remote_addr.as_raw();
let res = unsafe { libc::connect(self.0, addr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(())
}
pub fn recv_from(&self, buf: &mut [u8], flags: libc::c_int) -> Result<(usize, SocketAddr)> {
let mut addr = unsafe { mem::zeroed::<libc::sockaddr_nl>() };
let addr_ptr = &mut addr as *mut libc::sockaddr_nl as *mut libc::sockaddr;
let mut addrlen = mem::size_of_val(&addr);
let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t;
let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok((res as usize, SocketAddr(addr)))
}
pub fn recv(&self, buf: &mut [u8], flags: libc::c_int) -> Result<usize> {
let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn recv_from_full(&self) -> Result<(Vec<u8>, SocketAddr)> {
let mut buf = Vec::<u8>::new();
let (rlen, _) = self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?;
let mut buf = vec![0; rlen as usize];
let (_, addr) = self.recv_from(&mut buf, 0)?;
Ok((buf, addr))
}
pub fn send_to(&self, buf: &[u8], addr: &SocketAddr, flags: libc::c_int) -> Result<usize> {
let (addr_ptr, addr_len) = addr.as_raw();
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe { libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn send(&self, buf: &[u8], flags: libc::c_int) -> Result<usize> {
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Ok(res as usize)
}
pub fn set_pktinfo(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, value)
}
pub fn get_pktinfo(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO)?;
Ok(res == 1)
}
pub fn add_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_ADD_MEMBERSHIP,
group,
)
}
pub fn drop_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_DROP_MEMBERSHIP,
group,
)
}
pub fn set_broadcast_error(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
value,
)
}
pub fn get_broadcast_error(&self) -> Result<bool> {
let res =
getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR)?;
Ok(res == 1)
}
pub fn set_no_enobufs(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, value)
}
pub fn get_no_enobufs(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS)?;
Ok(res == 1)
}
pub fn set_listen_all_namespaces(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(
self.0,
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
value,
)
}
pub fn get_listen_all_namespaces(&self) -> Result<bool> {
let res =
getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID)?;
Ok(res == 1)
}
pub fn set_cap_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = if value { 1 } else { 0 };
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, value)
}
pub fn get_cap_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK)?;
Ok(res == 1)
}
}
pub(crate) fn getsockopt<T: Copy>(fd: RawFd, level: libc::c_int, option: libc::c_int) -> Result<T> {
unsafe {
let mut slot: T = mem::zeroed();
let slot_ptr = &mut slot as *mut T as *mut libc::c_void;
let mut slot_len = mem::size_of::<T>() as libc::socklen_t;
let slot_len_ptr = &mut slot_len as *mut libc::socklen_t;
let res = libc::getsockopt(fd, level, option, slot_ptr, slot_len_ptr);
if res < 0 {
return Err(Error::last_os_error());
}
assert_eq!(slot_len as usize, mem::size_of::<T>());
Ok(slot)
}
}
fn setsockopt<T>(fd: RawFd, level: libc::c_int, option: libc::c_int, payload: T) -> Result<()> {
unsafe {
let payload = &payload as *const T as *const libc::c_void;
let payload_len = mem::size_of::<T>() as libc::socklen_t;
let res = libc::setsockopt(fd, level, option, payload, payload_len);
if res < 0 {
return Err(Error::last_os_error());
}
}
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn new() {
Socket::new(Protocol::Route).unwrap();
}
#[test]
fn connect() {
let sock = Socket::new(Protocol::Route).unwrap();
sock.connect(&SocketAddr::new(0, 0)).unwrap();
}
#[test]
fn bind() {
let mut sock = Socket::new(Protocol::Route).unwrap();
sock.bind(&SocketAddr::new(4321, 0)).unwrap();
}
#[test]
fn bind_auto() {
let mut sock = Socket::new(Protocol::Route).unwrap();
let addr = sock.bind_auto().unwrap();
assert!(addr.port_number() != 0);
}
#[test]
fn set_non_blocking() {
let sock = Socket::new(Protocol::Route).unwrap();
sock.set_non_blocking(true).unwrap();
sock.set_non_blocking(false).unwrap();
}
#[test]
fn options() {
let mut sock = Socket::new(Protocol::Route).unwrap();
sock.set_cap_ack(true).unwrap();
assert!(sock.get_cap_ack().unwrap());
sock.set_cap_ack(false).unwrap();
assert!(!sock.get_cap_ack().unwrap());
sock.set_no_enobufs(true).unwrap();
assert!(sock.get_no_enobufs().unwrap());
sock.set_no_enobufs(false).unwrap();
assert!(!sock.get_no_enobufs().unwrap());
sock.set_broadcast_error(true).unwrap();
assert!(sock.get_broadcast_error().unwrap());
sock.set_broadcast_error(false).unwrap();
assert!(!sock.get_broadcast_error().unwrap());
}
#[test]
fn address() {
let mut addr = SocketAddr::new(42, 1234);
assert_eq!(addr.port_number(), 42);
assert_eq!(addr.multicast_groups(), 1234);
{
let (addr_ptr, _) = addr.as_raw();
let inner_addr = unsafe { *(addr_ptr as *const libc::sockaddr_nl) };
assert_eq!(inner_addr.nl_pid, 42);
assert_eq!(inner_addr.nl_groups, 1234);
}
{
let (addr_ptr, _) = addr.as_raw_mut();
let sockaddr_nl = addr_ptr as *mut libc::sockaddr_nl;
unsafe {
sockaddr_nl.as_mut().unwrap().nl_pid = 24;
sockaddr_nl.as_mut().unwrap().nl_groups = 4321
}
}
assert_eq!(addr.port_number(), 24);
assert_eq!(addr.multicast_groups(), 4321);
}
}