#![deny(missing_docs)]
#![forbid(unsafe_code)]
#![doc(
html_logo_url = "https://raw.githubusercontent.com/bytecodealliance/cap-std/main/media/cap-std.svg"
)]
#![doc(
html_favicon_url = "https://raw.githubusercontent.com/bytecodealliance/cap-std/main/media/cap-std.ico"
)]
use cap_primitives::net::no_socket_addrs;
use cap_std::net::{IpAddr, Pool, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
use rustix::fd::OwnedFd;
use std::io;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum AddressFamily {
Ipv4,
Ipv6,
}
impl AddressFamily {
pub fn of_ip_addr(ip_addr: IpAddr) -> Self {
match ip_addr {
IpAddr::V4(_) => AddressFamily::Ipv4,
IpAddr::V6(_) => AddressFamily::Ipv6,
}
}
pub fn of_socket_addr(socket_addr: SocketAddr) -> Self {
match socket_addr {
SocketAddr::V4(_) => AddressFamily::Ipv4,
SocketAddr::V6(_) => AddressFamily::Ipv6,
}
}
}
impl From<AddressFamily> for rustix::net::AddressFamily {
fn from(address_family: AddressFamily) -> Self {
match address_family {
AddressFamily::Ipv4 => rustix::net::AddressFamily::INET,
AddressFamily::Ipv6 => rustix::net::AddressFamily::INET6,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Blocking {
No,
Yes,
}
pub trait TcpListenerExt: private::Sealed + Sized {
fn new(address_family: AddressFamily, blocking: Blocking) -> io::Result<Self>;
fn listen(&self, backlog: Option<i32>) -> io::Result<()>;
fn accept_with(&self, blocking: Blocking) -> io::Result<(TcpStream, SocketAddr)>;
}
impl TcpListenerExt for TcpListener {
fn new(address_family: AddressFamily, blocking: Blocking) -> io::Result<Self> {
socket(address_family, blocking, rustix::net::SocketType::STREAM).map(Self::from)
}
fn listen(&self, backlog: Option<i32>) -> io::Result<()> {
let backlog = backlog.unwrap_or_else(default_backlog);
Ok(rustix::net::listen(self, backlog)?)
}
fn accept_with(&self, blocking: Blocking) -> io::Result<(TcpStream, SocketAddr)> {
let (stream, addr) = rustix::net::acceptfrom_with(self, socket_flags(blocking))?;
set_socket_flags(&stream, blocking)?;
let addr = SocketAddr::try_from(addr.unwrap()).unwrap();
Ok((TcpStream::from(stream), addr))
}
}
pub trait UdpSocketExt: private::Sealed + Sized {
fn new(address_family: AddressFamily, blocking: Blocking) -> io::Result<Self>;
}
impl UdpSocketExt for UdpSocket {
fn new(address_family: AddressFamily, blocking: Blocking) -> io::Result<Self> {
socket(address_family, blocking, rustix::net::SocketType::DGRAM).map(Self::from)
}
}
pub trait PoolExt: private::Sealed {
fn bind_existing_tcp_listener<A: ToSocketAddrs>(
&self,
listener: &TcpListener,
addrs: A,
) -> io::Result<()>;
fn bind_existing_udp_socket<A: ToSocketAddrs>(
&self,
socket: &UdpSocket,
addrs: A,
) -> io::Result<()>;
fn connect_into_tcp_stream<A: ToSocketAddrs>(
&self,
socket: TcpListener,
addrs: A,
) -> io::Result<TcpStream>;
fn connect_existing_tcp_listener<A: ToSocketAddrs>(
&self,
socket: &TcpListener,
addrs: A,
) -> io::Result<()>;
fn connect_existing_udp_socket<A: ToSocketAddrs>(
&self,
socket: &UdpSocket,
addrs: A,
) -> io::Result<()>;
fn tcp_binder<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<TcpBinder>;
fn udp_binder<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<UdpBinder>;
fn tcp_connecter<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<TcpConnecter>;
fn udp_connecter<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<UdpConnecter>;
}
impl PoolExt for Pool {
fn bind_existing_tcp_listener<A: ToSocketAddrs>(
&self,
listener: &TcpListener,
addrs: A,
) -> io::Result<()> {
let addrs = addrs.to_socket_addrs()?;
let mut last_err = None;
for addr in addrs {
self._pool().check_addr(&addr)?;
set_reuseaddr(listener)?;
match rustix::net::bind(listener, &addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err.into()),
}
}
match last_err {
Some(err) => Err(err),
None => Err(no_socket_addrs()),
}
}
fn bind_existing_udp_socket<A: ToSocketAddrs>(
&self,
socket: &UdpSocket,
addrs: A,
) -> io::Result<()> {
let addrs = addrs.to_socket_addrs()?;
let mut last_err = None;
for addr in addrs {
self._pool().check_addr(&addr)?;
match rustix::net::bind(socket, &addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err),
}
}
match last_err {
Some(err) => Err(err.into()),
None => Err(no_socket_addrs()),
}
}
fn connect_into_tcp_stream<A: ToSocketAddrs>(
&self,
socket: TcpListener,
addrs: A,
) -> io::Result<TcpStream> {
self.connect_existing_tcp_listener(&socket, addrs)?;
Ok(TcpStream::from(OwnedFd::from(socket)))
}
fn connect_existing_tcp_listener<A: ToSocketAddrs>(
&self,
socket: &TcpListener,
addrs: A,
) -> io::Result<()> {
let addrs = addrs.to_socket_addrs()?;
let mut last_err = None;
for addr in addrs {
self._pool().check_addr(&addr)?;
match rustix::net::connect(socket, &addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err),
}
}
match last_err {
Some(err) => Err(err.into()),
None => Err(no_socket_addrs()),
}
}
fn connect_existing_udp_socket<A: ToSocketAddrs>(
&self,
socket: &UdpSocket,
addrs: A,
) -> io::Result<()> {
let addrs = addrs.to_socket_addrs()?;
let mut last_err = None;
for addr in addrs {
self._pool().check_addr(&addr)?;
match rustix::net::connect(socket, &addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err),
}
}
match last_err {
Some(err) => Err(err.into()),
None => Err(no_socket_addrs()),
}
}
fn tcp_binder<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<TcpBinder> {
Ok(TcpBinder(check_addrs(self._pool(), addrs)?))
}
fn udp_binder<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<UdpBinder> {
Ok(UdpBinder(check_addrs(self._pool(), addrs)?))
}
fn tcp_connecter<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<TcpConnecter> {
Ok(TcpConnecter(check_addrs(self._pool(), addrs)?))
}
fn udp_connecter<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<UdpConnecter> {
Ok(UdpConnecter(check_addrs(self._pool(), addrs)?))
}
}
fn check_addrs<A: ToSocketAddrs>(
pool: &cap_primitives::net::Pool,
addrs: A,
) -> io::Result<smallvec::SmallVec<[SocketAddr; 1]>> {
let mut checked = smallvec::SmallVec::new();
for addr in addrs.to_socket_addrs()? {
pool.check_addr(&addr)?;
checked.push(addr);
}
Ok(checked)
}
pub struct TcpBinder(smallvec::SmallVec<[SocketAddr; 1]>);
impl TcpBinder {
pub fn bind_existing_tcp_listener(&self, listener: &TcpListener) -> io::Result<()> {
let mut last_err = None;
for addr in &self.0 {
set_reuseaddr(listener)?;
match rustix::net::bind(listener, addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err.into()),
}
}
match last_err {
Some(err) => Err(err),
None => Err(no_socket_addrs()),
}
}
}
pub struct UdpBinder(smallvec::SmallVec<[SocketAddr; 1]>);
impl UdpBinder {
pub fn bind_existing_udp_socket(&self, socket: &UdpSocket) -> io::Result<()> {
let mut last_err = None;
for addr in &self.0 {
match rustix::net::bind(socket, addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err.into()),
}
}
match last_err {
Some(err) => Err(err),
None => Err(no_socket_addrs()),
}
}
}
pub struct TcpConnecter(smallvec::SmallVec<[SocketAddr; 1]>);
impl TcpConnecter {
pub fn connect_into_tcp_stream(&self, socket: TcpListener) -> io::Result<TcpStream> {
self.connect_existing_tcp_listener(&socket)?;
Ok(TcpStream::from(OwnedFd::from(socket)))
}
pub fn connect_existing_tcp_listener(&self, socket: &TcpListener) -> io::Result<()> {
let mut last_err = None;
for addr in &self.0 {
match rustix::net::connect(socket, addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err),
}
}
match last_err {
Some(err) => Err(err.into()),
None => Err(no_socket_addrs()),
}
}
}
pub struct UdpConnecter(smallvec::SmallVec<[SocketAddr; 1]>);
impl UdpConnecter {
pub fn connect_existing_udp_socket(&self, socket: &UdpSocket) -> io::Result<()> {
let mut last_err = None;
for addr in &self.0 {
match rustix::net::connect(socket, addr) {
Ok(()) => return Ok(()),
Err(err) => last_err = Some(err),
}
}
match last_err {
Some(err) => Err(err.into()),
None => Err(no_socket_addrs()),
}
}
}
fn socket(
address_family: AddressFamily,
blocking: Blocking,
socket_type: rustix::net::SocketType,
) -> io::Result<OwnedFd> {
#[cfg(windows)]
{
use std::sync::Once;
static START: Once = Once::new();
START.call_once(|| {
std::net::TcpStream::connect(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::UNSPECIFIED,
0,
))
.unwrap_err();
});
}
let socket = rustix::net::socket_with(
address_family.into(),
socket_type,
socket_flags(blocking),
None,
)?;
set_socket_flags(&socket, blocking)?;
Ok(socket)
}
fn socket_flags(blocking: Blocking) -> rustix::net::SocketFlags {
let _ = blocking;
#[allow(unused_mut)]
let mut socket_flags = rustix::net::SocketFlags::empty();
#[cfg(not(any(
windows,
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "visionos",
target_os = "haiku"
)))]
{
socket_flags |= rustix::net::SocketFlags::CLOEXEC;
}
#[cfg(not(any(
windows,
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "visionos",
target_os = "haiku"
)))]
match blocking {
Blocking::Yes => (),
Blocking::No => socket_flags |= rustix::net::SocketFlags::NONBLOCK,
}
socket_flags
}
fn set_socket_flags(fd: &OwnedFd, blocking: Blocking) -> io::Result<()> {
let _ = fd;
let _ = blocking;
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "visionos",
))]
{
rustix::io::ioctl_fioclex(fd)?;
}
#[cfg(any(
windows,
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_os = "visionos"
))]
match blocking {
Blocking::Yes => (),
Blocking::No => rustix::io::ioctl_fionbio(fd, true)?,
}
#[cfg(target_os = "haiku")]
{
let mut flags = rustix::fs::fcntl_getfd(fd)?;
flags |= rustix::fs::OFlags::CLOEXEC;
match blocking {
Blocking::Yes => (),
Blocking::No => flags |= rustix::fs::OFlags::NONBLOCK,
}
rustix::fs::fcntl_setfd(fd, flags)?;
}
Ok(())
}
fn set_reuseaddr(listener: &TcpListener) -> io::Result<()> {
let _ = listener;
#[cfg(not(windows))]
rustix::net::sockopt::set_socket_reuseaddr(listener, true)?;
Ok(())
}
fn default_backlog() -> i32 {
#[cfg(target_os = "horizon")]
let backlog = 20;
#[cfg(not(target_os = "horizon"))]
let backlog = 128;
backlog
}
mod private {
pub trait Sealed {}
impl Sealed for super::TcpListener {}
impl Sealed for super::UdpSocket {}
impl Sealed for super::Pool {}
}