use async_trait::async_trait;
use std::net::SocketAddr;
#[async_trait]
pub trait UdpSocketFactory: Sized {
type Socket: UdpSocket;
type Error: std::error::Error;
async fn bind(&mut self, addr: &SocketAddr) -> Result<Self::Socket, Self::Error>;
}
#[async_trait]
pub trait UdpSocket: Sized {
type Error: std::error::Error;
async fn enable_broadcast(&mut self) -> Result<(), Self::Error>;
async fn connect(&mut self, addr: &SocketAddr) -> Result<(), Self::Error>;
async fn send(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error>;
async fn recv(&mut self, but: &mut [u8]) -> Result<usize, Self::Error>;
async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error>;
}
#[cfg(feature = "tokio")]
pub type DefaultSocketFactory = TokioSocketFactory;
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
pub type DefaultSocketFactory = AsyncStdSocketFactory;
#[cfg(feature = "tokio")]
pub struct TokioSocketFactory;
#[cfg(feature = "tokio")]
impl TokioSocketFactory {
pub fn new() -> TokioSocketFactory {
TokioSocketFactory
}
}
#[cfg(feature = "tokio")]
#[async_trait]
impl UdpSocketFactory for TokioSocketFactory {
type Error = tokio::io::Error;
type Socket = tokio::net::UdpSocket;
async fn bind(&mut self, addr: &SocketAddr) -> Result<Self::Socket, Self::Error> {
tokio::net::UdpSocket::bind(addr).await
}
}
#[cfg(feature = "tokio")]
#[async_trait]
impl UdpSocket for tokio::net::UdpSocket {
type Error = tokio::io::Error;
async fn enable_broadcast(&mut self) -> Result<(), Self::Error> {
Self::set_broadcast(self, true)
}
async fn connect(&mut self, addr: &SocketAddr) -> Result<(), Self::Error> {
Self::connect(self, addr).await
}
async fn send(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
Self::send(self, buf).await
}
async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
Self::send_to(self, buf, addr).await
}
async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
Self::recv(self, buf).await
}
async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
Self::recv_from(self, buf).await
}
}
#[cfg(feature = "async-std")]
pub struct AsyncStdSocketFactory;
#[cfg(feature = "async-std")]
impl AsyncStdSocketFactory {
pub fn new() -> AsyncStdSocketFactory {
AsyncStdSocketFactory
}
}
#[cfg(feature = "async-std")]
#[async_trait]
impl UdpSocketFactory for AsyncStdSocketFactory {
type Error = async_std::io::Error;
type Socket = async_std::net::UdpSocket;
async fn bind(&mut self, addr: &SocketAddr) -> Result<Self::Socket, Self::Error> {
async_std::net::UdpSocket::bind(addr).await
}
}
#[cfg(feature = "async-std")]
#[async_trait]
impl UdpSocket for async_std::net::UdpSocket {
type Error = async_std::io::Error;
async fn enable_broadcast(&mut self) -> Result<(), Self::Error> {
Self::set_broadcast(self, true)
}
async fn connect(&mut self, addr: &SocketAddr) -> Result<(), Self::Error> {
Self::connect(self, addr).await
}
async fn send(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
Self::send(self, buf).await
}
async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
Self::send_to(self, buf, addr).await
}
async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
Self::recv(self, buf).await
}
async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
use std::net::ToSocketAddrs;
match Self::recv_from(self, buf).await {
Ok((recv_bytes, addr)) => {
Ok((recv_bytes, addr.to_socket_addrs().unwrap().next().unwrap()))
}
Err(x) => Err(x),
}
}
}