use std::io::Error as IoError;
#[cfg(any(feature = "async-std", all(target_os = "linux", feature = "server", feature = "tokio")))]
use std::net::UdpSocket as StdUdpSocket;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use cfg_if::cfg_if;
#[cfg(feature = "server")]
use log::{debug, trace};
use thiserror::Error;
use crate::bytes::{ByteBuffer, ByteBufferMut, DynamicByteBuffer};
cfg_if! {
if #[cfg(feature = "tokio")] {
use tokio::net::UdpSocket as TokioSocket;
} else if #[cfg(feature = "async-std")] {
use async_io::Async;
}
}
cfg_if! {
if #[cfg(all(target_os = "linux", feature = "server"))] {
use std::io::ErrorKind;
use socket2::{Domain, Protocol, Socket as S2Socket, Type};
}
}
#[derive(Error, Debug)]
#[error("asynchronous socket IO error: {}", source.to_string())]
pub struct SocketError {
source: IoError,
}
impl SocketError {
#[inline]
fn new_socket_error(source: IoError) -> Self {
SocketError {
source,
}
}
}
pub struct Socket {
#[cfg(feature = "tokio")]
sock: TokioSocket,
#[cfg(feature = "async-std")]
sock: Async<StdUdpSocket>,
}
impl Socket {
#[cfg(feature = "tokio")]
pub async fn new(peer: SocketAddr, local: Option<SocketAddr>) -> Result<Self, SocketError> {
let local_addr = local.unwrap_or_else(|| SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
let sock = TokioSocket::bind(local_addr).await.map_err(SocketError::new_socket_error)?;
if let Err(err) = sock.connect(peer).await {
return Err(SocketError::new_socket_error(err));
}
Ok(Self {
sock,
})
}
#[cfg(feature = "async-std")]
pub async fn new(peer: SocketAddr, local: Option<SocketAddr>) -> Result<Self, SocketError> {
let local_addr = local.unwrap_or_else(|| SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)));
let sock = StdUdpSocket::bind(local_addr).map_err(SocketError::new_socket_error)?;
if let Err(err) = sock.connect(peer) {
return Err(SocketError::new_socket_error(err));
}
Ok(Self {
sock: Async::new(sock).map_err(SocketError::new_socket_error)?,
})
}
#[cfg(all(feature = "tokio", feature = "server"))]
pub async fn bind(local: SocketAddr) -> Result<Self, SocketError> {
let sock = TokioSocket::bind(local).await.map_err(SocketError::new_socket_error)?;
Ok(Self {
sock,
})
}
#[cfg(all(feature = "async-std", feature = "server"))]
pub async fn bind(local: SocketAddr) -> Result<Self, SocketError> {
let sock = StdUdpSocket::bind(local).map_err(SocketError::new_socket_error)?;
Ok(Self {
sock: Async::new(sock).map_err(SocketError::new_socket_error)?,
})
}
#[cfg(all(target_os = "linux", feature = "server", feature = "tokio"))]
pub fn bind_reuse_port(local: SocketAddr, count: usize) -> Result<Vec<Self>, SocketError> {
if local.port() == 0 {
return Err(SocketError::new_socket_error(IoError::new(ErrorKind::InvalidInput, "SO_REUSEPORT requires port > 0")));
}
let domain = if local.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let mut sockets = Vec::with_capacity(count);
for _ in 0..count {
let s2 = S2Socket::new(domain, Type::DGRAM, Some(Protocol::UDP)).map_err(SocketError::new_socket_error)?;
s2.set_reuse_port(true).map_err(SocketError::new_socket_error)?;
s2.bind(&local.into()).map_err(SocketError::new_socket_error)?;
s2.set_nonblocking(true).map_err(SocketError::new_socket_error)?;
let std_sock: StdUdpSocket = s2.into();
let tok_sock = TokioSocket::from_std(std_sock).map_err(SocketError::new_socket_error)?;
sockets.push(Socket {
sock: tok_sock,
});
}
Ok(sockets)
}
#[cfg(all(target_os = "linux", feature = "server", feature = "async-std"))]
pub fn bind_reuse_port(local: SocketAddr, count: usize) -> Result<Vec<Self>, SocketError> {
if local.port() == 0 {
return Err(SocketError::new_socket_error(IoError::new(ErrorKind::InvalidInput, "SO_REUSEPORT requires port > 0")));
}
let domain = if local.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let mut sockets = Vec::with_capacity(count);
for _ in 0..count {
let s2 = S2Socket::new(domain, Type::DGRAM, Some(Protocol::UDP)).map_err(SocketError::new_socket_error)?;
s2.set_reuse_port(true).map_err(SocketError::new_socket_error)?;
s2.bind(&local.into()).map_err(SocketError::new_socket_error)?;
let std_sock: StdUdpSocket = s2.into();
sockets.push(Socket {
sock: Async::new(std_sock).map_err(SocketError::new_socket_error)?,
});
}
Ok(sockets)
}
#[cfg(feature = "tokio")]
pub async fn send(&self, data: DynamicByteBuffer) -> Result<usize, SocketError> {
self.sock.send(data.slice()).await.map_err(SocketError::new_socket_error)
}
#[cfg(feature = "async-std")]
pub async fn send(&self, data: DynamicByteBuffer) -> Result<usize, SocketError> {
self.sock.send(data.slice()).await.map_err(SocketError::new_socket_error)
}
#[cfg(feature = "tokio")]
pub async fn recv(&self, buf: DynamicByteBuffer) -> Result<DynamicByteBuffer, SocketError> {
let res = self.sock.recv(buf.slice_mut()).await.map_err(SocketError::new_socket_error)?;
Ok(buf.rebuffer_end(res))
}
#[cfg(feature = "async-std")]
pub async fn recv(&self, buf: DynamicByteBuffer) -> Result<DynamicByteBuffer, SocketError> {
let res = self.sock.recv(buf.slice_mut()).await.map_err(SocketError::new_socket_error)?;
Ok(buf.rebuffer_end(res))
}
#[cfg(all(feature = "tokio", feature = "server"))]
pub async fn send_to(&self, data: DynamicByteBuffer, target: SocketAddr) -> Result<usize, SocketError> {
let len = data.slice().len();
match self.sock.send_to(data.slice(), target).await {
Ok(sent) => {
if sent < len {
debug!("socket: send_to partial write: {sent} of {len} bytes sent to {target}");
}
trace!("socket: send_to {len} bytes to {target} → ok ({sent} sent)");
Ok(sent)
}
Err(e) => {
debug!("socket: send_to {len} bytes to {target} → error: {e}");
Err(SocketError::new_socket_error(e))
}
}
}
#[cfg(all(feature = "async-std", feature = "server"))]
pub async fn send_to(&self, data: DynamicByteBuffer, target: SocketAddr) -> Result<usize, SocketError> {
let len = data.slice().len();
match self.sock.send_to(data.slice(), target).await {
Ok(sent) => {
if sent < len {
debug!("socket: send_to partial write: {} of {} bytes sent to {}", sent, len, target);
}
trace!("socket: send_to {} bytes to {} → ok ({} sent)", len, target, sent);
Ok(sent)
}
Err(e) => {
debug!("socket: send_to {} bytes to {} → error: {}", len, target, e);
Err(SocketError::new_socket_error(e))
}
}
}
#[cfg(all(feature = "tokio", feature = "server"))]
pub async fn recv_from(&self, buf: DynamicByteBuffer) -> Result<(DynamicByteBuffer, SocketAddr), SocketError> {
match self.sock.recv_from(buf.slice_mut()).await {
Ok((res, addr)) => {
trace!("socket: recv_from {res} bytes from {addr}");
Ok((buf.rebuffer_end(res), addr))
}
Err(e) => {
debug!("socket: recv_from error: {e}");
Err(SocketError::new_socket_error(e))
}
}
}
#[cfg(all(feature = "async-std", feature = "server"))]
pub async fn recv_from(&self, buf: DynamicByteBuffer) -> Result<(DynamicByteBuffer, SocketAddr), SocketError> {
match self.sock.recv_from(buf.slice_mut()).await {
Ok((res, addr)) => {
trace!("socket: recv_from {} bytes from {}", res, addr);
Ok((buf.rebuffer_end(res), addr))
}
Err(e) => {
debug!("socket: recv_from error: {}", e);
Err(SocketError::new_socket_error(e))
}
}
}
}