use std::{
future::Future,
io::Result,
net::{SocketAddr, ToSocketAddrs, UdpSocket as StdUdpSocket},
os::fd::AsRawFd,
pin::Pin,
task::{Context, Poll},
};
use io_uring::{opcode, types};
use libc::{iovec, msghdr};
use crate::reactor::{Reactor, ReactorIo};
use super::sock_addr::CSockAddr;
pub struct UdpSocket {
inner: StdUdpSocket,
}
impl UdpSocket {
pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
let sock = StdUdpSocket::bind(addr)?;
Ok(Self { inner: sock })
}
pub fn recv_from<'a, 'b>(&'a mut self, buf: &'b mut [u8]) -> RecvFrom<'a, 'b> {
RecvFrom {
sock: &self.inner,
io: Reactor::new_io(),
hdr: unsafe { std::mem::zeroed() },
iov: unsafe { std::mem::zeroed() },
csock: unsafe { std::mem::zeroed() },
buf,
}
}
pub fn send_to<'a, 'b, A: ToSocketAddrs>(
&'a self,
buf: &'b [u8],
target: A,
) -> SendTo<'a, 'b, A> {
SendTo {
sock: &self.inner,
dst: target,
io: Reactor::new_io(),
buf,
hdr: unsafe { std::mem::zeroed() },
csock: unsafe { std::mem::zeroed() },
iov: unsafe { std::mem::zeroed() },
}
}
}
pub struct RecvFrom<'a, 'b> {
sock: &'a StdUdpSocket,
io: ReactorIo,
hdr: msghdr,
iov: iovec,
csock: CSockAddr,
buf: &'b mut [u8],
}
impl Future for RecvFrom<'_, '_> {
type Output = Result<(usize, SocketAddr)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
this.io
.submit_or_get_result(|| {
this.iov.iov_base = this.buf.as_mut_ptr() as *mut _;
this.iov.iov_len = this.buf.len();
this.hdr.msg_iov = &mut this.iov as *mut _;
this.hdr.msg_iovlen = 1;
this.hdr.msg_name = &mut this.csock.addr as *mut _ as *mut _;
this.hdr.msg_namelen = std::mem::size_of_val(&this.csock.addr) as _;
(
opcode::RecvMsg::new(types::Fd(this.sock.as_raw_fd()), &mut this.hdr as *mut _)
.build(),
cx.waker().clone(),
)
})
.map(|x| {
let sz = x?;
this.csock.len = this.hdr.msg_namelen as _;
match <&CSockAddr as TryInto<SocketAddr>>::try_into(&this.csock) {
Ok(addr) => Ok((sz as _, addr)),
Err(e) => Err(e),
}
})
}
}
pub struct SendTo<'a, 'b, A: ToSocketAddrs> {
sock: &'a StdUdpSocket,
dst: A,
io: ReactorIo,
hdr: msghdr,
csock: CSockAddr,
iov: iovec,
buf: &'b [u8],
}
impl<A: ToSocketAddrs> Future for SendTo<'_, '_, A> {
type Output = Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
this.io
.submit_or_get_result(|| {
this.csock = this.dst.to_socket_addrs().unwrap().next().unwrap().into();
this.hdr.msg_namelen = this.csock.len as _;
this.hdr.msg_name = &mut this.csock.addr as *mut _ as *mut _;
this.iov.iov_base = this.buf.as_ptr() as *mut _;
this.iov.iov_len = this.buf.len();
this.hdr.msg_iov = &mut this.iov as *mut _ as *mut _;
this.hdr.msg_iovlen = 1;
(
opcode::SendMsg::new(types::Fd(this.sock.as_raw_fd()), &this.hdr as *const _)
.build(),
cx.waker().clone(),
)
})
.map(|x| x.map(|x| x as _))
}
}
#[cfg(test)]
mod tests {
use super::UdpSocket;
use crate::task::Executor;
use std::net::Ipv4Addr;
#[test]
fn send_recv() {
Executor::block_on(async {
let dst = (Ipv4Addr::LOCALHOST, 8086);
let tx_sock = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
let mut rx_sock = UdpSocket::bind(dst).unwrap();
let task = Executor::spawn(async move {
let mut buf = [0; 4];
rx_sock.recv_from(&mut buf).await.unwrap();
});
tx_sock
.send_to(&0xdeadbeef_u32.to_le_bytes(), dst)
.await
.unwrap();
task.await;
});
}
}