use std::{
future::Future,
io,
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::Duration,
};
use virtual_mio::InterestHandler;
use virtual_net::{
NetworkError, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, VirtualTcpBoundSocket,
VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, net_error_into_io_err,
};
use wasmer_types::MemorySize;
use wasmer_wasix_types::wasi::{Addressfamily, Errno, Rights, SockProto, Sockoption, Socktype};
use crate::{VirtualTaskManager, net::net_error_into_wasi_err};
#[derive(Debug)]
pub enum InodeHttpSocketType {
Request,
Response,
Headers,
}
type TcpConnectFuture<'a> =
Pin<Box<dyn Future<Output = Result<Box<dyn VirtualTcpSocket + Sync>, Errno>> + 'a>>;
#[derive(Debug)]
pub struct SocketProperties {
pub family: Addressfamily,
pub ty: Socktype,
pub pt: SockProto,
pub only_v6: bool,
pub reuse_port: bool,
pub reuse_addr: bool,
pub no_delay: Option<bool>,
pub keep_alive: Option<bool>,
pub dont_route: Option<bool>,
pub send_buf_size: Option<usize>,
pub recv_buf_size: Option<usize>,
pub write_timeout: Option<Duration>,
pub read_timeout: Option<Duration>,
pub accept_timeout: Option<Duration>,
pub connect_timeout: Option<Duration>,
pub handler: Option<Box<dyn InterestHandler + Send + Sync>>,
}
impl SocketProperties {
fn snapshot_for_bound(&self) -> Self {
Self {
family: self.family,
ty: self.ty,
pt: self.pt,
only_v6: self.only_v6,
reuse_port: self.reuse_port,
reuse_addr: self.reuse_addr,
no_delay: self.no_delay,
keep_alive: self.keep_alive,
dont_route: self.dont_route,
send_buf_size: self.send_buf_size,
recv_buf_size: self.recv_buf_size,
write_timeout: self.write_timeout,
read_timeout: self.read_timeout,
accept_timeout: self.accept_timeout,
connect_timeout: self.connect_timeout,
handler: None,
}
}
}
#[derive(Debug)]
pub enum InodeSocketKind {
PreSocket {
props: SocketProperties,
addr: Option<SocketAddr>,
},
Icmp(Box<dyn VirtualIcmpSocket + Sync>),
Raw(Box<dyn VirtualRawSocket + Sync>),
TcpListener {
socket: Box<dyn VirtualTcpListener + Sync>,
accept_timeout: Option<Duration>,
},
TcpStream {
socket: Box<dyn VirtualTcpSocket + Sync>,
write_timeout: Option<Duration>,
read_timeout: Option<Duration>,
},
BoundTcp {
socket: Box<dyn VirtualTcpBoundSocket + Sync>,
props: SocketProperties,
},
UdpSocket {
socket: Box<dyn VirtualUdpSocket + Sync>,
peer: Option<SocketAddr>,
},
RemoteSocket {
props: SocketProperties,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ttl: u32,
multicast_ttl: u32,
is_dead: bool,
},
}
pub enum WasiSocketOption {
Noop,
ReusePort,
ReuseAddr,
NoDelay,
DontRoute,
OnlyV6,
Broadcast,
MulticastLoopV4,
MulticastLoopV6,
Promiscuous,
Listening,
LastError,
KeepAlive,
Linger,
OobInline,
RecvBufSize,
SendBufSize,
RecvLowat,
SendLowat,
RecvTimeout,
SendTimeout,
ConnectTimeout,
AcceptTimeout,
Ttl,
MulticastTtlV4,
Type,
Proto,
}
impl TryFrom<Sockoption> for WasiSocketOption {
type Error = Errno;
fn try_from(opt: Sockoption) -> Result<Self, Self::Error> {
use WasiSocketOption::*;
match opt {
Sockoption::Noop => Ok(Noop),
Sockoption::ReusePort => Ok(ReusePort),
Sockoption::ReuseAddr => Ok(ReuseAddr),
Sockoption::NoDelay => Ok(NoDelay),
Sockoption::DontRoute => Ok(DontRoute),
Sockoption::OnlyV6 => Ok(OnlyV6),
Sockoption::Broadcast => Ok(Broadcast),
Sockoption::MulticastLoopV4 => Ok(MulticastLoopV4),
Sockoption::MulticastLoopV6 => Ok(MulticastLoopV6),
Sockoption::Promiscuous => Ok(Promiscuous),
Sockoption::Listening => Ok(Listening),
Sockoption::LastError => Ok(LastError),
Sockoption::KeepAlive => Ok(KeepAlive),
Sockoption::Linger => Ok(Linger),
Sockoption::OobInline => Ok(OobInline),
Sockoption::RecvBufSize => Ok(RecvBufSize),
Sockoption::SendBufSize => Ok(SendBufSize),
Sockoption::RecvLowat => Ok(RecvLowat),
Sockoption::SendLowat => Ok(SendLowat),
Sockoption::RecvTimeout => Ok(RecvTimeout),
Sockoption::SendTimeout => Ok(SendTimeout),
Sockoption::ConnectTimeout => Ok(ConnectTimeout),
Sockoption::AcceptTimeout => Ok(AcceptTimeout),
Sockoption::Ttl => Ok(Ttl),
Sockoption::MulticastTtlV4 => Ok(MulticastTtlV4),
Sockoption::Type => Ok(Type),
Sockoption::Proto => Ok(Proto),
_ => Err(Errno::Inval),
}
}
}
#[derive(Debug)]
pub enum WasiSocketStatus {
Opening,
Opened,
Closed,
Failed,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum TimeType {
ReadTimeout,
WriteTimeout,
AcceptTimeout,
ConnectTimeout,
BindTimeout,
Linger,
}
impl From<TimeType> for wasmer_journal::SocketOptTimeType {
fn from(value: TimeType) -> Self {
match value {
TimeType::ReadTimeout => Self::ReadTimeout,
TimeType::WriteTimeout => Self::WriteTimeout,
TimeType::AcceptTimeout => Self::AcceptTimeout,
TimeType::ConnectTimeout => Self::ConnectTimeout,
TimeType::BindTimeout => Self::BindTimeout,
TimeType::Linger => Self::Linger,
}
}
}
impl From<wasmer_journal::SocketOptTimeType> for TimeType {
fn from(value: wasmer_journal::SocketOptTimeType) -> Self {
use wasmer_journal::SocketOptTimeType;
match value {
SocketOptTimeType::ReadTimeout => TimeType::ReadTimeout,
SocketOptTimeType::WriteTimeout => TimeType::WriteTimeout,
SocketOptTimeType::AcceptTimeout => TimeType::AcceptTimeout,
SocketOptTimeType::ConnectTimeout => TimeType::ConnectTimeout,
SocketOptTimeType::BindTimeout => TimeType::BindTimeout,
SocketOptTimeType::Linger => TimeType::Linger,
}
}
}
#[derive(Debug)]
pub(crate) struct InodeSocketProtected {
pub kind: InodeSocketKind,
}
#[derive(Debug)]
pub(crate) struct InodeSocketInner {
pub protected: RwLock<InodeSocketProtected>,
}
#[derive(Debug, Clone)]
pub struct InodeSocket {
pub(crate) inner: Arc<InodeSocketInner>,
}
impl InodeSocket {
pub fn new(kind: InodeSocketKind) -> Self {
let protected = InodeSocketProtected { kind };
Self {
inner: Arc::new(InodeSocketInner {
protected: RwLock::new(protected),
}),
}
}
pub fn poll_read_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let mut inner = self.inner.protected.write().unwrap();
inner.poll_read_ready(cx)
}
pub fn poll_write_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let mut inner = self.inner.protected.write().unwrap();
inner.poll_write_ready(cx)
}
pub async fn auto_bind_udp(
&self,
tasks: &dyn VirtualTaskManager,
net: &dyn VirtualNetworking,
) -> Result<Option<InodeSocket>, Errno> {
let timeout = self
.opt_time(TimeType::BindTimeout)
.ok()
.flatten()
.unwrap_or(Duration::from_secs(30));
let family = {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::PreSocket { props, .. } if props.ty == Socktype::Dgram => {
Some(props.family)
}
_ => None,
}
};
let Some(family) = family else {
return Ok(None);
};
let addr = match family {
Addressfamily::Inet4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
Addressfamily::Inet6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
_ => return Err(Errno::Notsup),
};
self.bind_internal(tasks, net, addr, timeout).await
}
pub async fn bind(
&self,
tasks: &dyn VirtualTaskManager,
net: &dyn VirtualNetworking,
set_addr: SocketAddr,
) -> Result<Option<InodeSocket>, Errno> {
let timeout = self
.opt_time(TimeType::BindTimeout)
.ok()
.flatten()
.unwrap_or(Duration::from_secs(30));
self.bind_internal(tasks, net, set_addr, timeout).await
}
async fn bind_internal(
&self,
tasks: &dyn VirtualTaskManager,
net: &dyn VirtualNetworking,
set_addr: SocketAddr,
timeout: Duration,
) -> Result<Option<InodeSocket>, Errno> {
enum PendingBind {
Tcp {
addr: SocketAddr,
props: SocketProperties,
},
Udp {
addr: SocketAddr,
reuse_port: bool,
reuse_addr: bool,
},
}
let bind = {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props, addr, .. } => {
match props.family {
Addressfamily::Inet4 => {
if !set_addr.is_ipv4() {
tracing::debug!(
"IP address is the wrong type IPv4 ({set_addr}) vs IPv6 family"
);
return Err(Errno::Inval);
}
}
Addressfamily::Inet6 => {
if !set_addr.is_ipv6() {
tracing::debug!(
"IP address is the wrong type IPv6 ({set_addr}) vs IPv4 family"
);
return Err(Errno::Inval);
}
}
_ => {
return Err(Errno::Notsup);
}
}
addr.replace(set_addr);
let addr = (*addr).unwrap();
match props.ty {
Socktype::Stream => PendingBind::Tcp {
addr,
props: props.snapshot_for_bound(),
},
Socktype::Dgram => PendingBind::Udp {
addr,
reuse_port: props.reuse_port,
reuse_addr: props.reuse_addr,
},
_ => return Err(Errno::Inval),
}
}
InodeSocketKind::RemoteSocket {
props,
local_addr: addr,
..
} => {
match props.family {
Addressfamily::Inet4 => {
if !set_addr.is_ipv4() {
tracing::debug!(
"IP address is the wrong type IPv4 ({set_addr}) vs IPv6 family"
);
return Err(Errno::Inval);
}
}
Addressfamily::Inet6 => {
if !set_addr.is_ipv6() {
tracing::debug!(
"IP address is the wrong type IPv6 ({set_addr}) vs IPv4 family"
);
return Err(Errno::Inval);
}
}
_ => {
return Err(Errno::Notsup);
}
}
*addr = set_addr;
let addr = *addr;
match props.ty {
Socktype::Stream => {
return Ok(None);
}
Socktype::Dgram => PendingBind::Udp {
addr,
reuse_port: props.reuse_port,
reuse_addr: props.reuse_addr,
},
_ => return Err(Errno::Inval),
}
}
InodeSocketKind::BoundTcp { .. } => return Err(Errno::Inval),
_ => return Err(Errno::Notsup),
}
};
match bind {
PendingBind::Tcp { addr, mut props } => {
tokio::select! {
socket = net.bind_tcp(
addr,
props.only_v6,
props.reuse_port,
props.reuse_addr,
) => {
match socket {
Ok(socket) => {
props.handler = {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props: live_props, .. } => {
live_props.handler.take()
}
_ => return Err(Errno::Inval),
}
};
Ok(Some(InodeSocket::new(InodeSocketKind::BoundTcp { socket, props })))
}
Err(NetworkError::Unsupported) => {
tracing::debug!(
"bind_tcp unsupported by networking backend; leaving socket unbound"
);
let mut inner = self.inner.protected.write().unwrap();
if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind {
addr.take();
}
Err(Errno::Notsup)
}
Err(err) => {
let mut inner = self.inner.protected.write().unwrap();
if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind {
addr.take();
}
Err(net_error_into_wasi_err(err))
}
}
},
_ = tasks.sleep_now(timeout) => {
let mut inner = self.inner.protected.write().unwrap();
if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind {
addr.take();
}
Err(Errno::Timedout)
}
}
}
PendingBind::Udp {
addr,
reuse_port,
reuse_addr,
} => {
tokio::select! {
socket = net.bind_udp(addr, reuse_port, reuse_addr) => {
match socket {
Ok(socket) => Ok(Some(InodeSocket::new(InodeSocketKind::UdpSocket {
socket,
peer: None,
}))),
Err(err) => {
let mut inner = self.inner.protected.write().unwrap();
if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind {
addr.take();
}
Err(net_error_into_wasi_err(err))
}
}
},
_ = tasks.sleep_now(timeout) => {
let mut inner = self.inner.protected.write().unwrap();
if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind {
addr.take();
}
Err(Errno::Timedout)
}
}
}
}
}
pub async fn listen(
&self,
tasks: &dyn VirtualTaskManager,
net: &dyn VirtualNetworking,
_backlog: usize,
) -> Result<Option<InodeSocket>, Errno> {
let timeout = self
.opt_time(TimeType::AcceptTimeout)
.ok()
.flatten()
.unwrap_or(Duration::from_secs(30));
let socket = {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props, addr, .. } => match props.ty {
Socktype::Stream => {
if addr.is_none() {
tracing::warn!("wasi[?]::sock_listen - failed - address not set");
return Err(Errno::Inval);
}
let addr = *addr.as_ref().unwrap();
let only_v6 = props.only_v6;
let reuse_port = props.reuse_port;
let reuse_addr = props.reuse_addr;
drop(inner);
net.listen_tcp(addr, only_v6, reuse_port, reuse_addr)
}
ty => {
tracing::warn!(
"wasi[?]::sock_listen - failed - not supported(pre-socket:{:?})",
ty
);
return Err(Errno::Notsup);
}
},
InodeSocketKind::RemoteSocket {
props,
local_addr: addr,
..
} => match props.ty {
Socktype::Stream => {
let addr = *addr;
let only_v6 = props.only_v6;
let reuse_port = props.reuse_port;
let reuse_addr = props.reuse_addr;
drop(inner);
net.listen_tcp(addr, only_v6, reuse_port, reuse_addr)
}
ty => {
tracing::warn!(
"wasi[?]::sock_listen - failed - not supported(remote-socket:{:?})",
ty
);
return Err(Errno::Notsup);
}
},
InodeSocketKind::BoundTcp { socket, .. } => {
return Ok(Some(InodeSocket::new(InodeSocketKind::TcpListener {
socket: socket.listen().map_err(net_error_into_wasi_err)?,
accept_timeout: Some(timeout),
})));
}
InodeSocketKind::Icmp(_) => {
tracing::warn!("wasi[?]::sock_listen - failed - not supported(icmp)");
return Err(Errno::Notsup);
}
InodeSocketKind::Raw(_) => {
tracing::warn!("wasi[?]::sock_listen - failed - not supported(raw)");
return Err(Errno::Notsup);
}
InodeSocketKind::TcpListener { .. } => {
tracing::warn!(
"wasi[?]::sock_listen - failed - already listening (tcp-listener)"
);
return Err(Errno::Notsup);
}
InodeSocketKind::TcpStream { .. } => {
tracing::warn!("wasi[?]::sock_listen - failed - not supported(tcp-stream)");
return Err(Errno::Notsup);
}
InodeSocketKind::UdpSocket { .. } => {
tracing::warn!("wasi[?]::sock_listen - failed - not supported(udp-socket)");
return Err(Errno::Notsup);
}
}
};
tokio::select! {
socket = socket => {
let socket = socket.map_err(net_error_into_wasi_err)?;
Ok(Some(InodeSocket::new(InodeSocketKind::TcpListener {
socket,
accept_timeout: Some(timeout),
})))
},
_ = tasks.sleep_now(timeout) => Err(Errno::Timedout)
}
}
pub async fn accept(
&self,
tasks: &dyn VirtualTaskManager,
nonblocking: bool,
timeout: Option<Duration>,
) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr), Errno> {
struct SocketAccepter<'a> {
sock: &'a InodeSocket,
nonblocking: bool,
handler_registered: bool,
}
impl Drop for SocketAccepter<'_> {
fn drop(&mut self) {
if self.handler_registered {
let mut inner = self.sock.inner.protected.write().unwrap();
inner.remove_handler();
}
}
}
impl Future for SocketAccepter<'_> {
type Output = Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr), Errno>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
loop {
let mut inner = self.sock.inner.protected.write().unwrap();
return match &mut inner.kind {
InodeSocketKind::TcpListener { socket, .. } => match socket.try_accept() {
Ok((child, addr)) => Poll::Ready(Ok((child, addr))),
Err(NetworkError::WouldBlock) if self.nonblocking => {
Poll::Ready(Err(Errno::Again))
}
Err(NetworkError::WouldBlock) if !self.handler_registered => {
let res = socket.set_handler(cx.waker().into());
if let Err(err) = res {
return Poll::Ready(Err(net_error_into_wasi_err(err)));
}
drop(inner);
self.handler_registered = true;
continue;
}
Err(NetworkError::WouldBlock) => Poll::Pending,
Err(err) => Poll::Ready(Err(net_error_into_wasi_err(err))),
},
InodeSocketKind::PreSocket { .. } => Poll::Ready(Err(Errno::Notconn)),
_ => Poll::Ready(Err(Errno::Notsup)),
};
}
}
}
let acceptor = SocketAccepter {
sock: self,
nonblocking,
handler_registered: false,
};
if let Some(timeout) = timeout {
tokio::select! {
res = acceptor => res,
_ = tasks.sleep_now(timeout) => Err(Errno::Timedout)
}
} else {
acceptor.await
}
}
pub fn close(&self) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::TcpListener { .. } => {}
InodeSocketKind::TcpStream { socket, .. } => {
socket.close().map_err(net_error_into_wasi_err)?;
}
InodeSocketKind::BoundTcp { .. } => {}
InodeSocketKind::Icmp(_) => {}
InodeSocketKind::UdpSocket { .. } => {}
InodeSocketKind::Raw(_) => {}
InodeSocketKind::PreSocket { .. } => return Err(Errno::Notconn),
InodeSocketKind::RemoteSocket { .. } => {}
};
Ok(())
}
pub async fn connect(
&mut self,
tasks: &dyn VirtualTaskManager,
net: &dyn VirtualNetworking,
peer: SocketAddr,
timeout: Option<std::time::Duration>,
nonblocking: bool,
) -> Result<Option<InodeSocket>, Errno> {
let new_write_timeout;
let new_read_timeout;
let timeout = timeout
.or_else(|| self.opt_time(TimeType::ConnectTimeout).ok().flatten())
.unwrap_or(Duration::from_secs(30));
let handler;
let connect: TcpConnectFuture<'_> = {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props, addr, .. } => {
handler = props.handler.take();
new_write_timeout = props.write_timeout;
new_read_timeout = props.read_timeout;
match props.ty {
Socktype::Stream => {
let no_delay = props.no_delay;
let keep_alive = props.keep_alive;
let dont_route = props.dont_route;
let addr = match addr {
Some(a) => *a,
None => {
let ip = match peer.is_ipv4() {
true => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
false => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
};
SocketAddr::new(ip, 0)
}
};
Box::pin(async move {
let mut ret = net
.connect_tcp(addr, peer)
.await
.map_err(net_error_into_wasi_err)?;
if let Some(no_delay) = no_delay {
ret.set_nodelay(no_delay).ok();
}
if let Some(keep_alive) = keep_alive {
ret.set_keepalive(keep_alive).ok();
}
if let Some(dont_route) = dont_route {
ret.set_dontroute(dont_route).ok();
}
if !nonblocking {
futures::future::poll_fn(|cx| ret.poll_write_ready(cx))
.await
.map_err(net_error_into_wasi_err)?;
}
Ok(ret)
})
}
Socktype::Dgram => return Err(Errno::Inval),
_ => return Err(Errno::Notsup),
}
}
InodeSocketKind::BoundTcp { socket, props } => {
handler = props.handler.take();
new_write_timeout = props.write_timeout;
new_read_timeout = props.read_timeout;
match props.ty {
Socktype::Stream => {
let no_delay = props.no_delay;
let keep_alive = props.keep_alive;
let dont_route = props.dont_route;
let mut ret = socket.connect(peer).map_err(net_error_into_wasi_err)?;
if let Some(no_delay) = no_delay {
ret.set_nodelay(no_delay).ok();
}
if let Some(keep_alive) = keep_alive {
ret.set_keepalive(keep_alive).ok();
}
if let Some(dont_route) = dont_route {
ret.set_dontroute(dont_route).ok();
}
Box::pin(async move {
if !nonblocking {
futures::future::poll_fn(|cx| ret.poll_write_ready(cx))
.await
.map_err(net_error_into_wasi_err)?;
}
Ok(ret)
})
}
Socktype::Dgram => return Err(Errno::Inval),
_ => return Err(Errno::Notsup),
}
}
InodeSocketKind::UdpSocket {
peer: target_peer, ..
} => {
target_peer.replace(peer);
return Ok(None);
}
InodeSocketKind::RemoteSocket { peer_addr, .. } => {
*peer_addr = peer;
return Ok(None);
}
_ => return Err(Errno::Notsup),
}
};
let mut socket = tokio::select! {
res = connect => res?,
_ = tasks.sleep_now(timeout) => return Err(Errno::Timedout)
};
if let Some(handler) = handler {
socket
.set_handler(handler)
.map_err(net_error_into_wasi_err)?;
}
let socket = InodeSocket::new(InodeSocketKind::TcpStream {
socket,
write_timeout: new_write_timeout,
read_timeout: new_read_timeout,
});
Ok(Some(socket))
}
pub fn status(&self) -> Result<WasiSocketStatus, Errno> {
let inner = self.inner.protected.read().unwrap();
Ok(match &inner.kind {
InodeSocketKind::PreSocket { .. } => WasiSocketStatus::Opening,
InodeSocketKind::BoundTcp { .. } => WasiSocketStatus::Opened,
InodeSocketKind::TcpListener { .. } => WasiSocketStatus::Opened,
InodeSocketKind::TcpStream { socket, .. } => match socket.status() {
Ok(virtual_net::SocketStatus::Opening) => WasiSocketStatus::Opening,
Ok(virtual_net::SocketStatus::Opened) => WasiSocketStatus::Opened,
Ok(virtual_net::SocketStatus::Closed) => WasiSocketStatus::Closed,
Ok(virtual_net::SocketStatus::Failed) => WasiSocketStatus::Failed,
Err(_) => WasiSocketStatus::Failed,
},
InodeSocketKind::UdpSocket { .. } => WasiSocketStatus::Opened,
InodeSocketKind::RemoteSocket { is_dead, .. } => match is_dead {
true => WasiSocketStatus::Closed,
false => WasiSocketStatus::Opened,
},
_ => WasiSocketStatus::Failed,
})
}
pub fn last_error(&self) -> Result<Errno, Errno> {
self.status()?;
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::TcpStream { socket, .. } => socket
.last_error()
.map_err(net_error_into_wasi_err)
.map(|err| err.map(net_error_into_wasi_err).unwrap_or(Errno::Success)),
InodeSocketKind::RemoteSocket { is_dead, .. } if *is_dead => Ok(Errno::Connreset),
_ => Ok(Errno::Success),
}
}
pub fn addr_local(&self) -> Result<SocketAddr, Errno> {
let inner = self.inner.protected.read().unwrap();
Ok(match &inner.kind {
InodeSocketKind::PreSocket { props, addr, .. } => {
if let Some(addr) = addr {
*addr
} else {
SocketAddr::new(
match props.family {
Addressfamily::Inet4 => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Addressfamily::Inet6 => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
_ => return Err(Errno::Inval),
},
0,
)
}
}
InodeSocketKind::Icmp(sock) => sock.addr_local().map_err(net_error_into_wasi_err)?,
InodeSocketKind::TcpListener { socket, .. } => {
socket.addr_local().map_err(net_error_into_wasi_err)?
}
InodeSocketKind::TcpStream { socket, .. } => {
socket.addr_local().map_err(net_error_into_wasi_err)?
}
InodeSocketKind::BoundTcp { socket, .. } => {
socket.addr_local().map_err(net_error_into_wasi_err)?
}
InodeSocketKind::UdpSocket { socket, .. } => {
socket.addr_local().map_err(net_error_into_wasi_err)?
}
InodeSocketKind::RemoteSocket {
local_addr: addr, ..
} => *addr,
_ => return Err(Errno::Notsup),
})
}
pub fn addr_peer(&self) -> Result<SocketAddr, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::BoundTcp { .. } => Err(Errno::Notconn),
InodeSocketKind::PreSocket { props, .. } => Ok(SocketAddr::new(
match props.family {
Addressfamily::Inet4 => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Addressfamily::Inet6 => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
_ => return Err(Errno::Inval),
},
0,
)),
InodeSocketKind::TcpStream { socket, .. } => {
socket.addr_peer().map_err(net_error_into_wasi_err)
}
InodeSocketKind::UdpSocket {
peer: Some(peer), ..
} => Ok(*peer),
InodeSocketKind::UdpSocket { socket, .. } => socket
.addr_peer()
.map_err(net_error_into_wasi_err)?
.ok_or(Errno::Notconn),
InodeSocketKind::RemoteSocket { peer_addr, .. } => Ok(*peer_addr),
_ => Err(Errno::Notsup),
}
}
pub fn set_opt_flag(&mut self, option: WasiSocketOption, val: bool) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
match option {
WasiSocketOption::OnlyV6 => props.only_v6 = val,
WasiSocketOption::ReusePort => props.reuse_port = val,
WasiSocketOption::ReuseAddr => props.reuse_addr = val,
WasiSocketOption::NoDelay => props.no_delay = Some(val),
WasiSocketOption::KeepAlive => props.keep_alive = Some(val),
WasiSocketOption::DontRoute => props.dont_route = Some(val),
_ => return Err(Errno::Inval),
};
}
InodeSocketKind::Raw(sock) => match option {
WasiSocketOption::Promiscuous => {
sock.set_promiscuous(val).map_err(net_error_into_wasi_err)?
}
_ => return Err(Errno::Inval),
},
InodeSocketKind::TcpStream { socket, .. } => match option {
WasiSocketOption::NoDelay => {
socket.set_nodelay(val).map_err(net_error_into_wasi_err)?
}
WasiSocketOption::KeepAlive => {
socket.set_keepalive(val).map_err(net_error_into_wasi_err)?
}
WasiSocketOption::DontRoute => {
socket.set_dontroute(val).map_err(net_error_into_wasi_err)?
}
_ => return Err(Errno::Inval),
},
InodeSocketKind::TcpListener { .. } => return Err(Errno::Inval),
InodeSocketKind::UdpSocket { socket, .. } => match option {
WasiSocketOption::Broadcast => {
socket.set_broadcast(val).map_err(net_error_into_wasi_err)?
}
WasiSocketOption::MulticastLoopV4 => socket
.set_multicast_loop_v4(val)
.map_err(net_error_into_wasi_err)?,
WasiSocketOption::MulticastLoopV6 => socket
.set_multicast_loop_v6(val)
.map_err(net_error_into_wasi_err)?,
_ => return Err(Errno::Inval),
},
_ => return Err(Errno::Notsup),
}
Ok(())
}
pub fn get_opt_flag(&self, option: WasiSocketOption) -> Result<bool, Errno> {
let mut inner = self.inner.protected.write().unwrap();
Ok(match &mut inner.kind {
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => match option {
WasiSocketOption::OnlyV6 => props.only_v6,
WasiSocketOption::ReusePort => props.reuse_port,
WasiSocketOption::ReuseAddr => props.reuse_addr,
WasiSocketOption::NoDelay => props.no_delay.unwrap_or_default(),
WasiSocketOption::KeepAlive => props.keep_alive.unwrap_or_default(),
_ => return Err(Errno::Inval),
},
InodeSocketKind::Raw(sock) => match option {
WasiSocketOption::Promiscuous => {
sock.promiscuous().map_err(net_error_into_wasi_err)?
}
_ => return Err(Errno::Inval),
},
InodeSocketKind::TcpStream { socket, .. } => match option {
WasiSocketOption::NoDelay => socket.nodelay().map_err(net_error_into_wasi_err)?,
WasiSocketOption::KeepAlive => {
socket.keepalive().map_err(net_error_into_wasi_err)?
}
WasiSocketOption::DontRoute => {
socket.dontroute().map_err(net_error_into_wasi_err)?
}
_ => return Err(Errno::Inval),
},
InodeSocketKind::UdpSocket { socket, .. } => match option {
WasiSocketOption::Broadcast => {
socket.broadcast().map_err(net_error_into_wasi_err)?
}
WasiSocketOption::MulticastLoopV4 => socket
.multicast_loop_v4()
.map_err(net_error_into_wasi_err)?,
WasiSocketOption::MulticastLoopV6 => socket
.multicast_loop_v6()
.map_err(net_error_into_wasi_err)?,
_ => return Err(Errno::Inval),
},
_ => return Err(Errno::Notsup),
})
}
pub fn set_send_buf_size(&mut self, size: usize) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
props.send_buf_size = Some(size);
}
InodeSocketKind::TcpStream { socket, .. } => {
socket
.set_send_buf_size(size)
.map_err(net_error_into_wasi_err)?;
}
_ => return Err(Errno::Notsup),
}
Ok(())
}
pub fn send_buf_size(&self) -> Result<usize, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
Ok(props.send_buf_size.unwrap_or_default())
}
InodeSocketKind::TcpStream { socket, .. } => {
socket.send_buf_size().map_err(net_error_into_wasi_err)
}
_ => Err(Errno::Notsup),
}
}
pub fn set_recv_buf_size(&mut self, size: usize) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
props.recv_buf_size = Some(size);
}
InodeSocketKind::TcpStream { socket, .. } => {
socket
.set_recv_buf_size(size)
.map_err(net_error_into_wasi_err)?;
}
_ => return Err(Errno::Notsup),
}
Ok(())
}
pub fn recv_buf_size(&self) -> Result<usize, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
Ok(props.recv_buf_size.unwrap_or_default())
}
InodeSocketKind::TcpStream { socket, .. } => {
socket.recv_buf_size().map_err(net_error_into_wasi_err)
}
_ => Err(Errno::Notsup),
}
}
pub fn set_linger(&mut self, linger: Option<std::time::Duration>) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::TcpStream { socket, .. } => {
socket.set_linger(linger).map_err(net_error_into_wasi_err)
}
InodeSocketKind::RemoteSocket { .. } => Ok(()),
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn linger(&self) -> Result<Option<std::time::Duration>, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::TcpStream { socket, .. } => {
socket.linger().map_err(net_error_into_wasi_err)
}
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn set_opt_time(
&self,
ty: TimeType,
timeout: Option<std::time::Duration>,
) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::TcpStream {
write_timeout,
read_timeout,
..
} => {
match ty {
TimeType::WriteTimeout => *write_timeout = timeout,
TimeType::ReadTimeout => *read_timeout = timeout,
_ => return Err(Errno::Inval),
}
Ok(())
}
InodeSocketKind::TcpListener { accept_timeout, .. } => {
match ty {
TimeType::AcceptTimeout => *accept_timeout = timeout,
_ => return Err(Errno::Inval),
}
Ok(())
}
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
match ty {
TimeType::ConnectTimeout => props.connect_timeout = timeout,
TimeType::AcceptTimeout => props.accept_timeout = timeout,
TimeType::ReadTimeout => props.read_timeout = timeout,
TimeType::WriteTimeout => props.write_timeout = timeout,
_ => return Err(Errno::Io),
}
Ok(())
}
_ => Err(Errno::Notsup),
}
}
pub fn opt_time(&self, ty: TimeType) -> Result<Option<std::time::Duration>, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::TcpStream {
read_timeout,
write_timeout,
..
} => Ok(match ty {
TimeType::ReadTimeout => *read_timeout,
TimeType::WriteTimeout => *write_timeout,
_ => return Err(Errno::Inval),
}),
InodeSocketKind::TcpListener { accept_timeout, .. } => Ok(match ty {
TimeType::AcceptTimeout => *accept_timeout,
_ => return Err(Errno::Inval),
}),
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => match ty {
TimeType::ConnectTimeout => Ok(props.connect_timeout),
TimeType::AcceptTimeout => Ok(props.accept_timeout),
TimeType::ReadTimeout => Ok(props.read_timeout),
TimeType::WriteTimeout => Ok(props.write_timeout),
_ => Err(Errno::Inval),
},
_ => Err(Errno::Notsup),
}
}
pub fn set_ttl(&self, ttl: u32) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::BoundTcp { socket, .. } => {
socket.set_ttl(ttl).map_err(net_error_into_wasi_err)
}
InodeSocketKind::TcpStream { socket, .. } => {
socket.set_ttl(ttl).map_err(net_error_into_wasi_err)
}
InodeSocketKind::UdpSocket { socket, .. } => {
socket.set_ttl(ttl).map_err(net_error_into_wasi_err)
}
InodeSocketKind::RemoteSocket { ttl: set_ttl, .. } => {
*set_ttl = ttl;
Ok(())
}
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn ttl(&self) -> Result<u32, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::BoundTcp { socket, .. } => {
socket.ttl().map_err(net_error_into_wasi_err)
}
InodeSocketKind::TcpStream { socket, .. } => {
socket.ttl().map_err(net_error_into_wasi_err)
}
InodeSocketKind::UdpSocket { socket, .. } => {
socket.ttl().map_err(net_error_into_wasi_err)
}
InodeSocketKind::RemoteSocket { ttl, .. } => Ok(*ttl),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn set_multicast_ttl_v4(&self, ttl: u32) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::UdpSocket { socket, .. } => socket
.set_multicast_ttl_v4(ttl)
.map_err(net_error_into_wasi_err),
InodeSocketKind::RemoteSocket {
multicast_ttl: set_ttl,
..
} => {
*set_ttl = ttl;
Ok(())
}
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn multicast_ttl_v4(&self) -> Result<u32, Errno> {
let inner = self.inner.protected.read().unwrap();
match &inner.kind {
InodeSocketKind::UdpSocket { socket, .. } => {
socket.multicast_ttl_v4().map_err(net_error_into_wasi_err)
}
InodeSocketKind::RemoteSocket { multicast_ttl, .. } => Ok(*multicast_ttl),
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::UdpSocket { socket, .. } => socket
.join_multicast_v4(multiaddr, iface)
.map_err(net_error_into_wasi_err),
InodeSocketKind::RemoteSocket { .. } => Ok(()),
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::UdpSocket { socket, .. } => socket
.leave_multicast_v4(multiaddr, iface)
.map_err(net_error_into_wasi_err),
InodeSocketKind::RemoteSocket { .. } => Ok(()),
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn join_multicast_v6(&self, multiaddr: Ipv6Addr, iface: u32) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::UdpSocket { socket, .. } => socket
.join_multicast_v6(multiaddr, iface)
.map_err(net_error_into_wasi_err),
InodeSocketKind::RemoteSocket { .. } => Ok(()),
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::UdpSocket { socket, .. } => socket
.leave_multicast_v6(multiaddr, iface)
.map_err(net_error_into_wasi_err),
InodeSocketKind::RemoteSocket { .. } => Ok(()),
InodeSocketKind::BoundTcp { .. } => Err(Errno::Io),
InodeSocketKind::PreSocket { .. } => Err(Errno::Io),
_ => Err(Errno::Notsup),
}
}
pub async fn send(
&self,
tasks: &dyn VirtualTaskManager,
buf: &[u8],
timeout: Option<Duration>,
nonblocking: bool,
) -> Result<usize, Errno> {
struct SocketSender<'a, 'b> {
inner: &'a InodeSocketInner,
data: &'b [u8],
nonblocking: bool,
handler_registered: bool,
}
impl Drop for SocketSender<'_, '_> {
fn drop(&mut self) {
if self.handler_registered {
let mut inner = self.inner.protected.write().unwrap();
inner.remove_handler();
}
}
}
impl Future for SocketSender<'_, '_> {
type Output = Result<usize, Errno>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
loop {
let mut inner = self.inner.protected.write().unwrap();
let res = match &mut inner.kind {
InodeSocketKind::Raw(socket) => socket.try_send(self.data),
InodeSocketKind::TcpStream { socket, .. } => socket.try_send(self.data),
InodeSocketKind::UdpSocket { socket, peer } => {
if let Some(peer) = peer {
socket.try_send_to(self.data, *peer)
} else {
Err(NetworkError::NotConnected)
}
}
InodeSocketKind::PreSocket { .. } => {
return Poll::Ready(Err(Errno::Notconn));
}
InodeSocketKind::RemoteSocket { is_dead, .. } => {
return match is_dead {
true => Poll::Ready(Err(Errno::Connreset)),
false => Poll::Ready(Ok(self.data.len())),
};
}
_ => return Poll::Ready(Err(Errno::Notsup)),
};
return match res {
Ok(amt) => Poll::Ready(Ok(amt)),
Err(NetworkError::WouldBlock) if self.nonblocking => {
Poll::Ready(Err(Errno::Again))
}
Err(NetworkError::WouldBlock) if !self.handler_registered => {
inner
.set_handler(cx.waker().into())
.map_err(net_error_into_wasi_err)?;
drop(inner);
self.handler_registered = true;
continue;
}
Err(NetworkError::WouldBlock) => Poll::Pending,
Err(err) => Poll::Ready(Err(net_error_into_wasi_err(err))),
};
}
}
}
let poller = SocketSender {
inner: &self.inner,
data: buf,
nonblocking,
handler_registered: false,
};
if let Some(timeout) = timeout {
tokio::select! {
res = poller => res,
_ = tasks.sleep_now(timeout) => Err(Errno::Timedout)
}
} else {
poller.await
}
}
pub async fn send_to<M: MemorySize>(
&self,
tasks: &dyn VirtualTaskManager,
buf: &[u8],
addr: SocketAddr,
timeout: Option<Duration>,
nonblocking: bool,
) -> Result<usize, Errno> {
struct SocketSender<'a, 'b> {
inner: &'a InodeSocketInner,
data: &'b [u8],
addr: SocketAddr,
nonblocking: bool,
handler_registered: bool,
}
impl Drop for SocketSender<'_, '_> {
fn drop(&mut self) {
if self.handler_registered {
let mut inner = self.inner.protected.write().unwrap();
inner.remove_handler();
}
}
}
impl Future for SocketSender<'_, '_> {
type Output = Result<usize, Errno>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
loop {
let mut inner = self.inner.protected.write().unwrap();
let res = match &mut inner.kind {
InodeSocketKind::Icmp(socket) => socket.try_send_to(self.data, self.addr),
InodeSocketKind::TcpStream { socket, .. } => socket.try_send(self.data),
InodeSocketKind::UdpSocket { socket, .. } => {
socket.try_send_to(self.data, self.addr)
}
InodeSocketKind::PreSocket { .. } => {
return Poll::Ready(Err(Errno::Notconn));
}
InodeSocketKind::RemoteSocket { is_dead, .. } => {
return match is_dead {
true => Poll::Ready(Err(Errno::Connreset)),
false => Poll::Ready(Ok(self.data.len())),
};
}
_ => return Poll::Ready(Err(Errno::Notsup)),
};
return match res {
Ok(amt) => Poll::Ready(Ok(amt)),
Err(NetworkError::WouldBlock) if self.nonblocking => {
Poll::Ready(Err(Errno::Again))
}
Err(NetworkError::WouldBlock) if !self.handler_registered => {
inner
.set_handler(cx.waker().into())
.map_err(net_error_into_wasi_err)?;
self.handler_registered = true;
drop(inner);
continue;
}
Err(NetworkError::WouldBlock) => Poll::Pending,
Err(err) => Poll::Ready(Err(net_error_into_wasi_err(err))),
};
}
}
}
let poller = SocketSender {
inner: &self.inner,
data: buf,
addr,
nonblocking,
handler_registered: false,
};
if let Some(timeout) = timeout {
tokio::select! {
res = poller => res,
_ = tasks.sleep_now(timeout) => Err(Errno::Timedout)
}
} else {
poller.await
}
}
pub async fn recv(
&self,
tasks: &dyn VirtualTaskManager,
buf: &mut [MaybeUninit<u8>],
timeout: Option<Duration>,
nonblocking: bool,
peek: bool,
) -> Result<usize, Errno> {
struct SocketReceiver<'a, 'b> {
inner: &'a InodeSocketInner,
data: &'b mut [MaybeUninit<u8>],
nonblocking: bool,
peek: bool,
handler_registered: bool,
}
impl Drop for SocketReceiver<'_, '_> {
fn drop(&mut self) {
if self.handler_registered {
let mut inner = self.inner.protected.write().unwrap();
inner.remove_handler();
}
}
}
impl Future for SocketReceiver<'_, '_> {
type Output = Result<usize, Errno>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
loop {
let peek = self.peek;
let mut inner = self.inner.protected.write().unwrap();
let res = match &mut inner.kind {
InodeSocketKind::Raw(socket) => socket.try_recv(self.data, peek),
InodeSocketKind::TcpStream { socket, .. } => {
socket.try_recv(self.data, peek)
}
InodeSocketKind::UdpSocket { socket, peer } => match peer {
Some(peer) => {
try_recv_from_connected_udp(socket.as_mut(), self.data, peek, peer)
.map(|(amt, _)| amt)
}
None => socket.try_recv_from(self.data, peek).map(|(amt, _)| amt),
},
InodeSocketKind::RemoteSocket { is_dead, .. } => {
return match is_dead {
true => Poll::Ready(Ok(0)),
false => Poll::Pending,
};
}
InodeSocketKind::PreSocket { .. } => {
return Poll::Ready(Err(Errno::Notconn));
}
_ => return Poll::Ready(Err(Errno::Notsup)),
};
return match res {
Ok(amt) => Poll::Ready(Ok(amt)),
Err(NetworkError::WouldBlock) if self.nonblocking => {
Poll::Ready(Err(Errno::Again))
}
Err(NetworkError::WouldBlock) if !self.handler_registered => {
inner
.set_handler(cx.waker().into())
.map_err(net_error_into_wasi_err)?;
self.handler_registered = true;
drop(inner);
continue;
}
Err(NetworkError::WouldBlock) => Poll::Pending,
Err(err) => Poll::Ready(Err(net_error_into_wasi_err(err))),
};
}
}
}
let poller = SocketReceiver {
inner: &self.inner,
data: buf,
nonblocking,
peek,
handler_registered: false,
};
if let Some(timeout) = timeout {
tokio::select! {
res = poller => res,
_ = tasks.sleep_now(timeout) => Err(Errno::Timedout)
}
} else {
poller.await
}
}
pub async fn recv_from(
&self,
tasks: &dyn VirtualTaskManager,
buf: &mut [MaybeUninit<u8>],
timeout: Option<Duration>,
nonblocking: bool,
peek: bool,
) -> Result<(usize, SocketAddr), Errno> {
struct SocketReceiver<'a, 'b> {
inner: &'a InodeSocketInner,
data: &'b mut [MaybeUninit<u8>],
nonblocking: bool,
peek: bool,
handler_registered: bool,
}
impl Drop for SocketReceiver<'_, '_> {
fn drop(&mut self) {
if self.handler_registered {
let mut inner = self.inner.protected.write().unwrap();
inner.remove_handler();
}
}
}
impl Future for SocketReceiver<'_, '_> {
type Output = Result<(usize, SocketAddr), Errno>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
let peek = self.peek;
let mut inner = self.inner.protected.write().unwrap();
loop {
let res = match &mut inner.kind {
InodeSocketKind::Icmp(socket) => socket.try_recv_from(self.data, peek),
InodeSocketKind::UdpSocket { socket, peer } => match peer {
Some(peer) => {
try_recv_from_connected_udp(socket.as_mut(), self.data, peek, peer)
}
None => socket.try_recv_from(self.data, peek),
},
InodeSocketKind::RemoteSocket {
is_dead, peer_addr, ..
} => {
return match is_dead {
true => Poll::Ready(Ok((0, *peer_addr))),
false => Poll::Pending,
};
}
InodeSocketKind::PreSocket { .. } => {
return Poll::Ready(Err(Errno::Notconn));
}
_ => return Poll::Ready(Err(Errno::Notsup)),
};
return match res {
Ok((amt, addr)) => Poll::Ready(Ok((amt, addr))),
Err(NetworkError::WouldBlock) if self.nonblocking => {
Poll::Ready(Err(Errno::Again))
}
Err(NetworkError::WouldBlock) if !self.handler_registered => {
inner
.set_handler(cx.waker().into())
.map_err(net_error_into_wasi_err)?;
self.handler_registered = true;
continue;
}
Err(NetworkError::WouldBlock) => Poll::Pending,
Err(err) => Poll::Ready(Err(net_error_into_wasi_err(err))),
};
}
}
}
let poller = SocketReceiver {
inner: &self.inner,
data: buf,
nonblocking,
peek,
handler_registered: false,
};
if let Some(timeout) = timeout {
tokio::select! {
res = poller => res,
_ = tasks.sleep_now(timeout) => Err(Errno::Timedout)
}
} else {
poller.await
}
}
pub fn shutdown(&mut self, how: std::net::Shutdown) -> Result<(), Errno> {
let mut inner = self.inner.protected.write().unwrap();
match &mut inner.kind {
InodeSocketKind::TcpStream { socket, .. } => {
socket.shutdown(how).map_err(net_error_into_wasi_err)?;
}
InodeSocketKind::RemoteSocket { .. } => return Ok(()),
InodeSocketKind::BoundTcp { .. } => return Err(Errno::Notconn),
InodeSocketKind::PreSocket { .. } => return Err(Errno::Notconn),
_ => return Err(Errno::Notsup),
}
Ok(())
}
pub async fn can_write(&self) -> bool {
if let Ok(mut guard) = self.inner.protected.try_write() {
#[allow(clippy::match_like_matches_macro)]
match &mut guard.kind {
InodeSocketKind::TcpStream { .. }
| InodeSocketKind::BoundTcp { .. }
| InodeSocketKind::UdpSocket { .. }
| InodeSocketKind::Raw(..) => true,
InodeSocketKind::RemoteSocket { is_dead, .. } => !(*is_dead),
_ => false,
}
} else {
false
}
}
pub fn is_dgram(&self) -> bool {
let guard = self.inner.protected.read().unwrap();
match &guard.kind {
InodeSocketKind::UdpSocket { .. } => true,
InodeSocketKind::RemoteSocket { props, .. } => props.ty == Socktype::Dgram,
_ => false,
}
}
}
fn discard_non_matching_udp_datagrams(
socket: &mut dyn VirtualUdpSocket,
peer: &SocketAddr,
) -> Result<(), NetworkError> {
let mut discard = [MaybeUninit::<u8>::uninit()];
loop {
match socket.try_recv_from(&mut discard, true) {
Ok((_, addr)) if addr == *peer => return Ok(()),
Ok(_) => match socket.try_recv_from(&mut discard, false) {
Ok(_) => {}
Err(NetworkError::WouldBlock) => {}
Err(err) => return Err(err),
},
Err(NetworkError::WouldBlock) => return Ok(()),
Err(err) => return Err(err),
}
}
}
fn try_recv_from_connected_udp(
socket: &mut dyn VirtualUdpSocket,
data: &mut [MaybeUninit<u8>],
peek: bool,
peer: &SocketAddr,
) -> Result<(usize, SocketAddr), NetworkError> {
discard_non_matching_udp_datagrams(socket, peer)?;
loop {
match socket.try_recv_from(data, peek) {
Ok((amt, addr)) if addr == *peer => return Ok((amt, addr)),
Ok(_) if peek => match socket.try_recv_from(data, false) {
Ok(_) => {}
Err(NetworkError::WouldBlock) => {}
Err(err) => return Err(err),
},
Ok(_) => continue,
Err(err) => return Err(err),
}
}
}
impl InodeSocketProtected {
pub fn remove_handler(&mut self) {
match &mut self.kind {
InodeSocketKind::TcpListener { socket, .. } => socket.remove_handler(),
InodeSocketKind::TcpStream { socket, .. } => socket.remove_handler(),
InodeSocketKind::UdpSocket { socket, .. } => socket.remove_handler(),
InodeSocketKind::Raw(socket) => socket.remove_handler(),
InodeSocketKind::Icmp(socket) => socket.remove_handler(),
InodeSocketKind::PreSocket { props, .. } => {
props.handler.take();
}
InodeSocketKind::BoundTcp { props, .. } => {
props.handler.take();
}
InodeSocketKind::RemoteSocket { props, .. } => {
props.handler.take();
}
}
}
pub fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
match &mut self.kind {
InodeSocketKind::TcpListener { socket, .. } => socket.poll_read_ready(cx),
InodeSocketKind::TcpStream { socket, .. } => socket.poll_read_ready(cx),
InodeSocketKind::UdpSocket {
socket,
peer: Some(peer),
} => loop {
if let Err(err) = discard_non_matching_udp_datagrams(socket.as_mut(), peer) {
break Poll::Ready(Err(err));
}
match socket.poll_read_ready(cx) {
Poll::Pending => break Poll::Pending,
Poll::Ready(Err(err)) => break Poll::Ready(Err(err)),
Poll::Ready(Ok(n)) => {
let mut peek_buf = [MaybeUninit::<u8>::uninit()];
match socket.try_recv_from(&mut peek_buf, true) {
Ok((_, addr)) if addr == *peer => break Poll::Ready(Ok(n)),
Ok(_) => continue,
Err(NetworkError::WouldBlock) => break Poll::Pending,
Err(err) => break Poll::Ready(Err(err)),
}
}
}
},
InodeSocketKind::UdpSocket { socket, peer: None } => socket.poll_read_ready(cx),
InodeSocketKind::Raw(socket) => socket.poll_read_ready(cx),
InodeSocketKind::Icmp(socket) => socket.poll_read_ready(cx),
InodeSocketKind::BoundTcp { .. } => Poll::Pending,
InodeSocketKind::PreSocket { .. } => Poll::Pending,
InodeSocketKind::RemoteSocket { is_dead, .. } => match is_dead {
true => Poll::Ready(Ok(0)),
false => Poll::Pending,
},
}
.map_err(net_error_into_io_err)
}
pub fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
match &mut self.kind {
InodeSocketKind::TcpListener { socket, .. } => socket.poll_write_ready(cx),
InodeSocketKind::TcpStream { socket, .. } => socket.poll_write_ready(cx),
InodeSocketKind::UdpSocket { socket, .. } => socket.poll_write_ready(cx),
InodeSocketKind::Raw(socket) => socket.poll_write_ready(cx),
InodeSocketKind::Icmp(socket) => socket.poll_write_ready(cx),
InodeSocketKind::BoundTcp { .. } => Poll::Ready(Ok(1)),
InodeSocketKind::PreSocket { .. } => Poll::Pending,
InodeSocketKind::RemoteSocket { is_dead, .. } => match is_dead {
true => Poll::Ready(Ok(0)),
false => Poll::Pending,
},
}
.map_err(net_error_into_io_err)
}
pub fn set_handler(
&mut self,
handler: Box<dyn InterestHandler + Send + Sync>,
) -> virtual_net::Result<()> {
match &mut self.kind {
InodeSocketKind::TcpListener { socket, .. } => socket.set_handler(handler),
InodeSocketKind::TcpStream { socket, .. } => socket.set_handler(handler),
InodeSocketKind::UdpSocket { socket, .. } => socket.set_handler(handler),
InodeSocketKind::Raw(socket) => socket.set_handler(handler),
InodeSocketKind::Icmp(socket) => socket.set_handler(handler),
InodeSocketKind::PreSocket { props, .. }
| InodeSocketKind::BoundTcp { props, .. }
| InodeSocketKind::RemoteSocket { props, .. } => {
props.handler.replace(handler);
Ok(())
}
}
}
}
#[allow(dead_code)]
pub(crate) fn all_socket_rights() -> Rights {
Rights::FD_FDSTAT_SET_FLAGS
.union(Rights::FD_FILESTAT_GET)
.union(Rights::FD_READ)
.union(Rights::FD_WRITE)
.union(Rights::POLL_FD_READWRITE)
.union(Rights::SOCK_SHUTDOWN)
.union(Rights::SOCK_CONNECT)
.union(Rights::SOCK_LISTEN)
.union(Rights::SOCK_BIND)
.union(Rights::SOCK_ACCEPT)
.union(Rights::SOCK_RECV)
.union(Rights::SOCK_SEND)
.union(Rights::SOCK_ADDR_LOCAL)
.union(Rights::SOCK_ADDR_REMOTE)
.union(Rights::SOCK_RECV_FROM)
.union(Rights::SOCK_SEND_TO)
}
#[cfg(test)]
mod tests {
use super::{InodeSocket, InodeSocketKind, SocketProperties, WasiSocketStatus};
use std::{
future::pending,
mem::MaybeUninit,
net::{Ipv4Addr, Shutdown, SocketAddr},
pin::Pin,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
task::{Context, Poll},
time::Duration,
};
use virtual_mio::InterestHandler;
use virtual_net::{
NetworkError, Result as NetResult, SocketStatus, VirtualConnectedSocket, VirtualIoSource,
VirtualNetworking, VirtualSocket, VirtualTcpBoundSocket, VirtualTcpSocket,
};
use wasmer_wasix_types::wasi::{Addressfamily, Errno, SockProto, Socktype};
#[derive(Debug)]
struct MockTcpSocket {
read_calls: Arc<AtomicUsize>,
write_calls: Arc<AtomicUsize>,
status: Arc<AtomicUsize>,
}
const MOCK_STATUS_OPENING: usize = 0;
const MOCK_STATUS_OPENED: usize = 1;
fn decode_mock_status(value: usize) -> SocketStatus {
match value {
MOCK_STATUS_OPENED => SocketStatus::Opened,
_ => SocketStatus::Opening,
}
}
impl VirtualIoSource for MockTcpSocket {
fn remove_handler(&mut self) {}
fn poll_read_ready(&mut self, _cx: &mut Context<'_>) -> Poll<NetResult<usize>> {
self.read_calls.fetch_add(1, Ordering::Relaxed);
Poll::Ready(Ok(3))
}
fn poll_write_ready(&mut self, _cx: &mut Context<'_>) -> Poll<NetResult<usize>> {
self.write_calls.fetch_add(1, Ordering::Relaxed);
self.status.store(MOCK_STATUS_OPENED, Ordering::Relaxed);
Poll::Ready(Ok(7))
}
}
impl VirtualSocket for MockTcpSocket {
fn set_ttl(&mut self, _ttl: u32) -> NetResult<()> {
Ok(())
}
fn ttl(&self) -> NetResult<u32> {
Ok(64)
}
fn addr_local(&self) -> NetResult<SocketAddr> {
Ok(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
}
fn status(&self) -> NetResult<SocketStatus> {
Ok(decode_mock_status(self.status.load(Ordering::Relaxed)))
}
fn set_handler(
&mut self,
_handler: Box<dyn InterestHandler + Send + Sync>,
) -> NetResult<()> {
Ok(())
}
}
impl VirtualConnectedSocket for MockTcpSocket {
fn set_linger(&mut self, _linger: Option<Duration>) -> NetResult<()> {
Ok(())
}
fn linger(&self) -> NetResult<Option<Duration>> {
Ok(None)
}
fn try_send(&mut self, _data: &[u8]) -> NetResult<usize> {
Err(NetworkError::Unsupported)
}
fn try_flush(&mut self) -> NetResult<()> {
Err(NetworkError::Unsupported)
}
fn close(&mut self) -> NetResult<()> {
Ok(())
}
fn try_recv(&mut self, _buf: &mut [MaybeUninit<u8>], _peek: bool) -> NetResult<usize> {
Err(NetworkError::Unsupported)
}
}
impl VirtualTcpSocket for MockTcpSocket {
fn set_recv_buf_size(&mut self, _size: usize) -> NetResult<()> {
Ok(())
}
fn recv_buf_size(&self) -> NetResult<usize> {
Ok(0)
}
fn set_send_buf_size(&mut self, _size: usize) -> NetResult<()> {
Ok(())
}
fn send_buf_size(&self) -> NetResult<usize> {
Ok(0)
}
fn set_nodelay(&mut self, _reuse: bool) -> NetResult<()> {
Ok(())
}
fn nodelay(&self) -> NetResult<bool> {
Ok(true)
}
fn set_keepalive(&mut self, _keepalive: bool) -> NetResult<()> {
Ok(())
}
fn keepalive(&self) -> NetResult<bool> {
Ok(false)
}
fn set_dontroute(&mut self, _keepalive: bool) -> NetResult<()> {
Ok(())
}
fn dontroute(&self) -> NetResult<bool> {
Ok(false)
}
fn addr_peer(&self) -> NetResult<SocketAddr> {
Ok(SocketAddr::from((Ipv4Addr::LOCALHOST, 80)))
}
fn shutdown(&mut self, _how: Shutdown) -> NetResult<()> {
Ok(())
}
fn is_closed(&self) -> bool {
false
}
}
#[derive(Debug)]
struct MockTcpBoundSocket {
ttl: Arc<AtomicUsize>,
}
impl VirtualTcpBoundSocket for MockTcpBoundSocket {
fn addr_local(&self) -> NetResult<SocketAddr> {
Ok(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
}
fn listen(&mut self) -> NetResult<Box<dyn virtual_net::VirtualTcpListener + Sync>> {
Err(NetworkError::Unsupported)
}
fn connect(
&mut self,
_peer: SocketAddr,
) -> NetResult<Box<dyn virtual_net::VirtualTcpSocket + Sync>> {
Err(NetworkError::Unsupported)
}
fn set_ttl(&mut self, ttl: u32) -> NetResult<()> {
self.ttl.store(ttl as usize, Ordering::Relaxed);
Ok(())
}
fn ttl(&self) -> NetResult<u32> {
Ok(self.ttl.load(Ordering::Relaxed) as u32)
}
}
#[derive(Debug)]
struct PendingBindNetworking;
#[async_trait::async_trait]
impl VirtualNetworking for PendingBindNetworking {
async fn bind_tcp(
&self,
_addr: SocketAddr,
_only_v6: bool,
_reuse_port: bool,
_reuse_addr: bool,
) -> NetResult<Box<dyn VirtualTcpBoundSocket + Sync>> {
pending::<()>().await;
unreachable!("pending bind_tcp future should never complete")
}
}
#[test]
fn inode_socket_poll_write_ready_uses_write_path() {
let read_calls = Arc::new(AtomicUsize::new(0));
let write_calls = Arc::new(AtomicUsize::new(0));
let status = Arc::new(AtomicUsize::new(MOCK_STATUS_OPENED));
let mut inode = InodeSocket::new(InodeSocketKind::TcpStream {
socket: Box::new(MockTcpSocket {
read_calls: read_calls.clone(),
write_calls: write_calls.clone(),
status,
}),
write_timeout: None,
read_timeout: None,
});
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let ready = Pin::new(&mut inode).poll_write_ready(&mut cx);
assert!(matches!(ready, Poll::Ready(Ok(7))));
assert_eq!(read_calls.load(Ordering::Relaxed), 0);
assert_eq!(write_calls.load(Ordering::Relaxed), 1);
}
#[test]
fn inode_socket_status_tracks_tcp_socket_status() {
let status = Arc::new(AtomicUsize::new(MOCK_STATUS_OPENING));
let inode = InodeSocket::new(InodeSocketKind::TcpStream {
socket: Box::new(MockTcpSocket {
read_calls: Arc::new(AtomicUsize::new(0)),
write_calls: Arc::new(AtomicUsize::new(0)),
status: status.clone(),
}),
write_timeout: None,
read_timeout: None,
});
assert!(matches!(inode.status().unwrap(), WasiSocketStatus::Opening));
status.store(MOCK_STATUS_OPENED, Ordering::Relaxed);
assert!(matches!(inode.status().unwrap(), WasiSocketStatus::Opened));
}
#[test]
fn inode_socket_bound_tcp_forwards_ttl() {
let ttl = Arc::new(AtomicUsize::new(64));
let inode = InodeSocket::new(InodeSocketKind::BoundTcp {
socket: Box::new(MockTcpBoundSocket { ttl: ttl.clone() }),
props: SocketProperties {
family: Addressfamily::Inet4,
ty: Socktype::Stream,
pt: SockProto::Tcp,
only_v6: false,
reuse_port: false,
reuse_addr: false,
no_delay: None,
keep_alive: None,
dont_route: None,
send_buf_size: None,
recv_buf_size: None,
write_timeout: None,
read_timeout: None,
accept_timeout: None,
connect_timeout: None,
handler: None,
},
});
inode.set_ttl(42).unwrap();
assert_eq!(inode.ttl().unwrap(), 42);
assert_eq!(ttl.load(Ordering::Relaxed), 42);
}
#[cfg(feature = "sys")]
#[tokio::test(flavor = "current_thread")]
async fn inode_socket_tcp_bind_respects_bind_timeout() {
let inode = InodeSocket::new(InodeSocketKind::PreSocket {
props: SocketProperties {
family: Addressfamily::Inet4,
ty: Socktype::Stream,
pt: SockProto::Tcp,
only_v6: false,
reuse_port: false,
reuse_addr: false,
no_delay: None,
keep_alive: None,
dont_route: None,
send_buf_size: None,
recv_buf_size: None,
write_timeout: None,
read_timeout: None,
accept_timeout: None,
connect_timeout: None,
handler: None,
},
addr: None,
});
let tasks = crate::runtime::task_manager::tokio::TokioTaskManager::default();
let net = PendingBindNetworking;
let err = inode
.bind_internal(
&tasks,
&net,
SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
Duration::from_millis(10),
)
.await
.unwrap_err();
assert_eq!(err, Errno::Timedout);
}
}