use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpStream, UnixStream};
use crate::sdk::xinet::SocketAddr;
pub enum Stream {
Unix(UnixStream),
Tcp(TcpStream),
}
impl Stream {
pub async fn connect(addr: &SocketAddr) -> io::Result<Self> {
match addr {
SocketAddr::Unix(path) => {
let stream = UnixStream::connect(path).await?;
Ok(Stream::Unix(stream))
}
SocketAddr::Tcp(addr) => {
let stream = TcpStream::connect(addr).await?;
Ok(Stream::Tcp(stream))
}
}
}
pub fn into_split(self) -> (StreamReadHalf, StreamWriteHalf) {
match self {
Stream::Unix(s) => {
let (r, w) = s.into_split();
(StreamReadHalf::Unix(r), StreamWriteHalf::Unix(w))
}
Stream::Tcp(s) => {
let (r, w) = s.into_split();
(StreamReadHalf::Tcp(r), StreamWriteHalf::Tcp(w))
}
}
}
}
pub enum StreamReadHalf {
Unix(tokio::net::unix::OwnedReadHalf),
Tcp(tokio::net::tcp::OwnedReadHalf),
}
pub enum StreamWriteHalf {
Unix(tokio::net::unix::OwnedWriteHalf),
Tcp(tokio::net::tcp::OwnedWriteHalf),
}
impl AsyncRead for StreamReadHalf {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
match self.get_mut() {
StreamReadHalf::Unix(s) => std::pin::Pin::new(s).poll_read(cx, buf),
StreamReadHalf::Tcp(s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for StreamWriteHalf {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
match self.get_mut() {
StreamWriteHalf::Unix(s) => std::pin::Pin::new(s).poll_write(cx, buf),
StreamWriteHalf::Tcp(s) => std::pin::Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
match self.get_mut() {
StreamWriteHalf::Unix(s) => std::pin::Pin::new(s).poll_flush(cx),
StreamWriteHalf::Tcp(s) => std::pin::Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
match self.get_mut() {
StreamWriteHalf::Unix(s) => std::pin::Pin::new(s).poll_shutdown(cx),
StreamWriteHalf::Tcp(s) => std::pin::Pin::new(s).poll_shutdown(cx),
}
}
}
pub async fn forward_bidirectional<R1, W1, R2, W2>(
mut client_read: R1,
mut client_write: W1,
mut backend_read: R2,
mut backend_write: W2,
) -> io::Result<(u64, u64)>
where
R1: AsyncRead + Unpin,
W1: AsyncWrite + Unpin,
R2: AsyncRead + Unpin,
W2: AsyncWrite + Unpin,
{
let client_to_backend = async {
let mut buf = [0u8; 8192];
let mut total = 0u64;
loop {
let n = client_read.read(&mut buf).await?;
if n == 0 {
break;
}
backend_write.write_all(&buf[..n]).await?;
total += n as u64;
}
backend_write.shutdown().await?;
Ok::<u64, io::Error>(total)
};
let backend_to_client = async {
let mut buf = [0u8; 8192];
let mut total = 0u64;
loop {
let n = backend_read.read(&mut buf).await?;
if n == 0 {
break;
}
client_write.write_all(&buf[..n]).await?;
total += n as u64;
}
client_write.shutdown().await?;
Ok::<u64, io::Error>(total)
};
let (c2b_result, b2c_result) = tokio::join!(client_to_backend, backend_to_client);
let c2b = c2b_result.unwrap_or(0);
let b2c = b2c_result.unwrap_or(0);
Ok((c2b, b2c))
}
pub async fn is_socket_available(addr: &SocketAddr) -> bool {
match addr {
SocketAddr::Unix(path) => tokio::net::UnixStream::connect(path).await.is_ok(),
SocketAddr::Tcp(addr) => tokio::net::TcpStream::connect(addr).await.is_ok(),
}
}
pub async fn wait_for_socket(addr: &SocketAddr, timeout_secs: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(timeout_secs);
while start.elapsed() < timeout {
if is_socket_available(addr).await {
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
false
}