use core::future::Future;
use core::ops::Deref;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use wasmtime::component::{HasData, ResourceTable};
mod tcp;
mod udp;
pub(crate) mod util;
#[cfg(feature = "p3")]
pub(crate) use tcp::NonInheritedOptions;
pub use tcp::TcpSocket;
pub use udp::UdpSocket;
pub struct WasiSockets;
impl HasData for WasiSockets {
type Data<'a> = WasiSocketsCtxView<'a>;
}
pub(crate) const DEFAULT_TCP_BACKLOG: u32 = 128;
pub(crate) const MAX_UDP_DATAGRAM_SIZE: usize = u16::MAX as usize;
#[derive(Clone, Default)]
pub struct WasiSocketsCtx {
pub(crate) socket_addr_check: SocketAddrCheck,
pub(crate) allowed_network_uses: AllowedNetworkUses,
}
pub struct WasiSocketsCtxView<'a> {
pub ctx: &'a mut WasiSocketsCtx,
pub table: &'a mut ResourceTable,
}
pub trait WasiSocketsView: Send {
fn sockets(&mut self) -> WasiSocketsCtxView<'_>;
}
#[derive(Copy, Clone)]
pub(crate) struct AllowedNetworkUses {
pub(crate) ip_name_lookup: bool,
pub(crate) udp: bool,
pub(crate) tcp: bool,
}
impl Default for AllowedNetworkUses {
fn default() -> Self {
Self {
ip_name_lookup: false,
udp: true,
tcp: true,
}
}
}
impl AllowedNetworkUses {
pub(crate) fn check_allowed_udp(&self) -> std::io::Result<()> {
if !self.udp {
return Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"UDP is not allowed",
));
}
Ok(())
}
pub(crate) fn check_allowed_tcp(&self) -> std::io::Result<()> {
if !self.tcp {
return Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"TCP is not allowed",
));
}
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct SocketAddrCheck(
Arc<
dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
+ Send
+ Sync,
>,
);
impl SocketAddrCheck {
pub(crate) fn new(
f: impl Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
+ Send
+ Sync
+ 'static,
) -> Self {
Self(Arc::new(f))
}
pub(crate) async fn check(
&self,
addr: SocketAddr,
reason: SocketAddrUse,
) -> std::io::Result<()> {
if (self.0)(addr, reason).await {
Ok(())
} else {
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"An address was not permitted by the socket address check.",
))
}
}
}
impl Deref for SocketAddrCheck {
type Target = dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
+ Send
+ Sync;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl Default for SocketAddrCheck {
fn default() -> Self {
Self(Arc::new(|_, _| Box::pin(async { false })))
}
}
#[derive(Clone, Copy, Debug)]
pub enum SocketAddrUse {
TcpBind,
TcpConnect,
UdpBind,
UdpConnect,
UdpOutgoingDatagram,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub(crate) enum SocketAddressFamily {
Ipv4,
Ipv6,
}