use std::{
future::Future,
io::{self, IoSlice},
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{TcpListener, TcpStream},
};
pub trait Listener {
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>>;
fn accept(&mut self) -> ListenerAcceptFut<'_, Self>
where
Self: Sized,
{
ListenerAcceptFut { listener: self }
}
}
impl Listener for TcpListener {
type Io = TcpStream;
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>> {
Self::poll_accept(self, cx).map_ok(|(stream, _)| stream)
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>> {
Self::poll_accept(self, cx).map_ok(|(stream, _)| stream)
}
}
pub struct ListenerAcceptFut<'a, L> {
listener: &'a mut L,
}
impl<'a, L> Future for ListenerAcceptFut<'a, L>
where
L: Listener,
{
type Output = io::Result<L::Io>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.listener.poll_accept(cx)
}
}
#[derive(Debug)]
pub enum EitherListener<T, U> {
Tcp(T),
Unix(U),
}
impl<T, U> Listener for EitherListener<T, U>
where
T: Listener,
U: Listener,
{
type Io = EitherStream<T::Io, U::Io>;
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>> {
match self {
Self::Tcp(listener) => listener.poll_accept(cx).map_ok(Self::Io::Tcp),
Self::Unix(listener) => listener.poll_accept(cx).map_ok(Self::Io::Unix),
}
}
}
#[derive(Debug)]
pub enum EitherStream<T, U> {
Tcp(T),
Unix(U),
}
impl<T, U> AsyncRead for EitherStream<T, U>
where
T: AsyncRead + Unpin,
U: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl<T, U> AsyncWrite for EitherStream<T, U>
where
T: AsyncWrite + Unpin,
U: AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
Self::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
Self::Tcp(stream) => stream.is_write_vectored(),
Self::Unix(stream) => stream.is_write_vectored(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}