use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use irontide_utp::UdpTransport;
type AcceptFuture<'a> =
Pin<Box<dyn Future<Output = io::Result<(BoxedStream, SocketAddr)>> + Send + 'a>>;
type BindFn = Box<
dyn Fn(
SocketAddr,
) -> Pin<Box<dyn Future<Output = io::Result<Box<dyn TransportListener>>> + Send>>
+ Send
+ Sync,
>;
type ConnectFn = Box<
dyn Fn(SocketAddr) -> Pin<Box<dyn Future<Output = io::Result<BoxedStream>> + Send>>
+ Send
+ Sync,
>;
type BindUdpFn = Box<
dyn Fn(SocketAddr) -> Pin<Box<dyn Future<Output = io::Result<Box<dyn UdpTransport>>> + Send>>
+ Send
+ Sync,
>;
pub struct BoxedStream {
inner: Pin<Box<dyn StreamRw + Send>>,
}
impl std::fmt::Debug for BoxedStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BoxedStream").finish_non_exhaustive()
}
}
trait StreamRw: AsyncRead + AsyncWrite + Unpin {}
impl<T: AsyncRead + AsyncWrite + Unpin> StreamRw for T {}
impl BoxedStream {
pub fn new<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(stream: S) -> Self {
Self {
inner: Box::pin(stream),
}
}
}
impl AsyncRead for BoxedStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.inner.as_mut().poll_read(cx, buf)
}
}
impl AsyncWrite for BoxedStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.inner.as_mut().poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.as_mut().poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.as_mut().poll_shutdown(cx)
}
}
impl Unpin for BoxedStream {}
pub trait TransportListener: Send + Sync {
fn accept(&mut self) -> AcceptFuture<'_>;
fn local_addr(&self) -> io::Result<SocketAddr>;
}
pub struct TokioListener(pub TcpListener);
impl TransportListener for TokioListener {
fn accept(&mut self) -> AcceptFuture<'_> {
Box::pin(async move {
let (stream, addr) = self.0.accept().await?;
#[allow(deprecated)]
let _ = stream.set_linger(Some(std::time::Duration::ZERO));
Ok((BoxedStream::new(stream), addr))
})
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.0.local_addr()
}
}
pub struct NetworkFactory {
bind_tcp: BindFn,
connect_tcp: ConnectFn,
bind_udp: Option<BindUdpFn>,
is_simulated: bool,
}
impl NetworkFactory {
#[must_use]
pub fn new(bind_tcp: BindFn, connect_tcp: ConnectFn, is_simulated: bool) -> Self {
Self {
bind_tcp,
connect_tcp,
bind_udp: None,
is_simulated,
}
}
#[must_use]
pub fn with_bind_udp(mut self, bind_udp: BindUdpFn) -> Self {
self.bind_udp = Some(bind_udp);
self
}
#[must_use]
pub fn tokio() -> Self {
Self {
bind_tcp: Box::new(|addr| {
Box::pin(async move {
let listener = TcpListener::bind(addr).await?;
Ok(Box::new(TokioListener(listener)) as Box<dyn TransportListener>)
})
}),
connect_tcp: Box::new(|addr| {
Box::pin(async move {
let stream = TcpStream::connect(addr).await?;
#[allow(deprecated)]
let _ = stream.set_linger(Some(std::time::Duration::ZERO));
Ok(BoxedStream::new(stream))
})
}),
bind_udp: None,
is_simulated: false,
}
}
pub async fn bind_tcp(&self, addr: SocketAddr) -> io::Result<Box<dyn TransportListener>> {
(self.bind_tcp)(addr).await
}
pub async fn connect_tcp(&self, addr: SocketAddr) -> io::Result<BoxedStream> {
(self.connect_tcp)(addr).await
}
#[must_use]
pub fn is_simulated(&self) -> bool {
self.is_simulated
}
#[must_use]
pub fn has_bind_udp(&self) -> bool {
self.bind_udp.is_some()
}
pub async fn bind_udp(&self, addr: SocketAddr) -> io::Result<Box<dyn UdpTransport>> {
match self.bind_udp.as_ref() {
Some(f) => f(addr).await,
None => Err(io::Error::new(
io::ErrorKind::Unsupported,
"factory has no UDP bind installed",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[test]
fn tokio_factory_creation() {
let _factory = NetworkFactory::tokio();
}
#[test]
fn tokio_factory_is_not_simulated() {
let factory = NetworkFactory::tokio();
assert!(!factory.is_simulated());
}
#[tokio::test]
async fn tokio_bind_and_accept() {
let factory = NetworkFactory::tokio();
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = factory.bind_tcp(addr).await.unwrap();
let local = listener.local_addr().unwrap();
assert_ne!(local.port(), 0);
}
#[tokio::test]
async fn tokio_connect_to_listener() {
let factory = NetworkFactory::tokio();
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut listener = factory.bind_tcp(addr).await.unwrap();
let local = listener.local_addr().unwrap();
let accept_handle = tokio::spawn(async move { listener.accept().await.unwrap() });
let mut client = factory.connect_tcp(local).await.unwrap();
client.write_all(b"hello").await.unwrap();
let (mut server_stream, peer_addr) = accept_handle.await.unwrap();
assert_eq!(
peer_addr.ip(),
"127.0.0.1".parse::<std::net::IpAddr>().unwrap()
);
let mut buf = [0u8; 5];
server_stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[test]
fn custom_factory_is_simulated() {
let factory = NetworkFactory::new(
Box::new(|_addr| {
Box::pin(async move { Err(io::Error::new(io::ErrorKind::Unsupported, "stub")) })
}),
Box::new(|_addr| {
Box::pin(async move { Err(io::Error::new(io::ErrorKind::Unsupported, "stub")) })
}),
true,
);
assert!(factory.is_simulated());
}
}