use super::Error;
use super::Result;
use libc::c_int;
use nix::sys::socket;
use std::marker::PhantomData;
use std::os::fd::AsRawFd;
use std::os::fd::OwnedFd;
pub struct Socket<T>(OwnedFd, PhantomData<T>);
const NL_CFG_F_NONROOT_RECV: c_int = 1;
const NL_CFG_F_NONROOT_SEND: c_int = 1 << 1;
const NLMSG_HDRSIZE: usize = 0x10;
impl<T> Socket<T> {
fn __new() -> Result<Self> {
let fd = match socket::socket(
nix::sys::socket::AddressFamily::Netlink,
nix::sys::socket::SockType::Raw,
nix::sys::socket::SockFlag::empty(),
nix::sys::socket::SockProtocol::NetlinkUserSock,
) {
Ok(fd) => fd,
_ => socket::socket(
nix::sys::socket::AddressFamily::Netlink,
nix::sys::socket::SockType::Raw,
nix::sys::socket::SockFlag::from_bits_truncate(NL_CFG_F_NONROOT_RECV)
| nix::sys::socket::SockFlag::from_bits_truncate(NL_CFG_F_NONROOT_SEND),
nix::sys::socket::SockProtocol::NetlinkUserSock,
)?,
};
let pid = unsafe { libc::getpid() };
socket::bind(
fd.as_raw_fd(),
&nix::sys::socket::NetlinkAddr::new(pid as u32, 0),
)?;
Ok(Socket(fd, PhantomData))
}
pub fn new() -> Result<Self> {
let s = Self::__new()?;
let opt = 22;
use std::mem;
s.setsockopt(
270,
libc::NETLINK_ADD_MEMBERSHIP,
&opt as *const i32 as *const libc::c_void,
mem::size_of::<c_int>() as u32,
)?;
let to = libc::timespec {
tv_sec: 1 as libc::time_t,
tv_nsec: 0 as libc::c_long,
};
s.setsockopt(
libc::SOL_SOCKET,
libc::SO_RCVTIMEO,
&to as *const libc::timespec as *const libc::c_void,
mem::size_of::<libc::timespec>() as u32,
)?;
Ok(s)
}
fn setsockopt(
&self,
level: c_int,
option: c_int,
val: *const libc::c_void,
sz: u32,
) -> Result<()> {
let res = unsafe { libc::setsockopt(self.0.as_raw_fd(), level, option as c_int, val, sz) };
if res == -1 {
return Err(Error::from(nix::Error::last()));
}
Ok(())
}
fn __recv(&self, buf: &mut [u8], flags: nix::sys::socket::MsgFlags) -> Result<usize> {
let mut nl_buf = [0u8; 1024];
let end = socket::recvmsg::<'_, '_, '_, ()>(
self.0.as_raw_fd(),
&mut [std::io::IoSliceMut::new(&mut nl_buf[..])],
None,
flags,
)
.map(|r| r.bytes)
.map_err(Error::from)?;
buf[..(end - NLMSG_HDRSIZE)].copy_from_slice(&nl_buf[NLMSG_HDRSIZE..end]);
Ok(end - NLMSG_HDRSIZE)
}
fn __send(&self, buf: &[u8]) -> Result<()> {
let len = NLMSG_HDRSIZE + buf.len();
let mut msg = Vec::<u8>::with_capacity(len);
msg.resize(4, 0u8);
super::super::serialize::u32_to_u8s(&mut msg[0..4], len as u32);
msg.extend_from_slice(&[0u8; 12]);
msg.extend_from_slice(buf);
socket::sendmsg(
self.0.as_raw_fd(),
&[std::io::IoSlice::new(&msg[..])],
&[],
nix::sys::socket::MsgFlags::empty(),
None::<&()>,
)
.map(|_| ())
.map_err(Error::from)
}
fn __close(&mut self) -> Result<()> {
let ok = unsafe { libc::close(self.0.as_raw_fd()) as i32 };
if ok < 0 {
Err(Error(format!("could not close netlink socket: {}", ok)))
} else {
Ok(())
}
}
}
use super::Blocking;
impl super::Ipc for Socket<Blocking> {
type Addr = ();
fn name() -> String {
String::from("netlink")
}
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Addr)> {
self.__recv(buf, nix::sys::socket::MsgFlags::empty())
.map(|s| (s, ()))
}
fn send(&self, buf: &[u8], _to: &Self::Addr) -> Result<()> {
self.__send(buf)
}
fn close(&mut self) -> Result<()> {
self.__close()
}
}
use super::Nonblocking;
impl super::Ipc for Socket<Nonblocking> {
type Addr = ();
fn name() -> String {
String::from("netlink")
}
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Addr)> {
self.__recv(buf, nix::sys::socket::MsgFlags::MSG_DONTWAIT)
.map(|s| (s, ()))
}
fn send(&self, buf: &[u8], _to: &Self::Addr) -> Result<()> {
self.__send(buf)
}
fn close(&mut self) -> Result<()> {
self.__close()
}
}