#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]
use ::std::{
io,
marker::PhantomData,
net::ToSocketAddrs,
os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd},
path::Path,
};
use bitflags::bitflags;
use capsicum::casper;
use nix::{
errno::Errno,
sys::socket::{
AddressFamily,
SockFlag,
SockType,
SockaddrIn,
SockaddrIn6,
SockaddrLike,
},
Result,
};
mod ffi;
pub mod std;
#[cfg(feature = "tokio")]
pub mod tokio;
casper::service_connection! {
#[derive(Debug)]
pub CapNetAgent,
c"system.net",
net
}
impl CapNetAgent {
pub fn bind<F>(&mut self, sock: &F, addr: &dyn SockaddrLike) -> Result<()>
where
F: AsFd,
{
let fd = sock.as_fd().as_raw_fd();
let res = unsafe {
ffi::cap_bind(self.0.as_mut_ptr(), fd, addr.as_ptr(), addr.len())
};
Errno::result(res).map(drop)
}
fn bind_std_fd(
&mut self,
sock: BorrowedFd,
addr: ::std::net::SocketAddr,
) -> io::Result<()> {
let ap = self.0.as_mut_ptr();
let fd = sock.as_raw_fd();
let res = match addr {
::std::net::SocketAddr::V4(addr) => {
let sin = SockaddrIn::from(addr);
unsafe { ffi::cap_bind(ap, fd, sin.as_ptr(), sin.len()) }
}
::std::net::SocketAddr::V6(addr) => {
let sin6 = SockaddrIn6::from(addr);
unsafe { ffi::cap_bind(ap, fd, sin6.as_ptr(), sin6.len()) }
}
};
if res == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
fn bind_std_to_addrs<A, S>(&mut self, addrs: A) -> io::Result<S>
where
A: ToSocketAddrs,
S: From<OwnedFd>,
{
let mut last_err = None;
for addr in addrs.to_socket_addrs()? {
let family = if addr.is_ipv4() {
AddressFamily::Inet
} else {
AddressFamily::Inet6
};
let sock = nix::sys::socket::socket(
family,
SockType::Stream,
SockFlag::empty(),
None,
)
.map_err(io::Error::from)?;
match self.bind_std_fd(sock.as_fd(), addr) {
Ok(()) => return Ok(S::from(sock)),
Err(e) => {
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
}
fn bind_std_unix<P>(
&mut self,
sock_type: SockType,
path: P,
) -> io::Result<OwnedFd>
where
P: AsRef<Path>,
{
let s = nix::sys::socket::socket(
AddressFamily::Unix,
sock_type,
SockFlag::empty(),
None,
)
.unwrap();
let want = nix::sys::socket::UnixAddr::new(path.as_ref()).unwrap();
self.bind(&s, &want)?;
Ok(s)
}
pub fn connect<F>(
&mut self,
sock: &F,
addr: &dyn SockaddrLike,
) -> Result<()>
where
F: AsFd,
{
let fd = sock.as_fd().as_raw_fd();
let res = unsafe {
ffi::cap_connect(self.0.as_mut_ptr(), fd, addr.as_ptr(), addr.len())
};
Errno::result(res).map(drop)
}
fn connect_std_fd(
&mut self,
sock: BorrowedFd,
addr: ::std::net::SocketAddr,
) -> io::Result<()> {
let ap = self.0.as_mut_ptr();
let fd = sock.as_raw_fd();
let res = match addr {
::std::net::SocketAddr::V4(addr) => {
let sin = SockaddrIn::from(addr);
unsafe { ffi::cap_connect(ap, fd, sin.as_ptr(), sin.len()) }
}
::std::net::SocketAddr::V6(addr) => {
let sin6 = SockaddrIn6::from(addr);
unsafe { ffi::cap_connect(ap, fd, sin6.as_ptr(), sin6.len()) }
}
};
if res == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
fn connect_std_to_addrs<A>(
&mut self,
sock: BorrowedFd,
addrs: A,
) -> io::Result<()>
where
A: ToSocketAddrs,
{
let mut last_err = None;
for addr in addrs.to_socket_addrs()? {
match self.connect_std_fd(sock, addr) {
Ok(()) => return Ok(()),
Err(e) => {
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
}
pub fn limit(&mut self, flags: LimitFlags) -> Limit<'_> {
let limit = unsafe {
ffi::cap_net_limit_init(self.0.as_mut_ptr(), flags.bits())
};
assert!(!limit.is_null());
Limit {
limit,
phantom: PhantomData,
}
}
}
#[repr(transparent)]
pub struct Limit<'a> {
limit: *mut ffi::cap_net_limit_t,
phantom: PhantomData<&'a mut CapNetAgent>,
}
bitflags! {
pub struct LimitFlags: u64 {
const BIND = ffi::CAPNET_BIND as u64;
const CONNECT = ffi::CAPNET_CONNECT as u64;
}
}
impl Limit<'_> {
pub fn bind(&mut self, sa: &dyn SockaddrLike) -> &mut Self {
let newlimit = unsafe {
ffi::cap_net_limit_bind(self.limit, sa.as_ptr(), sa.len())
};
assert_eq!(newlimit, self.limit);
self
}
pub fn connect(&mut self, sa: &dyn SockaddrLike) -> &mut Self {
let newlimit = unsafe {
ffi::cap_net_limit_connect(self.limit, sa.as_ptr(), sa.len())
};
assert_eq!(newlimit, self.limit);
self
}
pub fn limit(self) -> io::Result<()> {
let res = unsafe { ffi::cap_net_limit(self.limit) };
if res == 0 {
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
}