zinit 0.3.9

Process supervisor with dependency management
Documentation
//! Connection handling for xinet
//!
//! Provides bidirectional data forwarding between two sockets.

use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpStream, UnixStream};

use crate::sdk::xinet::SocketAddr;

/// A generic stream that can be either Unix or TCP
pub enum Stream {
    Unix(UnixStream),
    Tcp(TcpStream),
}

impl Stream {
    /// Connect to a socket address
    pub async fn connect(addr: &SocketAddr) -> io::Result<Self> {
        match addr {
            SocketAddr::Unix(path) => {
                let stream = UnixStream::connect(path).await?;
                Ok(Stream::Unix(stream))
            }
            SocketAddr::Tcp(addr) => {
                let stream = TcpStream::connect(addr).await?;
                Ok(Stream::Tcp(stream))
            }
        }
    }

    /// Split the stream into read and write halves
    pub fn into_split(self) -> (StreamReadHalf, StreamWriteHalf) {
        match self {
            Stream::Unix(s) => {
                let (r, w) = s.into_split();
                (StreamReadHalf::Unix(r), StreamWriteHalf::Unix(w))
            }
            Stream::Tcp(s) => {
                let (r, w) = s.into_split();
                (StreamReadHalf::Tcp(r), StreamWriteHalf::Tcp(w))
            }
        }
    }
}

/// Read half of a stream
pub enum StreamReadHalf {
    Unix(tokio::net::unix::OwnedReadHalf),
    Tcp(tokio::net::tcp::OwnedReadHalf),
}

/// Write half of a stream
pub enum StreamWriteHalf {
    Unix(tokio::net::unix::OwnedWriteHalf),
    Tcp(tokio::net::tcp::OwnedWriteHalf),
}

impl AsyncRead for StreamReadHalf {
    fn poll_read(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> std::task::Poll<io::Result<()>> {
        match self.get_mut() {
            StreamReadHalf::Unix(s) => std::pin::Pin::new(s).poll_read(cx, buf),
            StreamReadHalf::Tcp(s) => std::pin::Pin::new(s).poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for StreamWriteHalf {
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<io::Result<usize>> {
        match self.get_mut() {
            StreamWriteHalf::Unix(s) => std::pin::Pin::new(s).poll_write(cx, buf),
            StreamWriteHalf::Tcp(s) => std::pin::Pin::new(s).poll_write(cx, buf),
        }
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<io::Result<()>> {
        match self.get_mut() {
            StreamWriteHalf::Unix(s) => std::pin::Pin::new(s).poll_flush(cx),
            StreamWriteHalf::Tcp(s) => std::pin::Pin::new(s).poll_flush(cx),
        }
    }

    fn poll_shutdown(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<io::Result<()>> {
        match self.get_mut() {
            StreamWriteHalf::Unix(s) => std::pin::Pin::new(s).poll_shutdown(cx),
            StreamWriteHalf::Tcp(s) => std::pin::Pin::new(s).poll_shutdown(cx),
        }
    }
}

/// Forward data bidirectionally between two streams
///
/// Copies bytes from client → backend and backend → client concurrently.
/// Returns when either side closes or an error occurs.
pub async fn forward_bidirectional<R1, W1, R2, W2>(
    mut client_read: R1,
    mut client_write: W1,
    mut backend_read: R2,
    mut backend_write: W2,
) -> io::Result<(u64, u64)>
where
    R1: AsyncRead + Unpin,
    W1: AsyncWrite + Unpin,
    R2: AsyncRead + Unpin,
    W2: AsyncWrite + Unpin,
{
    let client_to_backend = async {
        let mut buf = [0u8; 8192];
        let mut total = 0u64;
        loop {
            let n = client_read.read(&mut buf).await?;
            if n == 0 {
                break;
            }
            backend_write.write_all(&buf[..n]).await?;
            total += n as u64;
        }
        backend_write.shutdown().await?;
        Ok::<u64, io::Error>(total)
    };

    let backend_to_client = async {
        let mut buf = [0u8; 8192];
        let mut total = 0u64;
        loop {
            let n = backend_read.read(&mut buf).await?;
            if n == 0 {
                break;
            }
            client_write.write_all(&buf[..n]).await?;
            total += n as u64;
        }
        client_write.shutdown().await?;
        Ok::<u64, io::Error>(total)
    };

    // Run both directions concurrently
    let (c2b_result, b2c_result) = tokio::join!(client_to_backend, backend_to_client);

    let c2b = c2b_result.unwrap_or(0);
    let b2c = b2c_result.unwrap_or(0);

    Ok((c2b, b2c))
}

/// Check if a socket address is available (can connect to it)
pub async fn is_socket_available(addr: &SocketAddr) -> bool {
    match addr {
        SocketAddr::Unix(path) => tokio::net::UnixStream::connect(path).await.is_ok(),
        SocketAddr::Tcp(addr) => tokio::net::TcpStream::connect(addr).await.is_ok(),
    }
}

/// Wait for a socket to become available with timeout
pub async fn wait_for_socket(addr: &SocketAddr, timeout_secs: u64) -> bool {
    let start = std::time::Instant::now();
    let timeout = std::time::Duration::from_secs(timeout_secs);

    while start.elapsed() < timeout {
        if is_socket_available(addr).await {
            return true;
        }
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
    }

    false
}