use crate::runtime::with_ambient_tokio_runtime;
use crate::sockets::util::{
ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,
};
use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
use cap_net_ext::AddressFamily;
use io_lifetimes::AsSocketlike as _;
use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
use rustix::io::Errno;
use rustix::net::connect;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::debug;
enum UdpState {
Default,
BindStarted,
Bound,
#[cfg_attr(
not(feature = "p3"),
expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
)]
Connected(SocketAddr),
}
pub struct UdpSocket {
socket: Arc<tokio::net::UdpSocket>,
udp_state: UdpState,
family: SocketAddressFamily,
socket_addr_check: Option<SocketAddrCheck>,
}
impl UdpSocket {
pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {
cx.allowed_network_uses.check_allowed_udp()?;
let fd = udp_socket(family)?;
let socket_address_family = match family {
AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
AddressFamily::Ipv6 => {
rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
SocketAddressFamily::Ipv6
}
};
let socket = with_ambient_tokio_runtime(|| {
tokio::net::UdpSocket::try_from(unsafe {
std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
})
})?;
Ok(Self {
socket: Arc::new(socket),
udp_state: UdpState::Default,
family: socket_address_family,
socket_addr_check: None,
})
}
pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
if !matches!(self.udp_state, UdpState::Default) {
return Err(ErrorCode::InvalidState);
}
if !is_valid_address_family(addr.ip(), self.family) {
return Err(ErrorCode::InvalidArgument);
}
udp_bind(&self.socket, addr)?;
self.udp_state = UdpState::BindStarted;
Ok(())
}
pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
match self.udp_state {
UdpState::BindStarted => {
self.udp_state = UdpState::Bound;
Ok(())
}
_ => Err(ErrorCode::NotInProgress),
}
}
pub(crate) fn is_connected(&self) -> bool {
matches!(self.udp_state, UdpState::Connected(..))
}
pub(crate) fn is_bound(&self) -> bool {
matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
}
pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
if !self.is_connected() {
return Err(ErrorCode::InvalidState);
}
udp_disconnect(&self.socket)?;
self.udp_state = UdpState::Bound;
Ok(())
}
pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
match self.udp_state {
UdpState::Bound | UdpState::Connected(_) => {}
_ => return Err(ErrorCode::InvalidState),
}
self.connect_common(addr)
}
#[cfg(feature = "p3")]
pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
match self.udp_state {
UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
_ => return Err(ErrorCode::InvalidState),
}
self.connect_common(addr)
}
fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
return Err(ErrorCode::InvalidArgument);
}
if let UdpState::Connected(..) = self.udp_state {
udp_disconnect(&self.socket)?;
self.udp_state = UdpState::Bound;
}
connect(&self.socket, &addr).map_err(|error| match error {
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
debug!("UDP connect returned EINPROGRESS, which should never happen");
ErrorCode::Unknown
}
err => err.into(),
})?;
self.udp_state = UdpState::Connected(addr);
Ok(())
}
#[cfg(feature = "p3")]
pub(crate) fn send_p3(
&mut self,
buf: Vec<u8>,
addr: Option<SocketAddr>,
) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
enum Mode {
Send(Arc<tokio::net::UdpSocket>),
SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
}
let mut socket = match (&self.udp_state, addr) {
(UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
(UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
(UdpState::Default | UdpState::Bound, Some(addr)) => {
Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
}
(UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
(UdpState::Connected(caddr), Some(addr)) => {
if addr == *caddr {
Ok(Mode::Send(Arc::clone(&self.socket)))
} else {
Err(ErrorCode::InvalidArgument)
}
}
};
if socket.is_ok()
&& let UdpState::Default = self.udp_state
{
let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
match udp_bind(&self.socket, implicit_addr) {
Ok(()) => {
self.udp_state = UdpState::Bound;
}
Err(e) => {
socket = Err(e);
}
}
}
async move {
match socket? {
Mode::Send(socket) => send(&socket, &buf).await,
Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
}
}
}
#[cfg(feature = "p3")]
pub(crate) fn receive_p3(
&self,
) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
enum Mode {
Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
RecvFrom(Arc<tokio::net::UdpSocket>),
}
let socket = match self.udp_state {
UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
};
async move {
let socket = socket?;
let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
let (n, addr) = match socket {
Mode::Recv(socket, addr) => {
let n = socket.recv(&mut buf).await?;
(n, addr)
}
Mode::RecvFrom(socket) => {
let (n, addr) = socket.recv_from(&mut buf).await?;
(n, addr)
}
};
buf.truncate(n);
Ok((buf, addr))
}
}
pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
return Err(ErrorCode::InvalidState);
}
let addr = self
.socket
.as_socketlike_view::<std::net::UdpSocket>()
.local_addr()?;
Ok(addr)
}
pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
if !matches!(self.udp_state, UdpState::Connected(..)) {
return Err(ErrorCode::InvalidState);
}
let addr = self
.socket
.as_socketlike_view::<std::net::UdpSocket>()
.peer_addr()?;
Ok(addr)
}
pub(crate) fn address_family(&self) -> SocketAddressFamily {
self.family
}
pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
let n = get_unicast_hop_limit(&self.socket, self.family)?;
Ok(n)
}
pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
set_unicast_hop_limit(&self.socket, self.family, value)?;
Ok(())
}
pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
let n = receive_buffer_size(&self.socket)?;
Ok(n)
}
pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
set_receive_buffer_size(&self.socket, value)?;
Ok(())
}
pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
let n = send_buffer_size(&self.socket)?;
Ok(n)
}
pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
set_send_buffer_size(&self.socket, value)?;
Ok(())
}
pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
&self.socket
}
pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
self.socket_addr_check.as_ref()
}
pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
self.socket_addr_check = check;
}
}
#[cfg(feature = "p3")]
async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
let n = socket.send(buf).await?;
if n != buf.len() {
Err(ErrorCode::Unknown)
} else {
Ok(())
}
}
#[cfg(feature = "p3")]
async fn send_to(
socket: &tokio::net::UdpSocket,
buf: &[u8],
addr: SocketAddr,
) -> Result<(), ErrorCode> {
let n = socket.send_to(buf, addr).await?;
if n != buf.len() {
Err(ErrorCode::Unknown)
} else {
Ok(())
}
}