use std::fmt;
use std::future::Future;
use std::io::{self, IoSlice};
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{self, Poll};
use heph::actor;
#[cfg(target_os = "linux")]
use log::warn;
use mio::{net, Interest};
use socket2::{SockAddr, SockRef};
use crate::bytes::{Bytes, BytesVectored, MaybeUninitSlice};
use crate::net::convert_address;
use crate::{self as rt, Bound};
#[allow(missing_debug_implementations)]
#[allow(clippy::empty_enum)]
pub enum Unconnected {}
#[allow(missing_debug_implementations)]
#[allow(clippy::empty_enum)]
pub enum Connected {}
pub struct UdpSocket<M = Unconnected> {
socket: net::UdpSocket,
mode: PhantomData<M>,
}
impl UdpSocket {
pub fn bind<M, RT>(
ctx: &mut actor::Context<M, RT>,
local: SocketAddr,
) -> io::Result<UdpSocket<Unconnected>>
where
RT: rt::Access,
{
let mut socket = net::UdpSocket::bind(local)?;
ctx.runtime()
.register(&mut socket, Interest::READABLE | Interest::WRITABLE)?;
#[cfg(target_os = "linux")]
if let Some(cpu) = ctx.runtime_ref().cpu() {
if let Err(err) = SockRef::from(&socket).set_cpu_affinity(cpu) {
warn!("failed to set CPU affinity on UdpSocket: {}", err);
}
}
Ok(UdpSocket {
socket,
mode: PhantomData,
})
}
}
impl<M> UdpSocket<M> {
pub fn connect(self, remote: SocketAddr) -> io::Result<UdpSocket<Connected>> {
self.socket.connect(remote).map(|()| UdpSocket {
socket: self.socket,
mode: PhantomData,
})
}
pub fn local_addr(&mut self) -> io::Result<SocketAddr> {
self.socket.local_addr()
}
pub fn take_error(&mut self) -> io::Result<Option<io::Error>> {
self.socket.take_error()
}
}
impl UdpSocket<Unconnected> {
pub fn try_send_to(&mut self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.socket.send_to(buf, target)
}
pub fn send_to<'a, 'b>(&'a mut self, buf: &'b [u8], target: SocketAddr) -> SendTo<'a, 'b> {
SendTo {
socket: self,
buf,
target,
}
}
pub fn try_send_to_vectored(
&mut self,
bufs: &[IoSlice<'_>],
target: SocketAddr,
) -> io::Result<usize> {
SockRef::from(&self.socket).send_to_vectored(bufs, &target.into())
}
pub fn send_to_vectored<'a, 'b>(
&'a mut self,
bufs: &'b mut [IoSlice<'b>],
target: SocketAddr,
) -> SendToVectored<'a, 'b> {
SendToVectored {
socket: self,
bufs,
target: target.into(),
}
}
pub fn try_recv_from<B>(&mut self, mut buf: B) -> io::Result<(usize, SocketAddr)>
where
B: Bytes,
{
debug_assert!(
buf.has_spare_capacity(),
"called `UdpSocket::try_recv_from` with an empty buffer"
);
SockRef::from(&self.socket)
.recv_from(buf.as_bytes())
.and_then(|(read, address)| {
unsafe { buf.update_length(read) }
let address = convert_address(address)?;
Ok((read, address))
})
}
pub fn recv_from<B>(&mut self, buf: B) -> RecvFrom<'_, B>
where
B: Bytes,
{
RecvFrom { socket: self, buf }
}
pub fn try_recv_from_vectored<B>(&mut self, mut bufs: B) -> io::Result<(usize, SocketAddr)>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `UdpSocket::try_recv_from` with empty buffers"
);
let res = SockRef::from(&self.socket)
.recv_from_vectored(MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()));
match res {
Ok((read, _, address)) => {
unsafe { bufs.update_lengths(read) }
let address = convert_address(address)?;
Ok((read, address))
}
Err(err) => Err(err),
}
}
pub fn recv_from_vectored<B>(&mut self, bufs: B) -> RecvFromVectored<'_, B>
where
B: BytesVectored,
{
RecvFromVectored { socket: self, bufs }
}
pub fn try_peek_from<B>(&mut self, mut buf: B) -> io::Result<(usize, SocketAddr)>
where
B: Bytes,
{
debug_assert!(
buf.has_spare_capacity(),
"called `UdpSocket::try_peek_from` with an empty buffer"
);
SockRef::from(&self.socket)
.peek_from(buf.as_bytes())
.and_then(|(read, address)| {
unsafe { buf.update_length(read) }
let address = convert_address(address)?;
Ok((read, address))
})
}
pub fn peek_from<B>(&mut self, buf: B) -> PeekFrom<'_, B>
where
B: Bytes,
{
PeekFrom { socket: self, buf }
}
pub fn try_peek_from_vectored<B>(&mut self, mut bufs: B) -> io::Result<(usize, SocketAddr)>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `UdpSocket::try_peek_from_vectored` with empty buffers"
);
let res = SockRef::from(&self.socket).recv_from_vectored_with_flags(
MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()),
libc::MSG_PEEK,
);
match res {
Ok((read, _, address)) => {
unsafe { bufs.update_lengths(read) }
let address = convert_address(address)?;
Ok((read, address))
}
Err(err) => Err(err),
}
}
pub fn peek_from_vectored<B>(&mut self, bufs: B) -> PeekFromVectored<'_, B>
where
B: BytesVectored,
{
PeekFromVectored { socket: self, bufs }
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendTo<'a, 'b> {
socket: &'a mut UdpSocket<Unconnected>,
buf: &'b [u8],
target: SocketAddr,
}
impl<'a, 'b> Future for SendTo<'a, 'b> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
#[rustfmt::skip]
let SendTo { socket, buf, target } = Pin::into_inner(self);
try_io!(socket.try_send_to(buf, *target))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendToVectored<'a, 'b> {
socket: &'a mut UdpSocket<Unconnected>,
bufs: &'b mut [IoSlice<'b>],
target: SockAddr,
}
impl<'a, 'b> Future for SendToVectored<'a, 'b> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
#[rustfmt::skip]
let SendToVectored { socket, bufs, target } = Pin::into_inner(self);
try_io!(SockRef::from(&socket.socket).send_to_vectored(bufs, target))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct RecvFrom<'a, B> {
socket: &'a mut UdpSocket<Unconnected>,
buf: B,
}
impl<'a, B> Future for RecvFrom<'a, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<(usize, SocketAddr)>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let RecvFrom { socket, buf } = Pin::into_inner(self);
try_io!(socket.try_recv_from(&mut *buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct RecvFromVectored<'a, B> {
socket: &'a mut UdpSocket<Unconnected>,
bufs: B,
}
impl<'a, B> Future for RecvFromVectored<'a, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<(usize, SocketAddr)>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let RecvFromVectored { socket, bufs } = Pin::into_inner(self);
try_io!(socket.try_recv_from_vectored(&mut *bufs))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct PeekFrom<'a, B> {
socket: &'a mut UdpSocket<Unconnected>,
buf: B,
}
impl<'a, B> Future for PeekFrom<'a, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<(usize, SocketAddr)>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let PeekFrom { socket, buf } = Pin::into_inner(self);
try_io!(socket.try_peek_from(&mut *buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct PeekFromVectored<'a, B> {
socket: &'a mut UdpSocket<Unconnected>,
bufs: B,
}
impl<'a, B> Future for PeekFromVectored<'a, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<(usize, SocketAddr)>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let PeekFromVectored { socket, bufs } = Pin::into_inner(self);
try_io!(socket.try_peek_from_vectored(&mut *bufs))
}
}
impl UdpSocket<Connected> {
pub fn try_send(&mut self, buf: &[u8]) -> io::Result<usize> {
self.socket.send(buf)
}
pub fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> Send<'a, 'b> {
Send { socket: self, buf }
}
pub fn try_send_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
SockRef::from(&self.socket).send_vectored(bufs)
}
pub fn send_vectored<'a, 'b>(
&'a mut self,
bufs: &'b mut [IoSlice<'b>],
) -> SendVectored<'a, 'b> {
SendVectored { socket: self, bufs }
}
pub fn try_recv<B>(&mut self, mut buf: B) -> io::Result<usize>
where
B: Bytes,
{
debug_assert!(
buf.has_spare_capacity(),
"called `UdpSocket::try_recv` with an empty buffer"
);
SockRef::from(&self.socket)
.recv(buf.as_bytes())
.map(|read| {
unsafe { buf.update_length(read) }
read
})
}
pub fn recv<B>(&mut self, buf: B) -> Recv<'_, B>
where
B: Bytes,
{
Recv { socket: self, buf }
}
pub fn try_recv_vectored<B>(&mut self, mut bufs: B) -> io::Result<usize>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `UdpSocket::try_recv_vectored` with empty buffers"
);
let res = SockRef::from(&self.socket)
.recv_vectored(MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()));
match res {
Ok((read, _)) => {
unsafe { bufs.update_lengths(read) }
Ok(read)
}
Err(err) => Err(err),
}
}
pub fn recv_vectored<B>(&mut self, bufs: B) -> RecvVectored<'_, B>
where
B: BytesVectored,
{
RecvVectored { socket: self, bufs }
}
pub fn try_peek<B>(&mut self, mut buf: B) -> io::Result<usize>
where
B: Bytes,
{
debug_assert!(
buf.has_spare_capacity(),
"called `UdpSocket::try_peek` with an empty buffer"
);
SockRef::from(&self.socket)
.peek(buf.as_bytes())
.map(|read| {
unsafe { buf.update_length(read) }
read
})
}
pub fn peek<B>(&mut self, buf: B) -> Peek<'_, B>
where
B: Bytes,
{
Peek { socket: self, buf }
}
pub fn try_peek_vectored<B>(&mut self, mut bufs: B) -> io::Result<usize>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `UdpSocket::try_peek_vectored` with empty buffers"
);
let res = SockRef::from(&self.socket).recv_vectored_with_flags(
MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()),
libc::MSG_PEEK,
);
match res {
Ok((read, _)) => {
unsafe { bufs.update_lengths(read) }
Ok(read)
}
Err(err) => Err(err),
}
}
pub fn peek_vectored<B>(&mut self, bufs: B) -> PeekVectored<'_, B>
where
B: BytesVectored,
{
PeekVectored { socket: self, bufs }
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Send<'a, 'b> {
socket: &'a mut UdpSocket<Connected>,
buf: &'b [u8],
}
impl<'a, 'b> Future for Send<'a, 'b> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let Send { socket, buf } = Pin::into_inner(self);
try_io!(socket.try_send(*buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendVectored<'a, 'b> {
socket: &'a mut UdpSocket<Connected>,
bufs: &'b mut [IoSlice<'b>],
}
impl<'a, 'b> Future for SendVectored<'a, 'b> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let SendVectored { socket, bufs } = Pin::into_inner(self);
try_io!(socket.try_send_vectored(*bufs))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Recv<'a, B> {
socket: &'a mut UdpSocket<Connected>,
buf: B,
}
impl<'a, B> Future for Recv<'a, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let Recv { socket, buf } = Pin::into_inner(self);
try_io!(socket.try_recv(&mut *buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct PeekVectored<'a, B> {
socket: &'a mut UdpSocket<Connected>,
bufs: B,
}
impl<'a, B> Future for PeekVectored<'a, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let PeekVectored { socket, bufs } = Pin::into_inner(self);
try_io!(socket.try_peek_vectored(&mut *bufs))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Peek<'a, B> {
socket: &'a mut UdpSocket<Connected>,
buf: B,
}
impl<'a, B> Future for Peek<'a, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let Peek { socket, buf } = Pin::into_inner(self);
try_io!(socket.try_peek(&mut *buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct RecvVectored<'a, B> {
socket: &'a mut UdpSocket<Connected>,
bufs: B,
}
impl<'a, B> Future for RecvVectored<'a, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let RecvVectored { socket, bufs } = Pin::into_inner(self);
try_io!(socket.try_recv_vectored(&mut *bufs))
}
}
impl<M> fmt::Debug for UdpSocket<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.socket.fmt(f)
}
}
impl<M, RT: rt::Access> Bound<RT> for UdpSocket<M> {
type Error = io::Error;
fn bind_to<Msg>(&mut self, ctx: &mut actor::Context<Msg, RT>) -> io::Result<()> {
ctx.runtime()
.reregister(&mut self.socket, Interest::READABLE | Interest::WRITABLE)
}
}