use std::io;
use std::net::SocketAddr;
use std::task::{Context, Poll};
use std::boxed::Box;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::ReadBuf;
use tokio::net::{TcpListener, TcpStream, UdpSocket};
pub trait AsyncDgramSock {
fn poll_send_to(
&self,
cx: &mut Context<'_>,
data: &[u8],
dest: &SocketAddr,
) -> Poll<io::Result<usize>>;
fn readable(
&self,
) -> Pin<Box<dyn Future<Output = io::Result<()>> + '_ + Send>>;
fn try_recv_buf_from(
&self,
buf: &mut ReadBuf<'_>,
) -> io::Result<(usize, SocketAddr)>;
}
impl AsyncDgramSock for UdpSocket {
fn poll_send_to(
&self,
cx: &mut Context<'_>,
data: &[u8],
dest: &SocketAddr,
) -> Poll<io::Result<usize>> {
UdpSocket::poll_send_to(self, cx, data, *dest)
}
fn readable(
&self,
) -> Pin<Box<dyn Future<Output = io::Result<()>> + '_ + Send>> {
Box::pin(UdpSocket::readable(self))
}
fn try_recv_buf_from(
&self,
buf: &mut ReadBuf<'_>,
) -> io::Result<(usize, SocketAddr)> {
UdpSocket::try_recv_buf_from(self, buf)
}
}
impl AsyncDgramSock for Arc<UdpSocket> {
fn poll_send_to(
&self,
cx: &mut Context<'_>,
data: &[u8],
dest: &SocketAddr,
) -> Poll<io::Result<usize>> {
UdpSocket::poll_send_to(self, cx, data, *dest)
}
fn readable(
&self,
) -> Pin<Box<dyn Future<Output = io::Result<()>> + '_ + Send>> {
Box::pin(UdpSocket::readable(self))
}
fn try_recv_buf_from(
&self,
buf: &mut ReadBuf<'_>,
) -> io::Result<(usize, SocketAddr)> {
UdpSocket::try_recv_buf_from(self, buf)
}
}
pub trait AsyncAccept {
type Error;
type StreamType;
type Future: std::future::Future<
Output = Result<Self::StreamType, Self::Error>,
>;
fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(Self::Future, SocketAddr)>>;
}
impl AsyncAccept for TcpListener {
type Error = io::Error;
type StreamType = TcpStream;
type Future = std::future::Ready<Result<Self::StreamType, io::Error>>;
fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(Self::Future, SocketAddr)>> {
TcpListener::poll_accept(self, cx).map(|res| {
res.map(|(stream, addr)| (std::future::ready(Ok(stream)), addr))
})
}
}