use std::net::SocketAddr;
#[cfg(not(target_os = "windows"))]
use std::os::fd::AsFd;
#[cfg(target_os = "windows")]
use std::os::windows::io::AsSocket;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream};
use crate::addr::{each_addr, ToSocketAddrs};
use crate::udp::{UdpListener, UdpStream, UdpStreamReadHalf, UdpStreamWriteHalf};
type Result<T, E = std::io::Error> = std::result::Result<T, E>;
pub trait StreamSplit {
type ReaderRef<'a>: AsyncReadExt + Send + Unpin
where
Self: 'a;
type WriterRef<'a>: AsyncWriteExt + Send + Unpin
where
Self: 'a;
type ReaderOwned: AsyncReadExt + Send + Unpin + 'static;
type WriterOwned: AsyncWriteExt + Send + Unpin + 'static;
fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>);
fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned);
}
pub trait NetworkStream:
StreamSplit + AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static
{
}
macro_rules! gen_stream_impl {
($struct_name:ident, $inner_ty:ty) => {
pub struct $struct_name($inner_ty);
impl $struct_name {
pub fn new(stream: $inner_ty) -> Self {
Self(stream)
}
}
impl AsyncRead for $struct_name {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl AsyncWrite for $struct_name {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
}
}
};
}
gen_stream_impl!(TcpStreamImpl, TcpStream);
gen_stream_impl!(UdpStreamImpl, UdpStream);
impl StreamSplit for TcpStreamImpl {
type ReaderOwned = tokio::net::tcp::OwnedReadHalf;
type ReaderRef<'a>
= ReadHalf<'a>
where
Self: 'a;
type WriterOwned = tokio::net::tcp::OwnedWriteHalf;
type WriterRef<'a>
= WriteHalf<'a>
where
Self: 'a;
fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
self.0.split()
}
fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned) {
self.0.into_split()
}
}
impl StreamSplit for UdpStreamImpl {
type ReaderOwned = crate::udp::UdpStreamOwnedReadHalf;
type ReaderRef<'a> = UdpStreamReadHalf;
type WriterOwned = crate::udp::UdpStreamOwnedWriteHalf;
type WriterRef<'a>
= UdpStreamWriteHalf<'a>
where
Self: 'a;
fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
self.0.split()
}
fn into_split(self) -> (Self::ReaderOwned, Self::WriterOwned) {
self.0.into_split()
}
}
impl NetworkStream for TcpStreamImpl {}
impl NetworkStream for UdpStreamImpl {}
pub trait StreamProvider {
type Item: NetworkStream;
fn from_addr<A: ToSocketAddrs + Send>(
addr: A,
) -> impl std::future::Future<Output = Result<Self::Item>> + Send;
}
pub struct TcpStreamProvider;
impl StreamProvider for TcpStreamProvider {
type Item = TcpStreamImpl;
async fn from_addr<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
Ok(TcpStreamImpl(each_addr(addr, TcpStream::connect).await?))
}
}
pub struct UdpStreamProvider;
impl StreamProvider for UdpStreamProvider {
type Item = UdpStreamImpl;
async fn from_addr<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
Ok(UdpStreamImpl(UdpStream::connect(addr).await?))
}
}
pub trait ListenerProvider {
type Listener: StreamAccept + 'static;
fn bind<A: ToSocketAddrs + Send>(
addr: A,
) -> impl std::future::Future<Output = Result<Self::Listener>> + Send;
}
pub trait StreamAccept {
type Item: NetworkStream;
fn accept(&self) -> impl std::future::Future<Output = Result<(Self::Item, SocketAddr)>> + Send;
}
pub struct TcpListenerProvider;
pub struct TcpListenerImpl(TcpListener);
impl StreamAccept for TcpListenerImpl {
type Item = TcpStreamImpl;
async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
let (stream, addr) = self.0.accept().await?;
Ok((TcpStreamImpl::new(stream), addr))
}
}
impl ListenerProvider for TcpListenerProvider {
type Listener = TcpListenerImpl;
async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
Ok(TcpListenerImpl(each_addr(addr, TcpListener::bind).await?))
}
}
pub struct UdpListenerProvider;
pub struct UdpListenerImpl(UdpListener);
impl StreamAccept for UdpListenerImpl {
type Item = UdpStreamImpl;
async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
let (stream, addr) = self.0.accept().await?;
Ok((UdpStreamImpl::new(stream), addr))
}
}
impl ListenerProvider for UdpListenerProvider {
type Listener = UdpListenerImpl;
async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
Ok(UdpListenerImpl(UdpListener::bind(addr).await?))
}
}
const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(20);
const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(20);
#[cfg(not(target_os = "windows"))]
pub fn set_tcp_keep_alive<S: AsFd>(stream: &S) -> std::result::Result<(), std::io::Error> {
let sock_ref = socket2::SockRef::from(stream);
let mut ka = socket2::TcpKeepalive::new();
ka = ka.with_time(TCP_KEEPALIVE_TIME);
ka = ka.with_interval(TCP_KEEPALIVE_INTERVAL);
sock_ref.set_tcp_keepalive(&ka)
}
#[cfg(target_os = "windows")]
pub fn set_tcp_keep_alive<S: AsSocket>(stream: &S) -> std::result::Result<(), std::io::Error> {
let sock_ref = socket2::SockRef::from(stream);
let mut ka = socket2::TcpKeepalive::new();
ka = ka.with_time(TCP_KEEPALIVE_TIME);
ka = ka.with_interval(TCP_KEEPALIVE_INTERVAL);
sock_ref.set_tcp_keepalive(&ka)
}
#[cfg(not(target_os = "windows"))]
pub fn set_tcp_nodelay<S: AsFd>(stream: &S) -> std::result::Result<(), std::io::Error> {
let sock_ref = socket2::SockRef::from(stream);
sock_ref.set_tcp_nodelay(true)
}
#[cfg(target_os = "windows")]
pub fn set_tcp_nodelay<S: AsSocket>(stream: &S) -> std::result::Result<(), std::io::Error> {
let sock_ref = socket2::SockRef::from(stream);
sock_ref.set_tcp_nodelay(true)
}
pub async fn got_one_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
let mut iter = addr.to_socket_addrs().await?;
iter.next().ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
})
}