use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::platform::sys::{set_nonblocking, Interest};
use crate::reactor::source::{next_token, IoSource};
use super::sockaddr::{reclaim_raw_sockaddr, sockaddr_to_socketaddr, socketaddr_to_raw};
pub struct UdpSocket {
source: IoSource,
}
impl UdpSocket {
pub fn bind(addr: SocketAddr) -> io::Result<Self> {
let fd = create_udp_socket(addr)?;
bind_socket(fd, addr)?;
set_nonblocking(fd)?;
let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
Ok(Self { source })
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
raw_local_addr(self.source.raw())
}
pub fn send_to<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> SendToFuture<'a> {
SendToFuture {
socket: self,
buf,
target,
}
}
pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFromFuture<'a> {
RecvFromFuture { socket: self, buf }
}
}
impl Drop for UdpSocket {
fn drop(&mut self) {
let fd = self.source.raw();
unsafe { libc::close(fd) };
}
}
pub struct SendToFuture<'a> {
socket: &'a UdpSocket,
buf: &'a [u8],
target: SocketAddr,
}
impl<'a> Future for SendToFuture<'a> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match try_send_to(self.socket.source.raw(), self.buf, self.target) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
match Pin::new(&mut self.socket.source.writable()).poll(cx) {
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
}
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
pub struct RecvFromFuture<'a> {
socket: &'a UdpSocket,
buf: &'a mut [u8],
}
impl<'a> Future for RecvFromFuture<'a> {
type Output = io::Result<(usize, SocketAddr)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let fd = self.socket.source.raw();
match try_recv_from(fd, self.buf) {
Ok(result) => Poll::Ready(Ok(result)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
match Pin::new(&mut self.socket.source.readable()).poll(cx) {
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
}
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
fn create_udp_socket(addr: SocketAddr) -> io::Result<i32> {
let family = match addr {
SocketAddr::V4(_) => libc::AF_INET,
SocketAddr::V6(_) => libc::AF_INET6,
};
let fd = unsafe { libc::socket(family, libc::SOCK_DGRAM, 0) };
if fd == -1 {
return Err(io::Error::last_os_error());
}
Ok(fd)
}
fn bind_socket(fd: i32, addr: SocketAddr) -> io::Result<()> {
let (sa_ptr, sa_len) = socketaddr_to_raw(addr);
let rc = unsafe { libc::bind(fd, sa_ptr, sa_len) };
unsafe { reclaim_raw_sockaddr(sa_ptr, addr) };
if rc == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
}
fn try_send_to(fd: i32, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
let (sa_ptr, sa_len) = socketaddr_to_raw(target);
let n = unsafe {
libc::sendto(
fd,
buf.as_ptr() as *const libc::c_void,
buf.len(),
0, sa_ptr,
sa_len,
)
};
unsafe { reclaim_raw_sockaddr(sa_ptr, target) };
if n == -1 {
return Err(io::Error::last_os_error());
}
Ok(n as usize)
}
fn try_recv_from(fd: i32, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
let n = unsafe {
libc::recvfrom(
fd,
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
0, &mut addr as *mut _ as *mut libc::sockaddr,
&mut len,
)
};
if n == -1 {
return Err(io::Error::last_os_error());
}
let sender = sockaddr_to_socketaddr(&addr, len)?;
Ok((n as usize, sender))
}
fn raw_local_addr(fd: i32) -> io::Result<SocketAddr> {
let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
if rc == -1 {
return Err(io::Error::last_os_error());
}
sockaddr_to_socketaddr(&addr, len)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::block_on_with_spawn;
#[test]
fn bind_and_local_addr() {
let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).expect("bind failed");
let addr = sock.local_addr().expect("local_addr failed");
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert!(addr.port() > 0);
}
#[test]
fn send_to_and_recv_from() {
block_on_with_spawn(async {
let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let recv_addr = receiver.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let msg = b"ping";
let n = sender.send_to(msg, recv_addr).await.unwrap();
assert_eq!(n, msg.len());
let mut buf = [0u8; 16];
let (n, from) = receiver.recv_from(&mut buf).await.unwrap();
assert_eq!(n, msg.len());
assert_eq!(&buf[..n], msg);
assert_eq!(from.ip(), sender.local_addr().unwrap().ip());
});
}
#[test]
fn udp_echo_round_trip() {
block_on_with_spawn(async {
let server = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let server_addr = server.local_addr().unwrap();
let client = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
client.send_to(b"hello", server_addr).await.unwrap();
let mut buf = [0u8; 16];
let (n, from) = server.recv_from(&mut buf).await.unwrap();
server.send_to(&buf[..n], from).await.unwrap();
let mut reply = [0u8; 16];
let (rn, _) = client.recv_from(&mut reply).await.unwrap();
assert_eq!(&reply[..rn], b"hello");
});
}
#[test]
fn udp_bind_port_zero_gets_assigned() {
let sock = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = sock.local_addr().unwrap();
assert!(addr.port() > 1024);
}
#[test]
fn udp_send_returns_correct_byte_count() {
block_on_with_spawn(async {
let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let recv_addr = receiver.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let msg = b"test123";
let n = sender.send_to(msg, recv_addr).await.unwrap();
assert_eq!(n, msg.len());
});
}
#[test]
fn udp_recv_from_returns_sender_ip() {
block_on_with_spawn(async {
let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let recv_addr = receiver.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let sender_addr = sender.local_addr().unwrap();
sender.send_to(b"hi", recv_addr).await.unwrap();
let mut buf = [0u8; 8];
let (_, from) = receiver.recv_from(&mut buf).await.unwrap();
assert_eq!(from.ip(), sender_addr.ip());
});
}
#[test]
fn udp_multiple_datagrams_sequential() {
block_on_with_spawn(async {
let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let recv_addr = receiver.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
for i in 0u8..5 {
let msg = [i; 1];
sender.send_to(&msg, recv_addr).await.unwrap();
let mut buf = [0u8; 4];
let (n, _) = receiver.recv_from(&mut buf).await.unwrap();
assert_eq!(n, 1);
assert_eq!(buf[0], i);
}
});
}
#[test]
fn udp_large_datagram_fits_buf() {
block_on_with_spawn(async {
let receiver = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let recv_addr = receiver.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let msg = [42u8; 1024];
let n = sender.send_to(&msg, recv_addr).await.unwrap();
assert_eq!(n, 1024);
let mut buf = [0u8; 1024];
let (rn, _) = receiver.recv_from(&mut buf).await.unwrap();
assert_eq!(rn, 1024);
assert!(buf.iter().all(|&b| b == 42));
});
}
#[test]
fn udp_two_sockets_cross_exchange() {
block_on_with_spawn(async {
let a = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let b = UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let a_addr = a.local_addr().unwrap();
let b_addr = b.local_addr().unwrap();
a.send_to(b"from_a", b_addr).await.unwrap();
let mut buf = [0u8; 8];
let (n, from) = b.recv_from(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"from_a");
assert_eq!(from.ip(), a_addr.ip());
b.send_to(b"from_b", a_addr).await.unwrap();
let mut buf2 = [0u8; 8];
let (n2, from2) = a.recv_from(&mut buf2).await.unwrap();
assert_eq!(&buf2[..n2], b"from_b");
assert_eq!(from2.ip(), b_addr.ip());
});
}
}