use std::{
fmt::{Debug, Display},
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
sync::Arc,
task::Poll,
};
use futures::{AsyncRead, AsyncWrite};
use rasi_syscall::{global_network, Handle, Network};
use crate::utils::cancelable_would_block;
pub struct TcpListener {
sys_socket: rasi_syscall::Handle,
syscall: &'static dyn Network,
}
impl TcpListener {
pub async fn bind_with<A: ToSocketAddrs>(
laddrs: A,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
let sys_socket =
cancelable_would_block(|cx| syscall.tcp_listener_bind(cx.waker().clone(), &laddrs))
.await?;
Ok(TcpListener {
sys_socket,
syscall,
})
}
pub async fn bind<A: ToSocketAddrs>(laddrs: A) -> io::Result<Self> {
Self::bind_with(laddrs, global_network()).await
}
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (sys_socket, raddr) = cancelable_would_block(|cx| {
self.syscall
.tcp_listener_accept(cx.waker().clone(), &self.sys_socket)
})
.await?;
Ok((TcpStream::new(sys_socket, self.syscall), raddr))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.syscall.tcp_listener_local_addr(&self.sys_socket)
}
pub fn ttl(&self) -> io::Result<u32> {
self.syscall.tcp_listener_ttl(&self.sys_socket)
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.syscall.tcp_listener_set_ttl(&self.sys_socket, ttl)
}
}
pub struct TcpStream {
sys_socket: rasi_syscall::Handle,
syscall: &'static dyn Network,
write_cancel_handle: Option<Handle>,
read_cancel_handle: Option<Handle>,
}
impl Debug for TcpStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpStream")
.field("socket", &self.sys_socket)
.field(
"syscall",
&format!("{:?}", self.syscall as *const dyn Network),
)
.finish()
}
}
impl TcpStream {
fn new(sys_socket: Handle, syscall: &'static dyn Network) -> Self {
Self {
sys_socket,
syscall,
write_cancel_handle: None,
read_cancel_handle: None,
}
}
pub async fn connect_with<A: ToSocketAddrs>(
raddrs: A,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let raddrs = raddrs.to_socket_addrs()?.collect::<Vec<_>>();
let sys_socket =
cancelable_would_block(|cx| syscall.tcp_stream_connect(cx.waker().clone(), &raddrs))
.await?;
Ok(Self::new(sys_socket, syscall))
}
pub async fn connect<A: ToSocketAddrs>(raddrs: A) -> io::Result<Self> {
Self::connect_with(raddrs, global_network()).await
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.syscall.tcp_stream_local_addr(&self.sys_socket)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.syscall.tcp_stream_remote_addr(&self.sys_socket)
}
pub fn ttl(&self) -> io::Result<u32> {
self.syscall.tcp_stream_ttl(&self.sys_socket)
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.syscall.tcp_stream_set_ttl(&self.sys_socket, ttl)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.syscall.tcp_stream_nodelay(&self.sys_socket)
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.syscall
.tcp_stream_set_nodelay(&self.sys_socket, nodelay)
}
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
self.syscall.tcp_stream_shutdown(&self.sys_socket, how)
}
pub fn split(self) -> (TcpStreamRead, TcpStreamWrite) {
let sys_socket = Arc::new(self.sys_socket);
(
TcpStreamRead {
sys_socket: sys_socket.clone(),
syscall: self.syscall,
read_cancel_handle: self.read_cancel_handle,
},
TcpStreamWrite {
sys_socket,
syscall: self.syscall,
write_cancel_handle: self.write_cancel_handle,
},
)
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<io::Result<usize>> {
match self
.syscall
.tcp_stream_read(cx.waker().clone(), &self.sys_socket, buf)
{
rasi_syscall::CancelablePoll::Ready(r) => Poll::Ready(r),
rasi_syscall::CancelablePoll::Pending(read_cancel_handle) => {
self.read_cancel_handle = Some(read_cancel_handle);
Poll::Pending
}
}
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self
.syscall
.tcp_stream_write(cx.waker().clone(), &self.sys_socket, buf)
{
rasi_syscall::CancelablePoll::Ready(r) => Poll::Ready(r),
rasi_syscall::CancelablePoll::Pending(write_cancel_handle) => {
self.write_cancel_handle = Some(write_cancel_handle);
Poll::Pending
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
self.shutdown(std::net::Shutdown::Both)?;
Poll::Ready(Ok(()))
}
}
pub struct TcpStreamRead {
sys_socket: Arc<rasi_syscall::Handle>,
syscall: &'static dyn Network,
read_cancel_handle: Option<Handle>,
}
impl AsyncRead for TcpStreamRead {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<io::Result<usize>> {
match self
.syscall
.tcp_stream_read(cx.waker().clone(), &self.sys_socket, buf)
{
rasi_syscall::CancelablePoll::Ready(r) => Poll::Ready(r),
rasi_syscall::CancelablePoll::Pending(read_cancel_handle) => {
self.read_cancel_handle = Some(read_cancel_handle);
Poll::Pending
}
}
}
}
pub struct TcpStreamWrite {
sys_socket: Arc<rasi_syscall::Handle>,
syscall: &'static dyn Network,
write_cancel_handle: Option<Handle>,
}
impl AsyncWrite for TcpStreamWrite {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self
.syscall
.tcp_stream_write(cx.waker().clone(), &self.sys_socket, buf)
{
rasi_syscall::CancelablePoll::Ready(r) => Poll::Ready(r),
rasi_syscall::CancelablePoll::Pending(write_cancel_handle) => {
self.write_cancel_handle = Some(write_cancel_handle);
Poll::Pending
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
self.syscall
.tcp_stream_shutdown(&self.sys_socket, std::net::Shutdown::Both)?;
Poll::Ready(Ok(()))
}
}
pub struct UdpSocket {
sys_socket: rasi_syscall::Handle,
syscall: &'static dyn Network,
}
impl Display for UdpSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "UdpSocket({:?})", self.local_addr())
}
}
impl UdpSocket {
pub async fn bind<A: ToSocketAddrs>(laddrs: A) -> io::Result<Self> {
Self::bind_with(laddrs, global_network()).await
}
pub async fn bind_with<A: ToSocketAddrs>(
laddrs: A,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
let sys_socket =
cancelable_would_block(|cx| syscall.udp_bind(cx.waker().clone(), &laddrs)).await?;
Ok(UdpSocket {
sys_socket,
syscall,
})
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.syscall.udp_local_addr(&self.sys_socket)
}
pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> io::Result<usize> {
cancelable_would_block(|cx| {
self.syscall
.udp_send_to(cx.waker().clone(), &self.sys_socket, buf, addr)
})
.await
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
cancelable_would_block(|cx| {
self.syscall
.udp_recv_from(cx.waker().clone(), &self.sys_socket, buf)
})
.await
}
pub fn broadcast(&self) -> io::Result<bool> {
self.syscall.udp_broadcast(&self.sys_socket)
}
pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
self.syscall.udp_set_broadcast(&self.sys_socket, on)
}
pub fn ttl(&self) -> io::Result<u32> {
self.syscall.udp_ttl(&self.sys_socket)
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.syscall.udp_set_ttl(&self.sys_socket, ttl)
}
pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.syscall
.udp_join_multicast_v4(&self.sys_socket, multiaddr, interface)
}
pub fn join_multicast_v6(&self, multiaddr: Ipv6Addr, interface: u32) -> io::Result<()> {
self.syscall
.udp_join_multicast_v6(&self.sys_socket, &multiaddr, interface)
}
pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.syscall
.udp_leave_multicast_v4(&self.sys_socket, multiaddr, interface)
}
pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.syscall
.udp_leave_multicast_v6(&self.sys_socket, multiaddr, interface)
}
}
#[cfg(all(unix, feature = "unix_socket"))]
mod unix {
use std::{fmt::Debug, io, os::unix::net::SocketAddr, path::Path, task::Poll};
use futures::{AsyncRead, AsyncWrite};
use rasi_syscall::{global_network, Handle, Network};
use crate::utils::cancelable_would_block;
pub struct UnixListener {
sys_socket: rasi_syscall::Handle,
syscall: &'static dyn Network,
}
impl UnixListener {
pub async fn bind_with<P: AsRef<Path>>(
path: P,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let path = path.as_ref();
let sys_socket = cancelable_would_block(move |cx| {
syscall.unix_listener_bind(cx.waker().clone(), path)
})
.await?;
Ok(UnixListener {
sys_socket,
syscall,
})
}
pub async fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
Self::bind_with(path, global_network()).await
}
pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
let (sys_socket, raddr) = cancelable_would_block(|cx| {
self.syscall
.unix_listener_accept(cx.waker().clone(), &self.sys_socket)
})
.await?;
Ok((UnixStream::new(sys_socket, self.syscall), raddr))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.syscall.unix_listener_local_addr(&self.sys_socket)
}
}
pub struct UnixStream {
sys_socket: rasi_syscall::Handle,
syscall: &'static dyn Network,
write_cancel_handle: Option<Handle>,
read_cancel_handle: Option<Handle>,
}
impl Debug for UnixStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnixStream")
.field("socket", &self.sys_socket)
.field(
"syscall",
&format!("{:?}", self.syscall as *const dyn Network),
)
.finish()
}
}
impl UnixStream {
fn new(sys_socket: rasi_syscall::Handle, syscall: &'static dyn Network) -> Self {
Self {
sys_socket,
syscall,
write_cancel_handle: None,
read_cancel_handle: None,
}
}
pub async fn connect_with<P: AsRef<Path>>(
path: P,
syscall: &'static dyn Network,
) -> io::Result<Self> {
let path = path.as_ref();
let sys_socket =
cancelable_would_block(|cx| syscall.unix_stream_connect(cx.waker().clone(), path))
.await?;
Ok(Self::new(sys_socket, syscall))
}
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
Self::connect_with(path, global_network()).await
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.syscall.unix_stream_local_addr(&self.sys_socket)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.syscall.unix_stream_peer_addr(&self.sys_socket)
}
}
impl AsyncRead for UnixStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<io::Result<usize>> {
match self
.syscall
.unix_stream_read(cx.waker().clone(), &self.sys_socket, buf)
{
rasi_syscall::CancelablePoll::Ready(r) => Poll::Ready(r),
rasi_syscall::CancelablePoll::Pending(read_cancel_handle) => {
self.read_cancel_handle = Some(read_cancel_handle);
Poll::Pending
}
}
}
}
impl AsyncWrite for UnixStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self
.syscall
.unix_stream_write(cx.waker().clone(), &self.sys_socket, buf)
{
rasi_syscall::CancelablePoll::Ready(r) => Poll::Ready(r),
rasi_syscall::CancelablePoll::Pending(write_cancel_handle) => {
self.write_cancel_handle = Some(write_cancel_handle);
Poll::Pending
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
self.syscall
.unix_stream_shutdown(&self.sys_socket, std::net::Shutdown::Both)?;
Poll::Ready(Ok(()))
}
}
}
#[cfg(all(unix, feature = "unix_socket"))]
pub use unix::*;