use std::{
future::Future,
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
};
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_driver::{
BufferRef, impl_raw_fd,
op::{RecvFlags, RecvFromMultiResult, RecvMsgMultiResult},
};
use futures_util::Stream;
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
use crate::{MSG_NOSIGNAL, Socket, ToSocketAddrsAsync};
#[derive(Debug, Clone)]
pub struct UdpSocket {
inner: Socket,
}
impl UdpSocket {
pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
super::each_addr(addr, |addr| async move {
let addr = SockAddr::from(addr);
let socket = Socket::new(addr.domain(), Type::DGRAM, Some(Protocol::UDP)).await?;
socket.bind(&addr).await?;
Ok(Self { inner: socket })
})
.await
}
pub async fn connect(&self, addr: impl ToSocketAddrsAsync) -> io::Result<()> {
super::each_addr(addr, |addr| async move {
self.inner.connect(&SockAddr::from(addr))
})
.await
}
pub fn from_std(socket: std::net::UdpSocket) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(socket))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner
.peer_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn sock_nonempty(&self) -> Option<bool> {
self.inner.sock_nonempty()
}
pub async fn recv<T: IoBufMut>(&self, buffer: T) -> BufResult<usize, T> {
self.inner.recv(buffer, RecvFlags::empty()).await
}
pub async fn recv_vectored<T: IoVectoredBufMut>(&self, buffer: T) -> BufResult<usize, T> {
self.inner.recv_vectored(buffer, RecvFlags::empty()).await
}
pub async fn recv_managed(&self, len: usize) -> io::Result<Option<BufferRef>> {
self.inner.recv_managed(len, RecvFlags::empty()).await
}
pub fn recv_multi(&self, len: usize) -> impl Stream<Item = io::Result<BufferRef>> {
self.inner.recv_multi(len, RecvFlags::empty())
}
pub async fn send<T: IoBuf>(&self, buffer: T) -> BufResult<usize, T> {
self.inner.send(buffer, MSG_NOSIGNAL).await
}
pub async fn send_vectored<T: IoVectoredBuf>(&self, buffer: T) -> BufResult<usize, T> {
self.inner.send_vectored(buffer, MSG_NOSIGNAL).await
}
pub async fn recv_from<T: IoBufMut>(&self, buffer: T) -> BufResult<(usize, SocketAddr), T> {
self.inner
.recv_from(buffer, RecvFlags::empty())
.await
.map_res(|(n, addr)| {
let addr = addr
.expect("should have addr")
.as_socket()
.expect("should be SocketAddr");
(n, addr)
})
}
pub async fn recv_from_vectored<T: IoVectoredBufMut>(
&self,
buffer: T,
) -> BufResult<(usize, SocketAddr), T> {
self.inner
.recv_from_vectored(buffer, RecvFlags::empty())
.await
.map_res(|(n, addr)| {
let addr = addr
.expect("should have addr")
.as_socket()
.expect("should be SocketAddr");
(n, addr)
})
}
pub async fn recv_from_managed(
&self,
len: usize,
) -> io::Result<Option<(BufferRef, SocketAddr)>> {
let res = self
.inner
.recv_from_managed(len, RecvFlags::empty())
.await?;
let ret = match res {
Some((buffer, addr)) => {
let addr = addr
.expect("should have addr")
.as_socket()
.expect("should be SocketAddr");
Some((buffer, addr))
}
None => None,
};
Ok(ret)
}
pub fn recv_from_multi(&self) -> impl Stream<Item = io::Result<RecvFromMultiResult>> {
self.inner.recv_from_multi(RecvFlags::empty())
}
pub async fn recv_msg<T: IoBufMut, C: IoBufMut>(
&self,
buffer: T,
control: C,
) -> BufResult<(usize, usize, SocketAddr), (T, C)> {
self.inner
.recv_msg(buffer, control, RecvFlags::empty())
.await
.map_res(|(n, m, addr)| {
let addr = addr
.expect("should have addr")
.as_socket()
.expect("should be SocketAddr");
(n, m, addr)
})
}
pub async fn recv_msg_vectored<T: IoVectoredBufMut, C: IoBufMut>(
&self,
buffer: T,
control: C,
) -> BufResult<(usize, usize, SocketAddr), (T, C)> {
self.inner
.recv_msg_vectored(buffer, control, RecvFlags::empty())
.await
.map_res(|(n, m, addr)| {
let addr = addr
.expect("should have addr")
.as_socket()
.expect("should be SocketAddr");
(n, m, addr)
})
}
pub async fn recv_msg_managed<C: IoBufMut>(
&self,
len: usize,
control: C,
) -> io::Result<Option<(BufferRef, C, SocketAddr)>> {
let res = self
.inner
.recv_msg_managed(len, control, RecvFlags::empty())
.await?;
let ret = match res {
Some((buffer, control, addr)) => {
let addr = addr
.expect("should have addr")
.as_socket()
.expect("should be SocketAddr");
Some((buffer, control, addr))
}
None => None,
};
Ok(ret)
}
pub fn recv_msg_multi(
&self,
control_len: usize,
) -> impl Stream<Item = io::Result<RecvMsgMultiResult>> {
self.inner.recv_msg_multi(control_len, RecvFlags::empty())
}
pub async fn send_to<T: IoBuf>(
&self,
buffer: T,
addr: impl ToSocketAddrsAsync,
) -> BufResult<usize, T> {
super::first_addr_buf(addr, buffer, |addr, buffer| async move {
self.inner
.send_to(buffer, &SockAddr::from(addr), MSG_NOSIGNAL)
.await
})
.await
}
pub async fn send_to_vectored<T: IoVectoredBuf>(
&self,
buffer: T,
addr: impl ToSocketAddrsAsync,
) -> BufResult<usize, T> {
super::first_addr_buf(addr, buffer, |addr, buffer| async move {
self.inner
.send_to_vectored(buffer, &SockAddr::from(addr), MSG_NOSIGNAL)
.await
})
.await
}
pub async fn send_msg<T: IoBuf, C: IoBuf>(
&self,
buffer: T,
control: C,
addr: impl ToSocketAddrsAsync,
) -> BufResult<usize, (T, C)> {
super::first_addr_buf(
addr,
(buffer, control),
|addr, (buffer, control)| async move {
self.inner
.send_msg(buffer, control, Some(&SockAddr::from(addr)), MSG_NOSIGNAL)
.await
},
)
.await
}
pub async fn send_msg_vectored<T: IoVectoredBuf, C: IoBuf>(
&self,
buffer: T,
control: C,
addr: impl ToSocketAddrsAsync,
) -> BufResult<usize, (T, C)> {
super::first_addr_buf(
addr,
(buffer, control),
|addr, (buffer, control)| async move {
self.inner
.send_msg_vectored(buffer, control, Some(&SockAddr::from(addr)), MSG_NOSIGNAL)
.await
},
)
.await
}
pub async fn send_zerocopy<T: IoBuf>(
&self,
buf: T,
) -> BufResult<usize, impl Future<Output = T> + use<T>> {
self.inner.send_zerocopy(buf, MSG_NOSIGNAL).await
}
pub async fn send_zerocopy_vectored<T: IoVectoredBuf>(
&self,
buf: T,
) -> BufResult<usize, impl Future<Output = T> + use<T>> {
self.inner.send_zerocopy_vectored(buf, MSG_NOSIGNAL).await
}
pub async fn send_to_zerocopy<A: ToSocketAddrsAsync, T: IoBuf>(
&self,
buffer: T,
addr: A,
) -> BufResult<usize, impl Future<Output = T> + use<A, T>> {
super::first_addr_buf_zerocopy(addr, buffer, |addr, buffer| async move {
self.inner
.send_to_zerocopy(buffer, &addr.into(), MSG_NOSIGNAL)
.await
})
.await
}
pub async fn send_to_zerocopy_vectored<A: ToSocketAddrsAsync, T: IoVectoredBuf>(
&self,
buffer: T,
addr: A,
) -> BufResult<usize, impl Future<Output = T> + use<A, T>> {
super::first_addr_buf_zerocopy(addr, buffer, |addr, buffer| async move {
self.inner
.send_to_zerocopy_vectored(buffer, &addr.into(), MSG_NOSIGNAL)
.await
})
.await
}
pub async fn send_msg_zerocopy<A: ToSocketAddrsAsync, T: IoBuf, C: IoBuf>(
&self,
buffer: T,
control: C,
addr: A,
) -> BufResult<usize, impl Future<Output = (T, C)> + use<A, T, C>> {
super::first_addr_buf_zerocopy(addr, (buffer, control), |addr, (b, c)| async move {
self.inner
.send_msg_zerocopy(b, c, Some(&addr.into()), MSG_NOSIGNAL)
.await
})
.await
}
pub async fn send_msg_zerocopy_vectored<A: ToSocketAddrsAsync, T: IoVectoredBuf, C: IoBuf>(
&self,
buffer: T,
control: C,
addr: A,
) -> BufResult<usize, impl Future<Output = (T, C)> + use<A, T, C>> {
super::first_addr_buf_zerocopy(addr, (buffer, control), |addr, (b, c)| async move {
self.inner
.send_msg_zerocopy_vectored(b, c, Some(&addr.into()), MSG_NOSIGNAL)
.await
})
.await
}
pub fn broadcast(&self) -> io::Result<bool> {
self.inner.socket.broadcast()
}
pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
self.inner.socket.set_broadcast(on)
}
pub fn multicast_loop_v4(&self) -> io::Result<bool> {
self.inner.socket.multicast_loop_v4()
}
pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
self.inner.socket.set_multicast_loop_v4(on)
}
pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
self.inner.socket.multicast_ttl_v4()
}
pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> {
self.inner.socket.set_multicast_ttl_v4(ttl)
}
pub fn multicast_loop_v6(&self) -> io::Result<bool> {
self.inner.socket.multicast_loop_v6()
}
pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
self.inner.socket.set_multicast_loop_v6(on)
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "cygwin",
))]
pub fn tclass_v6(&self) -> io::Result<u32> {
self.inner.socket.tclass_v6()
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "cygwin",
))]
pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
self.inner.socket.set_tclass_v6(tclass)
}
pub fn ttl_v4(&self) -> io::Result<u32> {
self.inner.socket.ttl_v4()
}
pub fn set_ttl_v4(&self, ttl: u32) -> io::Result<()> {
self.inner.socket.set_ttl_v4(ttl)
}
#[cfg(not(any(
target_os = "fuchsia",
target_os = "redox",
target_os = "solaris",
target_os = "illumos",
target_os = "haiku"
)))]
pub fn tos_v4(&self) -> io::Result<u32> {
self.inner.socket.tos_v4()
}
#[cfg(not(any(
target_os = "fuchsia",
target_os = "redox",
target_os = "solaris",
target_os = "illumos",
target_os = "haiku"
)))]
pub fn set_tos_v4(&self, tos: u32) -> io::Result<()> {
self.inner.socket.set_tos_v4(tos)
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux",))]
pub fn device(&self) -> io::Result<Option<Vec<u8>>> {
self.inner.socket.device()
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
pub fn bind_device(&self, interface: Option<&[u8]>) -> io::Result<()> {
self.inner.socket.bind_device(interface)
}
pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.inner.socket.join_multicast_v4(multiaddr, interface)
}
pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.inner.socket.join_multicast_v6(multiaddr, interface)
}
pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.inner.socket.leave_multicast_v4(multiaddr, interface)
}
pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.inner.socket.leave_multicast_v6(multiaddr, interface)
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
pub unsafe fn get_socket_option<T: Copy>(&self, level: i32, name: i32) -> io::Result<T> {
unsafe { self.inner.get_socket_option(level, name) }
}
pub unsafe fn set_socket_option<T: Copy>(
&self,
level: i32,
name: i32,
value: &T,
) -> io::Result<()> {
unsafe { self.inner.set_socket_option(level, name, value) }
}
}
impl_raw_fd!(UdpSocket, socket2::Socket, inner, socket);