use core::future::Future;
use core::net::SocketAddr;
use std::sync::Arc;
use cap_net_ext::AddressFamily;
use io_lifetimes::AsSocketlike as _;
use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
use rustix::io::Errno;
use rustix::net::connect;
use tracing::debug;
use crate::p3::bindings::sockets::types::{ErrorCode, IpAddressFamily, IpSocketAddress};
use crate::runtime::with_ambient_tokio_runtime;
use crate::sockets::util::{
get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, receive_buffer_size,
send_buffer_size, set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit,
udp_bind, udp_disconnect, udp_socket,
};
use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddressFamily};
enum UdpState {
Default,
Bound,
Connected(SocketAddr),
}
pub struct UdpSocket {
socket: Arc<tokio::net::UdpSocket>,
udp_state: UdpState,
family: SocketAddressFamily,
}
impl UdpSocket {
pub fn new(family: AddressFamily) -> std::io::Result<Self> {
let fd = udp_socket(family)?;
let socket_address_family = match family {
AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
AddressFamily::Ipv6 => {
rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
SocketAddressFamily::Ipv6
}
};
let socket = with_ambient_tokio_runtime(|| {
tokio::net::UdpSocket::try_from(unsafe {
std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
})
})?;
Ok(Self {
socket: Arc::new(socket),
udp_state: UdpState::Default,
family: socket_address_family,
})
}
pub fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
if !matches!(self.udp_state, UdpState::Default) {
return Err(ErrorCode::InvalidState);
}
if !is_valid_address_family(addr.ip(), self.family) {
return Err(ErrorCode::InvalidArgument);
}
udp_bind(&self.socket, addr)?;
self.udp_state = UdpState::Bound;
Ok(())
}
pub fn disconnect(&mut self) -> Result<(), ErrorCode> {
if !matches!(self.udp_state, UdpState::Connected(..)) {
return Err(ErrorCode::InvalidState);
}
udp_disconnect(&self.socket)?;
self.udp_state = UdpState::Bound;
Ok(())
}
pub fn connect(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
return Err(ErrorCode::InvalidArgument);
}
if let UdpState::Connected(..) = self.udp_state {
udp_disconnect(&self.socket)?;
self.udp_state = UdpState::Bound;
}
connect(&self.socket, &addr).map_err(|error| match error {
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
debug!("UDP connect returned EINPROGRESS, which should never happen");
ErrorCode::Unknown
}
err => err.into(),
})?;
self.udp_state = UdpState::Connected(addr);
Ok(())
}
pub fn send(&self, buf: Vec<u8>) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
let socket = if let UdpState::Connected(..) = self.udp_state {
Ok(Arc::clone(&self.socket))
} else {
Err(ErrorCode::InvalidArgument)
};
async move {
let socket = socket?;
send(&socket, &buf).await
}
}
pub fn send_to(
&self,
buf: Vec<u8>,
addr: SocketAddr,
) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
enum Mode {
Send(Arc<tokio::net::UdpSocket>),
SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
}
let socket = match &self.udp_state {
UdpState::Default | UdpState::Bound => Ok(Mode::SendTo(Arc::clone(&self.socket), addr)),
UdpState::Connected(caddr) if addr == *caddr => {
Ok(Mode::Send(Arc::clone(&self.socket)))
}
UdpState::Connected(..) => Err(ErrorCode::InvalidArgument),
};
async move {
match socket? {
Mode::Send(socket) => send(&socket, &buf).await,
Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
}
}
}
pub fn receive(
&self,
) -> impl Future<Output = Result<(Vec<u8>, IpSocketAddress), ErrorCode>> + use<> {
enum Mode {
Recv(Arc<tokio::net::UdpSocket>, IpSocketAddress),
RecvFrom(Arc<tokio::net::UdpSocket>),
}
let socket = match self.udp_state {
UdpState::Default => Err(ErrorCode::InvalidState),
UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr.into())),
};
async move {
let socket = socket?;
let mut buf = vec![0; MAX_UDP_DATAGRAM_SIZE];
let (n, addr) = match socket {
Mode::Recv(socket, addr) => {
let n = socket.recv(&mut buf).await?;
(n, addr)
}
Mode::RecvFrom(socket) => {
let (n, addr) = socket.recv_from(&mut buf).await?;
(n, addr.into())
}
};
buf.truncate(n);
Ok((buf, addr))
}
}
pub fn local_address(&self) -> Result<IpSocketAddress, ErrorCode> {
if matches!(self.udp_state, UdpState::Default) {
return Err(ErrorCode::InvalidState);
}
let addr = self
.socket
.as_socketlike_view::<std::net::UdpSocket>()
.local_addr()?;
Ok(addr.into())
}
pub fn remote_address(&self) -> Result<IpSocketAddress, ErrorCode> {
if !matches!(self.udp_state, UdpState::Connected(..)) {
return Err(ErrorCode::InvalidState);
}
let addr = self
.socket
.as_socketlike_view::<std::net::UdpSocket>()
.peer_addr()?;
Ok(addr.into())
}
pub fn address_family(&self) -> IpAddressFamily {
match self.family {
SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
}
}
pub fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
let n = get_unicast_hop_limit(&self.socket, self.family)?;
Ok(n)
}
pub fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
set_unicast_hop_limit(&self.socket, self.family, value)?;
Ok(())
}
pub fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
let n = receive_buffer_size(&self.socket)?;
Ok(n)
}
pub fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
set_receive_buffer_size(&self.socket, value)?;
Ok(())
}
pub fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
let n = send_buffer_size(&self.socket)?;
Ok(n)
}
pub fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
set_send_buffer_size(&self.socket, value)?;
Ok(())
}
}
async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
let n = socket.send(buf).await?;
if n != buf.len() {
Err(ErrorCode::Unknown)
} else {
Ok(())
}
}
async fn send_to(
socket: &tokio::net::UdpSocket,
buf: &[u8],
addr: SocketAddr,
) -> Result<(), ErrorCode> {
let n = socket.send_to(buf, addr).await?;
if n != buf.len() {
Err(ErrorCode::Unknown)
} else {
Ok(())
}
}