#![cfg_attr(unstable_bool_to_result, feature(bool_to_result))]
#![cfg_attr(unstable_never_type, feature(never_type))]
use std::{
io,
net::{SocketAddr, ToSocketAddrs},
pin::Pin,
task::{Context, Poll},
};
use futures::{Stream, sink::Sink};
use futures_net::driver::{
PollEvented,
sys::{self},
};
use socket2::{Domain, Type};
#[derive(Debug)]
pub struct UdpStream<const BUF_SIZE: usize> {
io: PollEvented<sys::net::UdpSocket>,
}
pub trait EventedUdpSocket
where
Self: Sized,
{
fn from_evented_socket(evented_socket: PollEvented<sys::net::UdpSocket>) -> io::Result<Self>;
fn bind(addr: SocketAddr) -> io::Result<Self> {
let s2 = socket2::Socket::new(Domain::IPV4, Type::DGRAM, None)?;
let addr = addr.into();
s2.set_nonblocking(true)?;
#[cfg(any(unix, all(target_os = "wasi", not(target_env = "p1"))))]
s2.nonblocking()?
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?;
s2.set_reuse_address(true)?;
s2.reuse_address()?
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?;
s2.bind(&addr)?;
let sstd: std::net::UdpSocket = s2.into();
let evented_socket = PollEvented::new(sys::net::UdpSocket::from_socket(sstd)?);
Self::from_evented_socket(evented_socket)
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.as_socket().local_addr()
}
fn as_socket(&self) -> &sys::net::UdpSocket;
fn as_socket_mut(&mut self) -> &mut sys::net::UdpSocket;
fn as_evented_socket_pin(self: Pin<&mut Self>) -> Pin<&mut PollEvented<sys::net::UdpSocket>>;
fn clear_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<!>>;
fn unblock(
self: Pin<&mut Self>,
error: io::Result<!>,
cx: &mut Context<'_>,
) -> Poll<io::Result<!>> {
let Err(error) = error;
match error.kind() {
io::ErrorKind::WouldBlock => self.clear_ready(cx),
_ => Poll::Ready(Err(error)),
}
}
}
impl<const _BS: usize> EventedUdpSocket for UdpStream<_BS> {
fn from_evented_socket(evented_socket: PollEvented<sys::net::UdpSocket>) -> io::Result<Self> {
Ok(Self { io: evented_socket })
}
fn as_socket(&self) -> &sys::net::UdpSocket {
let io = &self.io;
io.get_ref()
}
fn as_socket_mut(&mut self) -> &mut sys::net::UdpSocket {
let io = &mut self.io;
io.get_mut()
}
fn as_evented_socket_pin(self: Pin<&mut Self>) -> Pin<&mut PollEvented<sys::net::UdpSocket>> {
let this = self.get_mut();
let io = &mut this.io;
Pin::new(&mut *io)
}
fn clear_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<!>> {
match self.as_evented_socket_pin().clear_read_ready(cx) {
Ok(_) => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl<const BUF_SIZE: usize> Stream for UdpStream<BUF_SIZE> {
type Item = io::Result<([u8; BUF_SIZE], usize, SocketAddr)>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let evented_socket = self.as_mut().as_evented_socket_pin();
match evented_socket.poll_read_ready(cx) {
Poll::Ready(is_ready) => match is_ready {
Ok(readiness) => match readiness.is_readable() {
true => {
let mut buf: [u8; BUF_SIZE] = [b'\x00'; BUF_SIZE];
let recv = self
.as_socket()
.recv_from(&mut buf)
.map(|(len, addr)| (buf, len, addr));
match recv {
Ok(_) => Poll::Ready(Some(recv)),
Err(e) => self.unblock(Err(e), cx).map_ok(|x| x).map(Some),
}
}
false => self.clear_ready(cx).map_ok(|x| x).map(Some),
},
Err(e) => self.unblock(Err(e), cx).map_ok(|x| x).map(Some),
},
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
pub struct UdpSink {
io: PollEvented<sys::net::UdpSocket>,
}
impl EventedUdpSocket for UdpSink {
fn from_evented_socket(evented_socket: PollEvented<sys::net::UdpSocket>) -> io::Result<Self> {
Ok(Self { io: evented_socket })
}
fn as_socket(&self) -> &sys::net::UdpSocket {
let io = &self.io;
io.get_ref()
}
fn as_socket_mut(&mut self) -> &mut sys::net::UdpSocket {
let io = &mut self.io;
io.get_mut()
}
fn as_evented_socket_pin(self: Pin<&mut Self>) -> Pin<&mut PollEvented<sys::net::UdpSocket>> {
let this = self.get_mut();
let io = &mut this.io;
Pin::new(&mut *io)
}
fn clear_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<!>> {
match self.as_evented_socket_pin().clear_write_ready(cx) {
Ok(_) => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl<A: ToSocketAddrs> Sink<(&[u8], &A)> for UdpSink {
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let evented_socket = self.as_mut().as_evented_socket_pin();
match evented_socket.poll_write_ready(cx) {
Poll::Ready(is_ready) => match is_ready {
Ok(readiness) => match readiness.is_writable() {
true => Poll::Ready(Ok(())),
false => self.clear_ready(cx).map_ok(|x| x),
},
Err(e) => self.unblock(Err(e), cx).map_ok(|x| x),
},
Poll::Pending => Poll::Pending,
}
}
fn start_send(self: Pin<&mut Self>, item: (&[u8], &A)) -> Result<(), Self::Error> {
let socket = self.as_socket();
let (msg, addr) = item;
let addr = addr
.to_socket_addrs()?
.next()
.ok_or(io::Error::from(io::ErrorKind::InvalidInput))?;
socket.send_to(msg, &addr).and_then(|l| {
if l != msg.len() {
Err(io::Error::other(format!(
"{} bytes sent but message was {} bytes",
l,
msg.len()
)))
} else {
Ok(())
}
})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<Self as futures::Sink<(&[u8], &A)>>::poll_ready(self, cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<Self as futures::Sink<(&[u8], &A)>>::poll_flush(self, cx)
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddrV4};
use super::*;
use futures::{SinkExt, StreamExt};
use futures_net::runtime::Runtime;
#[futures_net::test]
async fn non_blocking() {
let loopback = Ipv4Addr::new(127, 0, 0, 1);
let addr: SocketAddr = SocketAddrV4::new(loopback, 0).into();
let first = UdpStream::<32>::bind(addr).expect("first connection");
let addr = first.local_addr().expect("bound port");
let _second = UdpStream::<32>::bind(addr).expect("second connection");
}
#[futures_net::test]
async fn truncated_next() {
let loopback = Ipv4Addr::new(127, 0, 0, 1);
let addr: SocketAddr = SocketAddrV4::new(loopback, 0).into();
let mut receiver = UdpStream::<8>::bind(addr).expect("receiver");
let rec_addr = receiver.local_addr().expect("bound port");
let mut sender = UdpSink::bind(addr).expect("sender");
let original_msg = b"udp loopback test";
let send = async move {
sender
.send((original_msg, &rec_addr))
.await
.expect("send msg");
};
let rec = async {
let (msg, len, _sent_by) = receiver
.next()
.await
.expect("a message")
.expect("a valid message");
assert_eq!(len, 8);
assert_eq!(msg, original_msg[..8]);
};
futures::join!(rec, send);
}
}