use std::io;
use std::net::SocketAddr;
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use mio::{Interest, Token};
use super::waker_to_ptr;
use crate::io::IoHandle;
pub struct UdpSocket {
inner: mio::net::UdpSocket,
io: IoHandle,
token: Option<Token>,
registered_task: *mut u8,
}
impl UdpSocket {
pub fn bind(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
let inner = mio::net::UdpSocket::bind(addr)?;
Ok(Self {
inner,
io,
token: None,
registered_task: std::ptr::null_mut(),
})
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
self.inner.connect(addr)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr()
}
pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
self.inner.set_broadcast(on)
}
pub fn broadcast(&self) -> io::Result<bool> {
self.inner.broadcast()
}
pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
self.inner.set_multicast_loop_v4(on)
}
pub fn multicast_loop_v4(&self) -> io::Result<bool> {
self.inner.multicast_loop_v4()
}
pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> {
self.inner.set_multicast_ttl_v4(ttl)
}
pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
self.inner.multicast_ttl_v4()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.inner.ttl()
}
pub fn join_multicast_v4(
&self,
multiaddr: &std::net::Ipv4Addr,
interface: &std::net::Ipv4Addr,
) -> io::Result<()> {
self.inner.join_multicast_v4(multiaddr, interface)
}
pub fn leave_multicast_v4(
&self,
multiaddr: &std::net::Ipv4Addr,
interface: &std::net::Ipv4Addr,
) -> io::Result<()> {
self.inner.leave_multicast_v4(multiaddr, interface)
}
pub fn join_multicast_v6(
&self,
multiaddr: &std::net::Ipv6Addr,
interface: u32,
) -> io::Result<()> {
self.inner.join_multicast_v6(multiaddr, interface)
}
pub fn leave_multicast_v6(
&self,
multiaddr: &std::net::Ipv6Addr,
interface: u32,
) -> io::Result<()> {
self.inner.leave_multicast_v6(multiaddr, interface)
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
socket2::SockRef::from(&self.inner).take_error()
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
socket2::SockRef::from(&self.inner).send_buffer_size()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
socket2::SockRef::from(&self.inner).set_send_buffer_size(size)
}
pub fn recv_buffer_size(&self) -> io::Result<usize> {
socket2::SockRef::from(&self.inner).recv_buffer_size()
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
socket2::SockRef::from(&self.inner).set_recv_buffer_size(size)
}
pub fn from_std(socket: std::net::UdpSocket, io: IoHandle) -> io::Result<Self> {
let inner = mio::net::UdpSocket::from_std(socket);
Ok(Self {
inner,
io,
token: None,
registered_task: std::ptr::null_mut(),
})
}
pub fn into_std(mut self) -> io::Result<std::net::UdpSocket> {
if let Some(token) = self.token.take() {
let _ = unsafe { self.io.deregister(&mut self.inner, token) };
}
let fd = self.inner.as_raw_fd();
std::mem::forget(self);
Ok(unsafe { std::net::UdpSocket::from_raw_fd(fd) })
}
pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> {
self.inner.send(buf)
}
pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.recv(buf)
}
pub fn try_send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.inner.send_to(buf, target)
}
pub fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.inner.recv_from(buf)
}
#[inline(always)]
fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
let task_ptr = waker_to_ptr(cx);
if let Some(token) = self.token {
if task_ptr != self.registered_task {
self.io.set_waker(token, cx.waker().clone());
self.registered_task = task_ptr;
}
return Ok(());
}
self.do_register(task_ptr, cx.waker().clone())
}
#[cold]
fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
let interest = Interest::READABLE | Interest::WRITABLE;
let token = self.io.register(&mut self.inner, interest, waker)?;
self.token = Some(token);
self.registered_task = task_ptr;
Ok(())
}
pub fn poll_send_to(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.send_to(buf, target) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
pub fn poll_recv_from(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.recv_from(buf) {
Ok((n, addr)) => Poll::Ready(Ok((n, addr))),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
pub async fn send_to(&mut self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_send_to(cx, buf, target)).await
}
pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_recv_from(cx, buf)).await
}
pub fn poll_send(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.send(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
pub fn poll_recv(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.recv(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_send(cx, buf)).await
}
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_recv(cx, buf)).await
}
pub async fn peek_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
std::future::poll_fn(|cx| {
let this = &mut *self;
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.peek_from(buf) {
Ok((n, addr)) => Poll::Ready(Ok((n, addr))),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
})
.await
}
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| {
let this = &mut *self;
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.peek(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
})
.await
}
}
impl std::fmt::Debug for UdpSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpSocket")
.field("fd", &self.inner.as_raw_fd())
.field("registered", &self.token.is_some())
.finish()
}
}
impl AsFd for UdpSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
impl AsRawFd for UdpSocket {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl Drop for UdpSocket {
fn drop(&mut self) {
if let Some(token) = self.token {
let _ = unsafe { self.io.deregister(&mut self.inner, token) };
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Runtime, spawn_boxed};
use nexus_rt::WorldBuilder;
use std::cell::Cell;
use std::rc::Rc;
use std::time::Duration;
#[test]
fn udp_send_recv() {
let wb = WorldBuilder::new();
let mut world = wb.build();
let mut rt = Runtime::new(&mut world);
let done = Rc::new(Cell::new(false));
let done2 = done.clone();
rt.block_on(async move {
let recv_sock =
UdpSocket::bind("127.0.0.1:0".parse().unwrap(), crate::context::io()).unwrap();
let recv_addr = recv_sock.local_addr().unwrap();
let flag = done2;
spawn_boxed(async move {
let mut sock = recv_sock;
let mut buf = [0u8; 64];
let (n, _from) = sock.recv_from(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"test");
flag.set(true);
});
spawn_boxed(async move {
crate::context::sleep(Duration::from_millis(10)).await;
let mut sock =
UdpSocket::bind("127.0.0.1:0".parse().unwrap(), crate::context::io()).unwrap();
sock.send_to(b"test", recv_addr).await.unwrap();
});
crate::context::sleep(Duration::from_millis(500)).await;
});
assert!(done.get(), "UDP recv never completed");
}
#[test]
fn udp_echo() {
let wb = WorldBuilder::new();
let mut world = wb.build();
let mut rt = Runtime::new(&mut world);
let done = Rc::new(Cell::new(false));
let done2 = done.clone();
rt.block_on(async move {
let server_sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap(), crate::context::io())
.expect("bind failed");
let server_addr = server_sock.local_addr().unwrap();
spawn_boxed(async move {
let mut server = server_sock;
let mut buf = [0u8; 64];
let (n, peer) = server.recv_from(&mut buf).await.unwrap();
server.send_to(&buf[..n], peer).await.unwrap();
});
let flag = done2;
spawn_boxed(async move {
crate::context::sleep(Duration::from_millis(10)).await;
let client_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut client = UdpSocket::bind(client_addr, crate::context::io()).unwrap();
client.send_to(b"hello udp", server_addr).await.unwrap();
let mut buf = [0u8; 64];
let (n, _from) = client.recv_from(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello udp");
flag.set(true);
});
crate::context::sleep(Duration::from_millis(500)).await;
});
assert!(done.get(), "UDP echo never completed");
}
#[test]
fn udp_connected() {
let wb = WorldBuilder::new();
let mut world = wb.build();
let mut rt = Runtime::new(&mut world);
let done = Rc::new(Cell::new(false));
let done2 = done.clone();
rt.block_on(async move {
let io = crate::context::io();
let a_sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap(), io).unwrap();
let b_sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap(), io).unwrap();
let a_addr = a_sock.local_addr().unwrap();
let b_addr = b_sock.local_addr().unwrap();
spawn_boxed(async move {
let mut a = a_sock;
a.connect(b_addr).unwrap();
a.send(b"connected").await.unwrap();
});
let flag = done2;
spawn_boxed(async move {
crate::context::sleep(Duration::from_millis(10)).await;
let mut b = b_sock;
b.connect(a_addr).unwrap();
let mut buf = [0u8; 64];
let n = b.recv(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"connected");
flag.set(true);
});
crate::context::sleep(Duration::from_millis(500)).await;
});
assert!(done.get(), "UDP connected exchange never completed");
}
}