use std::{
fmt::Debug,
io::{ErrorKind, Result},
net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs},
ops::Deref,
sync::{Arc, OnceLock},
task::{Context, Poll},
};
use futures::{future::poll_fn, AsyncRead, AsyncWrite, Stream};
pub mod syscall {
use super::*;
#[cfg(unix)]
pub mod unix {
use super::*;
pub trait DriverUnixListener: Sync + Send {
fn local_addr(&self) -> Result<std::os::unix::net::SocketAddr>;
fn poll_next(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(crate::net::unix::UnixStream, std::os::unix::net::SocketAddr)>>;
}
pub trait DriverUnixStream: Sync + Send {
fn local_addr(&self) -> Result<std::os::unix::net::SocketAddr>;
fn peer_addr(&self) -> Result<std::os::unix::net::SocketAddr>;
fn shutdown(&self, how: Shutdown) -> Result<()>;
fn poll_read(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>>;
fn poll_write(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
}
}
pub trait Driver: Send + Sync {
fn tcp_listen(&self, laddrs: &[SocketAddr]) -> Result<TcpListener>;
#[cfg(unix)]
unsafe fn tcp_listener_from_raw_fd(&self, fd: std::os::fd::RawFd) -> Result<TcpListener>;
#[cfg(windows)]
unsafe fn tcp_listener_from_raw_socket(
&self,
socket: std::os::windows::io::RawSocket,
) -> Result<TcpListener>;
fn tcp_connect(&self, raddrs: &SocketAddr) -> Result<TcpStream>;
#[cfg(unix)]
unsafe fn tcp_stream_from_raw_fd(&self, fd: std::os::fd::RawFd) -> Result<TcpStream>;
#[cfg(windows)]
unsafe fn tcp_stream_from_raw_socket(
&self,
socket: std::os::windows::io::RawSocket,
) -> Result<TcpStream>;
fn udp_bind(&self, laddrs: &[SocketAddr]) -> Result<UdpSocket>;
#[cfg(unix)]
unsafe fn udp_from_raw_fd(&self, fd: std::os::fd::RawFd) -> Result<UdpSocket>;
#[cfg(windows)]
unsafe fn udp_from_raw_socket(
&self,
socket: std::os::windows::io::RawSocket,
) -> Result<UdpSocket>;
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
fn unix_listen(&self, path: &std::path::Path) -> Result<crate::net::unix::UnixListener>;
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
fn unix_connect(&self, path: &std::path::Path) -> Result<crate::net::unix::UnixStream>;
}
pub trait DriverTcpListener: Sync + Send {
fn local_addr(&self) -> Result<SocketAddr>;
fn ttl(&self) -> Result<u32>;
fn set_ttl(&self, ttl: u32) -> Result<()>;
fn poll_next(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
}
pub trait DriverTcpStream: Sync + Send + Debug {
fn local_addr(&self) -> Result<SocketAddr>;
fn peer_addr(&self) -> Result<SocketAddr>;
fn ttl(&self) -> Result<u32>;
fn set_ttl(&self, ttl: u32) -> Result<()>;
fn nodelay(&self) -> Result<bool>;
fn set_nodelay(&self, nodelay: bool) -> Result<()>;
fn shutdown(&self, how: Shutdown) -> Result<()>;
fn poll_read(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>>;
fn poll_write(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
}
pub trait DriverUdpSocket: Sync + Send {
fn shutdown(&self, how: Shutdown) -> Result<()>;
fn local_addr(&self) -> Result<SocketAddr>;
fn peer_addr(&self) -> Result<SocketAddr>;
fn ttl(&self) -> Result<u32>;
fn set_ttl(&self, ttl: u32) -> Result<()>;
fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> Result<()>;
fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()>;
fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> Result<()>;
fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()>;
fn set_multicast_loop_v4(&self, on: bool) -> Result<()>;
fn set_multicast_loop_v6(&self, on: bool) -> Result<()>;
fn multicast_loop_v4(&self) -> Result<bool>;
fn multicast_loop_v6(&self) -> Result<bool>;
fn set_broadcast(&self, on: bool) -> Result<()>;
fn broadcast(&self) -> Result<bool>;
fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, SocketAddr)>>;
fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
peer: SocketAddr,
) -> Poll<Result<usize>>;
}
}
pub struct TcpListener(Box<dyn syscall::DriverTcpListener>);
impl<T: syscall::DriverTcpListener + 'static> From<T> for TcpListener {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Deref for TcpListener {
type Target = dyn syscall::DriverTcpListener;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl TcpListener {
pub fn as_raw_ptr(&self) -> &dyn syscall::DriverTcpListener {
&*self.0
}
pub async fn accept(&self) -> Result<(TcpStream, SocketAddr)> {
poll_fn(|cx| self.poll_next(cx)).await
}
pub async fn bind<L: ToSocketAddrs>(laddrs: L) -> Result<Self> {
Self::bind_with(laddrs, get_network_driver()).await
}
pub async fn bind_with<L: ToSocketAddrs>(
laddrs: L,
driver: &dyn syscall::Driver,
) -> Result<Self> {
let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
driver.tcp_listen(&laddrs)
}
#[cfg(unix)]
pub unsafe fn from_raw_fd_with(
fd: std::os::fd::RawFd,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.tcp_listener_from_raw_fd(fd)
}
#[cfg(unix)]
pub unsafe fn from_raw_fd(fd: std::os::fd::RawFd) -> Result<Self> {
Self::from_raw_fd_with(fd, get_network_driver())
}
#[cfg(windows)]
pub unsafe fn from_raw_socket_with(
fd: std::os::windows::io::RawSocket,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.tcp_listener_from_raw_socket(fd)
}
#[cfg(windows)]
pub unsafe fn from_raw_socket(fd: std::os::windows::io::RawSocket) -> Result<Self> {
Self::from_raw_socket_with(fd, get_network_driver())
}
}
impl Stream for TcpListener {
type Item = Result<TcpStream>;
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.as_raw_ptr().poll_next(cx) {
Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))),
Poll::Ready(Err(err)) => {
if err.kind() == ErrorKind::BrokenPipe {
Poll::Ready(None)
} else {
Poll::Ready(Some(Err(err)))
}
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug, Clone)]
pub struct TcpStream(Arc<Box<dyn syscall::DriverTcpStream>>);
impl<T: syscall::DriverTcpStream + 'static> From<T> for TcpStream {
fn from(value: T) -> Self {
Self(Arc::new(Box::new(value)))
}
}
impl Deref for TcpStream {
type Target = dyn syscall::DriverTcpStream;
fn deref(&self) -> &Self::Target {
&**self.0
}
}
impl TcpStream {
pub fn as_raw_ptr(&self) -> &dyn syscall::DriverTcpStream {
&**self.0
}
pub async fn connect<R: ToSocketAddrs>(raddrs: R) -> Result<Self> {
Self::connect_with(raddrs, get_network_driver()).await
}
pub async fn connect_with<R: ToSocketAddrs>(
raddrs: R,
driver: &dyn syscall::Driver,
) -> Result<Self> {
let mut last_error = None;
for raddr in raddrs.to_socket_addrs()?.collect::<Vec<_>>() {
match driver.tcp_connect(&raddr) {
Ok(stream) => {
match poll_fn(|cx| stream.poll_ready(cx)).await {
Ok(()) => {
return Ok(stream);
}
Err(err) => {
last_error = Some(err);
}
}
}
Err(err) => last_error = Some(err),
}
}
Err(last_error.unwrap())
}
#[cfg(unix)]
pub unsafe fn from_raw_fd_with(
fd: std::os::fd::RawFd,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.tcp_stream_from_raw_fd(fd)
}
#[cfg(unix)]
pub unsafe fn from_raw_fd(fd: std::os::fd::RawFd) -> Result<Self> {
Self::from_raw_fd_with(fd, get_network_driver())
}
#[cfg(windows)]
pub unsafe fn from_raw_socket_with(
fd: std::os::windows::io::RawSocket,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.tcp_stream_from_raw_socket(fd)
}
#[cfg(windows)]
pub unsafe fn from_raw_socket(fd: std::os::windows::io::RawSocket) -> Result<Self> {
Self::from_raw_socket_with(fd, get_network_driver())
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
self.as_raw_ptr().poll_read(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.as_raw_ptr().poll_write(cx, buf)
}
fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
self.shutdown(Shutdown::Both)?;
Poll::Ready(Ok(()))
}
}
#[derive(Clone)]
pub struct UdpSocket(Arc<Box<dyn syscall::DriverUdpSocket>>);
impl<T: syscall::DriverUdpSocket + 'static> From<T> for UdpSocket {
fn from(value: T) -> Self {
Self(Arc::new(Box::new(value)))
}
}
impl Deref for UdpSocket {
type Target = dyn syscall::DriverUdpSocket;
fn deref(&self) -> &Self::Target {
&**self.0
}
}
impl UdpSocket {
pub fn as_raw_ptr(&self) -> &dyn syscall::DriverUdpSocket {
&**self.0
}
pub async fn bind<L: ToSocketAddrs>(laddrs: L) -> Result<Self> {
Self::bind_with(laddrs, get_network_driver()).await
}
pub async fn bind_with<L: ToSocketAddrs>(
laddrs: L,
driver: &dyn syscall::Driver,
) -> Result<Self> {
let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
driver.udp_bind(&laddrs)
}
#[cfg(unix)]
pub unsafe fn from_raw_fd_with(
fd: std::os::fd::RawFd,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.udp_from_raw_fd(fd)
}
#[cfg(unix)]
pub unsafe fn from_raw_fd(fd: std::os::fd::RawFd) -> Result<Self> {
Self::from_raw_fd_with(fd, get_network_driver())
}
#[cfg(windows)]
pub unsafe fn from_raw_socket_with(
fd: std::os::windows::io::RawSocket,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.udp_from_raw_socket(fd)
}
#[cfg(windows)]
pub unsafe fn from_raw_socket(fd: std::os::windows::io::RawSocket) -> Result<Self> {
Self::from_raw_socket_with(fd, get_network_driver())
}
pub async fn send_to<Buf: AsRef<[u8]>, A: ToSocketAddrs>(
&self,
buf: Buf,
target: A,
) -> Result<usize> {
let mut last_error = None;
let buf = buf.as_ref();
for raddr in target.to_socket_addrs()? {
match poll_fn(|cx| self.poll_send_to(cx, buf, raddr)).await {
Ok(send_size) => return Ok(send_size),
Err(err) => {
last_error = Some(err);
}
}
}
Err(last_error.unwrap())
}
pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
poll_fn(|cx| self.poll_recv_from(cx, buf)).await
}
}
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub mod unix {
use super::*;
use std::path::Path;
use super::syscall::unix::*;
pub struct UnixListener(Box<dyn DriverUnixListener>);
impl<T: DriverUnixListener + 'static> From<T> for UnixListener {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Deref for UnixListener {
type Target = dyn DriverUnixListener;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl UnixListener {
pub fn as_raw_ptr(&self) -> &dyn DriverUnixListener {
&*self.0
}
pub async fn accept(&self) -> Result<(UnixStream, std::os::unix::net::SocketAddr)> {
poll_fn(|cx| self.poll_next(cx)).await
}
pub async fn bind<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::bind_with(path, get_network_driver()).await
}
pub async fn bind_with<P: AsRef<Path>>(
path: P,
driver: &dyn syscall::Driver,
) -> Result<Self> {
driver.unix_listen(path.as_ref())
}
}
impl Stream for UnixListener {
type Item = Result<UnixStream>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.as_raw_ptr().poll_next(cx) {
Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))),
Poll::Ready(Err(err)) => {
if err.kind() == ErrorKind::BrokenPipe {
Poll::Ready(None)
} else {
Poll::Ready(Some(Err(err)))
}
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Clone)]
pub struct UnixStream(Arc<Box<dyn DriverUnixStream>>);
impl<T: DriverUnixStream + 'static> From<T> for UnixStream {
fn from(value: T) -> Self {
Self(Arc::new(Box::new(value)))
}
}
impl Deref for UnixStream {
type Target = dyn DriverUnixStream;
fn deref(&self) -> &Self::Target {
&**self.0
}
}
impl UnixStream {
pub fn as_raw_ptr(&self) -> &dyn DriverUnixStream {
&**self.0
}
pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::connect_with(path, get_network_driver()).await
}
pub async fn connect_with<P: AsRef<Path>>(
path: P,
driver: &dyn syscall::Driver,
) -> Result<Self> {
let stream = driver.unix_connect(path.as_ref())?;
poll_fn(|cx| stream.poll_ready(cx)).await?;
Ok(stream)
}
}
impl AsyncRead for UnixStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
self.as_raw_ptr().poll_read(cx, buf)
}
}
impl AsyncWrite for UnixStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.as_raw_ptr().poll_write(cx, buf)
}
fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
self.shutdown(Shutdown::Both)?;
Poll::Ready(Ok(()))
}
}
}
static NETWORK_DRIVER: OnceLock<Box<dyn syscall::Driver>> = OnceLock::new();
pub fn get_network_driver() -> &'static dyn syscall::Driver {
NETWORK_DRIVER
.get()
.expect("Call register_network_driver first.")
.as_ref()
}
pub fn register_network_driver<E: syscall::Driver + 'static>(driver: E) {
if NETWORK_DRIVER.set(Box::new(driver)).is_err() {
panic!("Multiple calls to register_global_network are not permitted!!!");
}
}