use async_trait::async_trait;
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
#[async_trait(?Send)]
pub trait NetworkProvider: Clone {
type TcpStream: AsyncRead + AsyncWrite + Unpin + 'static;
type TcpListener: TcpListenerTrait<TcpStream = Self::TcpStream> + 'static;
async fn bind(&self, addr: &str) -> io::Result<Self::TcpListener>;
async fn connect(&self, addr: &str) -> io::Result<Self::TcpStream>;
}
#[async_trait(?Send)]
pub trait TcpListenerTrait {
type TcpStream: AsyncRead + AsyncWrite + Unpin + 'static;
async fn accept(&self) -> io::Result<(Self::TcpStream, String)>;
fn local_addr(&self) -> io::Result<String>;
}
#[derive(Debug, Clone)]
pub struct TokioNetworkProvider;
impl TokioNetworkProvider {
pub fn new() -> Self {
Self
}
}
impl Default for TokioNetworkProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait(?Send)]
impl NetworkProvider for TokioNetworkProvider {
type TcpStream = tokio::net::TcpStream;
type TcpListener = TokioTcpListener;
async fn bind(&self, addr: &str) -> io::Result<Self::TcpListener> {
let listener = tokio::net::TcpListener::bind(addr).await?;
Ok(TokioTcpListener { inner: listener })
}
async fn connect(&self, addr: &str) -> io::Result<Self::TcpStream> {
tokio::net::TcpStream::connect(addr).await
}
}
#[derive(Debug)]
pub struct TokioTcpListener {
inner: tokio::net::TcpListener,
}
#[async_trait(?Send)]
impl TcpListenerTrait for TokioTcpListener {
type TcpStream = tokio::net::TcpStream;
async fn accept(&self) -> io::Result<(Self::TcpStream, String)> {
let (stream, addr) = self.inner.accept().await?;
Ok((stream, addr.to_string()))
}
fn local_addr(&self) -> io::Result<String> {
Ok(self.inner.local_addr()?.to_string())
}
}